Compare commits

..

3 Commits

Author SHA1 Message Date
Theodor S. Midtlien
d65927275d Refactor shell package and use getent for user/group lookup 2026-06-11 17:40:45 +02:00
Theodor S. Midtlien
064f7bf0fd WIP TOFU socket ownership 2026-06-10 17:40:17 +02:00
Theodor S. Midtlien
644615fed6 WIP test 2026-06-10 17:29:00 +02:00
70 changed files with 698 additions and 2738 deletions

View File

@@ -1,301 +0,0 @@
package cmd
import (
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"path/filepath"
"slices"
"strings"
"github.com/goccy/go-yaml"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/proto"
)
const (
KubernetesDNSSuffix = "netbird-kubeapi-proxy"
)
var kubernetesCmd = &cobra.Command{
Use: "kubernetes",
Short: "Kubernetes cluster commands.",
Long: "Kubernetes cluster commands.",
}
var kubernetesListCmd = &cobra.Command{
Use: "list",
RunE: kubernetesList,
Short: "List Kubernetes clusters.",
Long: "List Kubernetes clusters by discovering NetBird peers running netbird-kubeapi-proxy.",
}
var kubernetesWriteKubeconfigCmd = &cobra.Command{
Use: "write-kubeconfig",
RunE: kubernetesWriteKubeconfig,
Args: cobra.ExactArgs(1),
Short: "Write kubeconfig for a Kubernetes cluster.",
Long: "Updates kubeconfig in place to allow token-less access to the Kubernetes cluster through NetBird.",
}
func init() {
kubernetesWriteKubeconfigCmd.Flags().String("kubeconfig", "", "path to kubeconfig file")
}
func kubernetesList(cmd *cobra.Command, _ []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
statusResp, err := client.Status(cmd.Context(), &proto.StatusRequest{GetFullPeerStatus: true})
if err != nil {
return err
}
kcs, err := getKubernetesClusters(cmd.Context(), statusResp.FullStatus.Peers, "")
if err != nil {
return err
}
if len(kcs) == 0 {
cmd.Println("No Kubernetes clusters available.")
return nil
}
cmd.Println("Available Kubernetes clusters:")
for _, k := range kcs {
cmd.Printf("\n - Name: %s\n FQDN: %s\n Version: %s\n", k.name, k.url.Host, k.version)
}
return nil
}
func kubernetesWriteKubeconfig(cmd *cobra.Command, args []string) error {
kubeconfigPath, err := resolveKubeconfigPath(cmd)
if err != nil {
return err
}
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
statusResp, err := client.Status(cmd.Context(), &proto.StatusRequest{GetFullPeerStatus: true})
if err != nil {
return err
}
clusterName := args[0]
kcs, err := getKubernetesClusters(cmd.Context(), statusResp.FullStatus.Peers, clusterName)
if err != nil {
return err
}
if len(kcs) == 0 {
return fmt.Errorf("kubernetes cluster named %s not found", clusterName)
}
if len(kcs) > 1 {
return fmt.Errorf("too many Kubernetes clusters returned")
}
err = writeKubeconfig(kubeconfigPath, kcs[0])
if err != nil {
return err
}
return nil
}
type kubernetesCluster struct {
name string
url *url.URL
version string
}
func getKubernetesClusters(ctx context.Context, peers []*proto.PeerState, nameFilter string) ([]kubernetesCluster, error) {
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.TLSClientConfig = &tls.Config{
InsecureSkipVerify: true,
}
httpClient := &http.Client{
Transport: transport,
}
resolver := net.Resolver{
// Required so both DNS records are returned.
// https://github.com/golang/go/issues/17093
PreferGo: true,
}
kcs := []kubernetesCluster{}
attempted := map[string]struct{}{}
for _, peer := range peers {
fqdns, err := resolver.LookupAddr(ctx, peer.IP)
if err != nil {
return nil, err
}
for _, fqdn := range fqdns {
if _, ok := attempted[fqdn]; ok {
continue
}
attempted[fqdn] = struct{}{}
comps := strings.Split(fqdn, ".")
if len(comps) < 2 {
continue
}
if comps[1] != KubernetesDNSSuffix {
continue
}
if nameFilter != "" && nameFilter != comps[0] {
continue
}
clusterURL, clusterVersion, err := fingerprintClusters(ctx, httpClient, fqdn)
if err != nil {
log.Debugf("could not fingerprint Kubernetes cluster %s %q", fqdn, err)
continue
}
kc := kubernetesCluster{
name: comps[0],
url: clusterURL,
version: clusterVersion,
}
if nameFilter != "" {
return []kubernetesCluster{kc}, nil
}
kcs = append(kcs, kc)
}
}
return kcs, nil
}
func fingerprintClusters(ctx context.Context, httpClient *http.Client, fqdn string) (*url.URL, string, error) {
clusterURL, err := url.Parse("https://" + fqdn)
if err != nil {
return nil, "", err
}
versionURL, err := clusterURL.Parse("/version")
if err != nil {
return nil, "", err
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, versionURL.String(), nil)
if err != nil {
return nil, "", err
}
resp, err := httpClient.Do(req)
if err != nil {
return nil, "", err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, "", fmt.Errorf("expected %d response but got %s", http.StatusOK, resp.Status)
}
b, err := io.ReadAll(resp.Body)
if err != nil {
return nil, "", err
}
versionData := map[string]string{}
err = json.Unmarshal(b, &versionData)
if err != nil {
return nil, "", err
}
version, ok := versionData["gitVersion"]
if !ok {
return nil, "", errors.New("no version found in response")
}
return clusterURL, version, nil
}
func resolveKubeconfigPath(cmd *cobra.Command) (string, error) {
if cmd.Flags().Changed("kubeconfig") {
path, err := cmd.Flags().GetString("kubeconfig")
if err != nil {
return "", err
}
return path, nil
}
if env := os.Getenv("KUBECONFIG"); env != "" {
return env, nil
}
home, err := os.UserHomeDir()
if err != nil {
return "", fmt.Errorf("could not determine home directory: %w", err)
}
return filepath.Join(home, ".kube", "config"), nil
}
func writeKubeconfig(kubeconfigPath string, kc kubernetesCluster) error {
b, err := os.ReadFile(kubeconfigPath)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return err
}
var cfg map[string]any
if err := yaml.Unmarshal(b, &cfg); err != nil {
return err
}
if cfg == nil {
cfg = map[string]any{
"apiVersion": "v1",
"kind": "Config",
}
}
cfg["clusters"] = appendWithName(cfg["clusters"], map[string]any{
"name": kc.name,
"cluster": map[string]any{
"server": kc.url.String(),
"insecure-skip-tls-verify": true,
},
})
cfg["users"] = appendWithName(cfg["users"], map[string]any{
"name": "netbird",
"user": map[string]any{
"token": "none",
},
})
cfg["contexts"] = appendWithName(cfg["contexts"], map[string]any{
"name": kc.name,
"context": map[string]any{
"cluster": kc.name,
"user": "netbird",
"namespace": "default",
},
})
cfg["current-context"] = kc.name
out, err := yaml.Marshal(cfg)
if err != nil {
return err
}
if err := os.WriteFile(kubeconfigPath, out, 0o600); err != nil {
return err
}
return nil
}
func appendWithName(data any, add map[string]any) any {
if data == nil {
return []any{add}
}
v, ok := data.([]any)
if !ok {
return []any{add}
}
i := slices.IndexFunc(v, func(item any) bool {
m, ok := item.(map[string]any)
if !ok {
return false
}
return m["name"] == add["name"]
})
if i == -1 {
return append(v, add)
}
v[i] = add
return v
}

View File

@@ -1,120 +0,0 @@
package cmd
import (
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"testing"
"github.com/spf13/cobra"
"github.com/stretchr/testify/require"
)
func TestFingerprintClusters(t *testing.T) {
t.Parallel()
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
//nolint: errcheck
w.Write([]byte(`{"gitVersion": "foobar"}`))
}))
defer srv.Close()
clusterURL, clusterVersion, err := fingerprintClusters(t.Context(), srv.Client(), srv.Listener.Addr().String())
require.NoError(t, err)
require.Equal(t, srv.URL, clusterURL.String())
require.Equal(t, "foobar", clusterVersion)
}
func TestResolveKubeconfigPath(t *testing.T) {
home, err := os.UserHomeDir()
if err != nil {
t.Fatalf("could not determine home directory: %v", err)
}
defaultPath := filepath.Join(home, ".kube", "config")
path, err := resolveKubeconfigPath(&cobra.Command{})
require.NoError(t, err)
require.Equal(t, defaultPath, path)
flagPath := "flag-path"
cmd := &cobra.Command{}
cmd.Flags().String("kubeconfig", "", "")
err = cmd.Flags().Set("kubeconfig", flagPath)
require.NoError(t, err)
path, err = resolveKubeconfigPath(cmd)
require.NoError(t, err)
require.Equal(t, flagPath, path)
envPath := "env-path"
t.Setenv("KUBECONFIG", envPath)
path, err = resolveKubeconfigPath(&cobra.Command{})
require.NoError(t, err)
require.Equal(t, envPath, path)
}
func TestWriteKubeconfig(t *testing.T) {
t.Parallel()
tests := []struct {
name string
existing string
}{
{
name: "empty file",
},
{
name: "existing content",
existing: `apiVersion: v1
clusters:
- cluster:
insecure-skip-tls-verify: true
server: https://foobar.com
name: foo
current-context: test
kind: Config
users: []
`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
kubeconfigPath := filepath.Join(t.TempDir(), "config")
err := os.WriteFile(kubeconfigPath, []byte(tt.existing), 0o644)
require.NoError(t, err)
kc := kubernetesCluster{
name: "foo",
url: &url.URL{Scheme: "https", Host: "example.com"},
}
err = writeKubeconfig(kubeconfigPath, kc)
require.NoError(t, err)
b, err := os.ReadFile(kubeconfigPath)
require.NoError(t, err)
expected := `apiVersion: v1
clusters:
- cluster:
insecure-skip-tls-verify: true
server: https://example.com
name: foo
contexts:
- context:
cluster: foo
namespace: default
user: netbird
name: foo
current-context: foo
kind: Config
users:
- name: netbird
user:
token: none
`
require.Equal(t, expected, string(b))
})
}
}

View File

@@ -0,0 +1,36 @@
//go:build darwin || freebsd
package cmd
import (
"fmt"
"net"
"golang.org/x/sys/unix"
)
// peerUID returns the uid of the process on the other end of a unix socket
// connection, read via LOCAL_PEERCRED (xucred). Note: xucred carries the uid
// and group list but no pid, so audit on these platforms is uid-based.
func peerUID(c net.Conn) (int, error) {
uc, ok := c.(*net.UnixConn)
if !ok {
return 0, fmt.Errorf("connection is not a unix socket: %T", c)
}
raw, err := uc.SyscallConn()
if err != nil {
return 0, fmt.Errorf("raw conn: %w", err)
}
var cred *unix.Xucred
var credErr error
if err := raw.Control(func(fd uintptr) {
cred, credErr = unix.GetsockoptXucred(int(fd), unix.SOL_LOCAL, unix.LOCAL_PEERCRED)
}); err != nil {
return 0, fmt.Errorf("getsockopt control: %w", err)
}
if credErr != nil {
return 0, fmt.Errorf("LOCAL_PEERCRED: %w", credErr)
}
return int(cred.Uid), nil
}

View File

@@ -0,0 +1,35 @@
//go:build linux
package cmd
import (
"fmt"
"net"
"golang.org/x/sys/unix"
)
// peerUID returns the uid of the process on the other end of a unix socket
// connection, read from the kernel via SO_PEERCRED.
func peerUID(c net.Conn) (int, error) {
uc, ok := c.(*net.UnixConn)
if !ok {
return 0, fmt.Errorf("connection is not a unix socket: %T", c)
}
raw, err := uc.SyscallConn()
if err != nil {
return 0, fmt.Errorf("raw conn: %w", err)
}
var cred *unix.Ucred
var credErr error
if err := raw.Control(func(fd uintptr) {
cred, credErr = unix.GetsockoptUcred(int(fd), unix.SOL_SOCKET, unix.SO_PEERCRED)
}); err != nil {
return 0, fmt.Errorf("getsockopt control: %w", err)
}
if credErr != nil {
return 0, fmt.Errorf("SO_PEERCRED: %w", credErr)
}
return int(cred.Uid), nil
}

View File

@@ -0,0 +1,16 @@
//go:build !linux && !darwin && !freebsd
package cmd
import (
"fmt"
"net"
"runtime"
)
// peerUID is unimplemented on this platform, so the trust-on-first-use socket
// migration cannot run here. Configure --socket-owner explicitly, or use
// --disable-strict-socket. (Windows uses a TCP socket and never reaches this.)
func peerUID(net.Conn) (int, error) {
return 0, fmt.Errorf("peer credential check not supported on %s", runtime.GOOS)
}

View File

@@ -77,6 +77,8 @@ var (
updateSettingsDisabled bool
captureEnabled bool
networksDisabled bool
socketOwner string
strictSocketDisabled bool
rootCmd = &cobra.Command{
Use: "netbird",
@@ -169,11 +171,6 @@ func init() {
debugCmd.AddCommand(forCmd)
debugCmd.AddCommand(persistenceCmd)
// kubernetes commands
rootCmd.AddCommand(kubernetesCmd)
kubernetesCmd.AddCommand(kubernetesListCmd)
kubernetesCmd.AddCommand(kubernetesWriteKubeconfigCmd)
// profile commands
profileCmd.AddCommand(profileListCmd)
profileCmd.AddCommand(profileAddCmd)

View File

@@ -57,6 +57,9 @@ func init() {
installCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc)
reconfigureCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc)
serviceCmd.PersistentFlags().StringVar(&socketOwner, "socket-owner", "", "user to own the daemon control socket; restricts it to that user plus the netbird group (0660). If unset, the first client to connect claims ownership (trust-on-first-use)")
serviceCmd.PersistentFlags().BoolVar(&strictSocketDisabled, "disable-strict-socket", false, "leave the daemon control socket world-writable (0666) instead of restricting it; set via the (root-only) service command")
rootCmd.AddCommand(serviceCmd)
}

View File

@@ -4,10 +4,15 @@ package cmd
import (
"context"
"errors"
"fmt"
"net"
"os"
"os/exec"
"os/user"
"strconv"
"strings"
"sync"
"time"
"github.com/kardianos/service"
@@ -16,6 +21,7 @@ import (
"github.com/spf13/cobra"
"google.golang.org/grpc"
"github.com/netbirdio/netbird/client/internal/shell"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/server"
"github.com/netbirdio/netbird/client/system"
@@ -54,10 +60,36 @@ func (p *program) Start(svc service.Service) error {
go func() {
defer listen.Close()
srvListener := listen
if split[0] == "unix" {
if err := os.Chmod(split[1], 0666); err != nil {
log.Errorf("failed setting daemon permissions: %v", split[1])
return
owner := effectiveSocketOwner()
switch {
case strictSocketDisabled:
// Opt-out (root-only, via service.json): leave it world-writable.
if err := os.Chmod(split[1], 0666); err != nil {
log.Errorf("failed setting daemon permissions: %v", split[1])
return
}
case owner != "":
// Seeded owner (flag, MDM, or persisted TOFU result): restrict
// before serving so there is no open window.
uid, err := lookupUser(owner)
if err != nil {
log.Errorf("lookup socket owner %q: %v", owner, err)
return
}
if err := restrictSocket(split[1], uid); err != nil {
log.Errorf("restrict socket to %q: %v", owner, err)
return
}
default:
// Trust-on-first-use: open the socket now; tofuListener locks it
// to the first caller's uid on the first connection.
if err := os.Chmod(split[1], 0666); err != nil {
log.Errorf("failed setting daemon permissions: %v", split[1])
return
}
srvListener = &tofuListener{Listener: listen, path: split[1], owner: -1}
}
}
@@ -72,13 +104,180 @@ func (p *program) Start(svc service.Service) error {
p.serverInstanceMu.Unlock()
log.Printf("started daemon server: %v", split[1])
if err := p.serv.Serve(listen); err != nil {
if err := p.serv.Serve(srvListener); err != nil {
log.Errorf("failed to serve daemon requests: %v", err)
}
}()
return nil
}
func lookupUser(username string) (int, error) {
u, err := shell.LookupWithGetent(username)
if err != nil {
return -1, fmt.Errorf("lookup user %s: %w", username, err)
}
uid, err := strconv.Atoi(u.Uid)
if err != nil {
return -1, fmt.Errorf("parse uid %s: %w", u.Uid, err)
}
return uid, nil
}
// addGroup creates a system group if it doesn't already exist and returns the gid.
// Must run as root.
func addGroup(name string) (int, error) {
group, err := shell.LookupGroupWithGetent(name)
if err == nil {
gid, err := strconv.ParseInt(group.Gid, 10, 64)
return int(gid), err
}
// looup failed, create the group
groupadd, err := exec.LookPath("groupadd")
if err != nil {
// Fallback for Alpine/BusyBox systems.
if groupadd, err = exec.LookPath("addgroup"); err != nil {
return -1, errors.New("neither groupadd nor addgroup found")
}
}
// Use --system for a service/daemon group (no login, low GID).
out, err := exec.Command(groupadd, "--system", name).CombinedOutput()
if err != nil {
return -1, fmt.Errorf("create group %q: %w: %s", name, err, out)
}
if group, err := shell.LookupWithGetent(name); err == nil {
gid, err := strconv.ParseInt(group.Gid, 10, 64)
return int(gid), err
}
return -1, fmt.Errorf("lookup group %q: %w", name, err)
}
// restrictSocket locks the unix socket down to the owner uid plus the netbird
// group (0660). If the group cannot be created or applied, it fails closed to
// owner-only 0600 — it never leaves the socket world-writable.
func restrictSocket(path string, uid int) error {
// TODO: introduce flag to configure this (LDAP/AD usecase)
gid, err := addGroup("netbird")
if err != nil {
log.Errorf("create netbird group, failing closed to owner-only 0600: %v", err)
return chownChmod(path, uid, -1, 0600)
}
if err := chownChmod(path, uid, gid, 0660); err != nil {
log.Errorf("apply netbird group to socket, failing closed to owner-only 0600: %v", err)
return chownChmod(path, uid, -1, 0600)
}
return nil
}
// chownChmod sets ownership and mode on the socket. A gid of -1 leaves the
// group unchanged.
func chownChmod(path string, uid, gid int, mode os.FileMode) error {
if err := os.Chown(path, uid, gid); err != nil {
return fmt.Errorf("chown socket %s: %w", path, err)
}
if err := os.Chmod(path, mode); err != nil {
return fmt.Errorf("chmod socket %s: %w", path, err)
}
return nil
}
// tofuListener implements trust-on-first-use for the daemon control socket.
// The socket starts world-writable; the first caller's uid (read via the
// platform peer-credential mechanism) becomes the owner. On that first
// connection the socket is restricted (see restrictSocket) and the owner is
// persisted so the open window never reopens on later starts. Connections that
// raced in during the open window and are neither the owner nor root are
// dropped. Changing the socket mode does not disturb the already-open
// connection, so the first caller's request is served normally.
type tofuListener struct {
net.Listener
path string
mu sync.Mutex
owner int // -1 until claimed
}
func (l *tofuListener) Accept() (net.Conn, error) {
for {
c, err := l.Listener.Accept()
if err != nil {
return nil, err
}
uid, err := peerUID(c)
if err != nil {
log.Errorf("read peer credentials, dropping connection: %v", err)
_ = c.Close()
continue
}
l.mu.Lock()
if l.owner == -1 {
if err := restrictSocket(l.path, uid); err != nil {
l.mu.Unlock()
_ = c.Close()
// Refuse to serve on a socket we could not lock down.
return nil, fmt.Errorf("restrict socket on first connection: %w", err)
}
l.owner = uid
persistSocketOwner(uid)
log.Infof("control socket restricted to first caller (uid %d)", uid)
l.mu.Unlock()
return c, nil
}
owner := l.owner
l.mu.Unlock()
// New connects are already gated by the 0660 perms set above; this only
// drops anything that slipped in during the brief open window.
if uid != owner && uid != 0 {
log.Warnf("dropping non-owner connection (uid %d) during socket bootstrap", uid)
_ = c.Close()
continue
}
return c, nil
}
}
// effectiveSocketOwner returns the configured socket owner: the --socket-owner
// flag when set, otherwise the owner persisted by a previous TOFU migration.
func effectiveSocketOwner() string {
if socketOwner != "" {
return socketOwner
}
params, err := loadServiceParams()
if err != nil {
log.Errorf("load service params for socket owner: %v", err)
return ""
}
if params != nil {
return params.SocketOwner
}
return ""
}
// persistSocketOwner records the TOFU-selected owner (by username) so the next
// daemon start restricts the socket immediately, with no open window.
func persistSocketOwner(uid int) {
u, err := user.LookupId(strconv.Itoa(uid))
if err != nil {
log.Errorf("resolve uid %d to username for persistence: %v", uid, err)
return
}
params, err := loadServiceParams()
if err != nil {
log.Errorf("load service params to persist socket owner: %v", err)
return
}
if params == nil {
params = currentServiceParams()
}
params.SocketOwner = u.Username
if err := saveServiceParams(params); err != nil {
log.Errorf("persist socket owner: %v", err)
}
}
func (p *program) Stop(srv service.Service) error {
p.serverInstanceMu.Lock()
if p.serverInstance != nil {

View File

@@ -67,6 +67,14 @@ func buildServiceArguments() []string {
args = append(args, "--disable-networks")
}
if socketOwner != "" {
args = append(args, "--socket-owner", socketOwner)
}
if strictSocketDisabled {
args = append(args, "--disable-strict-socket")
}
return args
}
@@ -127,6 +135,8 @@ var installCmd = &cobra.Command{
return err
}
cmd.Printf("SUDO_UID: %s\n", os.Getenv("SUDO_UID"))
if err := loadAndApplyServiceParams(cmd); err != nil {
cmd.PrintErrf("Warning: failed to load saved service params: %v\n", err)
}

View File

@@ -30,6 +30,8 @@ type serviceParams struct {
DisableUpdateSettings bool `json:"disable_update_settings,omitempty"`
EnableCapture bool `json:"enable_capture,omitempty"`
DisableNetworks bool `json:"disable_networks,omitempty"`
SocketOwner string `json:"socket_owner,omitempty"`
DisableStrictSocket bool `json:"disable_strict_socket,omitempty"`
ServiceEnvVars map[string]string `json:"service_env_vars,omitempty"`
}
@@ -82,6 +84,8 @@ func currentServiceParams() *serviceParams {
DisableUpdateSettings: updateSettingsDisabled,
EnableCapture: captureEnabled,
DisableNetworks: networksDisabled,
SocketOwner: socketOwner,
DisableStrictSocket: strictSocketDisabled,
}
if len(serviceEnvVars) > 0 {
@@ -154,6 +158,14 @@ func applyServiceParams(cmd *cobra.Command, params *serviceParams) {
networksDisabled = params.DisableNetworks
}
if !serviceCmd.PersistentFlags().Changed("socket-owner") {
socketOwner = params.SocketOwner
}
if !serviceCmd.PersistentFlags().Changed("disable-strict-socket") {
strictSocketDisabled = params.DisableStrictSocket
}
applyServiceEnvParams(cmd, params)
}

View File

@@ -279,10 +279,6 @@ func (c *Client) Start(startCtx context.Context) error {
select {
case <-startCtx.Done():
// Cancel the client context before stopping: Engine.Start blocks on the
// signal stream while holding the engine mutex and only unblocks on
// cancellation. Stopping first would deadlock on that mutex.
cancel()
if stopErr := client.Stop(); stopErr != nil {
return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err())
}

View File

@@ -1,168 +0,0 @@
package embed
import (
"context"
"net"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/internals/server/config"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
mgmt "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util"
)
const testSetupKey = "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
// TestClientStartTimeoutRollback reproduces a deadlock between Engine.Start and
// Engine.Stop. The signal endpoint accepts gRPC connections but never serves the
// SignalExchange service, so Engine.Start parks in WaitStreamConnected while
// holding the engine mutex. When the Start context expires, the rollback path
// calls ConnectClient.Stop, which must not block forever acquiring that mutex.
func TestClientStartTimeoutRollback(t *testing.T) {
signalAddr := startBlackholeSignal(t)
mgmAddr := startManagement(t, signalAddr)
wgPort := 0
client, err := New(Options{
DeviceName: "embed-rollback-test",
SetupKey: testSetupKey,
ManagementURL: "http://" + mgmAddr,
WireguardPort: &wgPort,
})
require.NoError(t, err, "embed client creation must succeed")
startCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
startErr := make(chan error, 1)
go func() {
startErr <- client.Start(startCtx)
}()
select {
case err := <-startErr:
require.ErrorIs(t, err, context.DeadlineExceeded)
case <-time.After(60 * time.Second):
t.Fatal("client.Start did not return after its context expired: Engine.Stop deadlocked against Engine.Start waiting for the signal stream")
}
}
// startBlackholeSignal starts a gRPC server without the SignalExchange service
// registered. Connections succeed, but the signal stream can never be
// established, which keeps Engine.Start parked in WaitStreamConnected.
func startBlackholeSignal(t *testing.T) string {
t.Helper()
lis, err := net.Listen("tcp", "localhost:0")
require.NoError(t, err)
s := grpc.NewServer()
go func() {
if err := s.Serve(lis); err != nil {
t.Error(err)
}
}()
t.Cleanup(s.Stop)
return lis.Addr().String()
}
func startManagement(t *testing.T, signalAddr string) string {
t.Helper()
cfg := &config.Config{
Stuns: []*config.Host{},
TURNConfig: &config.TURNConfig{},
Relay: &config.Relay{
Addresses: []string{"127.0.0.1:1234"},
CredentialsTTL: util.Duration{Duration: time.Hour},
Secret: "222222222222222222",
},
Signal: &config.Host{
Proto: "http",
URI: signalAddr,
},
Datadir: t.TempDir(),
HttpConfig: nil,
}
lis, err := net.Listen("tcp", "localhost:0")
require.NoError(t, err)
s := grpc.NewServer()
testStore, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", cfg.Datadir)
require.NoError(t, err)
t.Cleanup(cleanUp)
eventStore := &activity.InMemoryEventStore{}
permissionsManager := permissions.NewManager(testStore)
peersManager := peers.NewManager(testStore, permissionsManager)
jobManager := job.NewJobManager(nil, testStore, peersManager)
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
require.NoError(t, err)
iv, err := validator.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
require.NoError(t, err)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
settingsMockManager.EXPECT().
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
Return(&types.Settings{}, nil).
AnyTimes()
settingsMockManager.EXPECT().
GetExtraSettings(gomock.Any(), gomock.Any()).
Return(&types.ExtraSettings{}, nil).
AnyTimes()
groupsManager := groups.NewManagerMock()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := mgmt.NewAccountRequestBuffer(context.Background(), testStore)
networkMapController := controller.NewController(context.Background(), testStore, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(testStore, peersManager), cfg)
accountManager, err := mgmt.BuildManager(context.Background(), cfg, testStore, networkMapController, jobManager, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
require.NoError(t, err)
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, cfg.TURNConfig, cfg.Relay, settingsMockManager, groupsManager)
require.NoError(t, err)
mgmtServer, err := nbgrpc.NewServer(cfg, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil, nil)
require.NoError(t, err)
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
go func() {
if err := s.Serve(lis); err != nil {
t.Error(err)
}
}()
t.Cleanup(s.Stop)
return lis.Addr().String()
}

View File

@@ -3,7 +3,6 @@ package iptables
import (
"errors"
"fmt"
"maps"
"net"
"slices"
@@ -422,17 +421,12 @@ func (m *aclManager) updateState() {
currentState.Lock()
defer currentState.Unlock()
// Clone the maps so the persisted state holds a private snapshot. The
// live maps keep being mutated by subsequent rule operations while the
// state manager marshals the state from its periodic-save goroutine.
// Sharing them by reference races the two and aborts the process with a
// concurrent map iteration and write.
if m.v6 {
currentState.ACLEntries6 = maps.Clone(m.entries)
currentState.ACLIPsetStore6 = m.ipsetStore.clone()
currentState.ACLEntries6 = m.entries
currentState.ACLIPsetStore6 = m.ipsetStore
} else {
currentState.ACLEntries = maps.Clone(m.entries)
currentState.ACLIPsetStore = m.ipsetStore.clone()
currentState.ACLEntries = m.entries
currentState.ACLIPsetStore = m.ipsetStore
}
if err := m.stateManager.UpdateState(currentState); err != nil {

View File

@@ -4,7 +4,6 @@ package iptables
import (
"fmt"
"maps"
"net/netip"
"strconv"
"strings"
@@ -750,17 +749,11 @@ func (r *router) updateState() {
currentState.Lock()
defer currentState.Unlock()
// Clone the rule map so the persisted state holds a private snapshot. The
// live map keeps being mutated by subsequent rule operations while the
// state manager marshals the state from its periodic-save goroutine.
// Sharing it by reference races the two and aborts the process with a
// concurrent map iteration and write. The ipset counter guards itself
// during marshaling, so it can be shared directly.
if r.v6 {
currentState.RouteRules6 = maps.Clone(r.rules)
currentState.RouteRules6 = r.rules
currentState.RouteIPsetCounter6 = r.ipsetCounter
} else {
currentState.RouteRules = maps.Clone(r.rules)
currentState.RouteRules = r.rules
currentState.RouteIPsetCounter = r.ipsetCounter
}

View File

@@ -1,9 +1,6 @@
package iptables
import (
"encoding/json"
"maps"
)
import "encoding/json"
type ipList struct {
ips map[string]struct{}
@@ -22,14 +19,6 @@ func (s *ipList) addIP(ip string) {
s.ips[ip] = struct{}{}
}
// clone returns a deep copy of the ipList with its own ips map.
func (s *ipList) clone() *ipList {
if s == nil {
return nil
}
return &ipList{ips: maps.Clone(s.ips)}
}
// MarshalJSON implements json.Marshaler
func (s *ipList) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
@@ -66,19 +55,6 @@ func newIpsetStore() *ipsetStore {
}
}
// clone returns a deep copy of the ipsetStore with its own ipsets map and
// independent ipList entries.
func (s *ipsetStore) clone() *ipsetStore {
if s == nil {
return nil
}
cloned := &ipsetStore{ipsets: make(map[string]*ipList, len(s.ipsets))}
for name, list := range s.ipsets {
cloned.ipsets[name] = list.clone()
}
return cloned
}
func (s *ipsetStore) ipset(ipsetName string) (*ipList, bool) {
r, ok := s.ipsets[ipsetName]
return r, ok

View File

@@ -777,24 +777,13 @@ func (s *DefaultServer) applyHostConfig() {
// context is released rather than leaked until GC.
func (s *DefaultServer) registerFallback() {
originalNameservers := s.hostManager.getOriginalNameservers()
serverIP := s.service.RuntimeIP()
var servers []netip.AddrPort
for _, ns := range originalNameservers {
if ns == serverIP {
log.Debugf("skipping original nameserver %s as it is the same as the server IP %s", ns, serverIP)
continue
}
servers = append(servers, netip.AddrPortFrom(ns, DefaultPort))
}
if len(servers) == 0 {
if len(originalNameservers) == 0 {
log.Debugf("no fallback upstreams to register; clearing PriorityFallback handler")
s.clearFallback()
return
}
log.Infof("registering original nameservers %v as upstream handlers with priority %d", servers, PriorityFallback)
log.Infof("registering original nameservers %v as upstream handlers with priority %d", originalNameservers, PriorityFallback)
handler, err := newUpstreamResolver(
s.ctx,
@@ -808,6 +797,11 @@ func (s *DefaultServer) registerFallback() {
return
}
handler.selectedRoutes = s.selectedRoutes
var servers []netip.AddrPort
for _, ns := range originalNameservers {
servers = append(servers, netip.AddrPortFrom(ns, DefaultPort))
}
handler.addRace(servers)
prev := s.fallbackHandler

View File

@@ -880,25 +880,62 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate)
}
if err := e.updateNetbirdConfig(update.GetNetbirdConfig()); err != nil {
return err
}
if update.GetNetbirdConfig() != nil {
wCfg := update.GetNetbirdConfig()
err := e.updateTURNs(wCfg.GetTurns())
if err != nil {
return fmt.Errorf("update TURNs: %w", err)
}
// Posture checks are bound to the network map presence:
// NetworkMap != nil, checks present -> apply the received checks
// NetworkMap != nil, checks nil -> posture checks were removed, clear them
// NetworkMap == nil -> config-only update (e.g. relay token rotation),
// leave the previously applied checks untouched
nm := update.GetNetworkMap()
if nm == nil {
return nil
err = e.updateSTUNs(wCfg.GetStuns())
if err != nil {
return fmt.Errorf("update STUNs: %w", err)
}
var stunTurn []*stun.URI
stunTurn = append(stunTurn, e.STUNs...)
stunTurn = append(stunTurn, e.TURNs...)
e.stunTurn.Store(stunTurn)
err = e.handleRelayUpdate(wCfg.GetRelay())
if err != nil {
return err
}
err = e.handleFlowUpdate(wCfg.GetFlow())
if err != nil {
return fmt.Errorf("handle the flow configuration: %w", err)
}
if err := e.PopulateNetbirdConfig(wCfg, nil); err != nil {
log.Warnf("Failed to update DNS server config: %v", err)
}
// todo update signal
}
if err := e.updateChecksIfNew(update.Checks); err != nil {
return err
}
e.persistSyncResponse(update)
nm := update.GetNetworkMap()
if nm == nil {
return nil
}
// Persist sync response under the dedicated lock (syncRespMux), not under syncMsgMux.
// A non-nil syncStore is what marks persistence as enabled. Hold the lock for
// the whole Set so the store cannot be cleared (disabled / engine close)
// mid-call and have this write resurrect a file that was just removed.
e.syncRespMux.RLock()
if e.syncStore != nil {
if err := e.syncStore.Set(update); err != nil {
log.Errorf("failed to persist sync response: %v", err)
} else {
log.Debugf("sync response persisted with serial %d", nm.GetSerial())
}
}
e.syncRespMux.RUnlock()
// only apply new changes and ignore old ones
if err := e.updateNetworkMap(nm); err != nil {
@@ -910,64 +947,6 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return nil
}
// updateNetbirdConfig applies the management-provided NetBird configuration:
// STUN/TURN and relay servers, flow logging and DNS settings. A nil config is a no-op,
// which is the case for sync updates carrying only a network map.
func (e *Engine) updateNetbirdConfig(wCfg *mgmProto.NetbirdConfig) error {
if wCfg == nil {
return nil
}
if err := e.updateTURNs(wCfg.GetTurns()); err != nil {
return fmt.Errorf("update TURNs: %w", err)
}
if err := e.updateSTUNs(wCfg.GetStuns()); err != nil {
return fmt.Errorf("update STUNs: %w", err)
}
var stunTurn []*stun.URI
stunTurn = append(stunTurn, e.STUNs...)
stunTurn = append(stunTurn, e.TURNs...)
e.stunTurn.Store(stunTurn)
if err := e.handleRelayUpdate(wCfg.GetRelay()); err != nil {
return err
}
if err := e.handleFlowUpdate(wCfg.GetFlow()); err != nil {
return fmt.Errorf("handle the flow configuration: %w", err)
}
if err := e.PopulateNetbirdConfig(wCfg, nil); err != nil {
log.Warnf("Failed to update DNS server config: %v", err)
}
// todo update signal
return nil
}
// persistSyncResponse stores the full sync response so it can be restored on the next
// startup. Persistence is enabled only when syncStore is set. The dedicated syncRespMux
// (not syncMsgMux) is held for the whole Set so the store cannot be cleared (disabled /
// engine close) mid-call and have this write resurrect a file that was just removed.
func (e *Engine) persistSyncResponse(update *mgmProto.SyncResponse) {
e.syncRespMux.RLock()
defer e.syncRespMux.RUnlock()
if e.syncStore == nil {
return
}
if err := e.syncStore.Set(update); err != nil {
log.Errorf("failed to persist sync response: %v", err)
return
}
log.Debugf("sync response persisted with serial %d", update.GetNetworkMap().GetSerial())
}
func (e *Engine) handleRelayUpdate(update *mgmProto.RelayConfig) error {
if update != nil {
// when we receive token we expect valid address list too

View File

@@ -700,13 +700,6 @@ func resolveURLsToIPs(urls []string) []net.IP {
// updateRouteSelectorFromManagement updates the route selector based on the isSelected status from the management server
func (m *DefaultManager) updateRouteSelectorFromManagement(clientRoutes route.HAMap) {
// An explicit user "deselect all" must not be overridden by management auto-apply.
// Auto-applying an exit node here would call SelectRoutes, which clears the
// deselect-all flag and re-enables every route the user turned off.
if m.routeSelector.IsDeselectAll() {
return
}
exitNodeInfo := m.collectExitNodeInfo(clientRoutes)
if len(exitNodeInfo.allIDs) == 0 {
return

View File

@@ -1,71 +0,0 @@
package routemanager
import (
"net/netip"
"testing"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/route"
)
func exitNodeRoutes(netID route.NetID, skipAutoApply bool) route.HAMap {
haID := route.HAUniqueID(string(netID) + "|0.0.0.0/0")
return route.HAMap{
haID: []*route.Route{
{
ID: "r-" + route.ID(netID),
NetID: netID,
Network: netip.MustParsePrefix("0.0.0.0/0"),
NetworkType: route.IPv4Network,
Enabled: true,
SkipAutoApply: skipAutoApply,
},
},
}
}
func TestUpdateRouteSelectorFromManagement(t *testing.T) {
t.Run("management auto-apply selects exit node without user selection", func(t *testing.T) {
m := &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
routes := exitNodeRoutes("exit1", false)
m.updateRouteSelectorFromManagement(routes)
require.True(t, m.routeSelector.IsSelected("exit1"), "auto-apply exit node should be selected")
require.Len(t, m.routeSelector.FilterSelectedExitNodes(routes), 1, "selected exit node should pass the filter")
})
t.Run("management SkipAutoApply leaves exit node deselected", func(t *testing.T) {
m := &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
routes := exitNodeRoutes("exit1", true)
m.updateRouteSelectorFromManagement(routes)
require.False(t, m.routeSelector.IsSelected("exit1"), "SkipAutoApply exit node should not be selected")
require.Empty(t, m.routeSelector.FilterSelectedExitNodes(routes), "deselected exit node should be filtered out")
})
t.Run("user selection is not overridden by management", func(t *testing.T) {
m := &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
require.NoError(t, m.routeSelector.SelectRoutes([]route.NetID{"exit1"}, true, []route.NetID{"exit1"}))
routes := exitNodeRoutes("exit1", true)
m.updateRouteSelectorFromManagement(routes)
require.True(t, m.routeSelector.IsSelected("exit1"), "explicit user selection must survive a management sync that wants to skip auto-apply")
require.Len(t, m.routeSelector.FilterSelectedExitNodes(routes), 1, "user-selected exit node should pass the filter")
})
t.Run("deselect-all is preserved across a management sync", func(t *testing.T) {
m := &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
m.routeSelector.DeselectAllRoutes()
routes := exitNodeRoutes("exit1", false)
m.updateRouteSelectorFromManagement(routes)
require.True(t, m.routeSelector.IsDeselectAll(), "an explicit deselect-all must not be cleared by management auto-apply")
require.Empty(t, m.routeSelector.FilterSelectedExitNodes(routes), "no routes should be selected while deselect-all is set")
})
}

View File

@@ -116,14 +116,6 @@ func (rs *RouteSelector) DeselectAllRoutes() {
clear(rs.selectedRoutes)
}
// IsDeselectAll reports whether the user has explicitly deselected all routes.
func (rs *RouteSelector) IsDeselectAll() bool {
rs.mu.RLock()
defer rs.mu.RUnlock()
return rs.deselectAll
}
// IsSelected checks if a specific route is selected.
func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
rs.mu.RLock()

View File

@@ -1,6 +1,6 @@
//go:build cgo && !osusergo && !windows
package server
package shell
import "os/user"
@@ -8,17 +8,22 @@ import "os/user"
// When CGO is enabled, os/user uses libc (getpwnam_r) which goes through
// the NSS stack natively. If it fails, the user truly doesn't exist and
// getent would also fail.
func lookupWithGetent(username string) (*user.User, error) {
func LookupWithGetent(username string) (*user.User, error) {
return user.Lookup(username)
}
// currentUserWithGetent with CGO delegates directly to os/user.Current.
func currentUserWithGetent() (*user.User, error) {
func CurrentUserWithGetent() (*user.User, error) {
return user.Current()
}
// LookupGroupWithGetent returns the resolved group from either a gid or groupname
func LookupGroupWithGetent(name string) (*user.Group, error) {
return user.LookupGroup(name)
}
// groupIdsWithFallback with CGO delegates directly to user.GroupIds.
// libc's getgrouplist handles NSS groups natively.
func groupIdsWithFallback(u *user.User) ([]string, error) {
func GroupIdsWithFallback(u *user.User) ([]string, error) {
return u.GroupIds()
}

View File

@@ -1,6 +1,6 @@
//go:build (!cgo || osusergo) && !windows
package server
package shell
import (
"os"
@@ -13,7 +13,7 @@ import (
// lookupWithGetent looks up a user by name, falling back to getent if os/user fails.
// Without CGO, os/user only reads /etc/passwd and misses NSS-provided users.
// getent goes through the host's NSS stack.
func lookupWithGetent(username string) (*user.User, error) {
func LookupWithGetent(username string) (*user.User, error) {
u, err := user.Lookup(username)
if err == nil {
return u, nil
@@ -22,7 +22,7 @@ func lookupWithGetent(username string) (*user.User, error) {
stdErr := err
log.Debugf("os/user.Lookup(%q) failed, trying getent: %v", username, err)
u, _, getentErr := runGetent(username)
u, _, getentErr := runGetentPasswd(username)
if getentErr != nil {
log.Debugf("getent fallback for %q also failed: %v", username, getentErr)
return nil, stdErr
@@ -31,8 +31,25 @@ func lookupWithGetent(username string) (*user.User, error) {
return u, nil
}
// LookupGroupWithGetent returns the resolved group from either a gid or groupname
func LookupGroupWithGetent(name string) (*user.Group, error) {
g, err := user.LookupGroup(name)
if err == nil {
return g, nil
}
stdErr := err
log.Debugf("os/user.LookupGroup(%q) failed, trying getent: %v", name, err)
g, getentErr := runGetentGroup(name)
if getentErr != nil {
log.Debugf("getent fallback for %q also failed: %v", name, getentErr)
return nil, stdErr
}
return g, nil
}
// currentUserWithGetent gets the current user, falling back to getent if os/user fails.
func currentUserWithGetent() (*user.User, error) {
func CurrentUserWithGetent() (*user.User, error) {
u, err := user.Current()
if err == nil {
return u, nil
@@ -42,7 +59,7 @@ func currentUserWithGetent() (*user.User, error) {
uid := strconv.Itoa(os.Getuid())
log.Debugf("os/user.Current() failed, trying getent with UID %s: %v", uid, err)
u, _, getentErr := runGetent(uid)
u, _, getentErr := runGetentPasswd(uid)
if getentErr != nil {
return nil, stdErr
}
@@ -57,7 +74,7 @@ func currentUserWithGetent() (*user.User, error) {
// only reads /etc/group and silently returns incomplete results for NSS users
// (no error, just missing groups). The id command goes through NSS and returns
// the full set.
func groupIdsWithFallback(u *user.User) ([]string, error) {
func GroupIdsWithFallback(u *user.User) ([]string, error) {
ids, err := runIdGroups(u.Username)
if err == nil {
return ids, nil

View File

@@ -1,4 +1,4 @@
package server
package shell
import (
"os/user"
@@ -15,7 +15,7 @@ func TestLookupWithGetent_CurrentUser(t *testing.T) {
current, err := user.Current()
require.NoError(t, err)
u, err := lookupWithGetent(current.Username)
u, err := LookupWithGetent(current.Username)
require.NoError(t, err)
assert.Equal(t, current.Username, u.Username)
assert.Equal(t, current.Uid, u.Uid)
@@ -23,7 +23,7 @@ func TestLookupWithGetent_CurrentUser(t *testing.T) {
}
func TestLookupWithGetent_NonexistentUser(t *testing.T) {
_, err := lookupWithGetent("nonexistent_user_xyzzy_12345")
_, err := LookupWithGetent("nonexistent_user_xyzzy_12345")
require.Error(t, err, "should fail for nonexistent user")
}
@@ -31,7 +31,7 @@ func TestCurrentUserWithGetent(t *testing.T) {
stdUser, err := user.Current()
require.NoError(t, err)
u, err := currentUserWithGetent()
u, err := CurrentUserWithGetent()
require.NoError(t, err)
assert.Equal(t, stdUser.Uid, u.Uid)
assert.Equal(t, stdUser.Username, u.Username)
@@ -41,7 +41,7 @@ func TestGroupIdsWithFallback_CurrentUser(t *testing.T) {
current, err := user.Current()
require.NoError(t, err)
groups, err := groupIdsWithFallback(current)
groups, err := GroupIdsWithFallback(current)
require.NoError(t, err)
require.NotEmpty(t, groups, "current user should have at least one group")
@@ -56,7 +56,7 @@ func TestGroupIdsWithFallback_CurrentUser(t *testing.T) {
func TestGetShellFromGetent_CurrentUser(t *testing.T) {
if runtime.GOOS == "windows" {
// Windows stub always returns empty, which is correct
shell := getShellFromGetent("1000")
shell := GetShellFromGetent("1000")
assert.Empty(t, shell, "Windows stub should return empty")
return
}
@@ -65,7 +65,7 @@ func TestGetShellFromGetent_CurrentUser(t *testing.T) {
require.NoError(t, err)
// getent may not be available on all systems (e.g., macOS without Homebrew getent)
shell := getShellFromGetent(current.Uid)
shell := GetShellFromGetent(current.Uid)
if shell == "" {
t.Log("getShellFromGetent returned empty, getent may not be available")
return
@@ -78,7 +78,7 @@ func TestLookupWithGetent_RootUser(t *testing.T) {
t.Skip("no root user on Windows")
}
u, err := lookupWithGetent("root")
u, err := LookupWithGetent("root")
if err != nil {
t.Skip("root user not available on this system")
}
@@ -91,20 +91,20 @@ func TestLookupWithGetent_RootUser(t *testing.T) {
// consistent and correct results when composed together.
func TestIntegration_FullLookupChain(t *testing.T) {
// Step 1: currentUserWithGetent must resolve the running user.
current, err := currentUserWithGetent()
current, err := CurrentUserWithGetent()
require.NoError(t, err, "currentUserWithGetent must resolve the running user")
require.NotEmpty(t, current.Uid)
require.NotEmpty(t, current.Username)
// Step 2: lookupWithGetent by the same username must return matching identity.
byName, err := lookupWithGetent(current.Username)
byName, err := LookupWithGetent(current.Username)
require.NoError(t, err)
assert.Equal(t, current.Uid, byName.Uid, "lookup by name should return same UID")
assert.Equal(t, current.Gid, byName.Gid, "lookup by name should return same GID")
assert.Equal(t, current.HomeDir, byName.HomeDir, "lookup by name should return same home")
// Step 3: groupIdsWithFallback must return at least the primary GID.
groups, err := groupIdsWithFallback(current)
groups, err := GroupIdsWithFallback(current)
require.NoError(t, err)
require.NotEmpty(t, groups, "user must have at least one group")
@@ -123,7 +123,7 @@ func TestIntegration_FullLookupChain(t *testing.T) {
// Step 4: getShellFromGetent should either return a valid shell path or empty
// (empty is OK when getent is not available, e.g. macOS without Homebrew getent).
if runtime.GOOS != "windows" {
shell := getShellFromGetent(current.Uid)
shell := GetShellFromGetent(current.Uid)
if shell != "" {
assert.True(t, shell[0] == '/', "shell should be an absolute path, got %q", shell)
}
@@ -138,10 +138,10 @@ func TestIntegration_LookupAndGroupsConsistency(t *testing.T) {
require.NoError(t, err)
// Simulate the SSH server flow: lookup user, then get their groups.
resolved, err := lookupWithGetent(current.Username)
resolved, err := LookupWithGetent(current.Username)
require.NoError(t, err)
groups, err := groupIdsWithFallback(resolved)
groups, err := GroupIdsWithFallback(resolved)
require.NoError(t, err)
require.NotEmpty(t, groups, "resolved user must have groups")
@@ -154,19 +154,3 @@ func TestIntegration_LookupAndGroupsConsistency(t *testing.T) {
}
}
}
// TestIntegration_ShellLookupChain tests the full shell resolution chain
// (getShellFromPasswd -> getShellFromGetent -> $SHELL -> default) on Unix.
func TestIntegration_ShellLookupChain(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("Unix shell lookup not applicable on Windows")
}
current, err := user.Current()
require.NoError(t, err)
// getUserShell is the top-level function used by the SSH server.
shell := getUserShell(current.Uid)
require.NotEmpty(t, shell, "getUserShell must always return a shell")
assert.True(t, shell[0] == '/', "shell should be an absolute path, got %q", shell)
}

View File

@@ -1,6 +1,6 @@
//go:build !windows
package server
package shell
import (
"context"
@@ -14,19 +14,25 @@ import (
const getentTimeout = 5 * time.Second
// getShellFromGetent gets a user's login shell via getent by UID.
// GetShellFromGetent gets a user's login shell via getent by UID.
// This is needed even with CGO because getShellFromPasswd reads /etc/passwd
// directly and won't find NSS-provided users there.
func getShellFromGetent(userID string) string {
_, shell, err := runGetent(userID)
func GetShellFromGetent(userID string) string {
_, shell, err := runGetentPasswd(userID)
if err != nil {
return ""
}
return shell
}
// runGetent executes `getent passwd <query>` and returns the user and login shell.
func runGetent(query string) (*user.User, string, error) {
// GetUserFromGetent returns the resolved group from either a uid or username
func GetUserFromGetent(user string) (*user.User, error) {
u, _, err := runGetentPasswd(user)
return u, err
}
// runGetentPasswd executes `getent passwd <query>` and returns the user and login shell.
func runGetentPasswd(query string) (*user.User, string, error) {
if !validateGetentInput(query) {
return nil, "", fmt.Errorf("invalid getent input: %q", query)
}
@@ -42,6 +48,23 @@ func runGetent(query string) (*user.User, string, error) {
return parseGetentPasswd(string(out))
}
// runGetentGroup executes `getent group <query>` and returns the group
func runGetentGroup(query string) (*user.Group, error) {
if !validateGetentInput(query) {
return nil, fmt.Errorf("invalid getent input: %q", query)
}
ctx, cancel := context.WithTimeout(context.Background(), getentTimeout)
defer cancel()
out, err := exec.CommandContext(ctx, "getent", "group", query).Output()
if err != nil {
return nil, fmt.Errorf("getent passwd%s: %w", query, err)
}
return parseGetentGroup(string(out))
}
// parseGetentPasswd parses getent passwd output: "name:x:uid:gid:gecos:home:shell"
func parseGetentPasswd(output string) (*user.User, string, error) {
fields := strings.SplitN(strings.TrimSpace(output), ":", 8)
@@ -67,6 +90,20 @@ func parseGetentPasswd(output string) (*user.User, string, error) {
}, shell, nil
}
// parseGetentGroup parses getent group output: "group:x:gid:user"
func parseGetentGroup(output string) (*user.Group, error) {
fields := strings.SplitN(strings.TrimSpace(output), ":", 8)
if len(fields) < 4 {
return nil, fmt.Errorf("unexpected getent output (need 4+ fields): %q", output)
}
if fields[0] == "" || fields[2] == "" {
return nil, fmt.Errorf("missing required fields in getent output: %q", output)
}
return &user.Group{Gid: fields[2], Name: fields[0]}, nil
}
// validateGetentInput checks that the input is safe to pass to getent or id.
// Allows POSIX usernames, numeric UIDs, and common NSS extensions
// (@ for Kerberos, $ for Samba, + for NIS compat).

View File

@@ -1,6 +1,6 @@
//go:build !windows
package server
package shell
import (
"os/exec"
@@ -195,7 +195,7 @@ func TestRunGetent_RootUser(t *testing.T) {
t.Skip("getent not available on this system")
}
u, shell, err := runGetent("root")
u, shell, err := runGetentPasswd("root")
require.NoError(t, err)
assert.Equal(t, "root", u.Username)
assert.Equal(t, "0", u.Uid)
@@ -208,7 +208,7 @@ func TestRunGetent_ByUID(t *testing.T) {
t.Skip("getent not available on this system")
}
u, _, err := runGetent("0")
u, _, err := runGetentPasswd("0")
require.NoError(t, err)
assert.Equal(t, "root", u.Username)
assert.Equal(t, "0", u.Uid)
@@ -219,15 +219,15 @@ func TestRunGetent_NonexistentUser(t *testing.T) {
t.Skip("getent not available on this system")
}
_, _, err := runGetent("nonexistent_user_xyzzy_12345")
_, _, err := runGetentPasswd("nonexistent_user_xyzzy_12345")
assert.Error(t, err)
}
func TestRunGetent_InvalidInput(t *testing.T) {
_, _, err := runGetent("")
_, _, err := runGetentPasswd("")
assert.Error(t, err)
_, _, err = runGetent("user\x00name")
_, _, err = runGetentPasswd("user\x00name")
assert.Error(t, err)
}
@@ -236,7 +236,7 @@ func TestRunGetent_NotAvailable(t *testing.T) {
t.Skip("getent is available, can't test missing case")
}
_, _, err := runGetent("root")
_, _, err := runGetentPasswd("root")
assert.Error(t, err, "should fail when getent is not installed")
}
@@ -283,7 +283,7 @@ func TestGetentResultsMatchStdlib(t *testing.T) {
current, err := user.Current()
require.NoError(t, err)
getentUser, _, err := runGetent(current.Username)
getentUser, _, err := runGetentPasswd(current.Username)
require.NoError(t, err)
assert.Equal(t, current.Username, getentUser.Username, "username should match")
@@ -300,7 +300,7 @@ func TestGetentResultsMatchStdlib_ByUID(t *testing.T) {
current, err := user.Current()
require.NoError(t, err)
getentUser, _, err := runGetent(current.Uid)
getentUser, _, err := runGetentPasswd(current.Uid)
require.NoError(t, err)
assert.Equal(t, current.Username, getentUser.Username, "username should match when looked up by UID")
@@ -356,7 +356,7 @@ func TestGetShellFromPasswd_CurrentUser(t *testing.T) {
assert.True(t, shell[0] == '/', "shell should be an absolute path, got %q", shell)
if _, err := exec.LookPath("getent"); err == nil {
_, getentShell, getentErr := runGetent(current.Uid)
_, getentShell, getentErr := runGetentPasswd(current.Uid)
if getentErr == nil && getentShell != "" {
assert.Equal(t, getentShell, shell, "shell from /etc/passwd should match getent")
}
@@ -400,7 +400,7 @@ func TestGetShellFromPasswd_MatchesGetentForKnownUsers(t *testing.T) {
continue
}
_, getentShell, err := runGetent(uid)
_, getentShell, err := runGetentPasswd(uid)
if err != nil {
continue
}

View File

@@ -1,26 +1,26 @@
//go:build windows
package server
package shell
import "os/user"
// lookupWithGetent on Windows just delegates to os/user.Lookup.
// Windows does not use NSS/getent; its user lookup works without CGO.
func lookupWithGetent(username string) (*user.User, error) {
func LookupWithGetent(username string) (*user.User, error) {
return user.Lookup(username)
}
// currentUserWithGetent on Windows just delegates to os/user.Current.
func currentUserWithGetent() (*user.User, error) {
func CurrentUserWithGetent() (*user.User, error) {
return user.Current()
}
// getShellFromGetent is a no-op on Windows; shell resolution uses PowerShell detection.
func getShellFromGetent(_ string) string {
func GetShellFromGetent(_ string) string {
return ""
}
// groupIdsWithFallback on Windows just delegates to u.GroupIds().
func groupIdsWithFallback(u *user.User) ([]string, error) {
func GroupIdsWithFallback(u *user.User) ([]string, error) {
return u.GroupIds()
}

View File

@@ -1,17 +1,14 @@
package server
package shell
import (
"bufio"
"fmt"
"net"
"os"
"os/exec"
"os/user"
"runtime"
"strconv"
"strings"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
)
@@ -24,7 +21,7 @@ const (
// getUserShell returns the appropriate shell for the given user ID
// Handles all platform-specific logic and fallbacks consistently
func getUserShell(userID string) string {
func GetUserShell(userID string) string {
switch runtime.GOOS {
case "windows":
return getWindowsUserShell()
@@ -56,7 +53,7 @@ func getUnixUserShell(userID string) string {
return shell
}
if shell := getShellFromGetent(userID); shell != "" {
if shell := GetShellFromGetent(userID); shell != "" {
return shell
}
@@ -101,8 +98,8 @@ func getShellFromPasswd(userID string) string {
return ""
}
// prepareUserEnv prepares environment variables for user execution
func prepareUserEnv(user *user.User, shell string) []string {
// PrepareUserEnv prepares environment variables for user execution
func PrepareUserEnv(user *user.User, shell string) []string {
pathValue := "/usr/local/bin:/usr/bin:/bin:/usr/local/games:/usr/games"
if runtime.GOOS == "windows" {
pathValue = `C:\Windows\System32;C:\Windows;C:\Windows\System32\Wbem;C:\Windows\System32\WindowsPowerShell\v1.0`
@@ -119,7 +116,7 @@ func prepareUserEnv(user *user.User, shell string) []string {
// acceptEnv checks if environment variable from SSH client should be accepted
// This is a whitelist of variables that SSH clients can send to the server
func acceptEnv(envVar string) bool {
func AcceptEnv(envVar string) bool {
varName := envVar
if idx := strings.Index(envVar, "="); idx != -1 {
varName = envVar[:idx]
@@ -156,29 +153,3 @@ func acceptEnv(envVar string) bool {
return false
}
// prepareSSHEnv prepares SSH protocol-specific environment variables
// These variables provide information about the SSH connection itself
func prepareSSHEnv(session ssh.Session) []string {
remoteAddr := session.RemoteAddr()
localAddr := session.LocalAddr()
remoteHost, remotePort, err := net.SplitHostPort(remoteAddr.String())
if err != nil {
remoteHost = remoteAddr.String()
remotePort = "0"
}
localHost, localPort, err := net.SplitHostPort(localAddr.String())
if err != nil {
localHost = localAddr.String()
localPort = strconv.Itoa(InternalSSHPort)
}
return []string{
// SSH_CLIENT format: "client_ip client_port server_port"
fmt.Sprintf("SSH_CLIENT=%s %s %s", remoteHost, remotePort, localPort),
// SSH_CONNECTION format: "client_ip client_port server_ip server_port"
fmt.Sprintf("SSH_CONNECTION=%s %s %s %s", remoteHost, remotePort, localHost, localPort),
}
}

View File

@@ -0,0 +1,26 @@
package shell
import (
"os/user"
"runtime"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestIntegration_ShellLookupChain tests the full shell resolution chain
// (getShellFromPasswd -> getShellFromGetent -> $SHELL -> default) on Unix.
func TestIntegration_ShellLookupChain(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("Unix shell lookup not applicable on Windows")
}
current, err := user.Current()
require.NoError(t, err)
// getUserShell is the top-level function used by the SSH server.
shell := GetUserShell(current.Uid)
require.NotEmpty(t, shell, "getUserShell must always return a shell")
assert.True(t, shell[0] == '/', "shell should be an absolute path, got %q", shell)
}

View File

@@ -218,7 +218,7 @@ func (s *Server) Start() error {
// connectWithRetryRuns runs the client connection with a backoff strategy where we retry the operation as additional
// mechanism to keep the client connected even when the connection is lost.
// we cancel retry if the client receive a stop or down command.
// we cancel retry if the client receive a stop or down command, or if disable auto connect is configured.
func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}, giveUpChan chan struct{}) {
defer func() {
s.mutex.Lock()
@@ -226,6 +226,14 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil
s.mutex.Unlock()
}()
if s.config.DisableAutoConnect {
if err := s.connect(ctx, s.config, s.statusRecorder, runningChan); err != nil {
log.Debugf("run client connection exited with error: %v", err)
}
log.Tracef("client connection exited")
return
}
backOff := getConnectWithBackoff(ctx)
go func() {
t := time.NewTicker(24 * time.Hour)
@@ -1655,7 +1663,6 @@ func (s *Server) connect(ctx context.Context, config *profilemanager.Config, sta
s.mutex.Unlock()
if err := client.Run(runningChan, s.logFile); err != nil {
log.Debugf("run client connection exited with error: %v", err)
return err
}
return nil

View File

@@ -19,6 +19,7 @@ import (
"github.com/creack/pty"
"github.com/gliderlabs/ssh"
"github.com/netbirdio/netbird/client/internal/shell"
log "github.com/sirupsen/logrus"
)
@@ -146,10 +147,10 @@ func (s *Server) createShellCommand(ctx context.Context, shell string, args []st
// prepareCommandEnv prepares environment variables for command execution on Unix
func (s *Server) prepareCommandEnv(_ *log.Entry, localUser *user.User, session ssh.Session) []string {
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
env := shell.PrepareUserEnv(localUser, shell.GetUserShell(localUser.Uid))
env = append(env, prepareSSHEnv(session)...)
for _, v := range session.Environ() {
if acceptEnv(v) {
if shell.AcceptEnv(v) {
env = append(env, v)
}
}

View File

@@ -247,10 +247,10 @@ func (s *Server) prepareCommandEnv(logger *log.Entry, localUser *user.User, sess
userEnv, err := s.getUserEnvironment(logger, username, domain)
if err != nil {
log.Debugf("failed to get user environment for %s\\%s, using fallback: %v", domain, username, err)
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
env := shell.PrepareUserEnv(localUser, shell.GetUserShell(localUser.Uid))
env = append(env, prepareSSHEnv(session)...)
for _, v := range session.Environ() {
if acceptEnv(v) {
if shell.AcceptEnv(v) {
env = append(env, v)
}
}
@@ -260,7 +260,7 @@ func (s *Server) prepareCommandEnv(logger *log.Entry, localUser *user.User, sess
env := userEnv
env = append(env, prepareSSHEnv(session)...)
for _, v := range session.Environ() {
if acceptEnv(v) {
if shell.AcceptEnv(v) {
env = append(env, v)
}
}
@@ -273,7 +273,7 @@ func (s *Server) handlePtyLogin(logger *log.Entry, session ssh.Session, privileg
return false
}
shell := getUserShell(privilegeResult.User.Uid)
shell := shell.GetUserShell(privilegeResult.User.Uid)
logger.Infof("starting interactive shell: %s", shell)
s.executeCommandWithPty(logger, session, nil, privilegeResult, ptyReq, nil)
@@ -384,7 +384,7 @@ func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, _
}
username, domain := s.parseUsername(localUser.Username)
shell := getUserShell(localUser.Uid)
shell := shell.GetUserShell(localUser.Uid)
req := PtyExecutionRequest{
Shell: shell,

View File

@@ -3,11 +3,15 @@ package server
import (
"errors"
"fmt"
"net"
"os"
"os/user"
"runtime"
"strconv"
"strings"
"github.com/gliderlabs/ssh"
"github.com/netbirdio/netbird/client/internal/shell"
log "github.com/sirupsen/logrus"
)
@@ -23,8 +27,8 @@ func isPlatformUnix() bool {
// Dependency injection variables for testing - allows mocking dynamic runtime checks
var (
getCurrentUser = currentUserWithGetent
lookupUser = lookupWithGetent
getCurrentUser = shell.CurrentUserWithGetent
lookupUser = shell.LookupWithGetent
getCurrentOS = func() string { return runtime.GOOS }
getIsProcessPrivileged = isCurrentProcessPrivileged
@@ -409,3 +413,29 @@ func isWindowsElevated() bool {
log.Debugf("Windows user switching not supported: not running as privileged user (current: %s)", currentUser.Uid)
return false
}
// prepareSSHEnv prepares SSH protocol-specific environment variables
// These variables provide information about the SSH connection itself
func prepareSSHEnv(session ssh.Session) []string {
remoteAddr := session.RemoteAddr()
localAddr := session.LocalAddr()
remoteHost, remotePort, err := net.SplitHostPort(remoteAddr.String())
if err != nil {
remoteHost = remoteAddr.String()
remotePort = "0"
}
localHost, localPort, err := net.SplitHostPort(localAddr.String())
if err != nil {
localHost = localAddr.String()
localPort = strconv.Itoa(InternalSSHPort)
}
return []string{
// SSH_CLIENT format: "client_ip client_port server_port"
fmt.Sprintf("SSH_CLIENT=%s %s %s", remoteHost, remotePort, localPort),
// SSH_CONNECTION format: "client_ip client_port server_ip server_port"
fmt.Sprintf("SSH_CONNECTION=%s %s %s %s", remoteHost, remotePort, localHost, localPort),
}
}

View File

@@ -15,6 +15,7 @@ import (
"strconv"
"github.com/gliderlabs/ssh"
"github.com/netbirdio/netbird/client/internal/shell"
log "github.com/sirupsen/logrus"
)
@@ -160,7 +161,7 @@ func (s *Server) parseUserCredentials(localUser *user.User) (uint32, uint32, []u
// getSupplementaryGroups retrieves supplementary group IDs for a user.
// Uses id/getent fallback for NSS users in CGO_ENABLED=0 builds.
func (s *Server) getSupplementaryGroups(u *user.User) ([]uint32, error) {
groupIDStrings, err := groupIdsWithFallback(u)
groupIDStrings, err := shell.GroupIdsWithFallback(u)
if err != nil {
return nil, fmt.Errorf("get group IDs for user %s: %w", u.Username, err)
}
@@ -196,7 +197,7 @@ func (s *Server) createExecutorCommand(logger *log.Entry, session ssh.Session, l
GID: gid,
Groups: groups,
WorkingDir: localUser.HomeDir,
Shell: getUserShell(localUser.Uid),
Shell: shell.GetUserShell(localUser.Uid),
Command: session.RawCommand(),
PTY: hasPty,
}
@@ -228,7 +229,7 @@ func (s *Server) createPtyCommand(privilegeResult PrivilegeCheckResult, ptyReq s
func (s *Server) createDirectPtyCommand(session ssh.Session, localUser *user.User, ptyReq ssh.Pty) *exec.Cmd {
log.Debugf("creating direct Pty command for user %s (no user switching needed)", localUser.Username)
shell := getUserShell(localUser.Uid)
shell := shell.GetUserShell(localUser.Uid)
args := s.getShellCommandArgs(shell, session.RawCommand())
cmd := s.createShellCommand(session.Context(), shell, args)
@@ -245,12 +246,12 @@ func (s *Server) preparePtyEnv(localUser *user.User, ptyReq ssh.Pty, session ssh
termType = "xterm-256color"
}
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
env := shell.PrepareUserEnv(localUser, shell.GetUserShell(localUser.Uid))
env = append(env, prepareSSHEnv(session)...)
env = append(env, fmt.Sprintf("TERM=%s", termType))
for _, v := range session.Environ() {
if acceptEnv(v) {
if shell.AcceptEnv(v) {
env = append(env, v)
}
}

View File

@@ -13,6 +13,8 @@ import (
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
"github.com/netbirdio/netbird/client/internal/shell"
)
// validateUsername validates Windows usernames according to SAM Account Name rules
@@ -104,7 +106,7 @@ func (s *Server) createExecutorCommand(logger *log.Entry, session ssh.Session, l
func (s *Server) createUserSwitchCommand(logger *log.Entry, session ssh.Session, localUser *user.User) (*exec.Cmd, func(), error) {
username, domain := s.parseUsername(localUser.Username)
shell := getUserShell(localUser.Uid)
sh := shell.GetUserShell(localUser.Uid)
rawCmd := session.RawCommand()
var command string
@@ -116,7 +118,7 @@ func (s *Server) createUserSwitchCommand(logger *log.Entry, session ssh.Session,
Username: username,
Domain: domain,
WorkingDir: localUser.HomeDir,
Shell: shell,
Shell: sh,
Command: command,
}

View File

@@ -99,9 +99,6 @@ func addFields(entry *logrus.Entry) {
if ctxAccountID, ok := entry.Context.Value(context.AccountIDKey).(string); ok {
entry.Data[context.AccountIDKey] = ctxAccountID
}
if ctxUserAgent, ok := entry.Context.Value(context.UserAgentKey).(string); ok {
entry.Data[context.UserAgentKey] = ctxUserAgent
}
if ctxInitiatorID, ok := entry.Context.Value(context.UserIDKey).(string); ok {
entry.Data[context.UserIDKey] = ctxInitiatorID
}

6
go.mod
View File

@@ -2,8 +2,6 @@ module github.com/netbirdio/netbird
go 1.25.5
toolchain go1.25.11
require (
cunicu.li/go-rosenpass v0.5.42
github.com/cenkalti/backoff/v4 v4.3.0
@@ -56,7 +54,6 @@ require (
github.com/fsnotify/fsnotify v1.9.0
github.com/gliderlabs/ssh v0.3.8
github.com/go-jose/go-jose/v4 v4.1.4
github.com/goccy/go-yaml v1.18.0
github.com/godbus/dbus/v5 v5.1.0
github.com/golang-jwt/jwt/v5 v5.3.1
github.com/golang/mock v1.6.0
@@ -214,9 +211,10 @@ require (
github.com/go-viper/mapstructure/v2 v2.5.0 // indirect
github.com/go-webauthn/webauthn v0.16.4 // indirect
github.com/go-webauthn/x v0.2.3 // indirect
github.com/goccy/go-yaml v1.18.0 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang-jwt/jwt/v4 v4.5.2 // indirect
github.com/google/btree v1.1.3 // indirect
github.com/google/btree v1.1.2 // indirect
github.com/google/go-querystring v1.1.0 // indirect
github.com/google/go-tpm v0.9.8 // indirect
github.com/google/s2a-go v0.1.9 // indirect

4
go.sum
View File

@@ -275,8 +275,8 @@ github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiu
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=
github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU=
github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=

View File

@@ -488,195 +488,6 @@ func TestUpdate_AllowsPortChange(t *testing.T) {
assert.Equal(t, uint16(54321), updated.ListenPort, "explicit port change should be applied")
}
func TestUpdate_PreservesPortWhenCustomPortsNotSupported(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(false))
ctx := context.Background()
existing := seedService(t, testStore, "tcp-svc", "tcp", testCluster, testCluster, 12345)
updated := &rpservice.Service{
ID: existing.ID,
AccountID: testAccountID,
Name: "tcp-svc-renamed",
Mode: "tcp",
Domain: testCluster,
ProxyCluster: testCluster,
ListenPort: 0,
Enabled: true,
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true},
},
}
_, err := mgr.persistServiceUpdate(ctx, testAccountID, updated)
require.NoError(t, err, "update must not be rejected by the custom-port capability check")
assert.Equal(t, uint16(12345), updated.ListenPort, "existing listen port should be preserved on unsupported cluster")
}
func TestUpdate_PreservesPortWhenCustomPortsUnknown(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, nil)
ctx := context.Background()
existing := seedService(t, testStore, "tcp-svc", "tcp", testCluster, testCluster, 12345)
updated := &rpservice.Service{
ID: existing.ID,
AccountID: testAccountID,
Name: "tcp-svc-renamed",
Mode: "tcp",
Domain: testCluster,
ProxyCluster: testCluster,
ListenPort: 0,
Enabled: true,
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true},
},
}
_, err := mgr.persistServiceUpdate(ctx, testAccountID, updated)
require.NoError(t, err, "update must not be rejected when cluster capability is unknown")
assert.Equal(t, uint16(12345), updated.ListenPort, "existing listen port should be preserved when capability is unknown")
}
func TestUpdate_RejectsPortChangeWhenCustomPortsNotSupported(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(false))
ctx := context.Background()
existing := seedService(t, testStore, "tcp-svc", "tcp", testCluster, testCluster, 12345)
updated := &rpservice.Service{
ID: existing.ID,
AccountID: testAccountID,
Name: "tcp-svc",
Mode: "tcp",
Domain: testCluster,
ProxyCluster: testCluster,
ListenPort: 54321,
Enabled: true,
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true},
},
}
_, err := mgr.persistServiceUpdate(ctx, testAccountID, updated)
require.Error(t, err, "explicit port change on update must be rejected on unsupported clusters")
assert.Contains(t, err.Error(), "custom ports not supported on target cluster")
}
func TestUpdate_TLSPortChangeAllowedWhenNotSupported(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(false))
ctx := context.Background()
existing := seedService(t, testStore, "tls-svc", "tls", "app.example.com", testCluster, 443)
updated := &rpservice.Service{
ID: existing.ID,
AccountID: testAccountID,
Name: "tls-svc",
Mode: "tls",
Domain: "app.example.com",
ProxyCluster: testCluster,
ListenPort: 9999,
Enabled: true,
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8443, Enabled: true},
},
}
_, err := mgr.persistServiceUpdate(ctx, testAccountID, updated)
require.NoError(t, err, "TLS port change uses SNI routing and is exempt from the custom-port check")
assert.Equal(t, uint16(9999), updated.ListenPort, "TLS port change should be applied")
}
func TestValidateL4PortDiffOnClusterDiff(t *testing.T) {
tests := []struct {
name string
mode string
customPorts *bool
newPort uint16
oldPort uint16
wantErr bool
}{
{"tcp port change unsupported", "tcp", boolPtr(false), 54321, 12345, true},
{"tcp port change unknown capability", "tcp", nil, 54321, 12345, true},
{"udp port change unsupported", "udp", boolPtr(false), 54321, 12345, true},
{"tcp first port assignment unsupported", "tcp", boolPtr(false), 54321, 0, true},
{"tcp port change supported", "tcp", boolPtr(true), 54321, 12345, false},
{"tcp port unchanged unsupported", "tcp", boolPtr(false), 12345, 12345, false},
{"tcp zero port unsupported", "tcp", boolPtr(false), 0, 12345, false},
{"tls port change unsupported", "tls", boolPtr(false), 9999, 443, false},
{"http mode ignored", "http", boolPtr(false), 54321, 12345, false},
{"empty mode ignored", "", boolPtr(false), 54321, 12345, false},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
newSvc := &rpservice.Service{Mode: tc.mode, ListenPort: tc.newPort, ProxyCluster: testCluster}
oldSvc := &rpservice.Service{Mode: tc.mode, ListenPort: tc.oldPort, ProxyCluster: testCluster}
err := validateL4PortDiffOnClusterDiff(tc.customPorts, newSvc, oldSvc)
if tc.wantErr {
assert.Error(t, err, "port diff should be rejected for %s", tc.name)
} else {
assert.NoError(t, err, "port diff should be allowed for %s", tc.name)
}
})
}
}
func TestUpdate_PortConflictRejected(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
seedService(t, testStore, "tcp-a", "tcp", "tcp-a."+testCluster, testCluster, 5432)
svcB := seedService(t, testStore, "tcp-b", "tcp", "tcp-b."+testCluster, testCluster, 6543)
updated := &rpservice.Service{
ID: svcB.ID,
AccountID: testAccountID,
Name: "tcp-b",
Mode: "tcp",
Domain: "tcp-b." + testCluster,
ProxyCluster: testCluster,
ListenPort: 5432,
Enabled: true,
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true},
},
}
_, err := mgr.persistServiceUpdate(ctx, testAccountID, updated)
require.Error(t, err, "updating to a port held by another service should be rejected")
assert.Contains(t, err.Error(), "already in use")
}
func TestUpdate_AutoAssignsWhenNoPort(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(false))
ctx := context.Background()
existing := seedService(t, testStore, "tcp-svc", "tcp", testCluster, testCluster, 0)
updated := &rpservice.Service{
ID: existing.ID,
AccountID: testAccountID,
Name: "tcp-svc",
Mode: "tcp",
Domain: testCluster,
ProxyCluster: testCluster,
ListenPort: 0,
Enabled: true,
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true},
},
}
_, err := mgr.persistServiceUpdate(ctx, testAccountID, updated)
require.NoError(t, err)
assert.True(t, updated.ListenPort >= autoAssignPortMin && updated.ListenPort <= autoAssignPortMax,
"auto-assigned port %d should be in range [%d, %d]", updated.ListenPort, autoAssignPortMin, autoAssignPortMax)
assert.True(t, updated.PortAutoAssigned, "PortAutoAssigned should be set when update triggers auto-assignment")
}
func TestCreateServiceFromPeer_TCP(t *testing.T) {
mgr, _, _ := setupL4Test(t, boolPtr(false))
ctx := context.Background()

View File

@@ -338,7 +338,7 @@ func (m *Manager) persistNewService(ctx context.Context, accountID string, svc *
}
}
if err := m.ensureL4Port(ctx, transaction, svc, customPorts, false); err != nil {
if err := m.ensureL4Port(ctx, transaction, svc, customPorts); err != nil {
return err
}
@@ -367,11 +367,11 @@ func (m *Manager) clusterCustomPorts(ctx context.Context, svc *service.Service)
// ensureL4Port auto-assigns a listen port when needed and validates cluster support.
// customPorts must be pre-computed via clusterCustomPorts before entering a transaction.
func (m *Manager) ensureL4Port(ctx context.Context, tx store.Store, svc *service.Service, customPorts *bool, serviceUpdate bool) error {
func (m *Manager) ensureL4Port(ctx context.Context, tx store.Store, svc *service.Service, customPorts *bool) error {
if !service.IsL4Protocol(svc.Mode) {
return nil
}
if service.IsPortBasedProtocol(svc.Mode) && svc.ListenPort > 0 && !serviceUpdate && (customPorts == nil || !*customPorts) {
if service.IsPortBasedProtocol(svc.Mode) && svc.ListenPort > 0 && (customPorts == nil || !*customPorts) {
if svc.Source != service.SourceEphemeral {
return status.Errorf(status.InvalidArgument, "custom ports not supported on cluster %s", svc.ProxyCluster)
}
@@ -465,7 +465,7 @@ func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, pee
return err
}
if err := m.ensureL4Port(ctx, transaction, svc, customPorts, false); err != nil {
if err := m.ensureL4Port(ctx, transaction, svc, customPorts); err != nil {
return err
}
@@ -651,22 +651,12 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St
m.preserveListenPort(service, existingService)
updateInfo.serviceEnabledChanged = existingService.Enabled != service.Enabled
// if the service is being updated, and we decide in the future to allow mode update,
// we should reconsider the currently assigned port if not 0 for clusters that don't support custom ports
if err := validateL4PortDiffOnClusterDiff(customPorts, service, existingService); err != nil {
if err := m.ensureL4Port(ctx, transaction, service, customPorts); err != nil {
return err
}
if err := m.ensureL4Port(ctx, transaction, service, customPorts, true); err != nil {
return err
}
// we can try carrying the previous service port into a new cluster, if this becomes a problem for multiple users,
// we should reconsider adding another check
if err := m.checkPortConflict(ctx, transaction, service); err != nil {
return err
}
if err := transaction.UpdateService(ctx, service); err != nil {
return fmt.Errorf("update service: %w", err)
}
@@ -674,21 +664,6 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St
return nil
}
// validateL4PortDiffOnClusterDiff checks if custom L4 ports are configured and validates port changes across clusters.
// It ensures no port changes if custom ports are unsupported for a given cluster and protocol mode.
// Returns an error if validation fails, otherwise returns nil.
func validateL4PortDiffOnClusterDiff(customPorts *bool, newSVC, oldSVC *service.Service) error {
if !service.IsPortBasedProtocol(newSVC.Mode) || (customPorts != nil && *customPorts) {
return nil
}
if newSVC.ListenPort != 0 && newSVC.ListenPort != oldSVC.ListenPort {
return status.Errorf(status.InvalidArgument, "custom ports not supported on target cluster %s", newSVC.ProxyCluster)
}
return nil
}
// handleDomainChange validates the new domain is free inside the transaction
// and applies the pre-resolved cluster (computed outside the tx by
// resolveEffectiveCluster). It must NOT call clusterDeriver here: that talks

View File

@@ -8,8 +8,6 @@ import (
"strings"
"time"
"github.com/hashicorp/go-version"
nbversion "github.com/netbirdio/netbird/version"
log "github.com/sirupsen/logrus"
goproto "google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
@@ -30,23 +28,6 @@ import (
"github.com/netbirdio/netbird/shared/sshauth"
)
const (
// deprecatedRemotePeersVersion is the version of Netbird that introduced the NetworkMap.RemotePeers field, deprecated in favor of RemotePeers.
deprecatedRemotePeersVersion = "0.29.3"
)
// precomputedDeprecatedRemotePeersConstraint is the parsed ">= 0.29.3" constraint,
// built once at init since the bound is a compile-time constant.
var precomputedDeprecatedRemotePeersConstraint version.Constraints
func init() {
constraint, err := version.NewConstraint(">= " + deprecatedRemotePeersVersion)
if err != nil {
panic("parse deprecated remote peers version constraint: " + err.Error())
}
precomputedDeprecatedRemotePeersConstraint = constraint
}
func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig {
if config == nil {
return nil
@@ -174,11 +155,7 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
remotePeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers))
remotePeers = appendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName, includeIPv6)
if !shouldSkipSendingDeprecatedRemotePeers(peer.Meta.WtVersion) {
response.RemotePeers = remotePeers
}
response.RemotePeers = remotePeers
response.NetworkMap.RemotePeers = remotePeers
response.RemotePeersIsEmpty = len(remotePeers) == 0
response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty
@@ -269,19 +246,6 @@ func buildAuthorizedUsersProto(ctx context.Context, authorizedUsers map[string]m
return hashedUsers, machineUsers
}
func shouldSkipSendingDeprecatedRemotePeers(peerVersion string) bool {
if nbversion.IsDevelopmentVersion(peerVersion) {
return true
}
peerNBVersion, err := version.NewVersion(peerVersion)
if err != nil {
return false
}
return precomputedDeprecatedRemotePeersConstraint.Check(peerNBVersion)
}
func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string, includeIPv6 bool) []*proto.RemotePeerConfig {
for _, rPeer := range peers {
allowedIPs := []string{rPeer.IP.String() + "/32"}
@@ -399,6 +363,7 @@ func toProtocolFirewallRules(rules []*types.FirewallRule, includeIPv6, useSource
return result
}
// populateSourcePrefixes sets SourcePrefixes on fwRule and returns any
// additional rules needed (e.g. a v6 wildcard clone when the peer IP is unspecified).
func populateSourcePrefixes(fwRule *proto.FirewallRule, rule *types.FirewallRule, includeIPv6 bool) []*proto.FirewallRule {

View File

@@ -202,42 +202,6 @@ func TestBuildJWTConfig_Audiences(t *testing.T) {
}
}
// TestShouldSkipSendingDeprecatedRemotePeers covers the version gate that
// stops populating the deprecated top-level SyncResponse.RemotePeers field for
// peers new enough to read RemotePeers off the NetworkMap. Development builds
// are treated as latest and skip the field. The gate otherwise fails safe: a
// release version older than the boundary, or one that can't be parsed (empty,
// garbage, prereleases of the boundary) still receives the deprecated field so
// older/unknown clients keep working.
func TestShouldSkipSendingDeprecatedRemotePeers(t *testing.T) {
tests := []struct {
name string
peerVersion string
wantSkip bool
}{
{"exact boundary skips", "0.29.3", true},
{"newer patch skips", "0.29.4", true},
{"newer minor skips", "0.30.0", true},
{"newer major skips", "1.0.0", true},
{"v-prefixed newer skips", "v0.30.0", true},
{"development build skips", "development", true},
{"development build with commit skips", "development-abc123def456-dirty", true},
{"older patch keeps field", "0.29.2", false},
{"older minor keeps field", "0.28.0", false},
{"prerelease of boundary keeps field", "0.29.3-SNAPSHOT", false},
{"tagged dev prerelease keeps field", "v0.31.1-dev", false},
{"empty version keeps field", "", false},
{"garbage version keeps field", "not-a-version", false},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := shouldSkipSendingDeprecatedRemotePeers(tc.peerVersion)
assert.Equal(t, tc.wantSkip, got, "skip decision for peer version %q", tc.peerVersion)
})
}
}
// TestEncodeSessionExpiresAt pins the wire encoding the client's
// applySessionDeadline depends on:
//

View File

@@ -666,10 +666,8 @@ func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error)
case resp := <-conn.sendChan:
if err := conn.sendResponse(resp); err != nil {
errChan <- err
log.WithContext(conn.ctx).Tracef("Failed to send response to proxy %s: %v", conn.proxyID, err)
return
}
log.WithContext(conn.ctx).Tracef("Send response to proxy %s", conn.proxyID)
case <-conn.ctx.Done():
return
}

View File

@@ -12,7 +12,6 @@ const (
RoleKey = nbcontext.RoleKey
UserIDKey = nbcontext.UserIDKey
PeerIDKey = nbcontext.PeerIDKey
UserAgentKey = nbcontext.UserAgentKey
)
// RoleFromContext returns the role stored in ctx, or empty string and false if absent.

View File

@@ -21,8 +21,6 @@ const (
httpRequestCounterPrefix = "management.http.request.counter"
httpResponseCounterPrefix = "management.http.response.counter"
httpRequestDurationPrefix = "management.http.request.duration.ms"
RequestIDHeader = "X-Request-Id"
)
// WrappedResponseWriter is a wrapper for http.ResponseWriter that allows the
@@ -174,10 +172,6 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler {
reqID := xid.New().String()
//nolint
ctx = context.WithValue(ctx, nbContext.RequestIDKey, reqID)
//nolint
ctx = context.WithValue(ctx, nbContext.UserAgentKey, r.UserAgent())
rw.Header().Set(RequestIDHeader, reqID)
log.WithContext(ctx).Tracef("HTTP request %v: %v %v", reqID, r.Method, r.URL)

View File

@@ -557,6 +557,7 @@ func (c *NetworkMapComponents) getRoutingPeerRoutes(peerID string) (enabledRoute
return enabledRoutes, disabledRoutes
}
func (c *NetworkMapComponents) filterRoutesByGroups(routes []*route.Route, groupListMap LookupMap) []*route.Route {
var filteredRoutes []*route.Route
for _, r := range routes {
@@ -627,14 +628,9 @@ func (c *NetworkMapComponents) getDefaultPermit(r *route.Route, includeIPv6 bool
rules := []*RouteFirewallRule{&rule}
isDefaultV4 := r.Network.Addr().Is4() && r.Network.Bits() == 0
if includeIPv6 && (r.IsDynamic() || isDefaultV4) {
if includeIPv6 && r.IsDynamic() {
ruleV6 := rule
ruleV6.SourceRanges = []string{"::/0"}
if isDefaultV4 {
ruleV6.Destination = "::/0"
ruleV6.RouteID = r.ID + "-v6-default"
}
rules = append(rules, &ruleV6)
}

View File

@@ -5,7 +5,6 @@ import (
"fmt"
"net"
"net/netip"
"slices"
"testing"
"time"
@@ -1030,48 +1029,6 @@ func TestComponents_RouteDefaultPermit(t *testing.T) {
assert.True(t, hasDefaultPermit, "route without ACG should have default permit rule with 0.0.0.0/0 source")
}
// TestComponents_ExitNodeDefaultPermitIPv6 verifies that a default exit node route
// (0.0.0.0/0) without AccessControlGroups also emits an IPv6 default permit rule
// (::/0 source and destination) for peers that support IPv6, mirroring the route
// the client installs. Without it, IPv6 traffic is routed to the exit node but
// dropped at the forward chain.
func TestComponents_ExitNodeDefaultPermitIPv6(t *testing.T) {
account, validatedPeers := scalableTestAccount(20, 2)
routingPeerID := "peer-5"
routingPeer := account.Peers[routingPeerID]
routingPeer.IPv6 = netip.MustParseAddr("fd00::5")
routingPeer.Meta.Capabilities = append(routingPeer.Meta.Capabilities, nbpeer.PeerCapabilityIPv6Overlay)
account.Routes["route-exit"] = &route.Route{
ID: "route-exit", Network: netip.MustParsePrefix("0.0.0.0/0"),
PeerID: routingPeerID, Peer: routingPeer.Key,
Enabled: true, Groups: []string{"group-all"}, PeerGroups: []string{"group-0"},
AccessControlGroups: []string{},
AccountID: "test-account",
}
nm := componentsNetworkMap(account, routingPeerID, validatedPeers)
require.NotNil(t, nm)
hasV4 := false
hasV6 := false
for _, rfr := range nm.RoutesFirewallRules {
switch rfr.Destination {
case "0.0.0.0/0":
if slices.Contains(rfr.SourceRanges, "0.0.0.0/0") {
hasV4 = true
}
case "::/0":
if slices.Contains(rfr.SourceRanges, "::/0") {
hasV6 = true
}
}
}
assert.True(t, hasV4, "exit node route should have an IPv4 default permit rule (0.0.0.0/0)")
assert.True(t, hasV6, "exit node route should have an IPv6 default permit rule (::/0)")
}
// ──────────────────────────────────────────────────────────────────────────────
// 15. MULTIPLE ROUTERS PER NETWORK
// ──────────────────────────────────────────────────────────────────────────────

View File

@@ -249,7 +249,6 @@ func runServer(cmd *cobra.Command, args []string) error {
Private: private,
MaxDialTimeout: maxDialTimeout,
MaxSessionIdleTimeout: maxSessionIdleTimeout,
MappingBatchWatchdog: envDurationOrDefault("NB_PROXY_MAPPING_BATCH_WATCHDOG", 0),
GeoDataDir: geoDataDir,
CrowdSecAPIURL: crowdsecAPIURL,
CrowdSecAPIKey: crowdsecAPIKey,

View File

@@ -28,10 +28,6 @@ import (
const deviceNamePrefix = "ingress-proxy-"
const clientStopTimeout = 30 * time.Second
const createProxyPeerTimeout = 30 * time.Second
// backendKey identifies a backend by its host:port from the target URL.
type backendKey string
@@ -166,7 +162,6 @@ type NetBird struct {
clientsMux sync.RWMutex
clients map[types.AccountID]*clientEntry
lifecycleMu sync.Map
initLogOnce sync.Once
statusNotifier statusNotifier
// readyHandler runs after the embedded client for an account reports
@@ -182,10 +177,6 @@ type NetBird struct {
// (i.e. when a new client was actually created, not when an existing one
// was reused). The duration covers keygen + gRPC CreateProxyPeer + embed.New.
OnAddPeer func(d time.Duration, err error)
// startClient runs the post-create client startup. Nil uses runClientStartup;
// tests override it to avoid a real embed client.Start.
startClient func(accountID types.AccountID, client *embed.Client)
}
// ClientDebugInfo contains debug information about a client.
@@ -209,20 +200,31 @@ type skipTLSVerifyContextKey struct{}
func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key ServiceKey, authToken string, serviceID types.ServiceID) error {
si := serviceInfo{serviceID: serviceID}
if n.registerExistingClient(accountID, key, si) {
return nil
}
n.clientsMux.Lock()
lifecycle := n.accountLifecycle(accountID)
lifecycle.Lock()
transferred := false
defer func() {
if !transferred {
lifecycle.Unlock()
entry, exists := n.clients[accountID]
if exists {
entry.services[key] = si
started := entry.started
n.clientsMux.Unlock()
n.logger.WithFields(log.Fields{
"account_id": accountID,
"service_key": key,
}).Debug("registered service with existing client")
if started && n.statusNotifier != nil {
// Use a background context, not the caller's: the management
// connection notification must land even if the request /
// stream that triggered this registration is cancelled.
// Mirrors the async runClientStartup path.
if err := n.statusNotifier.NotifyStatus(context.Background(), accountID, serviceID, true); err != nil {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"service_key": key,
}).WithError(err).Warn("failed to notify status for existing client")
}
}
}()
if n.registerExistingClient(accountID, key, si) {
return nil
}
@@ -232,10 +234,10 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key Se
n.OnAddPeer(time.Since(createStart), err)
}
if err != nil {
n.clientsMux.Unlock()
return err
}
n.clientsMux.Lock()
n.clients[accountID] = entry
n.clientsMux.Unlock()
@@ -244,64 +246,17 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key Se
"service_key": key,
}).Info("created new client for account")
transferred = true
go func() {
defer lifecycle.Unlock()
n.startClientStartup(accountID, entry.client)
}()
// Attempt to start the client in the background; if this fails we will
// retry on the first request via RoundTrip. runClientStartup uses its
// own background context so the caller's request-scoped ctx can't
// cancel the inbound bring-up.
go n.runClientStartup(accountID, entry.client)
return nil
}
func (n *NetBird) startClientStartup(accountID types.AccountID, client *embed.Client) {
if n.startClient != nil {
n.startClient(accountID, client)
return
}
n.runClientStartup(accountID, client)
}
// registerExistingClient registers the service against an already-present
// client for the account and returns true when it did. It notifies management
// of the new service when the client is already started.
func (n *NetBird) registerExistingClient(accountID types.AccountID, key ServiceKey, si serviceInfo) bool {
n.clientsMux.Lock()
entry, exists := n.clients[accountID]
if !exists {
n.clientsMux.Unlock()
return false
}
entry.services[key] = si
started := entry.started
n.clientsMux.Unlock()
n.logger.WithFields(log.Fields{
"account_id": accountID,
"service_key": key,
}).Debug("registered service with existing client")
if started && n.statusNotifier != nil {
if err := n.statusNotifier.NotifyStatus(context.Background(), accountID, si.serviceID, true); err != nil {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"service_key": key,
}).WithError(err).Warn("failed to notify status for existing client")
}
}
return true
}
// accountLifecycle returns the per-account lifecycle mutex, serialising client
// creation against teardown so a slow client.Stop cannot race a new
// client.Start for the same account, without blocking clientsMux.
func (n *NetBird) accountLifecycle(accountID types.AccountID) *sync.Mutex {
mu, _ := n.lifecycleMu.LoadOrStore(accountID, &sync.Mutex{})
return mu.(*sync.Mutex)
}
// createClientEntry generates a WireGuard keypair, authenticates with management,
// and creates an embedded NetBird client. Must be called with the account's
// lifecycle mutex held.
// and creates an embedded NetBird client. Must be called with clientsMux held.
func (n *NetBird) createClientEntry(ctx context.Context, accountID types.AccountID, key ServiceKey, authToken string, si serviceInfo) (*clientEntry, error) {
serviceID := si.serviceID
n.logger.WithFields(log.Fields{
@@ -321,9 +276,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
"public_key": publicKey.String(),
}).Debug("authenticating new proxy peer with management")
createCtx, cancel := context.WithTimeout(ctx, createProxyPeerTimeout)
defer cancel()
resp, err := n.mgmtClient.CreateProxyPeer(createCtx, &proto.CreateProxyPeerRequest{
resp, err := n.mgmtClient.CreateProxyPeer(ctx, &proto.CreateProxyPeerRequest{
ServiceId: string(serviceID),
AccountId: string(accountID),
Token: authToken,
@@ -491,15 +444,6 @@ func (n *NetBird) notifyClientReady(accountID types.AccountID, client *embed.Cli
// RemovePeer unregisters a service from an account. The client is only stopped
// when no services are using it anymore.
func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, key ServiceKey) error {
lifecycle := n.accountLifecycle(accountID)
lifecycle.Lock()
transferred := false
defer func() {
if !transferred {
lifecycle.Unlock()
}
}()
n.clientsMux.Lock()
entry, exists := n.clients[accountID]
@@ -522,8 +466,17 @@ func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, key
delete(entry.services, key)
stopClient := len(entry.services) == 0
var client *embed.Client
var transport, insecureTransport *http.Transport
var inbound any
var stopHandler func(types.AccountID, any)
if stopClient {
n.logger.WithField("account_id", accountID).Info("stopping client, no more services")
client = entry.client
transport = entry.transport
insecureTransport = entry.insecureTransport
inbound = entry.inbound
stopHandler = n.stopHandler
delete(n.clients, accountID)
} else {
n.logger.WithFields(log.Fields{
@@ -537,40 +490,19 @@ func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, key
n.notifyDisconnect(ctx, accountID, key, si.serviceID)
if stopClient {
transferred = true
go n.stopClientLocked(accountID, lifecycle, entry)
if inbound != nil && stopHandler != nil {
stopHandler(accountID, inbound)
}
transport.CloseIdleConnections()
insecureTransport.CloseIdleConnections()
if err := client.Stop(ctx); err != nil {
n.logger.WithField("account_id", accountID).WithError(err).Warn("failed to stop netbird client")
}
}
return nil
}
// stopClientLocked releases a client's resources off the caller's goroutine so a
// slow client.Stop cannot wedge the mapping receive loop (which calls RemovePeer
// synchronously). It unlocks lifecycle when done so a new client.Start for the
// same account waits for this teardown.
func (n *NetBird) stopClientLocked(accountID types.AccountID, lifecycle *sync.Mutex, entry *clientEntry) {
defer lifecycle.Unlock()
if entry.inbound != nil && n.stopHandler != nil {
n.stopHandler(accountID, entry.inbound)
}
if entry.transport != nil {
entry.transport.CloseIdleConnections()
}
if entry.insecureTransport != nil {
entry.insecureTransport.CloseIdleConnections()
}
if entry.client == nil {
return
}
ctx, cancel := context.WithTimeout(context.Background(), clientStopTimeout)
defer cancel()
if err := entry.client.Stop(ctx); err != nil {
n.logger.WithField("account_id", accountID).WithError(err).Warn("failed to stop netbird client")
}
}
func (n *NetBird) notifyDisconnect(ctx context.Context, accountID types.AccountID, key ServiceKey, serviceID types.ServiceID) {
if n.statusNotifier == nil {
return

View File

@@ -6,7 +6,6 @@ import (
"net/netip"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -23,18 +22,6 @@ func (m *mockMgmtClient) CreateProxyPeer(_ context.Context, _ *proto.CreateProxy
return &proto.CreateProxyPeerResponse{Success: true}, nil
}
// signalMgmtClient closes entered the first time CreateProxyPeer is called, so
// tests can detect AddPeer reaching client creation.
type signalMgmtClient struct {
entered chan struct{}
once sync.Once
}
func (m *signalMgmtClient) CreateProxyPeer(_ context.Context, _ *proto.CreateProxyPeerRequest, _ ...grpc.CallOption) (*proto.CreateProxyPeerResponse, error) {
m.once.Do(func() { close(m.entered) })
return &proto.CreateProxyPeerResponse{Success: true}, nil
}
type mockStatusNotifier struct {
mu sync.Mutex
statuses []statusCall
@@ -65,15 +52,11 @@ func (m *mockStatusNotifier) calls() []statusCall {
// mockNetBird creates a NetBird instance for testing without actually connecting.
// It uses an invalid management URL to prevent real connections.
func mockNetBird() *NetBird {
nb := NewNetBird(context.Background(), "test-proxy", "invalid.test", ClientConfig{
return NewNetBird(context.Background(), "test-proxy", "invalid.test", ClientConfig{
MgmtAddr: "http://invalid.test:9999",
WGPort: 0,
PreSharedKey: "",
}, nil, nil, &mockMgmtClient{})
// Skip the real embed client.Start, which would hang against the unreachable
// mgmt URL and (now that the lifecycle lock spans startup) serialise removes.
nb.startClient = func(types.AccountID, *embed.Client) {}
return nb
}
func TestNetBird_AddPeer_CreatesClientForNewAccount(t *testing.T) {
@@ -305,7 +288,6 @@ func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) {
WGPort: 0,
PreSharedKey: "",
}, nil, notifier, &mockMgmtClient{})
nb.startClient = func(types.AccountID, *embed.Client) {}
accountID := types.AccountID("account-1")
// Add first service — creates a new client entry.
@@ -390,117 +372,6 @@ func TestNetBird_RemovePeer_NotifiesDisconnection(t *testing.T) {
assert.False(t, calls[0].connected)
}
// TestNetBird_RemovePeer_TeardownIsAsync proves the fix for the receive-loop
// stall: RemovePeer must return promptly even when the client teardown blocks,
// because teardown runs off the caller's goroutine. The receive loop calls
// RemovePeer synchronously, so a blocking teardown inline would wedge it.
func TestNetBird_RemovePeer_TeardownIsAsync(t *testing.T) {
nb := NewNetBird(context.Background(), "test-proxy", "invalid.test", ClientConfig{
MgmtAddr: "http://invalid.test:9999",
}, nil, &mockStatusNotifier{}, &mockMgmtClient{})
accountID := types.AccountID("acct-async-teardown")
key := DomainServiceKey("svc.example")
teardownEntered := make(chan struct{})
releaseTeardown := make(chan struct{})
nb.SetClientLifecycle(nil, func(types.AccountID, any) {
close(teardownEntered)
<-releaseTeardown
})
nb.clientsMux.Lock()
nb.clients[accountID] = &clientEntry{
services: map[ServiceKey]serviceInfo{key: {serviceID: types.ServiceID("svc-1")}},
started: true,
inbound: struct{}{},
}
nb.clientsMux.Unlock()
done := make(chan error, 1)
go func() { done <- nb.RemovePeer(context.Background(), accountID, key) }()
select {
case err := <-done:
require.NoError(t, err)
case <-time.After(2 * time.Second):
t.Fatal("RemovePeer did not return while teardown was blocked — teardown is not async")
}
select {
case <-teardownEntered:
case <-time.After(2 * time.Second):
t.Fatal("teardown never ran")
}
close(releaseTeardown)
}
// TestNetBird_AddPeer_WaitsForTeardown proves the lifecycle lock serialises a
// new client bringup behind an in-flight teardown for the same account, so a
// slow client.Stop can never race a new client.Start for that account.
//
// It targets the handoff race specifically: AddPeer is launched immediately
// after RemovePeer returns, WITHOUT waiting for the teardown goroutine to start.
// This only passes if RemovePeer acquires the lifecycle lock synchronously
// (before returning) and hands it to the teardown goroutine — if the goroutine
// acquired the lock itself, AddPeer could win the lock in this window and start
// a replacement client while the old teardown is still pending.
func TestNetBird_AddPeer_WaitsForTeardown(t *testing.T) {
nb := NewNetBird(context.Background(), "test-proxy", "invalid.test", ClientConfig{
MgmtAddr: "http://invalid.test:9999",
}, nil, &mockStatusNotifier{}, &mockMgmtClient{})
nb.startClient = func(types.AccountID, *embed.Client) {}
accountID := types.AccountID("acct-serialize")
key := DomainServiceKey("svc.example")
addEntered := make(chan struct{})
releaseTeardown := make(chan struct{})
nb.SetClientLifecycle(nil, func(types.AccountID, any) {
// Block teardown until released. If AddPeer ever reaches createClientEntry
// (signalled via the mgmt client below) while we hold the lock, the lock
// failed to serialise and the test fails before we release.
<-releaseTeardown
})
nb.clientsMux.Lock()
nb.clients[accountID] = &clientEntry{
services: map[ServiceKey]serviceInfo{key: {serviceID: types.ServiceID("svc-1")}},
started: true,
inbound: struct{}{},
}
nb.clientsMux.Unlock()
// createClientEntry calls CreateProxyPeer; closing addEntered there tells us
// AddPeer got past the lifecycle lock and into client creation.
nb.mgmtClient = &signalMgmtClient{entered: addEntered}
require.NoError(t, nb.RemovePeer(context.Background(), accountID, key))
// Launch AddPeer with NO synchronisation against the teardown goroutine.
addReturned := make(chan struct{})
go func() {
_ = nb.AddPeer(context.Background(), accountID, DomainServiceKey("svc2.example"), "key-2", types.ServiceID("svc-2"))
close(addReturned)
}()
select {
case <-addEntered:
t.Fatal("AddPeer entered client creation while teardown held the lifecycle lock — handoff race not closed")
case <-addReturned:
t.Fatal("AddPeer completed while teardown held the lifecycle lock — not serialised")
case <-time.After(300 * time.Millisecond):
}
close(releaseTeardown)
select {
case <-addReturned:
case <-time.After(2 * time.Second):
t.Fatal("AddPeer never completed after teardown released the lifecycle lock")
}
}
// TestNotifyClientReady_UsesBackgroundCtx pins the contract that the
// post-Start hooks (readyHandler + statusNotifier.NotifyStatus) run on
// a fresh context.Background() rather than inheriting the AddPeer

View File

@@ -114,10 +114,6 @@ type Config struct {
MaxDialTimeout time.Duration
// MaxSessionIdleTimeout caps the per-service session idle timeout.
MaxSessionIdleTimeout time.Duration
// MappingBatchWatchdog bounds how long a single mapping batch may spend
// being applied before the receive loop reconnects to resync. Zero falls
// back to the internal default.
MappingBatchWatchdog time.Duration
// GeoDataDir is the directory containing GeoLite2 MMDB files.
GeoDataDir string
@@ -168,7 +164,6 @@ func New(ctx context.Context, cfg Config) *Server {
Private: cfg.Private,
MaxDialTimeout: cfg.MaxDialTimeout,
MaxSessionIdleTimeout: cfg.MaxSessionIdleTimeout,
MappingBatchWatchdog: cfg.MappingBatchWatchdog,
GeoDataDir: cfg.GeoDataDir,
CrowdSecAPIURL: cfg.CrowdSecAPIURL,
CrowdSecAPIKey: cfg.CrowdSecAPIKey,

View File

@@ -1,282 +0,0 @@
package proxy
import (
"context"
"sync"
"sync/atomic"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
// blockingMgmtClient implements roundtrip's managementClient interface.
// CreateProxyPeer parks until release is closed, signalling entry on entered.
// This reproduces the confirmed real-world stall: createClientEntry calls
// CreateProxyPeer synchronously while holding clientsMux, and the proxy's
// receive loop calls that path synchronously inside processMappings.
type blockingMgmtClient struct {
entered chan struct{}
once sync.Once
}
func (b *blockingMgmtClient) CreateProxyPeer(ctx context.Context, _ *proto.CreateProxyPeerRequest, _ ...grpc.CallOption) (*proto.CreateProxyPeerResponse, error) {
b.once.Do(func() { close(b.entered) })
// Park until the caller's context is cancelled. In production this ctx is
// the gRPC mapping-stream context with no per-call timeout, so a slow or
// unresponsive CreateProxyPeer parks the receive loop here indefinitely.
<-ctx.Done()
return nil, ctx.Err()
}
// gatedMappingStream is a mock GetMappingUpdate client stream that hands out a
// pre-seeded list of messages, then records how many times Recv advanced. It
// lets the test observe whether the single-threaded receive loop ever gets
// past the first (blocking) batch to pull the second message.
type gatedMappingStream struct {
grpc.ClientStream
messages []*proto.GetMappingUpdateResponse
idx int32
}
func (g *gatedMappingStream) Recv() (*proto.GetMappingUpdateResponse, error) {
i := int(atomic.LoadInt32(&g.idx))
if i >= len(g.messages) {
// Block instead of returning EOF so the loop doesn't exit; we only
// care whether the loop ever reaches this second Recv at all.
select {}
}
msg := g.messages[i]
atomic.AddInt32(&g.idx, 1)
return msg, nil
}
func (g *gatedMappingStream) deliveredCount() int32 { return atomic.LoadInt32(&g.idx) }
func (g *gatedMappingStream) Header() (metadata.MD, error) { return nil, nil } //nolint:nilnil
func (g *gatedMappingStream) Trailer() metadata.MD { return nil }
func (g *gatedMappingStream) CloseSend() error { return nil }
func (g *gatedMappingStream) Context() context.Context { return context.Background() }
func (g *gatedMappingStream) SendMsg(any) error { return nil }
func (g *gatedMappingStream) RecvMsg(any) error { return nil }
// noopNotifier satisfies roundtrip's statusNotifier interface.
type noopNotifier struct{}
func (noopNotifier) NotifyStatus(context.Context, types.AccountID, types.ServiceID, bool) error {
return nil
}
// noopProxyClient is a proto.ProxyServiceClient that no-ops the one method the
// teardown unwind reaches (SendStatusUpdate, via notifyError when the parked
// AddPeer is cancelled). The embedded nil interface satisfies the rest at
// compile time; none of those methods are called by this test.
type noopProxyClient struct {
proto.ProxyServiceClient
}
func (noopProxyClient) SendStatusUpdate(context.Context, *proto.SendStatusUpdateRequest, ...grpc.CallOption) (*proto.SendStatusUpdateResponse, error) {
return &proto.SendStatusUpdateResponse{}, nil
}
// TestMappingStream_StallsWhenApplyBlocks proves the deadlock: the proxy's
// mapping receive loop processes batches strictly serially, so when applying
// one batch blocks (here: createClientEntry parked on a synchronous
// CreateProxyPeer call, exactly as observed in production), the loop never
// advances to Recv the next batch. Management can keep sending updates onto
// the stream with no error and no channel overflow, yet the proxy applies
// nothing further — it is stuck.
func TestMappingStream_StallsWhenApplyBlocks(t *testing.T) {
logger := log.New()
logger.SetLevel(log.PanicLevel)
mgmt := &blockingMgmtClient{
entered: make(chan struct{}),
}
nb := roundtrip.NewNetBird(
context.Background(),
"proxy-test",
"proxy.example.com",
roundtrip.ClientConfig{},
logger,
noopNotifier{},
mgmt,
)
s := &Server{
Logger: logger,
netbird: nb,
mgmtClient: noopProxyClient{},
routerReady: closedChan(),
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
}
// First batch: a CREATED mapping for a brand-new account. addMapping ->
// netbird.AddPeer -> createClientEntry -> CreateProxyPeer, which blocks.
// Empty Path keeps setupHTTPMapping a no-op (it returns early), so the
// ONLY blocking point is the synchronous CreateProxyPeer in AddPeer —
// no routers/auth need wiring. The second batch exists only to detect
// whether the loop ever advances past the blocked first batch.
stream := &gatedMappingStream{
messages: []*proto.GetMappingUpdateResponse{
{
Mapping: []*proto.ProxyMapping{
{
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
Id: "svc-1",
AccountId: "acct-1",
AuthToken: "token-1",
},
},
},
{
Mapping: []*proto.ProxyMapping{
{
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
Id: "svc-2",
AccountId: "acct-2",
AuthToken: "token-2",
},
},
},
},
}
ctx, cancel := context.WithCancel(context.Background())
// Unblock the parked apply on teardown via ctx (CreateProxyPeer returns
// ctx.Err()), so the wedged loop goroutine unwinds before embed.New —
// avoiding any dependency on collaborators this test deliberately leaves
// nil. The deadlock is fully proven before this fires.
t.Cleanup(cancel)
loopDone := make(chan struct{})
syncDone := false
go func() {
defer close(loopDone)
_ = s.handleMappingStream(ctx, stream, &syncDone, time.Time{})
}()
// The loop must reach the blocking apply for the first batch.
select {
case <-mgmt.entered:
case <-time.After(2 * time.Second):
t.Fatal("receive loop never reached CreateProxyPeer for the first batch")
}
// THE DEADLOCK: while the first batch is parked in CreateProxyPeer, the
// single-threaded loop cannot advance. The second batch is never pulled,
// even though it is already available on the stream. Give it ample time.
// deliveredCount is atomic; syncDone is intentionally not read here because
// the loop goroutine owns it (reading it from the test would race).
time.Sleep(500 * time.Millisecond)
assert.Equal(t, int32(1), stream.deliveredCount(),
"loop must NOT consume the second batch while the first is blocked in apply — proxy is stuck")
select {
case <-loopDone:
t.Fatal("receive loop returned while it should be wedged in apply")
default:
// Still wedged, as expected.
}
}
// TestMappingStream_StallsWhenRemoveBlocks proves the deadlock for the REMOVE
// path observed in production: a mapping remove tears down the account's last
// embedded client via netbird.RemovePeer -> client.Stop -> Engine.Stop, whose
// jobExecutorWG.Wait() is unbounded. Because the receive loop is single-
// threaded, a blocked remove wedges the loop: no further mapping updates of any
// kind (create/modify/remove) are applied, while management keeps sending them
// successfully (no send error, no channel-full). Matches the reported symptom:
// the last log line is a remove that stops a client, then silence.
func TestMappingStream_StallsWhenRemoveBlocks(t *testing.T) {
logger := log.New()
logger.SetLevel(log.PanicLevel)
enteredRemove := make(chan struct{})
blockRemove := make(chan struct{})
var once sync.Once
s := &Server{
Logger: logger,
mgmtClient: noopProxyClient{},
routerReady: closedChan(),
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
// Stand in for netbird.RemovePeer -> client.Stop hanging on
// Engine.Stop's unbounded jobExecutorWG.Wait(). Only the first remove
// blocks; later removes return immediately so the recovery assertion
// can observe the loop advancing.
removePeer: func(ctx context.Context, _ types.AccountID, _ roundtrip.ServiceKey) error {
first := false
once.Do(func() {
first = true
close(enteredRemove)
})
if !first {
return nil
}
select {
case <-blockRemove:
case <-ctx.Done():
}
return nil
},
}
// Batch 1 removes a service (blocks in teardown). Batch 2 is a later update
// that must never be applied while the remove is wedged.
stream := &gatedMappingStream{
messages: []*proto.GetMappingUpdateResponse{
{
Mapping: []*proto.ProxyMapping{
{Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, Id: "svc-1", AccountId: "acct-1"},
},
},
{
Mapping: []*proto.ProxyMapping{
{Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, Id: "svc-2", AccountId: "acct-1"},
},
},
},
}
loopDone := make(chan struct{})
syncDone := false
go func() {
defer close(loopDone)
_ = s.handleMappingStream(context.Background(), stream, &syncDone, time.Time{})
}()
select {
case <-enteredRemove:
case <-time.After(2 * time.Second):
t.Fatal("receive loop never reached the blocking remove for the first batch")
}
// THE DEADLOCK: the loop is parked in the blocked remove and cannot advance.
// syncDone is owned by the loop goroutine, so it is not read here.
time.Sleep(500 * time.Millisecond)
assert.Equal(t, int32(1), stream.deliveredCount(),
"loop must NOT consume the second batch while the first remove is blocked — proxy is stuck")
select {
case <-loopDone:
t.Fatal("receive loop returned while it should be wedged on the remove")
default:
}
// Unblock and confirm the wedge was solely the blocked remove: the loop
// then advances and consumes the next batch.
close(blockRemove)
assert.Eventually(t, func() bool {
return stream.deliveredCount() >= 2
}, 2*time.Second, 5*time.Millisecond,
"once the remove unblocks, the loop must advance and consume the next batch")
}

View File

@@ -24,7 +24,6 @@ import (
"time"
"github.com/cenkalti/backoff/v4"
"github.com/google/uuid"
"github.com/pires/go-proxyproto"
prometheus2 "github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
@@ -76,30 +75,29 @@ type portRouter struct {
}
type Server struct {
ctx context.Context
mgmtClient proto.ProxyServiceClient
proxy *proxy.ReverseProxy
netbird *roundtrip.NetBird
acme *acme.Manager
staticCertWatcher *certwatch.Watcher
auth *auth.Middleware
http *http.Server
https *http.Server
debug *http.Server
healthServer *health.Server
healthChecker *health.Checker
meter *proxymetrics.Metrics
accessLog *accesslog.Logger
mainRouter *nbtcp.Router
mainPort uint16
udpMu sync.Mutex
udpRelays map[types.ServiceID]*udprelay.Relay
udpRelayWg sync.WaitGroup
portMu sync.RWMutex
portRouters map[uint16]*portRouter
svcPorts map[types.ServiceID][]uint16
lastMappings map[types.ServiceID]*proto.ProxyMapping
portRouterWg sync.WaitGroup
ctx context.Context
mgmtClient proto.ProxyServiceClient
proxy *proxy.ReverseProxy
netbird *roundtrip.NetBird
acme *acme.Manager
auth *auth.Middleware
http *http.Server
https *http.Server
debug *http.Server
healthServer *health.Server
healthChecker *health.Checker
meter *proxymetrics.Metrics
accessLog *accesslog.Logger
mainRouter *nbtcp.Router
mainPort uint16
udpMu sync.Mutex
udpRelays map[types.ServiceID]*udprelay.Relay
udpRelayWg sync.WaitGroup
portMu sync.RWMutex
portRouters map[uint16]*portRouter
svcPorts map[types.ServiceID][]uint16
lastMappings map[types.ServiceID]*proto.ProxyMapping
portRouterWg sync.WaitGroup
// hijackTracker tracks hijacked connections (e.g. WebSocket upgrades)
// so they can be closed during graceful shutdown, since http.Server.Shutdown
@@ -120,9 +118,6 @@ type Server struct {
// The mapping worker waits on this before processing updates.
routerReady chan struct{}
// removePeer defaults to netbird.RemovePeer; overridable in tests.
removePeer func(ctx context.Context, accountID types.AccountID, key roundtrip.ServiceKey) error
// inbound, when non-nil, manages per-account inbound listeners. Set by
// initPrivateInbound only when Private is true so the standalone
// proxy keeps its zero-overhead default path.
@@ -232,10 +227,6 @@ type Server struct {
// Zero means no cap (the proxy honors whatever management sends).
// Set via NB_PROXY_MAX_SESSION_IDLE_TIMEOUT for shared deployments.
MaxSessionIdleTimeout time.Duration
// MappingBatchWatchdog bounds how long a single mapping batch may spend
// in processMappings before the receive loop reconnects to resync.
// Zero uses defaultMappingBatchWatchdog.
MappingBatchWatchdog time.Duration
}
// clampIdleTimeout returns d capped to MaxSessionIdleTimeout when configured.
@@ -616,7 +607,7 @@ func (s *Server) initDefaults() {
// If no ID is set then one can be generated.
if s.ID == "" {
s.ID = fmt.Sprintf("netbird-proxy-%s", uuid.NewString())
s.ID = "netbird-proxy-" + s.startTime.Format("20060102150405")
}
// Fallback version option in case it is not set.
if s.Version == "" {
@@ -794,7 +785,6 @@ func (s *Server) configureTLS(ctx context.Context) (*tls.Config, error) {
return nil, fmt.Errorf("initialize certificate watcher: %w", err)
}
go certWatcher.Watch(ctx)
s.staticCertWatcher = certWatcher
tlsConfig.GetCertificate = certWatcher.GetCertificate
return tlsConfig, nil
}
@@ -1182,30 +1172,24 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr
s.healthChecker.SetManagementConnected(false)
}
connected := false
onConnected := func() { connected = true }
var streamErr error
if syncSupported {
streamErr = s.trySyncMappings(ctx, client, &initialSyncDone, onConnected)
streamErr = s.trySyncMappings(ctx, client, &initialSyncDone)
if isSyncUnimplemented(streamErr) {
syncSupported = false
s.Logger.Info("management does not support SyncMappings, falling back to GetMappingUpdate")
streamErr = s.tryGetMappingUpdate(ctx, client, &initialSyncDone, onConnected)
streamErr = s.tryGetMappingUpdate(ctx, client, &initialSyncDone)
}
} else {
streamErr = s.tryGetMappingUpdate(ctx, client, &initialSyncDone, onConnected)
streamErr = s.tryGetMappingUpdate(ctx, client, &initialSyncDone)
}
if s.healthChecker != nil {
s.healthChecker.SetManagementConnected(false)
}
// Reset backoff only when a stream actually connected, so immediate
// connect failures still back off instead of spinning.
if connected {
bo.Reset()
}
// Stream established — reset backoff so the next failure retries quickly.
bo.Reset()
if streamErr == nil {
return fmt.Errorf("stream closed by server")
@@ -1237,7 +1221,7 @@ func (s *Server) proxyCapabilities() *proto.ProxyCapabilities {
}
}
func (s *Server) tryGetMappingUpdate(ctx context.Context, client proto.ProxyServiceClient, initialSyncDone *bool, onConnected func()) error {
func (s *Server) tryGetMappingUpdate(ctx context.Context, client proto.ProxyServiceClient, initialSyncDone *bool) error {
connectTime := time.Now()
mappingClient, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
ProxyId: s.ID,
@@ -1250,7 +1234,6 @@ func (s *Server) tryGetMappingUpdate(ctx context.Context, client proto.ProxyServ
return fmt.Errorf("create mapping stream: %w", err)
}
onConnected()
if s.healthChecker != nil {
s.healthChecker.SetManagementConnected(true)
}
@@ -1259,7 +1242,7 @@ func (s *Server) tryGetMappingUpdate(ctx context.Context, client proto.ProxyServ
return s.handleMappingStream(ctx, mappingClient, initialSyncDone, connectTime)
}
func (s *Server) trySyncMappings(ctx context.Context, client proto.ProxyServiceClient, initialSyncDone *bool, onConnected func()) error {
func (s *Server) trySyncMappings(ctx context.Context, client proto.ProxyServiceClient, initialSyncDone *bool) error {
connectTime := time.Now()
stream, err := client.SyncMappings(ctx)
if err != nil {
@@ -1280,7 +1263,6 @@ func (s *Server) trySyncMappings(ctx context.Context, client proto.ProxyServiceC
return fmt.Errorf("send sync init: %w", err)
}
onConnected()
if s.healthChecker != nil {
s.healthChecker.SetManagementConnected(true)
}
@@ -1325,9 +1307,7 @@ func (s *Server) handleSyncMappingsStream(ctx context.Context, stream proto.Prox
batchStart := time.Now()
s.Logger.Debug("Received mapping update, starting processing")
if err := s.processMappingsGuarded(ctx, msg.GetMapping()); err != nil {
return err
}
s.processMappings(ctx, msg.GetMapping())
s.Logger.Debug("Processing mapping update completed")
tracker.recordBatch(ctx, s, msg.GetMapping(), msg.GetInitialSyncComplete(), batchStart)
@@ -1411,9 +1391,7 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr
batchStart := time.Now()
s.Logger.Debug("Received mapping update, starting processing")
if err := s.processMappingsGuarded(ctx, msg.GetMapping()); err != nil {
return err
}
s.processMappings(ctx, msg.GetMapping())
s.Logger.Debug("Processing mapping update completed")
tracker.recordBatch(ctx, s, msg.GetMapping(), msg.GetInitialSyncComplete(), batchStart)
}
@@ -1478,44 +1456,6 @@ func redactMappingForLog(m *proto.ProxyMapping) *proto.ProxyMapping {
return c
}
const defaultMappingBatchWatchdog = 2 * time.Minute
// mappingBatchWatchdog returns the configured batch watchdog or the default.
func (s *Server) mappingBatchWatchdog() time.Duration {
if s.MappingBatchWatchdog > 0 {
return s.MappingBatchWatchdog
}
return defaultMappingBatchWatchdog
}
// processMappingsGuarded applies a batch under a watchdog, returning an error
// if processing exceeds the watchdog so the caller reconnects and resyncs
// instead of wedging silently.
func (s *Server) processMappingsGuarded(ctx context.Context, mappings []*proto.ProxyMapping) error {
batchCtx, cancel := context.WithCancel(ctx)
defer cancel()
done := make(chan struct{})
go func() {
defer close(done)
s.processMappings(batchCtx, mappings)
}()
watchdog := s.mappingBatchWatchdog()
timer := time.NewTimer(watchdog)
defer timer.Stop()
select {
case <-done:
return nil
case <-ctx.Done():
return ctx.Err()
case <-timer.C:
s.Logger.Errorf("processing mapping batch exceeded %s, cancelling and reconnecting to resync", watchdog)
return fmt.Errorf("mapping batch processing stalled after %s", watchdog)
}
}
func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMapping) {
debug := s.Logger != nil && s.Logger.IsLevelEnabled(log.DebugLevel)
for _, mapping := range mappings {
@@ -1626,8 +1566,6 @@ func (s *Server) setupHTTPMapping(ctx context.Context, mapping *proto.ProxyMappi
var wildcardHit bool
if s.acme != nil {
wildcardHit = s.acme.AddDomain(d, accountID, svcID)
} else {
wildcardHit = s.staticCertCovers(d)
}
httpRoute := nbtcp.Route{
Type: nbtcp.RouteHTTP,
@@ -1652,26 +1590,6 @@ func (s *Server) setupHTTPMapping(ctx context.Context, mapping *proto.ProxyMappi
return nil
}
// staticCertCovers reports whether the static certificate loaded when ACME is
// disabled covers the given domain, making it certificate-ready immediately —
// the equivalent of a wildcard hit in the ACME path. Domains the certificate
// does not cover are logged: clients connecting to them will get TLS errors.
func (s *Server) staticCertCovers(d domain.Domain) bool {
if s.staticCertWatcher == nil {
return false
}
leaf := s.staticCertWatcher.Leaf()
if leaf == nil {
return false
}
name := d.PunycodeString()
if err := leaf.VerifyHostname(name); err != nil {
s.Logger.Warnf("static certificate (SANs %v) does not cover domain %q: %v", leaf.DNSNames, name, err)
return false
}
return true
}
// setupTCPMapping sets up a TCP port-forwarding fallback route on the listen port.
func (s *Server) setupTCPMapping(ctx context.Context, mapping *proto.ProxyMapping) error {
svcID := types.ServiceID(mapping.GetId())
@@ -2033,11 +1951,7 @@ func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping)
func (s *Server) removeMapping(ctx context.Context, mapping *proto.ProxyMapping) {
accountID := types.AccountID(mapping.GetAccountId())
svcKey := s.serviceKeyForMapping(mapping)
removePeer := s.removePeer
if removePeer == nil {
removePeer = s.netbird.RemovePeer
}
if err := removePeer(ctx, accountID, svcKey); err != nil {
if err := s.netbird.RemovePeer(ctx, accountID, svcKey); err != nil {
s.Logger.WithFields(log.Fields{
"account_id": accountID,
"service_id": mapping.GetId(),

View File

@@ -1,89 +0,0 @@
package proxy
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/proxy/internal/certwatch"
"github.com/netbirdio/netbird/shared/management/domain"
)
func generateCertWithSANs(t *testing.T, dnsNames []string) (certPEM, keyPEM []byte) {
t.Helper()
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
template := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: dnsNames[0]},
DNSNames: dnsNames,
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(24 * time.Hour),
}
certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
require.NoError(t, err)
certPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
keyDER, err := x509.MarshalECPrivateKey(key)
require.NoError(t, err)
keyPEM = pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
return certPEM, keyPEM
}
func newStaticWatcher(t *testing.T, dnsNames []string) *certwatch.Watcher {
t.Helper()
dir := t.TempDir()
certPEM, keyPEM := generateCertWithSANs(t, dnsNames)
certPath := filepath.Join(dir, "tls.crt")
keyPath := filepath.Join(dir, "tls.key")
require.NoError(t, os.WriteFile(certPath, certPEM, 0o600))
require.NoError(t, os.WriteFile(keyPath, keyPEM, 0o600))
w, err := certwatch.NewWatcher(certPath, keyPath, quietLifecycleLogger())
require.NoError(t, err)
return w
}
func TestStaticCertCovers(t *testing.T) {
s := &Server{
Logger: quietLifecycleLogger(),
staticCertWatcher: newStaticWatcher(t, []string{"*.p.example.com", "exact.example.com"}),
}
cases := []struct {
domain string
covered bool
}{
{"svc.p.example.com", true},
{"exact.example.com", true},
{"a.b.p.example.com", false}, // wildcard does not span labels
{"p.example.com", false},
{"other.example.com", false},
}
for _, tc := range cases {
t.Run(tc.domain, func(t *testing.T) {
assert.Equal(t, tc.covered, s.staticCertCovers(domain.Domain(tc.domain)))
})
}
}
func TestStaticCertCoversNoWatcher(t *testing.T) {
s := &Server{Logger: quietLifecycleLogger()}
assert.False(t, s.staticCertCovers(domain.Domain("svc.p.example.com")))
}

View File

@@ -6,5 +6,4 @@ const (
RoleKey = "role"
UserIDKey = "userID"
PeerIDKey = "peerID"
UserAgentKey = "userAgent"
)

View File

@@ -322,21 +322,15 @@ func TestClient_Sync(t *testing.T) {
if resp.GetNetbirdConfig() == nil {
t.Error("expecting non nil NetbirdConfig got nil")
}
// we test network map peers from 0.29.3 and dev builds
if len(resp.GetRemotePeers()) != 0 {
t.Error("expecting top-level RemotePeers to be empty for v0.29.3+ clients")
}
networkMap := resp.GetNetworkMap()
if len(networkMap.GetRemotePeers()) != 1 {
t.Errorf("expecting RemotePeers size %d got %d", 1, len(networkMap.GetRemotePeers()))
if len(resp.GetRemotePeers()) != 1 {
t.Errorf("expecting RemotePeers size %d got %d", 1, len(resp.GetRemotePeers()))
return
}
if networkMap.GetRemotePeersIsEmpty() {
if resp.GetRemotePeersIsEmpty() == true {
t.Error("expecting RemotePeers property to be false, got true")
}
if networkMap.GetRemotePeers()[0].GetWgPubKey() != remoteKey.PublicKey().String() {
t.Errorf("expecting RemotePeer public key %s got %s", remoteKey.PublicKey().String(), networkMap.GetRemotePeers()[0].GetWgPubKey())
if resp.GetRemotePeers()[0].GetWgPubKey() != remoteKey.PublicKey().String() {
t.Errorf("expecting RemotePeer public key %s got %s", remoteKey.PublicKey().String(), resp.GetRemotePeers()[0].GetWgPubKey())
}
case <-time.After(3 * time.Second):
t.Error("timeout waiting for test to finish")

View File

@@ -5107,63 +5107,31 @@ components:
responses:
not_found:
description: Resource not found
headers:
X-Request-Id:
$ref: '#/components/headers/X-Request-Id'
content: { }
validation_failed_simple:
description: Validation failed
headers:
X-Request-Id:
$ref: '#/components/headers/X-Request-Id'
content: { }
bad_request:
description: Bad Request
headers:
X-Request-Id:
$ref: '#/components/headers/X-Request-Id'
content: { }
internal_error:
description: Internal Server Error
headers:
X-Request-Id:
$ref: '#/components/headers/X-Request-Id'
content: { }
validation_failed:
description: Validation failed
headers:
X-Request-Id:
$ref: '#/components/headers/X-Request-Id'
content: { }
forbidden:
description: Forbidden
headers:
X-Request-Id:
$ref: '#/components/headers/X-Request-Id'
content: { }
requires_authentication:
description: Requires authentication
headers:
X-Request-Id:
$ref: '#/components/headers/X-Request-Id'
content: { }
conflict:
description: Conflict
headers:
X-Request-Id:
$ref: '#/components/headers/X-Request-Id'
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
headers:
X-Request-Id:
description: |
Unique identifier assigned to the request by the server and set on every
response. Useful for correlating client requests with server-side logs.
schema:
type: string
example: cot7r4n3l3vh3qj4qveg
securitySchemes:
BearerAuth:
type: http

View File

@@ -9,14 +9,12 @@ import (
"net/url"
"strings"
"sync"
"sync/atomic"
"time"
log "github.com/sirupsen/logrus"
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
"github.com/netbirdio/netbird/shared/relay/client/dialer"
netErr "github.com/netbirdio/netbird/shared/relay/client/dialer/net"
"github.com/netbirdio/netbird/shared/relay/healthcheck"
"github.com/netbirdio/netbird/shared/relay/messages"
)
@@ -174,19 +172,6 @@ type Client struct {
stateSubscription *PeersStateSubscription
mtu uint16
// transportFallback, when set, records datagram-too-large failures so a
// datagram-sized transport is avoided on subsequent connects. Shared via
// the manager.
transportFallback *transportFallback
// datagramFallbackTriggered guards a single fallback per connection so a
// burst of oversized datagrams triggers one reconnect, not many.
datagramFallbackTriggered atomic.Bool
}
// SetTransportFallback wires the shared datagram-transport fallback tracker.
func (c *Client) SetTransportFallback(tf *transportFallback) {
c.transportFallback = tf
}
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
@@ -376,13 +361,12 @@ func (c *Client) Close() error {
}
func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
mode := transportModeFromEnv()
dialers := c.getDialers(mode)
dialers := c.getDialers()
var conn net.Conn
if c.serverIP.IsValid() {
var err error
conn, err = c.dialRaceDirect(ctx, mode, dialers)
conn, err = c.dialRaceDirect(ctx, dialers)
if err != nil {
c.log.Infof("dial via server IP %s failed, falling back to FQDN: %v", c.serverIP, err)
conn = nil
@@ -391,9 +375,6 @@ func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
if conn == nil {
rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, dialers...)
if mode.sequential() {
rd.WithSequential()
}
var err error
conn, err = rd.Dial(ctx)
if err != nil {
@@ -401,7 +382,6 @@ func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
}
}
c.relayConn = conn
c.datagramFallbackTriggered.Store(false)
instanceURL, err := c.handShake(ctx)
if err != nil {
@@ -416,7 +396,7 @@ func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
}
// dialRaceDirect dials c.serverIP, preserving the original FQDN as the TLS ServerName for SNI.
func (c *Client) dialRaceDirect(ctx context.Context, mode TransportMode, dialers []dialer.DialeFn) (net.Conn, error) {
func (c *Client) dialRaceDirect(ctx context.Context, dialers []dialer.DialeFn) (net.Conn, error) {
directURL, serverName, err := substituteHost(c.connectionURL, c.serverIP)
if err != nil {
return nil, fmt.Errorf("substitute host: %w", err)
@@ -426,9 +406,6 @@ func (c *Client) dialRaceDirect(ctx context.Context, mode TransportMode, dialers
rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, directURL, dialers...).
WithServerName(serverName)
if mode.sequential() {
rd.WithSequential()
}
return rd.Dial(ctx)
}
@@ -654,53 +631,13 @@ func (c *Client) writeTo(containerRef *connContainer, dstID messages.PeerID, pay
}
// the write always return with 0 length because the underling does not support the size feedback.
conn := c.relayConn
_, err = conn.Write(msg)
_, err = c.relayConn.Write(msg)
if err != nil {
if errors.Is(err, netErr.ErrDatagramTooLarge) {
c.onDatagramTooLarge(conn, err)
} else {
c.log.Errorf("failed to write transport message: %s", err)
}
c.log.Errorf("failed to write transport message: %s", err)
}
return len(payload), err
}
// onDatagramTooLarge reacts to a datagram rejected as too large for the path.
// When a non-datagram transport is available, it records a fallback for this
// server and closes the connection so the reconnect avoids datagram-sized
// transports. A single fallback is triggered per connection regardless of how
// many oversized datagrams arrive. cause carries the datagram size and budget.
func (c *Client) onDatagramTooLarge(conn net.Conn, cause error) {
// Handle one oversized datagram per connection; a burst triggers a single
// fallback (and a single log line), not many.
if !c.datagramFallbackTriggered.CompareAndSwap(false, true) {
return
}
// If the selected mode offers no non-datagram transport (e.g. pinned to a
// datagram-sized transport), reconnecting would just re-fail, so leave the
// connection up rather than loop.
if len(nonDatagramSized(c.baseDialers(transportModeFromEnv()))) == 0 {
c.log.Warnf("%s, but no non-datagram transport is available, not falling back", cause)
return
}
// Without the shared tracker a reconnect would just select the same
// transport again and re-fail, so leave the connection up rather than loop.
if c.transportFallback == nil {
c.log.Debugf("%s, but no transport fallback configured, leaving connection up", cause)
return
}
window := c.transportFallback.recordFailure(c.connectionURL)
c.log.Warnf("%s, avoiding datagram-sized transport for %s", cause, window)
if err := conn.Close(); err != nil {
c.log.Debugf("close relay connection for transport fallback: %s", err)
}
}
func (c *Client) listenForStopEvents(ctx context.Context, hc *healthcheck.Receiver, conn net.Conn, internalStopFlag *internalStopFlag) {
for {
select {

View File

@@ -1,18 +0,0 @@
package dialer
// DatagramSized is implemented by dialers whose connections carry each write in
// a single datagram, so a write can be rejected when it exceeds the path's
// datagram budget (e.g. QUIC). Transports without this capability (e.g.
// WebSocket over TCP) impose no per-write size limit, so the relay client can
// fall back to them when a datagram-sized transport rejects a write as too
// large. The capability is advertised per dialer rather than hardcoded, so a
// new transport only needs to declare whether it is datagram-sized.
type DatagramSized interface {
DatagramSized()
}
// IsDatagramSized reports whether d produces datagram-sized connections.
func IsDatagramSized(d DialeFn) bool {
_, ok := d.(DatagramSized)
return ok
}

View File

@@ -4,9 +4,4 @@ import "errors"
var (
ErrClosedByServer = errors.New("closed by server")
// ErrDatagramTooLarge is returned when a transport message exceeds the
// QUIC datagram size the path to the relay can carry. The relay client
// treats it as a signal to fall back to a non-datagram transport.
ErrDatagramTooLarge = errors.New("datagram frame too large")
)

View File

@@ -8,6 +8,7 @@ import (
"time"
"github.com/quic-go/quic-go"
log "github.com/sirupsen/logrus"
netErr "github.com/netbirdio/netbird/shared/relay/client/dialer/net"
)
@@ -51,8 +52,11 @@ func (c *Conn) Read(b []byte) (n int, err error) {
}
func (c *Conn) Write(b []byte) (int, error) {
if err := c.session.SendDatagram(b); err != nil {
return 0, c.writeErrHandling(err, len(b))
err := c.session.SendDatagram(b)
if err != nil {
err = c.remoteCloseErrHandling(err)
log.Errorf("failed to write to QUIC stream: %v", err)
return 0, err
}
return len(b), nil
}
@@ -91,15 +95,3 @@ func (c *Conn) remoteCloseErrHandling(err error) error {
}
return err
}
// writeErrHandling normalizes SendDatagram errors. A datagram that exceeds the
// path's QUIC packet budget is mapped to ErrDatagramTooLarge (annotated with the
// datagram size and path budget) so the relay client can fall back to a
// non-datagram transport.
func (c *Conn) writeErrHandling(err error, size int) error {
var tooLarge *quic.DatagramTooLargeError
if errors.As(err, &tooLarge) {
return fmt.Errorf("%w: %d byte datagram over path budget %d", netErr.ErrDatagramTooLarge, size, tooLarge.MaxDatagramPayloadSize)
}
return c.remoteCloseErrHandling(err)
}

View File

@@ -9,7 +9,6 @@ import (
"time"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/logging"
log "github.com/sirupsen/logrus"
nbnet "github.com/netbirdio/netbird/client/net"
@@ -24,12 +23,6 @@ func (d Dialer) Protocol() string {
return Network
}
// DatagramSized marks QUIC as a datagram-sized transport: relay traffic is
// carried in QUIC DATAGRAM frames, which must fit a single packet.
func (d Dialer) DatagramSized() {
// Intentional marker method; presence is the capability signal.
}
func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn, error) {
quicURL, err := prepareURL(address)
if err != nil {
@@ -54,7 +47,6 @@ func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn,
MaxIdleTimeout: 4 * time.Minute,
EnableDatagrams: true,
InitialPacketSize: nbRelay.QUICInitialPacketSize,
Tracer: connectionTracer(quicURL),
}
udpConn, err := nbnet.ListenUDP("udp", &net.UDPAddr{Port: 0})
@@ -82,28 +74,6 @@ func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn,
return conn, nil
}
// connectionTracer returns a QUIC tracer that logs the DPLPMTUD result and the
// reason a relay connection closed, so the path MTU settled on and teardown
// cause are visible in logs. Lines carry the relay address as a structured
// field, matching the rest of the relay client logging.
func connectionTracer(addr string) func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
relayLog := log.WithField("relay", addr)
return func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
return &logging.ConnectionTracer{
UpdatedMTU: func(mtu logging.ByteCount, done bool) {
if done {
relayLog.Infof("QUIC path MTU settled at %d", mtu)
return
}
relayLog.Debugf("QUIC path MTU probing at %d", mtu)
},
ClosedConnection: func(err error) {
relayLog.Debugf("QUIC connection closed: %v", err)
},
}
}
}
func prepareURL(address string) (string, error) {
var host string
var defaultPort string

View File

@@ -32,7 +32,6 @@ type RaceDial struct {
serverName string
dialerFns []DialeFn
connectionTimeout time.Duration
sequential bool
}
func NewRaceDial(log *log.Entry, connectionTimeout time.Duration, serverURL string, dialerFns ...DialeFn) *RaceDial {
@@ -54,21 +53,7 @@ func (r *RaceDial) WithServerName(serverName string) *RaceDial {
return r
}
// WithSequential makes Dial try the dialers in order, falling back to the next
// only when one fails to connect, instead of racing them concurrently.
//
// Mutates the receiver and is not safe for concurrent reconfiguration; a
// RaceDial is intended to be constructed per dial and discarded.
func (r *RaceDial) WithSequential() *RaceDial {
r.sequential = true
return r
}
func (r *RaceDial) Dial(ctx context.Context) (net.Conn, error) {
if r.sequential {
return r.dialSequential(ctx)
}
connChan := make(chan dialResult, len(r.dialerFns))
winnerConn := make(chan net.Conn, 1)
abortCtx, abort := context.WithCancel(ctx)
@@ -87,30 +72,6 @@ func (r *RaceDial) Dial(ctx context.Context) (net.Conn, error) {
return conn, nil
}
// dialSequential tries each dialer in order, returning the first connection and
// falling back to the next on failure.
func (r *RaceDial) dialSequential(ctx context.Context) (net.Conn, error) {
for _, dfn := range r.dialerFns {
if err := ctx.Err(); err != nil {
return nil, err
}
attemptCtx, cancel := context.WithTimeout(ctx, r.connectionTimeout)
r.log.Infof("dialing Relay server via %s", dfn.Protocol())
conn, err := dfn.Dial(attemptCtx, r.serverURL, r.serverName)
cancel()
if err != nil {
if errors.Is(err, context.Canceled) {
return nil, err
}
r.log.Errorf("failed to dial via %s: %s", dfn.Protocol(), err)
continue
}
r.log.Infof("successfully dialed via: %s", dfn.Protocol())
return conn, nil
}
return nil, errors.New("failed to dial to Relay server on any protocol")
}
func (r *RaceDial) dial(dfn DialeFn, abortCtx context.Context, connChan chan dialResult) {
ctx, cancel := context.WithTimeout(abortCtx, r.connectionTimeout)
defer cancel()

View File

@@ -250,66 +250,3 @@ func TestRaceDialFirstSuccessfulDialerWins(t *testing.T) {
}
}
}
func TestRaceDialSequentialFallback(t *testing.T) {
logger := logrus.NewEntry(logrus.New())
serverURL := "test.server.com"
var firstDialed, secondDialed bool
preferred := &MockDialer{
protocolStr: "quic",
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
firstDialed = true
return nil, errors.New("quic unreachable")
},
}
fallbackConn := &MockConn{remoteAddr: &MockAddr{network: "ws"}}
fallback := &MockDialer{
protocolStr: "ws",
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
secondDialed = true
return fallbackConn, nil
},
}
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, preferred, fallback).WithSequential()
conn, err := rd.Dial(context.Background())
if err != nil {
t.Fatalf("expected fallback to succeed, got %v", err)
}
if conn != fallbackConn {
t.Errorf("expected fallback connection, got %v", conn)
}
if !firstDialed || !secondDialed {
t.Errorf("expected both dialers attempted in order, first=%v second=%v", firstDialed, secondDialed)
}
}
func TestRaceDialSequentialPreferredWins(t *testing.T) {
logger := logrus.NewEntry(logrus.New())
serverURL := "test.server.com"
preferredConn := &MockConn{remoteAddr: &MockAddr{network: "quic"}}
preferred := &MockDialer{
protocolStr: "quic",
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
return preferredConn, nil
},
}
fallback := &MockDialer{
protocolStr: "ws",
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
t.Errorf("fallback dialer must not be tried when preferred succeeds")
return nil, errors.New("should not happen")
},
}
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, preferred, fallback).WithSequential()
conn, err := rd.Dial(context.Background())
if err != nil {
t.Fatalf("expected preferred to succeed, got %v", err)
}
if conn != preferredConn {
t.Errorf("expected preferred connection, got %v", conn)
}
}

View File

@@ -9,42 +9,11 @@ import (
"github.com/netbirdio/netbird/shared/relay/client/dialer/ws"
)
// getDialers returns the ordered dialers for connecting to the relay server. It
// applies the datagram fallback generically: if this server recently rejected a
// datagram-sized transport, those dialers are dropped, leaving the rest.
func (c *Client) getDialers(mode TransportMode) []dialer.DialeFn {
dialers := c.baseDialers(mode)
if c.transportFallback != nil && c.transportFallback.avoidDatagramSized(c.connectionURL) {
if filtered := nonDatagramSized(dialers); len(filtered) > 0 {
c.log.Infof("relay recently rejected a datagram-sized transport, avoiding it")
return filtered
}
}
return dialers
}
// baseDialers returns the ordered dialers for the mode, before any datagram
// fallback filtering. For racing modes (auto) the order is irrelevant; for
// prefer modes the first entry is tried before falling back to the second.
func (c *Client) baseDialers(mode TransportMode) []dialer.DialeFn {
switch mode {
case TransportModeWS:
c.log.Infof("%s=ws, using WebSocket transport", EnvRelayTransport)
return []dialer.DialeFn{ws.Dialer{}}
case TransportModeQUIC:
c.log.Infof("%s=quic, using QUIC transport", EnvRelayTransport)
return []dialer.DialeFn{quic.Dialer{}}
}
all := []dialer.DialeFn{quic.Dialer{}, ws.Dialer{}}
if mode == TransportModePreferWS {
all = []dialer.DialeFn{ws.Dialer{}, quic.Dialer{}}
}
// getDialers returns the list of dialers to use for connecting to the relay server.
func (c *Client) getDialers() []dialer.DialeFn {
if c.mtu > 0 && c.mtu > iface.DefaultMTU {
c.log.Infof("MTU %d exceeds default (%d), avoiding datagram-sized transports", c.mtu, iface.DefaultMTU)
return nonDatagramSized(all)
c.log.Infof("MTU %d exceeds default (%d), forcing WebSocket transport to avoid DATAGRAM frame size issues", c.mtu, iface.DefaultMTU)
return []dialer.DialeFn{ws.Dialer{}}
}
return all
return []dialer.DialeFn{quic.Dialer{}, ws.Dialer{}}
}

View File

@@ -1,101 +0,0 @@
//go:build !js
package client
import (
"os"
"testing"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/shared/relay/client/dialer"
netErr "github.com/netbirdio/netbird/shared/relay/client/dialer/net"
"github.com/netbirdio/netbird/shared/relay/client/dialer/quic"
"github.com/netbirdio/netbird/shared/relay/client/dialer/ws"
)
// TestDatagramSizedCapability locks the capability the generic fallback relies
// on: QUIC is datagram-sized, WebSocket is not.
func TestDatagramSizedCapability(t *testing.T) {
assert.True(t, dialer.IsDatagramSized(quic.Dialer{}), "QUIC must advertise datagram-sized")
assert.False(t, dialer.IsDatagramSized(ws.Dialer{}), "WebSocket must not advertise datagram-sized")
}
func protocols(dialers []dialer.DialeFn) []string {
out := make([]string, len(dialers))
for i, d := range dialers {
out[i] = d.Protocol()
}
return out
}
func TestGetDialers(t *testing.T) {
const url = "rels://relay.example:443"
tests := []struct {
name string
mode string
mtu uint16
preferWS bool
want []string
}{
{name: "auto races quic and ws", mode: "auto", mtu: iface.DefaultMTU, want: []string{"quic", "WS"}},
{name: "ws pinned", mode: "ws", mtu: iface.DefaultMTU, want: []string{"WS"}},
{name: "quic pinned", mode: "quic", mtu: iface.DefaultMTU, want: []string{"quic"}},
{name: "prefer-quic orders quic first", mode: "prefer-quic", mtu: iface.DefaultMTU, want: []string{"quic", "WS"}},
{name: "prefer-ws orders ws first", mode: "prefer-ws", mtu: iface.DefaultMTU, want: []string{"WS", "quic"}},
{name: "mtu above default forces ws", mode: "auto", mtu: iface.DefaultMTU + 100, want: []string{"WS"}},
{name: "sticky fallback forces ws in auto", mode: "auto", mtu: iface.DefaultMTU, preferWS: true, want: []string{"WS"}},
{name: "sticky fallback forces ws in prefer-quic", mode: "prefer-quic", mtu: iface.DefaultMTU, preferWS: true, want: []string{"WS"}},
{name: "quic pin overrides sticky fallback", mode: "quic", mtu: iface.DefaultMTU, preferWS: true, want: []string{"quic"}},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Setenv(EnvRelayTransport, tc.mode)
if tc.mode == "" {
os.Unsetenv(EnvRelayTransport)
}
tf := newTransportFallback()
if tc.preferWS {
tf.recordFailure(url)
}
c := &Client{
log: log.WithField("test", t.Name()),
connectionURL: url,
mtu: tc.mtu,
transportFallback: tf,
}
assert.Equal(t, tc.want, protocols(c.getDialers(transportModeFromEnv())))
})
}
}
// TestStickyFallbackAfterDatagramTooLarge verifies the full chain: an oversized
// datagram records a fallback that makes the next dial pick WebSocket, the way a
// reconnect would after the connection is closed.
func TestStickyFallbackAfterDatagramTooLarge(t *testing.T) {
const url = "rels://relay.example:443"
t.Setenv(EnvRelayTransport, string(TransportModeAuto))
c := &Client{
log: log.WithField("test", t.Name()),
connectionURL: url,
mtu: iface.DefaultMTU,
transportFallback: newTransportFallback(),
}
// First dial races both transports.
assert.Equal(t, []string{"quic", "WS"}, protocols(c.getDialers(transportModeFromEnv())))
// An oversized datagram records the fallback for this server.
c.onDatagramTooLarge(&closeTrackingConn{}, netErr.ErrDatagramTooLarge)
// The reconnect now sticks to WebSocket.
assert.Equal(t, []string{"WS"}, protocols(c.getDialers(transportModeFromEnv())))
}

View File

@@ -7,11 +7,7 @@ import (
"github.com/netbirdio/netbird/shared/relay/client/dialer/ws"
)
func (c *Client) getDialers(_ TransportMode) []dialer.DialeFn {
func (c *Client) getDialers() []dialer.DialeFn {
// JS/WASM build only uses WebSocket transport
return []dialer.DialeFn{ws.Dialer{}}
}
func (c *Client) baseDialers(_ TransportMode) []dialer.DialeFn {
return []dialer.DialeFn{ws.Dialer{}}
}

View File

@@ -79,30 +79,23 @@ type Manager struct {
cleanupInterval time.Duration
keepUnusedServerTime time.Duration
// transportFallback is shared across home and foreign relay clients so a
// datagram-too-large failure makes that server avoid datagram-sized transports across reconnects.
transportFallback *transportFallback
}
// NewManager creates a new manager instance.
// The serverURL address can be empty. In this case, the manager will not serve.
func NewManager(ctx context.Context, serverURLs []string, peerID string, mtu uint16, opts ...ManagerOption) *Manager {
tokenStore := &relayAuth.TokenStore{}
tf := newTransportFallback()
m := &Manager{
ctx: ctx,
peerID: peerID,
tokenStore: tokenStore,
mtu: mtu,
transportFallback: tf,
ctx: ctx,
peerID: peerID,
tokenStore: tokenStore,
mtu: mtu,
serverPicker: &ServerPicker{
TokenStore: tokenStore,
PeerID: peerID,
MTU: mtu,
ConnectionTimeout: defaultConnectionTimeout,
TransportFallback: tf,
},
relayClients: make(map[string]*RelayTrack),
onDisconnectedListeners: make(map[string]*list.List),
@@ -294,7 +287,6 @@ func (m *Manager) openConnVia(ctx context.Context, serverAddress, peerKey string
m.relayClientsMutex.Unlock()
relayClient := NewClientWithServerIP(serverAddress, serverIP, m.tokenStore, m.peerID, m.mtu)
relayClient.SetTransportFallback(m.transportFallback)
err := relayClient.Connect(m.ctx)
if err != nil {
rt.err = err

View File

@@ -29,7 +29,6 @@ type ServerPicker struct {
PeerID string
MTU uint16
ConnectionTimeout time.Duration
TransportFallback *transportFallback
}
func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) {
@@ -71,7 +70,6 @@ func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) {
func (sp *ServerPicker) startConnection(ctx context.Context, resultChan chan connResult, url string) {
log.Infof("try to connecting to relay server: %s", url)
relayClient := NewClient(url, sp.TokenStore, sp.PeerID, sp.MTU)
relayClient.SetTransportFallback(sp.TransportFallback)
err := relayClient.Connect(ctx)
resultChan <- connResult{
RelayClient: relayClient,

View File

@@ -1,129 +0,0 @@
package client
import (
"os"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/shared/relay/client/dialer"
)
// EnvRelayTransport pins the relay transport. Valid values: "auto" (default,
// race QUIC and WebSocket), "quic" (QUIC only), "ws" (WebSocket only),
// "prefer-quic" / "prefer-ws" (try the preferred transport first, fall back to
// the other only if it fails to connect; no race). The prefer modes trade a
// slower connect when the preferred transport is blackholed for deterministic
// transport selection.
const EnvRelayTransport = "NB_RELAY_TRANSPORT"
const (
// transportFallbackBase is the initial window a relay server avoids
// datagram-sized transports after a datagram is rejected as too large.
transportFallbackBase = 10 * time.Minute
// transportFallbackMax caps the pinned window when failures repeat.
transportFallbackMax = 60 * time.Minute
)
// TransportMode selects which relay dialers are used.
type TransportMode string
const (
TransportModeAuto TransportMode = "auto"
TransportModeQUIC TransportMode = "quic"
TransportModeWS TransportMode = "ws"
TransportModePreferQUIC TransportMode = "prefer-quic"
TransportModePreferWS TransportMode = "prefer-ws"
)
// transportModeFromEnv reads EnvRelayTransport, defaulting to auto for an empty
// or unrecognized value.
func transportModeFromEnv() TransportMode {
switch TransportMode(strings.ToLower(strings.TrimSpace(os.Getenv(EnvRelayTransport)))) {
case "", TransportModeAuto:
return TransportModeAuto
case TransportModeQUIC:
return TransportModeQUIC
case TransportModeWS:
return TransportModeWS
case TransportModePreferQUIC:
return TransportModePreferQUIC
case TransportModePreferWS:
return TransportModePreferWS
default:
log.Warnf("invalid %s value %q, using %q", EnvRelayTransport, os.Getenv(EnvRelayTransport), TransportModeAuto)
return TransportModeAuto
}
}
// sequential reports whether the mode tries dialers in order with fallback
// instead of racing them concurrently.
func (m TransportMode) sequential() bool {
return m == TransportModePreferQUIC || m == TransportModePreferWS
}
// transportFallback tracks relay servers that have rejected a datagram-sized
// transport (a write too large for the path) and should temporarily avoid such
// transports. It is shared across the relay manager so the preference survives
// client recreation (foreign relay clients are evicted and rebuilt on
// disconnect). Entries are keyed by server URL and expire after a window that
// grows on repeated failures.
type transportFallback struct {
mu sync.Mutex
entries map[string]*fallbackEntry
}
type fallbackEntry struct {
until time.Time
duration time.Duration
}
func newTransportFallback() *transportFallback {
return &transportFallback{entries: make(map[string]*fallbackEntry)}
}
// avoidDatagramSized reports whether serverURL is currently within a window
// where datagram-sized transports should be avoided.
func (f *transportFallback) avoidDatagramSized(serverURL string) bool {
f.mu.Lock()
defer f.mu.Unlock()
e := f.entries[serverURL]
return e != nil && time.Now().Before(e.until)
}
// recordFailure makes serverURL avoid datagram-sized transports for a window:
// transportFallbackBase on the first failure, doubling up to transportFallbackMax
// when a datagram transport fails again after a previous window expired. It
// returns the active window duration.
func (f *transportFallback) recordFailure(serverURL string) time.Duration {
f.mu.Lock()
defer f.mu.Unlock()
now := time.Now()
e := f.entries[serverURL]
switch {
case e == nil:
e = &fallbackEntry{duration: transportFallbackBase}
f.entries[serverURL] = e
case now.Before(e.until):
return time.Until(e.until)
default:
e.duration = min(e.duration*2, transportFallbackMax)
}
e.until = now.Add(e.duration)
return e.duration
}
// nonDatagramSized returns the dialers from in that are not datagram-sized,
// preserving order.
func nonDatagramSized(in []dialer.DialeFn) []dialer.DialeFn {
out := make([]dialer.DialeFn, 0, len(in))
for _, d := range in {
if !dialer.IsDatagramSized(d) {
out = append(out, d)
}
}
return out
}

View File

@@ -1,140 +0,0 @@
package client
import (
"net"
"os"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
netErr "github.com/netbirdio/netbird/shared/relay/client/dialer/net"
)
// closeTrackingConn records whether Close was called; only Close is exercised.
type closeTrackingConn struct {
net.Conn
closed bool
}
func (c *closeTrackingConn) Close() error {
c.closed = true
return nil
}
func TestTransportModeFromEnv(t *testing.T) {
tests := []struct {
value string
want TransportMode
}{
{"", TransportModeAuto},
{"auto", TransportModeAuto},
{"quic", TransportModeQUIC},
{"QUIC", TransportModeQUIC},
{"ws", TransportModeWS},
{" Ws ", TransportModeWS},
{"prefer-quic", TransportModePreferQUIC},
{"prefer-ws", TransportModePreferWS},
{"garbage", TransportModeAuto},
}
for _, tc := range tests {
t.Run(tc.value, func(t *testing.T) {
t.Setenv(EnvRelayTransport, tc.value)
if tc.value == "" {
os.Unsetenv(EnvRelayTransport)
}
assert.Equal(t, tc.want, transportModeFromEnv())
})
}
}
func TestTransportFallbackRecordAndExpiry(t *testing.T) {
const url = "rels://relay.example:443"
f := newTransportFallback()
assert.False(t, f.avoidDatagramSized(url), "no fallback recorded yet")
d := f.recordFailure(url)
assert.Equal(t, transportFallbackBase, d, "first failure pins for the base window")
assert.True(t, f.avoidDatagramSized(url), "datagram-sized transport avoided within the window")
// A second failure while still inside the window must not grow the window.
d = f.recordFailure(url)
assert.LessOrEqual(t, d, transportFallbackBase, "still within the active window")
require.NotNil(t, f.entries[url])
assert.Equal(t, transportFallbackBase, f.entries[url].duration, "duration unchanged inside window")
// Expire the window: datagram-sized transport allowed again.
f.entries[url].until = time.Now().Add(-time.Second)
assert.False(t, f.avoidDatagramSized(url), "window expired, datagram-sized transport allowed")
}
func TestTransportFallbackGrowsOnRepeat(t *testing.T) {
const url = "rels://relay.example:443"
f := newTransportFallback()
want := transportFallbackBase
for i := range 6 {
d := f.recordFailure(url)
assert.Equal(t, want, d, "window after %d expiries", i)
// expire the window so the next failure is treated as a repeat
f.entries[url].until = time.Now().Add(-time.Second)
want = min(want*2, transportFallbackMax)
}
assert.Equal(t, transportFallbackMax, f.entries[url].duration, "window caps at the max")
}
func TestOnDatagramTooLargeAuto(t *testing.T) {
const url = "rels://relay.example:443"
t.Setenv(EnvRelayTransport, string(TransportModeAuto))
tf := newTransportFallback()
c := &Client{
log: log.WithField("test", t.Name()),
connectionURL: url,
transportFallback: tf,
}
conn := &closeTrackingConn{}
c.onDatagramTooLarge(conn, netErr.ErrDatagramTooLarge)
assert.True(t, conn.closed, "connection closed to force reconnect")
assert.True(t, tf.avoidDatagramSized(url), "fallback recorded for the server")
// A second oversized datagram on the same connection must not re-close.
conn.closed = false
c.onDatagramTooLarge(conn, netErr.ErrDatagramTooLarge)
assert.False(t, conn.closed, "single fallback per connection")
}
func TestOnDatagramTooLargeQUICPinned(t *testing.T) {
const url = "rels://relay.example:443"
t.Setenv(EnvRelayTransport, string(TransportModeQUIC))
tf := newTransportFallback()
c := &Client{
log: log.WithField("test", t.Name()),
connectionURL: url,
transportFallback: tf,
}
conn := &closeTrackingConn{}
c.onDatagramTooLarge(conn, netErr.ErrDatagramTooLarge)
assert.False(t, conn.closed, "QUIC pin keeps the connection, no fallback redial")
assert.False(t, tf.avoidDatagramSized(url), "QUIC pin records no fallback")
}
func TestTransportFallbackPerServer(t *testing.T) {
f := newTransportFallback()
f.recordFailure("rels://a.example:443")
assert.True(t, f.avoidDatagramSized("rels://a.example:443"))
assert.False(t, f.avoidDatagramSized("rels://b.example:443"), "fallback is scoped to one server")
}