Compare commits

...

24 Commits

Author SHA1 Message Date
Zoltán Papp
0035ddde8c Merge branch 'main' into fix/exit-node-v6-deselect-propagation
# Conflicts:
#	client/internal/routemanager/manager.go
2026-06-11 21:56:47 +02:00
Maycon Santos
d7703767d5 [client, proxy] cancel context before stopping engine on embedded client (#6397)
- Engine.Start takes syncMsgMux with a deferred unlock (engine.go:445) and parks in receiveSignalEvents → WaitStreamConnected (engine.go:1762), which only wakes on
  signal-stream connect or client-context cancellation.
  - When signal never connects, the 30s startup timeout fires and embed.Client.Start's rollback (embed.go:281) called client.Stop() → Engine.Stop, which blocks acquiring
  syncMsgMux (engine.go:318). The cancel() that would unpark Start was deferred until Start returned — permanent cycle. RemovePeer calls (g43/g385) then queue behind the
  lifecycle mutex.
  - Notably, embed.Client.Stop and the daemon's cleanupConnection both cancel before stopping — the startup rollback was the only path that didn't.
  - Engine.Start takes syncMsgMux with a deferred unlock (engine.go:445) and parks in receiveSignalEvents → WaitStreamConnected (engine.go:1762), which only wakes on
  signal-stream connect or client-context cancellation.
  - When signal never connects, the 30s startup timeout fires and embed.Client.Start's rollback (embed.go:281) called client.Stop() → Engine.Stop, which blocks acquiring
  syncMsgMux (engine.go:318). The cancel() that would unpark Start was deferred until Start returned — permanent cycle. RemovePeer calls (g43/g385) then queue behind the
  lifecycle mutex.
  - Notably, embed.Client.Stop and the daemon's cleanupConnection both cancel before stopping — the startup rollback was the only path that didn't.
2026-06-10 21:26:54 +02:00
Maycon Santos
7feda907ca [management] fix L4 service update when no custom port (#6396)
This fixes an issue where L4 service update is not possible when proxy clusters don't support custom ports
2026-06-10 18:55:24 +02:00
Maycon Santos
62da482133 [management] Add version gate to stop sending deprecated RemotePeers field (#6371)
* [management] Add version gate to stop sending deprecated RemotePeers field

don't send top-level remote peers on peers in the  v0.29.3 or newer

* precompute deprecated remote peers version constraint

* [management] update tests to validate network map-based remote peers

* [management] move deprecatedRemotePeersVersion constant closer to its usage

* fix misplaced precomputed constraint definition

* ensure top-level RemotePeers is empty for v0.29.3+ clients
2026-06-10 16:59:09 +02:00
Philip Laine
079bce3c2f Add commands to discover and write Kubernetes configuration (#6260) 2026-06-10 15:00:10 +02:00
Maycon Santos
1a09aa6715 [misc] Update Go toolchain version in go.mod (#6377) 2026-06-10 14:50:57 +02:00
Maycon Santos
61abf5b9ea [proxy] Use UUID for proxy ID generation (#6391)
Use UUID for proxy ID instead of the second to avoid race conditions when running multiple nodes at the same time.
2026-06-10 13:35:26 +02:00
Boris Dolgov
e229050ba3 [proxy] Notify certificate ready for domains covered by the static certificate (#6389) 2026-06-10 12:05:34 +02:00
Zoltan Papp
e919b2d55d [client] Preserve posture checks on config-only sync updates (#6373)
* [client] Preserve posture checks on config-only sync updates

When management sends a MessageTypeControlConfig update (e.g. relay token
rotation), the SyncResponse carries no NetworkMap and no Checks. Moving the
updateChecksIfNew call after the nm == nil guard ensures posture checks are
only updated when a full network map is present, preventing relay token
rotation from silently clearing the previously applied posture check state.

* [client] Clarify posture check update logic with explicit comment

* [client] Extract NetBird config and sync persistence into helpers

Move the NetbirdConfig handling block out of handleSync into
updateNetbirdConfig and the sync response persistence into
persistSyncResponse, mirroring updateChecksIfNew. This flattens
handleSync and makes the individual update steps unit-testable.
2026-06-10 11:43:24 +02:00
Pascal Fischer
a40028092d [management] log user agent and return request id (#6380) 2026-06-09 15:24:26 +02:00
Pascal Fischer
13200265d8 [proxy] Add no-blocking mapping updates (#6369) 2026-06-09 13:57:17 +02:00
Viktor Liu
ed7a9363aa [management] Emit IPv6 default permit firewall rule for exit node routes (#6368) 2026-06-09 13:26:43 +02:00
Viktor Liu
d56859dc5d [client] Filter DNS fallback upstreams matching our server IP to prevent loops (#6183) 2026-06-09 12:26:03 +02:00
Viktor Liu
367d37050b [relay, client] Fall back to WebSocket relay transport on oversized QUIC datagrams (#6339) 2026-06-09 10:25:46 +02:00
Viktor Liu
106527182f [client] Snapshot iptables rule maps before persisting state (#6345) 2026-06-09 10:24:51 +02:00
Viktor Liu
8e1d5b78c2 [client] Preserve user deselect-all across management route sync (#6363) 2026-06-09 10:24:17 +02:00
Zoltan Papp
764642d8f2 [client] remove v6 exit-pair mirror DIAG logging
Drop the temporary DIAG diagnostics added to trace the v4/v6 exit-pair mirror.
The field log confirmed the write-time mirror keeps the pair consistent (the
::/0 route is only ever applied alongside its v4 base and is dropped on deselect),
so the diagnostics are no longer needed.
2026-06-03 01:27:04 +02:00
Zoltan Papp
ed76f8f065 [client] add DIAG logging to trace v6 exit-pair mirror
The write-time mirror did not eliminate the leak in field testing. Re-add the
DIAG diagnostics around the exit-node selection flow to capture a fresh trace:

- UpdateRoutes: incoming client networks, selector state before/after the
  management update, and the networks remaining after FilterSelectedExitNodes.
- mirrorV6ExitPairSelections: the NetIDs present in this update and the v6 pairs
  V6ExitMergeSet derives from them (reveals whether the v4 base and its ::/0 pair
  are present in the same update so the pair can be matched).
- SyncPairedSelection: the base/paired state before and after the sync.
- FilterSelectedExitNodes / applyExitNodeFilter: per-route SKIP/KEEP/DROP and the
  selection lookups behind each decision.
- updateExitNodeSelections / logExitNodeUpdate: categorization and deselect set.

Temporary; to be removed once the root cause is confirmed.
2026-06-03 01:01:53 +02:00
Zoltan Papp
d25c8d881d [client] mirror v4 exit selection onto v6 pair at write time
The synthesized "-v6" exit route shares its v4 base's NetID plus a "-v6"
suffix. Selection state was reconciled at read time via effectiveNetID, a
mirror that could only be applied on exit-node code paths, which forced a
parallel IsSelectedForExitNode() alongside IsSelected() and a clearPairedV6Locked()
orphan cleanup on every toggle. That machinery still missed the case observed
in the field: a persisted state with the v4 base deselected but its "-v6"
sibling explicitly selected (orphaned). Because effectiveNetID returns the v6
entry itself once it carries explicit state, and clearPairedV6Locked only fires
on a live toggle, the loaded orphan survived and the ::/0 route leaked onto the
tunnel despite the exit node being disabled, breaking IPv6 (happy eyeballs).

Treat the v4/v6 exit pair as a single toggle and keep state consistent at write
time instead. RouteSelector.SyncPairedSelection forces the "-v6" entry to match
its v4 base unconditionally, resetting any orphaned explicit state. The route
manager, which knows the route prefixes, computes the pairs (V6ExitMergeSet) and
calls it from updateRouteSelectorFromManagement before selection is read, so both
collectExitNodeInfo and FilterSelectedExitNodes see consistent state, including
pairs loaded from persisted selector state.

This removes effectiveNetID, IsSelectedForExitNode and clearPairedV6Locked; the
selector is literal again and no longer needs the "exit-node paths only" caveat.
HasUserSelectionForRoute and applyExitNodeFilter use the raw NetID.

Adds a selector test for SyncPairedSelection (including the orphaned-v6 case) and
a route-manager test reproducing the persisted-orphan scenario from the field log.
2026-06-03 00:33:36 +02:00
Zoltán Papp
8e5130cda7 [client] remove exit-node v6 DIAG logging and tidy routeselector
Drop the temporary DIAG diagnostics added to trace the leaking ::/0 route
(the root cause is fixed and confirmed). Also reorganize routeselector.go so
the exit-node helpers (clearPairedV6Locked, isExitNode) sit next to the
exit-node code paths and MarshalJSON/UnmarshalJSON are grouped together.
2026-06-01 10:55:07 +02:00
Zoltán Papp
aa164c93cf [ios] compute route connection status in the bridge
The iOS bridge exposed a route's Network as a possibly comma-joined string
("0.0.0.0/0, ::/0" for a merged exit node) but no connection status, forcing
the UI to infer status by string-matching that joined value against peer
routes — which never matched for the merged exit node, leaving it stuck as
not-connected. Android already computes status in the core (findBestRoutePeer).

Mirror that here: add a Status field to RoutesSelectionInfo and compute it from
the connected peers' route tables, matching the route's primary prefix, a merged
exit node's extra v6 prefix, or a dynamic route's domain pattern (the key the
route manager records). The UI can now read the status directly.
2026-05-31 21:01:07 +02:00
Zoltán Papp
99223a310d [client] clear orphaned v6 exit selection when v4 pair is toggled
Root cause of the leaking ::/0 route, confirmed from client logs: the
synthesized "-v6" exit route could stay explicitly selected in the persisted
route-selector state while its v4 base was deselected (selected=[...-v6],
deselected=[...v4base]). Because the v6 entry then has its own explicit state,
effectiveNetID stops mirroring the v4 base, so FilterSelectedExitNodes keeps
::/0 and it is installed on the tunnel even though the user disabled the exit
node. This happened because the iOS SDK's deselect only pairs the "-v6" sibling
via ExpandV6ExitPairs when the v6 route is present in the current routesMap; a
deselect at a moment it wasn't expanded left the v6 selection orphaned.

Fix at the selector write path so it is independent of routesMap timing: when a
v4 exit NetID is selected or deselected, clear any orphaned explicit state on
its "-v6" sibling (clearPairedV6Locked), unless the sibling is part of the same
batch (the deliberate ExpandV6ExitPairs case). The v6 then falls back to
inheriting the v4 base via effectiveNetID, so a v4 deselect also drops ::/0 and
a v4 select brings both back.

Adds regression tests: a stale explicit v6 selection is cleared by a later v4
deselect, and an explicit v6 select made in the same batch is preserved.
2026-05-31 15:33:13 +02:00
Zoltán Papp
84867c7e45 [client] add DIAG logging to trace exit-node v6 (::/0) route filtering
Temporary diagnostics to find why a deselected v4 exit node's synthesized
::/0 route still reaches the tunnel. Logs the full install path: incoming
client networks, route-selector state before/after the management-driven
update, what updateExitNodeSelections deselects/selects, and per-route
KEEP/SKIP/DROP decisions in FilterSelectedExitNodes and applyExitNodeFilter.
To be reverted once the real root cause is confirmed from a client log.
2026-05-31 14:25:50 +02:00
Zoltán Papp
c7499cf8fc [client] propagate exit-node deselect to synthesized v6 (::/0) route
When a client deselects an IPv4 exit node, the auto-generated IPv6 default
route (::/0) was still selected and pushed onto the tunnel interface, even
though the user disabled the exit node. On an exit node without a real IPv6
egress this blackholes IPv6 traffic, and because clients prefer IPv6 (happy
eyeballs) it can break general connectivity.

Root cause: the synthesized v6 route gets a different NetID than its v4 base
(base + "-v6"). The route selector keys deselects by NetID and defaults
unknown NetIDs to selected, so the "-v6" entry was never matched by the v4
deselect. The effectiveNetID() mirror that solves exactly this is used by
HasUserSelectionForRoute and FilterSelectedExitNodes, but categorizeUserSelection
called the raw IsSelected(), bypassing it and mis-categorizing the v6 pair as
user-selected.

Add RouteSelector.IsSelectedForExitNode(), which applies effectiveNetID before
the selection check, and use it in categorizeUserSelection. IsSelected() is left
untouched so non-exit code paths don't make unrelated "*-v6" routes inherit v4
state. Adds regression tests for the v4/v6 deselect mirror and explicit-v6
override.
2026-05-31 13:48:51 +02:00
53 changed files with 2893 additions and 302 deletions

301
client/cmd/kubernetes.go Normal file
View File

@@ -0,0 +1,301 @@
package cmd
import (
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"path/filepath"
"slices"
"strings"
"github.com/goccy/go-yaml"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/proto"
)
const (
KubernetesDNSSuffix = "netbird-kubeapi-proxy"
)
var kubernetesCmd = &cobra.Command{
Use: "kubernetes",
Short: "Kubernetes cluster commands.",
Long: "Kubernetes cluster commands.",
}
var kubernetesListCmd = &cobra.Command{
Use: "list",
RunE: kubernetesList,
Short: "List Kubernetes clusters.",
Long: "List Kubernetes clusters by discovering NetBird peers running netbird-kubeapi-proxy.",
}
var kubernetesWriteKubeconfigCmd = &cobra.Command{
Use: "write-kubeconfig",
RunE: kubernetesWriteKubeconfig,
Args: cobra.ExactArgs(1),
Short: "Write kubeconfig for a Kubernetes cluster.",
Long: "Updates kubeconfig in place to allow token-less access to the Kubernetes cluster through NetBird.",
}
func init() {
kubernetesWriteKubeconfigCmd.Flags().String("kubeconfig", "", "path to kubeconfig file")
}
func kubernetesList(cmd *cobra.Command, _ []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
statusResp, err := client.Status(cmd.Context(), &proto.StatusRequest{GetFullPeerStatus: true})
if err != nil {
return err
}
kcs, err := getKubernetesClusters(cmd.Context(), statusResp.FullStatus.Peers, "")
if err != nil {
return err
}
if len(kcs) == 0 {
cmd.Println("No Kubernetes clusters available.")
return nil
}
cmd.Println("Available Kubernetes clusters:")
for _, k := range kcs {
cmd.Printf("\n - Name: %s\n FQDN: %s\n Version: %s\n", k.name, k.url.Host, k.version)
}
return nil
}
func kubernetesWriteKubeconfig(cmd *cobra.Command, args []string) error {
kubeconfigPath, err := resolveKubeconfigPath(cmd)
if err != nil {
return err
}
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
statusResp, err := client.Status(cmd.Context(), &proto.StatusRequest{GetFullPeerStatus: true})
if err != nil {
return err
}
clusterName := args[0]
kcs, err := getKubernetesClusters(cmd.Context(), statusResp.FullStatus.Peers, clusterName)
if err != nil {
return err
}
if len(kcs) == 0 {
return fmt.Errorf("kubernetes cluster named %s not found", clusterName)
}
if len(kcs) > 1 {
return fmt.Errorf("too many Kubernetes clusters returned")
}
err = writeKubeconfig(kubeconfigPath, kcs[0])
if err != nil {
return err
}
return nil
}
type kubernetesCluster struct {
name string
url *url.URL
version string
}
func getKubernetesClusters(ctx context.Context, peers []*proto.PeerState, nameFilter string) ([]kubernetesCluster, error) {
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.TLSClientConfig = &tls.Config{
InsecureSkipVerify: true,
}
httpClient := &http.Client{
Transport: transport,
}
resolver := net.Resolver{
// Required so both DNS records are returned.
// https://github.com/golang/go/issues/17093
PreferGo: true,
}
kcs := []kubernetesCluster{}
attempted := map[string]struct{}{}
for _, peer := range peers {
fqdns, err := resolver.LookupAddr(ctx, peer.IP)
if err != nil {
return nil, err
}
for _, fqdn := range fqdns {
if _, ok := attempted[fqdn]; ok {
continue
}
attempted[fqdn] = struct{}{}
comps := strings.Split(fqdn, ".")
if len(comps) < 2 {
continue
}
if comps[1] != KubernetesDNSSuffix {
continue
}
if nameFilter != "" && nameFilter != comps[0] {
continue
}
clusterURL, clusterVersion, err := fingerprintClusters(ctx, httpClient, fqdn)
if err != nil {
log.Debugf("could not fingerprint Kubernetes cluster %s %q", fqdn, err)
continue
}
kc := kubernetesCluster{
name: comps[0],
url: clusterURL,
version: clusterVersion,
}
if nameFilter != "" {
return []kubernetesCluster{kc}, nil
}
kcs = append(kcs, kc)
}
}
return kcs, nil
}
func fingerprintClusters(ctx context.Context, httpClient *http.Client, fqdn string) (*url.URL, string, error) {
clusterURL, err := url.Parse("https://" + fqdn)
if err != nil {
return nil, "", err
}
versionURL, err := clusterURL.Parse("/version")
if err != nil {
return nil, "", err
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, versionURL.String(), nil)
if err != nil {
return nil, "", err
}
resp, err := httpClient.Do(req)
if err != nil {
return nil, "", err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, "", fmt.Errorf("expected %d response but got %s", http.StatusOK, resp.Status)
}
b, err := io.ReadAll(resp.Body)
if err != nil {
return nil, "", err
}
versionData := map[string]string{}
err = json.Unmarshal(b, &versionData)
if err != nil {
return nil, "", err
}
version, ok := versionData["gitVersion"]
if !ok {
return nil, "", errors.New("no version found in response")
}
return clusterURL, version, nil
}
func resolveKubeconfigPath(cmd *cobra.Command) (string, error) {
if cmd.Flags().Changed("kubeconfig") {
path, err := cmd.Flags().GetString("kubeconfig")
if err != nil {
return "", err
}
return path, nil
}
if env := os.Getenv("KUBECONFIG"); env != "" {
return env, nil
}
home, err := os.UserHomeDir()
if err != nil {
return "", fmt.Errorf("could not determine home directory: %w", err)
}
return filepath.Join(home, ".kube", "config"), nil
}
func writeKubeconfig(kubeconfigPath string, kc kubernetesCluster) error {
b, err := os.ReadFile(kubeconfigPath)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return err
}
var cfg map[string]any
if err := yaml.Unmarshal(b, &cfg); err != nil {
return err
}
if cfg == nil {
cfg = map[string]any{
"apiVersion": "v1",
"kind": "Config",
}
}
cfg["clusters"] = appendWithName(cfg["clusters"], map[string]any{
"name": kc.name,
"cluster": map[string]any{
"server": kc.url.String(),
"insecure-skip-tls-verify": true,
},
})
cfg["users"] = appendWithName(cfg["users"], map[string]any{
"name": "netbird",
"user": map[string]any{
"token": "none",
},
})
cfg["contexts"] = appendWithName(cfg["contexts"], map[string]any{
"name": kc.name,
"context": map[string]any{
"cluster": kc.name,
"user": "netbird",
"namespace": "default",
},
})
cfg["current-context"] = kc.name
out, err := yaml.Marshal(cfg)
if err != nil {
return err
}
if err := os.WriteFile(kubeconfigPath, out, 0o600); err != nil {
return err
}
return nil
}
func appendWithName(data any, add map[string]any) any {
if data == nil {
return []any{add}
}
v, ok := data.([]any)
if !ok {
return []any{add}
}
i := slices.IndexFunc(v, func(item any) bool {
m, ok := item.(map[string]any)
if !ok {
return false
}
return m["name"] == add["name"]
})
if i == -1 {
return append(v, add)
}
v[i] = add
return v
}

View File

@@ -0,0 +1,120 @@
package cmd
import (
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"testing"
"github.com/spf13/cobra"
"github.com/stretchr/testify/require"
)
func TestFingerprintClusters(t *testing.T) {
t.Parallel()
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
//nolint: errcheck
w.Write([]byte(`{"gitVersion": "foobar"}`))
}))
defer srv.Close()
clusterURL, clusterVersion, err := fingerprintClusters(t.Context(), srv.Client(), srv.Listener.Addr().String())
require.NoError(t, err)
require.Equal(t, srv.URL, clusterURL.String())
require.Equal(t, "foobar", clusterVersion)
}
func TestResolveKubeconfigPath(t *testing.T) {
home, err := os.UserHomeDir()
if err != nil {
t.Fatalf("could not determine home directory: %v", err)
}
defaultPath := filepath.Join(home, ".kube", "config")
path, err := resolveKubeconfigPath(&cobra.Command{})
require.NoError(t, err)
require.Equal(t, defaultPath, path)
flagPath := "flag-path"
cmd := &cobra.Command{}
cmd.Flags().String("kubeconfig", "", "")
err = cmd.Flags().Set("kubeconfig", flagPath)
require.NoError(t, err)
path, err = resolveKubeconfigPath(cmd)
require.NoError(t, err)
require.Equal(t, flagPath, path)
envPath := "env-path"
t.Setenv("KUBECONFIG", envPath)
path, err = resolveKubeconfigPath(&cobra.Command{})
require.NoError(t, err)
require.Equal(t, envPath, path)
}
func TestWriteKubeconfig(t *testing.T) {
t.Parallel()
tests := []struct {
name string
existing string
}{
{
name: "empty file",
},
{
name: "existing content",
existing: `apiVersion: v1
clusters:
- cluster:
insecure-skip-tls-verify: true
server: https://foobar.com
name: foo
current-context: test
kind: Config
users: []
`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
kubeconfigPath := filepath.Join(t.TempDir(), "config")
err := os.WriteFile(kubeconfigPath, []byte(tt.existing), 0o644)
require.NoError(t, err)
kc := kubernetesCluster{
name: "foo",
url: &url.URL{Scheme: "https", Host: "example.com"},
}
err = writeKubeconfig(kubeconfigPath, kc)
require.NoError(t, err)
b, err := os.ReadFile(kubeconfigPath)
require.NoError(t, err)
expected := `apiVersion: v1
clusters:
- cluster:
insecure-skip-tls-verify: true
server: https://example.com
name: foo
contexts:
- context:
cluster: foo
namespace: default
user: netbird
name: foo
current-context: foo
kind: Config
users:
- name: netbird
user:
token: none
`
require.Equal(t, expected, string(b))
})
}
}

View File

@@ -169,6 +169,11 @@ func init() {
debugCmd.AddCommand(forCmd)
debugCmd.AddCommand(persistenceCmd)
// kubernetes commands
rootCmd.AddCommand(kubernetesCmd)
kubernetesCmd.AddCommand(kubernetesListCmd)
kubernetesCmd.AddCommand(kubernetesWriteKubeconfigCmd)
// profile commands
profileCmd.AddCommand(profileListCmd)
profileCmd.AddCommand(profileAddCmd)

View File

@@ -279,6 +279,10 @@ func (c *Client) Start(startCtx context.Context) error {
select {
case <-startCtx.Done():
// Cancel the client context before stopping: Engine.Start blocks on the
// signal stream while holding the engine mutex and only unblocks on
// cancellation. Stopping first would deadlock on that mutex.
cancel()
if stopErr := client.Stop(); stopErr != nil {
return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err())
}

168
client/embed/embed_test.go Normal file
View File

@@ -0,0 +1,168 @@
package embed
import (
"context"
"net"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/internals/server/config"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
mgmt "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util"
)
const testSetupKey = "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
// TestClientStartTimeoutRollback reproduces a deadlock between Engine.Start and
// Engine.Stop. The signal endpoint accepts gRPC connections but never serves the
// SignalExchange service, so Engine.Start parks in WaitStreamConnected while
// holding the engine mutex. When the Start context expires, the rollback path
// calls ConnectClient.Stop, which must not block forever acquiring that mutex.
func TestClientStartTimeoutRollback(t *testing.T) {
signalAddr := startBlackholeSignal(t)
mgmAddr := startManagement(t, signalAddr)
wgPort := 0
client, err := New(Options{
DeviceName: "embed-rollback-test",
SetupKey: testSetupKey,
ManagementURL: "http://" + mgmAddr,
WireguardPort: &wgPort,
})
require.NoError(t, err, "embed client creation must succeed")
startCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
startErr := make(chan error, 1)
go func() {
startErr <- client.Start(startCtx)
}()
select {
case err := <-startErr:
require.ErrorIs(t, err, context.DeadlineExceeded)
case <-time.After(60 * time.Second):
t.Fatal("client.Start did not return after its context expired: Engine.Stop deadlocked against Engine.Start waiting for the signal stream")
}
}
// startBlackholeSignal starts a gRPC server without the SignalExchange service
// registered. Connections succeed, but the signal stream can never be
// established, which keeps Engine.Start parked in WaitStreamConnected.
func startBlackholeSignal(t *testing.T) string {
t.Helper()
lis, err := net.Listen("tcp", "localhost:0")
require.NoError(t, err)
s := grpc.NewServer()
go func() {
if err := s.Serve(lis); err != nil {
t.Error(err)
}
}()
t.Cleanup(s.Stop)
return lis.Addr().String()
}
func startManagement(t *testing.T, signalAddr string) string {
t.Helper()
cfg := &config.Config{
Stuns: []*config.Host{},
TURNConfig: &config.TURNConfig{},
Relay: &config.Relay{
Addresses: []string{"127.0.0.1:1234"},
CredentialsTTL: util.Duration{Duration: time.Hour},
Secret: "222222222222222222",
},
Signal: &config.Host{
Proto: "http",
URI: signalAddr,
},
Datadir: t.TempDir(),
HttpConfig: nil,
}
lis, err := net.Listen("tcp", "localhost:0")
require.NoError(t, err)
s := grpc.NewServer()
testStore, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", cfg.Datadir)
require.NoError(t, err)
t.Cleanup(cleanUp)
eventStore := &activity.InMemoryEventStore{}
permissionsManager := permissions.NewManager(testStore)
peersManager := peers.NewManager(testStore, permissionsManager)
jobManager := job.NewJobManager(nil, testStore, peersManager)
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
require.NoError(t, err)
iv, err := validator.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
require.NoError(t, err)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
settingsMockManager.EXPECT().
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
Return(&types.Settings{}, nil).
AnyTimes()
settingsMockManager.EXPECT().
GetExtraSettings(gomock.Any(), gomock.Any()).
Return(&types.ExtraSettings{}, nil).
AnyTimes()
groupsManager := groups.NewManagerMock()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := mgmt.NewAccountRequestBuffer(context.Background(), testStore)
networkMapController := controller.NewController(context.Background(), testStore, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(testStore, peersManager), cfg)
accountManager, err := mgmt.BuildManager(context.Background(), cfg, testStore, networkMapController, jobManager, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
require.NoError(t, err)
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, cfg.TURNConfig, cfg.Relay, settingsMockManager, groupsManager)
require.NoError(t, err)
mgmtServer, err := nbgrpc.NewServer(cfg, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil, nil)
require.NoError(t, err)
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
go func() {
if err := s.Serve(lis); err != nil {
t.Error(err)
}
}()
t.Cleanup(s.Stop)
return lis.Addr().String()
}

View File

@@ -3,6 +3,7 @@ package iptables
import (
"errors"
"fmt"
"maps"
"net"
"slices"
@@ -421,12 +422,17 @@ func (m *aclManager) updateState() {
currentState.Lock()
defer currentState.Unlock()
// Clone the maps so the persisted state holds a private snapshot. The
// live maps keep being mutated by subsequent rule operations while the
// state manager marshals the state from its periodic-save goroutine.
// Sharing them by reference races the two and aborts the process with a
// concurrent map iteration and write.
if m.v6 {
currentState.ACLEntries6 = m.entries
currentState.ACLIPsetStore6 = m.ipsetStore
currentState.ACLEntries6 = maps.Clone(m.entries)
currentState.ACLIPsetStore6 = m.ipsetStore.clone()
} else {
currentState.ACLEntries = m.entries
currentState.ACLIPsetStore = m.ipsetStore
currentState.ACLEntries = maps.Clone(m.entries)
currentState.ACLIPsetStore = m.ipsetStore.clone()
}
if err := m.stateManager.UpdateState(currentState); err != nil {

View File

@@ -4,6 +4,7 @@ package iptables
import (
"fmt"
"maps"
"net/netip"
"strconv"
"strings"
@@ -749,11 +750,17 @@ func (r *router) updateState() {
currentState.Lock()
defer currentState.Unlock()
// Clone the rule map so the persisted state holds a private snapshot. The
// live map keeps being mutated by subsequent rule operations while the
// state manager marshals the state from its periodic-save goroutine.
// Sharing it by reference races the two and aborts the process with a
// concurrent map iteration and write. The ipset counter guards itself
// during marshaling, so it can be shared directly.
if r.v6 {
currentState.RouteRules6 = r.rules
currentState.RouteRules6 = maps.Clone(r.rules)
currentState.RouteIPsetCounter6 = r.ipsetCounter
} else {
currentState.RouteRules = r.rules
currentState.RouteRules = maps.Clone(r.rules)
currentState.RouteIPsetCounter = r.ipsetCounter
}

View File

@@ -1,6 +1,9 @@
package iptables
import "encoding/json"
import (
"encoding/json"
"maps"
)
type ipList struct {
ips map[string]struct{}
@@ -19,6 +22,14 @@ func (s *ipList) addIP(ip string) {
s.ips[ip] = struct{}{}
}
// clone returns a deep copy of the ipList with its own ips map.
func (s *ipList) clone() *ipList {
if s == nil {
return nil
}
return &ipList{ips: maps.Clone(s.ips)}
}
// MarshalJSON implements json.Marshaler
func (s *ipList) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
@@ -55,6 +66,19 @@ func newIpsetStore() *ipsetStore {
}
}
// clone returns a deep copy of the ipsetStore with its own ipsets map and
// independent ipList entries.
func (s *ipsetStore) clone() *ipsetStore {
if s == nil {
return nil
}
cloned := &ipsetStore{ipsets: make(map[string]*ipList, len(s.ipsets))}
for name, list := range s.ipsets {
cloned.ipsets[name] = list.clone()
}
return cloned
}
func (s *ipsetStore) ipset(ipsetName string) (*ipList, bool) {
r, ok := s.ipsets[ipsetName]
return r, ok

View File

@@ -777,13 +777,24 @@ func (s *DefaultServer) applyHostConfig() {
// context is released rather than leaked until GC.
func (s *DefaultServer) registerFallback() {
originalNameservers := s.hostManager.getOriginalNameservers()
if len(originalNameservers) == 0 {
serverIP := s.service.RuntimeIP()
var servers []netip.AddrPort
for _, ns := range originalNameservers {
if ns == serverIP {
log.Debugf("skipping original nameserver %s as it is the same as the server IP %s", ns, serverIP)
continue
}
servers = append(servers, netip.AddrPortFrom(ns, DefaultPort))
}
if len(servers) == 0 {
log.Debugf("no fallback upstreams to register; clearing PriorityFallback handler")
s.clearFallback()
return
}
log.Infof("registering original nameservers %v as upstream handlers with priority %d", originalNameservers, PriorityFallback)
log.Infof("registering original nameservers %v as upstream handlers with priority %d", servers, PriorityFallback)
handler, err := newUpstreamResolver(
s.ctx,
@@ -797,11 +808,6 @@ func (s *DefaultServer) registerFallback() {
return
}
handler.selectedRoutes = s.selectedRoutes
var servers []netip.AddrPort
for _, ns := range originalNameservers {
servers = append(servers, netip.AddrPortFrom(ns, DefaultPort))
}
handler.addRace(servers)
prev := s.fallbackHandler

View File

@@ -880,62 +880,25 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate)
}
if update.GetNetbirdConfig() != nil {
wCfg := update.GetNetbirdConfig()
err := e.updateTURNs(wCfg.GetTurns())
if err != nil {
return fmt.Errorf("update TURNs: %w", err)
}
if err := e.updateNetbirdConfig(update.GetNetbirdConfig()); err != nil {
return err
}
err = e.updateSTUNs(wCfg.GetStuns())
if err != nil {
return fmt.Errorf("update STUNs: %w", err)
}
var stunTurn []*stun.URI
stunTurn = append(stunTurn, e.STUNs...)
stunTurn = append(stunTurn, e.TURNs...)
e.stunTurn.Store(stunTurn)
err = e.handleRelayUpdate(wCfg.GetRelay())
if err != nil {
return err
}
err = e.handleFlowUpdate(wCfg.GetFlow())
if err != nil {
return fmt.Errorf("handle the flow configuration: %w", err)
}
if err := e.PopulateNetbirdConfig(wCfg, nil); err != nil {
log.Warnf("Failed to update DNS server config: %v", err)
}
// todo update signal
// Posture checks are bound to the network map presence:
// NetworkMap != nil, checks present -> apply the received checks
// NetworkMap != nil, checks nil -> posture checks were removed, clear them
// NetworkMap == nil -> config-only update (e.g. relay token rotation),
// leave the previously applied checks untouched
nm := update.GetNetworkMap()
if nm == nil {
return nil
}
if err := e.updateChecksIfNew(update.Checks); err != nil {
return err
}
nm := update.GetNetworkMap()
if nm == nil {
return nil
}
// Persist sync response under the dedicated lock (syncRespMux), not under syncMsgMux.
// A non-nil syncStore is what marks persistence as enabled. Hold the lock for
// the whole Set so the store cannot be cleared (disabled / engine close)
// mid-call and have this write resurrect a file that was just removed.
e.syncRespMux.RLock()
if e.syncStore != nil {
if err := e.syncStore.Set(update); err != nil {
log.Errorf("failed to persist sync response: %v", err)
} else {
log.Debugf("sync response persisted with serial %d", nm.GetSerial())
}
}
e.syncRespMux.RUnlock()
e.persistSyncResponse(update)
// only apply new changes and ignore old ones
if err := e.updateNetworkMap(nm); err != nil {
@@ -947,6 +910,64 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return nil
}
// updateNetbirdConfig applies the management-provided NetBird configuration:
// STUN/TURN and relay servers, flow logging and DNS settings. A nil config is a no-op,
// which is the case for sync updates carrying only a network map.
func (e *Engine) updateNetbirdConfig(wCfg *mgmProto.NetbirdConfig) error {
if wCfg == nil {
return nil
}
if err := e.updateTURNs(wCfg.GetTurns()); err != nil {
return fmt.Errorf("update TURNs: %w", err)
}
if err := e.updateSTUNs(wCfg.GetStuns()); err != nil {
return fmt.Errorf("update STUNs: %w", err)
}
var stunTurn []*stun.URI
stunTurn = append(stunTurn, e.STUNs...)
stunTurn = append(stunTurn, e.TURNs...)
e.stunTurn.Store(stunTurn)
if err := e.handleRelayUpdate(wCfg.GetRelay()); err != nil {
return err
}
if err := e.handleFlowUpdate(wCfg.GetFlow()); err != nil {
return fmt.Errorf("handle the flow configuration: %w", err)
}
if err := e.PopulateNetbirdConfig(wCfg, nil); err != nil {
log.Warnf("Failed to update DNS server config: %v", err)
}
// todo update signal
return nil
}
// persistSyncResponse stores the full sync response so it can be restored on the next
// startup. Persistence is enabled only when syncStore is set. The dedicated syncRespMux
// (not syncMsgMux) is held for the whole Set so the store cannot be cleared (disabled /
// engine close) mid-call and have this write resurrect a file that was just removed.
func (e *Engine) persistSyncResponse(update *mgmProto.SyncResponse) {
e.syncRespMux.RLock()
defer e.syncRespMux.RUnlock()
if e.syncStore == nil {
return
}
if err := e.syncStore.Set(update); err != nil {
log.Errorf("failed to persist sync response: %v", err)
return
}
log.Debugf("sync response persisted with serial %d", update.GetNetworkMap().GetSerial())
}
func (e *Engine) handleRelayUpdate(update *mgmProto.RelayConfig) error {
if update != nil {
// when we receive token we expect valid address list too

View File

@@ -9,6 +9,7 @@ import (
"net/url"
"runtime"
"slices"
"strings"
"sync"
"sync/atomic"
"time"
@@ -700,6 +701,15 @@ func resolveURLsToIPs(urls []string) []net.IP {
// updateRouteSelectorFromManagement updates the route selector based on the isSelected status from the management server
func (m *DefaultManager) updateRouteSelectorFromManagement(clientRoutes route.HAMap) {
m.mirrorV6ExitPairSelections(clientRoutes)
// An explicit user "deselect all" must not be overridden by management auto-apply.
// Auto-applying an exit node here would call SelectRoutes, which clears the
// deselect-all flag and re-enables every route the user turned off.
if m.routeSelector.IsDeselectAll() {
return
}
exitNodeInfo := m.collectExitNodeInfo(clientRoutes)
if len(exitNodeInfo.allIDs) == 0 {
return
@@ -709,6 +719,24 @@ func (m *DefaultManager) updateRouteSelectorFromManagement(clientRoutes route.HA
m.logExitNodeUpdate(exitNodeInfo)
}
// mirrorV6ExitPairSelections keeps every synthesized "-v6" exit route's selection
// consistent with its v4 base. The v4/v6 exit pair is a single toggle, so the v6
// entry always follows the base: deselecting the v4 exit node also drops its ::/0
// pair, and any stale (orphaned) explicit selection on the v6 entry is reset. This
// runs before selection is read so both collectExitNodeInfo and FilterSelectedExitNodes
// see consistent state, including pairs loaded from persisted selector state.
func (m *DefaultManager) mirrorV6ExitPairSelections(clientRoutes route.HAMap) {
routesByNetID := make(map[route.NetID][]*route.Route, len(clientRoutes))
for haID, routes := range clientRoutes {
routesByNetID[haID.NetID()] = routes
}
for v6ID := range route.V6ExitMergeSet(routesByNetID) {
baseID := route.NetID(strings.TrimSuffix(string(v6ID), route.V6ExitSuffix))
m.routeSelector.SyncPairedSelection(baseID, v6ID)
}
}
type exitNodeInfo struct {
allIDs []route.NetID
selectedByManagement []route.NetID

View File

@@ -0,0 +1,47 @@
package routemanager
import (
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/route"
)
// TestUpdateRouteSelectorFromManagement_MirrorsV6ExitPair reproduces the bug seen
// in netbird-engine.log: persisted selector state has the v4 exit node deselected
// but its synthesized "-v6" pair explicitly selected (orphaned), so the ::/0 route
// leaked onto the tunnel. The management update must mirror the v4 deselect onto the
// v6 pair so FilterSelectedExitNodes drops it.
func TestUpdateRouteSelectorFromManagement_MirrorsV6ExitPair(t *testing.T) {
const (
v4ID = route.NetID("Exit Node (raspberrypi)")
v6ID = route.NetID("Exit Node (raspberrypi)-v6")
)
all := []route.NetID{v4ID, v6ID}
rs := routeselector.NewRouteSelector()
// Orphan the v6 selection: select the pair, then deselect only the v4 base.
require.NoError(t, rs.SelectRoutes([]route.NetID{v4ID, v6ID}, true, all))
require.NoError(t, rs.DeselectRoutes([]route.NetID{v4ID}, all))
require.True(t, rs.IsSelected(v6ID), "precondition: orphaned v6 selection survives v4 deselect")
m := &DefaultManager{routeSelector: rs}
v4Route := &route.Route{NetID: v4ID, Network: netip.MustParsePrefix("0.0.0.0/0")}
v6Route := &route.Route{NetID: v6ID, Network: netip.MustParsePrefix("::/0")}
clientRoutes := route.HAMap{
"Exit Node (raspberrypi)|0.0.0.0/0": {v4Route},
"Exit Node (raspberrypi)-v6|::/0": {v6Route},
}
m.updateRouteSelectorFromManagement(clientRoutes)
assert.False(t, rs.IsSelected(v6ID), "v6 pair must follow the v4 base deselect after the management update")
filtered := rs.FilterSelectedExitNodes(clientRoutes)
assert.Empty(t, filtered, "deselected v4 exit node must not leak its ::/0 pair onto the tunnel")
}

View File

@@ -0,0 +1,71 @@
package routemanager
import (
"net/netip"
"testing"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/route"
)
func exitNodeRoutes(netID route.NetID, skipAutoApply bool) route.HAMap {
haID := route.HAUniqueID(string(netID) + "|0.0.0.0/0")
return route.HAMap{
haID: []*route.Route{
{
ID: "r-" + route.ID(netID),
NetID: netID,
Network: netip.MustParsePrefix("0.0.0.0/0"),
NetworkType: route.IPv4Network,
Enabled: true,
SkipAutoApply: skipAutoApply,
},
},
}
}
func TestUpdateRouteSelectorFromManagement(t *testing.T) {
t.Run("management auto-apply selects exit node without user selection", func(t *testing.T) {
m := &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
routes := exitNodeRoutes("exit1", false)
m.updateRouteSelectorFromManagement(routes)
require.True(t, m.routeSelector.IsSelected("exit1"), "auto-apply exit node should be selected")
require.Len(t, m.routeSelector.FilterSelectedExitNodes(routes), 1, "selected exit node should pass the filter")
})
t.Run("management SkipAutoApply leaves exit node deselected", func(t *testing.T) {
m := &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
routes := exitNodeRoutes("exit1", true)
m.updateRouteSelectorFromManagement(routes)
require.False(t, m.routeSelector.IsSelected("exit1"), "SkipAutoApply exit node should not be selected")
require.Empty(t, m.routeSelector.FilterSelectedExitNodes(routes), "deselected exit node should be filtered out")
})
t.Run("user selection is not overridden by management", func(t *testing.T) {
m := &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
require.NoError(t, m.routeSelector.SelectRoutes([]route.NetID{"exit1"}, true, []route.NetID{"exit1"}))
routes := exitNodeRoutes("exit1", true)
m.updateRouteSelectorFromManagement(routes)
require.True(t, m.routeSelector.IsSelected("exit1"), "explicit user selection must survive a management sync that wants to skip auto-apply")
require.Len(t, m.routeSelector.FilterSelectedExitNodes(routes), 1, "user-selected exit node should pass the filter")
})
t.Run("deselect-all is preserved across a management sync", func(t *testing.T) {
m := &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
m.routeSelector.DeselectAllRoutes()
routes := exitNodeRoutes("exit1", false)
m.updateRouteSelectorFromManagement(routes)
require.True(t, m.routeSelector.IsDeselectAll(), "an explicit deselect-all must not be cleared by management auto-apply")
require.Empty(t, m.routeSelector.FilterSelectedExitNodes(routes), "no routes should be selected while deselect-all is set")
})
}

View File

@@ -4,7 +4,6 @@ import (
"encoding/json"
"fmt"
"slices"
"strings"
"sync"
"github.com/hashicorp/go-multierror"
@@ -116,6 +115,14 @@ func (rs *RouteSelector) DeselectAllRoutes() {
clear(rs.selectedRoutes)
}
// IsDeselectAll reports whether the user has explicitly deselected all routes.
func (rs *RouteSelector) IsDeselectAll() bool {
rs.mu.RLock()
defer rs.mu.RUnlock()
return rs.deselectAll
}
// IsSelected checks if a specific route is selected.
func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
rs.mu.RLock()
@@ -124,6 +131,33 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
return rs.isSelectedLocked(routeID)
}
// SyncPairedSelection forces pairedID's explicit selection state to match baseID's,
// so a synthesized "-v6" exit route always follows its v4 base: selecting or
// deselecting the v4 exit node governs the ::/0 pair, and any stale (orphaned)
// explicit state on the v6 entry is reset. The v4/v6 exit pair is treated as a single
// toggle, so the v6 entry carries no independent selection of its own.
func (rs *RouteSelector) SyncPairedSelection(baseID, pairedID route.NetID) {
rs.mu.Lock()
defer rs.mu.Unlock()
if rs.deselectAll {
return
}
_, baseSelected := rs.selectedRoutes[baseID]
_, baseDeselected := rs.deselectedRoutes[baseID]
delete(rs.selectedRoutes, pairedID)
delete(rs.deselectedRoutes, pairedID)
switch {
case baseSelected:
rs.selectedRoutes[pairedID] = struct{}{}
case baseDeselected:
rs.deselectedRoutes[pairedID] = struct{}{}
}
}
// FilterSelected removes unselected routes from the provided map.
func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
rs.mu.RLock()
@@ -143,14 +177,13 @@ func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
}
// HasUserSelectionForRoute returns true if the user has explicitly selected or deselected this route.
// Intended for exit-node code paths: a v6 exit-node pair (e.g. "MyExit-v6") with no explicit state of
// its own inherits its v4 base's state, so legacy persisted selections that predate v6 pairing
// transparently apply to the synthesized v6 entry.
// The lookup is literal; v4/v6 exit pairs are kept consistent at write time via SyncPairedSelection,
// so a synthesized "-v6" entry carries the same explicit state as its v4 base.
func (rs *RouteSelector) HasUserSelectionForRoute(routeID route.NetID) bool {
rs.mu.RLock()
defer rs.mu.RUnlock()
return rs.hasUserSelectionForRouteLocked(rs.effectiveNetID(routeID))
return rs.hasUserSelectionForRouteLocked(routeID)
}
func (rs *RouteSelector) FilterSelectedExitNodes(routes route.HAMap) route.HAMap {
@@ -179,83 +212,6 @@ func (rs *RouteSelector) FilterSelectedExitNodes(routes route.HAMap) route.HAMap
return filtered
}
// effectiveNetID returns the v4 base for a "-v6" exit pair entry that has no explicit
// state of its own, so selections made on the v4 entry govern the v6 entry automatically.
// Only call this from exit-node-specific code paths: applying it to a non-exit "-v6" route
// would make it inherit unrelated v4 state. Must be called with rs.mu held.
func (rs *RouteSelector) effectiveNetID(id route.NetID) route.NetID {
name := string(id)
if !strings.HasSuffix(name, route.V6ExitSuffix) {
return id
}
if _, ok := rs.selectedRoutes[id]; ok {
return id
}
if _, ok := rs.deselectedRoutes[id]; ok {
return id
}
return route.NetID(strings.TrimSuffix(name, route.V6ExitSuffix))
}
func (rs *RouteSelector) isSelectedLocked(routeID route.NetID) bool {
if rs.deselectAll {
return false
}
_, deselected := rs.deselectedRoutes[routeID]
return !deselected
}
func (rs *RouteSelector) isDeselectedLocked(netID route.NetID) bool {
if rs.deselectAll {
return true
}
_, deselected := rs.deselectedRoutes[netID]
return deselected
}
func (rs *RouteSelector) hasUserSelectionForRouteLocked(routeID route.NetID) bool {
_, selected := rs.selectedRoutes[routeID]
_, deselected := rs.deselectedRoutes[routeID]
return selected || deselected
}
func isExitNode(rt []*route.Route) bool {
return len(rt) > 0 && (route.IsV4DefaultRoute(rt[0].Network) || route.IsV6DefaultRoute(rt[0].Network))
}
func (rs *RouteSelector) applyExitNodeFilter(
id route.HAUniqueID,
netID route.NetID,
rt []*route.Route,
out route.HAMap,
) {
// Exit-node path: apply the v4/v6 pair mirror so a deselect on the v4 base also
// drops the synthesized v6 entry that lacks its own explicit state.
effective := rs.effectiveNetID(netID)
if rs.hasUserSelectionForRouteLocked(effective) {
if rs.isSelectedLocked(effective) {
out[id] = rt
}
return
}
// no explicit selection for this route: defer to management's SkipAutoApply flag
sel := collectSelected(rt)
if len(sel) > 0 {
out[id] = sel
}
}
func collectSelected(rt []*route.Route) []*route.Route {
var sel []*route.Route
for _, r := range rt {
if !r.SkipAutoApply {
sel = append(sel, r)
}
}
return sel
}
// MarshalJSON implements the json.Marshaler interface
func (rs *RouteSelector) MarshalJSON() ([]byte, error) {
rs.mu.RLock()
@@ -309,3 +265,59 @@ func (rs *RouteSelector) UnmarshalJSON(data []byte) error {
return nil
}
func (rs *RouteSelector) isSelectedLocked(routeID route.NetID) bool {
if rs.deselectAll {
return false
}
_, deselected := rs.deselectedRoutes[routeID]
return !deselected
}
func (rs *RouteSelector) isDeselectedLocked(netID route.NetID) bool {
if rs.deselectAll {
return true
}
_, deselected := rs.deselectedRoutes[netID]
return deselected
}
func (rs *RouteSelector) hasUserSelectionForRouteLocked(routeID route.NetID) bool {
_, selected := rs.selectedRoutes[routeID]
_, deselected := rs.deselectedRoutes[routeID]
return selected || deselected
}
func (rs *RouteSelector) applyExitNodeFilter(
id route.HAUniqueID,
netID route.NetID,
rt []*route.Route,
out route.HAMap,
) {
if rs.hasUserSelectionForRouteLocked(netID) {
if rs.isSelectedLocked(netID) {
out[id] = rt
}
return
}
// no explicit selection for this route: defer to management's SkipAutoApply flag
sel := collectSelected(rt)
if len(sel) > 0 {
out[id] = sel
}
}
func isExitNode(rt []*route.Route) bool {
return len(rt) > 0 && (route.IsV4DefaultRoute(rt[0].Network) || route.IsV6DefaultRoute(rt[0].Network))
}
func collectSelected(rt []*route.Route) []*route.Route {
var sel []*route.Route
for _, r := range rt {
if !r.SkipAutoApply {
sel = append(sel, r)
}
}
return sel
}

View File

@@ -330,39 +330,73 @@ func TestRouteSelector_FilterSelectedExitNodes(t *testing.T) {
assert.Len(t, filtered, 0) // No routes should be selected
}
// TestRouteSelector_V6ExitPairInherits covers the v4/v6 exit-node pair selection
// mirror. The mirror is scoped to exit-node code paths: HasUserSelectionForRoute
// and FilterSelectedExitNodes resolve a "-v6" entry without explicit state to its
// v4 base, so legacy persisted selections that predate v6 pairing transparently
// apply to the synthesized v6 entry. General lookups (IsSelected, FilterSelected)
// stay literal so unrelated routes named "*-v6" don't inherit unrelated state.
func TestRouteSelector_V6ExitPairInherits(t *testing.T) {
// TestRouteSelector_V6ExitPairSync covers SyncPairedSelection, which keeps a v4
// exit node and its synthesized "-v6" counterpart consistent. The selector itself
// is literal and never infers a v6 entry's state from its v4 base; callers that know
// the pairing (exit-node code paths) call SyncPairedSelection to force the v6 entry
// to follow the base, treating the pair as a single toggle.
func TestRouteSelector_V6ExitPairSync(t *testing.T) {
all := []route.NetID{"exit1", "exit1-v6", "exit2", "exit2-v6", "corp", "corp-v6"}
t.Run("HasUserSelectionForRoute mirrors deselected v4 base", func(t *testing.T) {
t.Run("selector lookups stay literal without sync", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
assert.True(t, rs.HasUserSelectionForRoute("exit1-v6"), "v6 pair sees v4 base's user selection")
// The selector does not pair-resolve: the v6 entry is independent until synced.
assert.False(t, rs.HasUserSelectionForRoute("exit1-v6"), "v6 entry has no state of its own")
assert.True(t, rs.IsSelected("exit1-v6"), "unsynced v6 entry stays selected by default")
// unrelated v6 with no v4 base touched is unaffected
assert.False(t, rs.HasUserSelectionForRoute("exit2-v6"))
// A route literally named "exit1-something" must never pair-resolve either.
assert.False(t, rs.HasUserSelectionForRoute("exit1-something"))
})
t.Run("IsSelected stays literal for non-exit lookups", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"corp"}, all))
// A non-exit route literally named "corp-v6" must not inherit "corp"'s state
// via the mirror; the mirror only applies in exit-node code paths.
assert.False(t, rs.IsSelected("corp"))
assert.True(t, rs.IsSelected("corp-v6"), "non-exit *-v6 routes must not inherit unrelated v4 state")
})
t.Run("explicit v6 state overrides v4 base in filter", func(t *testing.T) {
t.Run("sync mirrors deselected v4 base onto v6", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
rs.SyncPairedSelection("exit1", "exit1-v6")
assert.False(t, rs.IsSelected("exit1"))
assert.False(t, rs.IsSelected("exit1-v6"), "v6 pair follows v4 base deselect")
assert.True(t, rs.HasUserSelectionForRoute("exit1-v6"), "v6 carries explicit deselect after sync")
})
t.Run("sync mirrors selected v4 base onto v6", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.SelectRoutes([]route.NetID{"exit1"}, false, all))
rs.SyncPairedSelection("exit1", "exit1-v6")
assert.True(t, rs.IsSelected("exit1"))
assert.True(t, rs.IsSelected("exit1-v6"), "v6 pair follows v4 base select")
})
t.Run("sync clears v6 state when base has no explicit selection", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.SelectRoutes([]route.NetID{"exit1-v6"}, true, all))
require.True(t, rs.HasUserSelectionForRoute("exit1-v6"))
rs.SyncPairedSelection("exit1", "exit1-v6")
assert.False(t, rs.HasUserSelectionForRoute("exit1-v6"),
"v6 explicit state is cleared so it follows management like its base")
})
// Regression for the observed bug (see netbird-engine.log): persisted state has
// the v4 base deselected but the v6 sibling explicitly selected (orphaned). The
// sync must reset the orphan so the ::/0 route does not leak onto the tunnel.
t.Run("sync clears orphaned explicit v6 selection on deselected base", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
// Prior state: both explicitly selected, then only the v4 base deselected,
// leaving the v6 entry as a stale explicit selection.
require.NoError(t, rs.SelectRoutes([]route.NetID{"exit1", "exit1-v6"}, true, all))
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
require.True(t, rs.IsSelected("exit1-v6"), "precondition: orphaned v6 selection")
rs.SyncPairedSelection("exit1", "exit1-v6")
assert.False(t, rs.IsSelected("exit1-v6"), "orphaned v6 selection reset to follow v4 deselect")
v4Route := &route.Route{NetID: "exit1", Network: netip.MustParsePrefix("0.0.0.0/0")}
v6Route := &route.Route{NetID: "exit1-v6", Network: netip.MustParsePrefix("::/0")}
@@ -370,23 +404,14 @@ func TestRouteSelector_V6ExitPairInherits(t *testing.T) {
"exit1|0.0.0.0/0": {v4Route},
"exit1-v6|::/0": {v6Route},
}
filtered := rs.FilterSelectedExitNodes(routes)
assert.NotContains(t, filtered, route.HAUniqueID("exit1|0.0.0.0/0"))
assert.Contains(t, filtered, route.HAUniqueID("exit1-v6|::/0"), "explicit v6 select wins over v4 base")
assert.Empty(t, filtered, "deselecting v4 base must drop the v6 pair even if it was explicitly selected before")
})
t.Run("non-v6-suffix routes unaffected", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
// A route literally named "exit1-something" must not pair-resolve.
assert.False(t, rs.HasUserSelectionForRoute("exit1-something"))
})
t.Run("filter v6 paired with deselected v4 base", func(t *testing.T) {
t.Run("filter drops synced v6 pair of deselected v4 base", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
rs.SyncPairedSelection("exit1", "exit1-v6")
v4Route := &route.Route{NetID: "exit1", Network: netip.MustParsePrefix("0.0.0.0/0")}
v6Route := &route.Route{NetID: "exit1-v6", Network: netip.MustParsePrefix("::/0")}
@@ -399,6 +424,15 @@ func TestRouteSelector_V6ExitPairInherits(t *testing.T) {
assert.Empty(t, filtered, "deselecting v4 base must also drop the v6 pair")
})
t.Run("deselectAll makes sync a no-op", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
rs.DeselectAllRoutes()
rs.SyncPairedSelection("exit1", "exit1-v6")
assert.False(t, rs.HasUserSelectionForRoute("exit1-v6"), "sync must not write explicit state under deselectAll")
})
t.Run("non-exit *-v6 routes pass through FilterSelectedExitNodes", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"corp"}, all))

View File

@@ -54,6 +54,7 @@ type selectRoute struct {
Network netip.Prefix
Domains domain.List
Selected bool
Status string
extraNetworks []netip.Prefix
}
@@ -377,9 +378,57 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) {
routes := buildSelectRoutes(routesMap, routeSelector.IsSelected, v6ExitMerged)
resolvedDomains := c.recorder.GetResolvedDomainsStates()
// Compute each route's connection status in the core (mirroring the Android
// bridge), so the UI doesn't have to infer it by string-matching the joined
// Network value against peer routes. For a merged exit node the status reflects
// whichever of the v4/v6 prefixes is served by a connected peer; for dynamic
// (DNS) routes the peer route key is the domain pattern (see dynamic.Route.String).
connectedRoutes := c.connectedRouteSet()
for _, r := range routes {
r.Status = routeStatus(r, connectedRoutes)
}
return prepareRouteSelectionDetails(routes, resolvedDomains), nil
}
// connectedRouteSet returns the set of route keys (as strings) currently served by a
// connected peer, gathered across all connected peers' route tables. The keys match
// what the route manager records: a prefix string for static routes (e.g. "0.0.0.0/0")
// and the domain pattern for dynamic routes (e.g. "*.example.com").
func (c *Client) connectedRouteSet() map[string]struct{} {
connected := map[string]struct{}{}
for _, p := range c.recorder.GetFullStatus().Peers {
if p.ConnStatus != peer.StatusConnected {
continue
}
for r := range p.GetRoutes() {
connected[r] = struct{}{}
}
}
return connected
}
// routeStatus reports "Connected" if any of the route's keys is served by a connected
// peer: the primary Network prefix, an extra v6 network of a merged exit node, or the
// domain pattern for a dynamic DNS route. Otherwise "Idle".
func routeStatus(r *selectRoute, connectedRoutes map[string]struct{}) string {
keys := make([]string, 0, 1+len(r.extraNetworks))
if len(r.Domains) > 0 {
keys = append(keys, r.Domains.SafeString())
} else {
keys = append(keys, r.Network.String())
}
for _, extra := range r.extraNetworks {
keys = append(keys, extra.String())
}
for _, k := range keys {
if _, ok := connectedRoutes[k]; ok {
return peer.StatusConnected.String()
}
}
return peer.StatusIdle.String()
}
func buildSelectRoutes(routesMap map[route.NetID][]*route.Route, isSelected func(route.NetID) bool, v6Merged map[route.NetID]struct{}) []*selectRoute {
var routes []*selectRoute
for id, rt := range routesMap {
@@ -462,6 +511,7 @@ func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[dom
Network: netStr,
Domains: &domainDetails,
Selected: r.Selected,
Status: r.Status,
})
}

View File

@@ -20,6 +20,7 @@ type RoutesSelectionInfo struct {
Network string
Domains *DomainDetails
Selected bool
Status string
}
type DomainCollection interface {

View File

@@ -99,6 +99,9 @@ func addFields(entry *logrus.Entry) {
if ctxAccountID, ok := entry.Context.Value(context.AccountIDKey).(string); ok {
entry.Data[context.AccountIDKey] = ctxAccountID
}
if ctxUserAgent, ok := entry.Context.Value(context.UserAgentKey).(string); ok {
entry.Data[context.UserAgentKey] = ctxUserAgent
}
if ctxInitiatorID, ok := entry.Context.Value(context.UserIDKey).(string); ok {
entry.Data[context.UserIDKey] = ctxInitiatorID
}

6
go.mod
View File

@@ -2,6 +2,8 @@ module github.com/netbirdio/netbird
go 1.25.5
toolchain go1.25.11
require (
cunicu.li/go-rosenpass v0.5.42
github.com/cenkalti/backoff/v4 v4.3.0
@@ -54,6 +56,7 @@ require (
github.com/fsnotify/fsnotify v1.9.0
github.com/gliderlabs/ssh v0.3.8
github.com/go-jose/go-jose/v4 v4.1.4
github.com/goccy/go-yaml v1.18.0
github.com/godbus/dbus/v5 v5.1.0
github.com/golang-jwt/jwt/v5 v5.3.1
github.com/golang/mock v1.6.0
@@ -211,10 +214,9 @@ require (
github.com/go-viper/mapstructure/v2 v2.5.0 // indirect
github.com/go-webauthn/webauthn v0.16.4 // indirect
github.com/go-webauthn/x v0.2.3 // indirect
github.com/goccy/go-yaml v1.18.0 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang-jwt/jwt/v4 v4.5.2 // indirect
github.com/google/btree v1.1.2 // indirect
github.com/google/btree v1.1.3 // indirect
github.com/google/go-querystring v1.1.0 // indirect
github.com/google/go-tpm v0.9.8 // indirect
github.com/google/s2a-go v0.1.9 // indirect

4
go.sum
View File

@@ -275,8 +275,8 @@ github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiu
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU=
github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=
github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=

View File

@@ -488,6 +488,195 @@ func TestUpdate_AllowsPortChange(t *testing.T) {
assert.Equal(t, uint16(54321), updated.ListenPort, "explicit port change should be applied")
}
func TestUpdate_PreservesPortWhenCustomPortsNotSupported(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(false))
ctx := context.Background()
existing := seedService(t, testStore, "tcp-svc", "tcp", testCluster, testCluster, 12345)
updated := &rpservice.Service{
ID: existing.ID,
AccountID: testAccountID,
Name: "tcp-svc-renamed",
Mode: "tcp",
Domain: testCluster,
ProxyCluster: testCluster,
ListenPort: 0,
Enabled: true,
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true},
},
}
_, err := mgr.persistServiceUpdate(ctx, testAccountID, updated)
require.NoError(t, err, "update must not be rejected by the custom-port capability check")
assert.Equal(t, uint16(12345), updated.ListenPort, "existing listen port should be preserved on unsupported cluster")
}
func TestUpdate_PreservesPortWhenCustomPortsUnknown(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, nil)
ctx := context.Background()
existing := seedService(t, testStore, "tcp-svc", "tcp", testCluster, testCluster, 12345)
updated := &rpservice.Service{
ID: existing.ID,
AccountID: testAccountID,
Name: "tcp-svc-renamed",
Mode: "tcp",
Domain: testCluster,
ProxyCluster: testCluster,
ListenPort: 0,
Enabled: true,
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true},
},
}
_, err := mgr.persistServiceUpdate(ctx, testAccountID, updated)
require.NoError(t, err, "update must not be rejected when cluster capability is unknown")
assert.Equal(t, uint16(12345), updated.ListenPort, "existing listen port should be preserved when capability is unknown")
}
func TestUpdate_RejectsPortChangeWhenCustomPortsNotSupported(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(false))
ctx := context.Background()
existing := seedService(t, testStore, "tcp-svc", "tcp", testCluster, testCluster, 12345)
updated := &rpservice.Service{
ID: existing.ID,
AccountID: testAccountID,
Name: "tcp-svc",
Mode: "tcp",
Domain: testCluster,
ProxyCluster: testCluster,
ListenPort: 54321,
Enabled: true,
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true},
},
}
_, err := mgr.persistServiceUpdate(ctx, testAccountID, updated)
require.Error(t, err, "explicit port change on update must be rejected on unsupported clusters")
assert.Contains(t, err.Error(), "custom ports not supported on target cluster")
}
func TestUpdate_TLSPortChangeAllowedWhenNotSupported(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(false))
ctx := context.Background()
existing := seedService(t, testStore, "tls-svc", "tls", "app.example.com", testCluster, 443)
updated := &rpservice.Service{
ID: existing.ID,
AccountID: testAccountID,
Name: "tls-svc",
Mode: "tls",
Domain: "app.example.com",
ProxyCluster: testCluster,
ListenPort: 9999,
Enabled: true,
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8443, Enabled: true},
},
}
_, err := mgr.persistServiceUpdate(ctx, testAccountID, updated)
require.NoError(t, err, "TLS port change uses SNI routing and is exempt from the custom-port check")
assert.Equal(t, uint16(9999), updated.ListenPort, "TLS port change should be applied")
}
func TestValidateL4PortDiffOnClusterDiff(t *testing.T) {
tests := []struct {
name string
mode string
customPorts *bool
newPort uint16
oldPort uint16
wantErr bool
}{
{"tcp port change unsupported", "tcp", boolPtr(false), 54321, 12345, true},
{"tcp port change unknown capability", "tcp", nil, 54321, 12345, true},
{"udp port change unsupported", "udp", boolPtr(false), 54321, 12345, true},
{"tcp first port assignment unsupported", "tcp", boolPtr(false), 54321, 0, true},
{"tcp port change supported", "tcp", boolPtr(true), 54321, 12345, false},
{"tcp port unchanged unsupported", "tcp", boolPtr(false), 12345, 12345, false},
{"tcp zero port unsupported", "tcp", boolPtr(false), 0, 12345, false},
{"tls port change unsupported", "tls", boolPtr(false), 9999, 443, false},
{"http mode ignored", "http", boolPtr(false), 54321, 12345, false},
{"empty mode ignored", "", boolPtr(false), 54321, 12345, false},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
newSvc := &rpservice.Service{Mode: tc.mode, ListenPort: tc.newPort, ProxyCluster: testCluster}
oldSvc := &rpservice.Service{Mode: tc.mode, ListenPort: tc.oldPort, ProxyCluster: testCluster}
err := validateL4PortDiffOnClusterDiff(tc.customPorts, newSvc, oldSvc)
if tc.wantErr {
assert.Error(t, err, "port diff should be rejected for %s", tc.name)
} else {
assert.NoError(t, err, "port diff should be allowed for %s", tc.name)
}
})
}
}
func TestUpdate_PortConflictRejected(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
seedService(t, testStore, "tcp-a", "tcp", "tcp-a."+testCluster, testCluster, 5432)
svcB := seedService(t, testStore, "tcp-b", "tcp", "tcp-b."+testCluster, testCluster, 6543)
updated := &rpservice.Service{
ID: svcB.ID,
AccountID: testAccountID,
Name: "tcp-b",
Mode: "tcp",
Domain: "tcp-b." + testCluster,
ProxyCluster: testCluster,
ListenPort: 5432,
Enabled: true,
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true},
},
}
_, err := mgr.persistServiceUpdate(ctx, testAccountID, updated)
require.Error(t, err, "updating to a port held by another service should be rejected")
assert.Contains(t, err.Error(), "already in use")
}
func TestUpdate_AutoAssignsWhenNoPort(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(false))
ctx := context.Background()
existing := seedService(t, testStore, "tcp-svc", "tcp", testCluster, testCluster, 0)
updated := &rpservice.Service{
ID: existing.ID,
AccountID: testAccountID,
Name: "tcp-svc",
Mode: "tcp",
Domain: testCluster,
ProxyCluster: testCluster,
ListenPort: 0,
Enabled: true,
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true},
},
}
_, err := mgr.persistServiceUpdate(ctx, testAccountID, updated)
require.NoError(t, err)
assert.True(t, updated.ListenPort >= autoAssignPortMin && updated.ListenPort <= autoAssignPortMax,
"auto-assigned port %d should be in range [%d, %d]", updated.ListenPort, autoAssignPortMin, autoAssignPortMax)
assert.True(t, updated.PortAutoAssigned, "PortAutoAssigned should be set when update triggers auto-assignment")
}
func TestCreateServiceFromPeer_TCP(t *testing.T) {
mgr, _, _ := setupL4Test(t, boolPtr(false))
ctx := context.Background()

View File

@@ -338,7 +338,7 @@ func (m *Manager) persistNewService(ctx context.Context, accountID string, svc *
}
}
if err := m.ensureL4Port(ctx, transaction, svc, customPorts); err != nil {
if err := m.ensureL4Port(ctx, transaction, svc, customPorts, false); err != nil {
return err
}
@@ -367,11 +367,11 @@ func (m *Manager) clusterCustomPorts(ctx context.Context, svc *service.Service)
// ensureL4Port auto-assigns a listen port when needed and validates cluster support.
// customPorts must be pre-computed via clusterCustomPorts before entering a transaction.
func (m *Manager) ensureL4Port(ctx context.Context, tx store.Store, svc *service.Service, customPorts *bool) error {
func (m *Manager) ensureL4Port(ctx context.Context, tx store.Store, svc *service.Service, customPorts *bool, serviceUpdate bool) error {
if !service.IsL4Protocol(svc.Mode) {
return nil
}
if service.IsPortBasedProtocol(svc.Mode) && svc.ListenPort > 0 && (customPorts == nil || !*customPorts) {
if service.IsPortBasedProtocol(svc.Mode) && svc.ListenPort > 0 && !serviceUpdate && (customPorts == nil || !*customPorts) {
if svc.Source != service.SourceEphemeral {
return status.Errorf(status.InvalidArgument, "custom ports not supported on cluster %s", svc.ProxyCluster)
}
@@ -465,7 +465,7 @@ func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, pee
return err
}
if err := m.ensureL4Port(ctx, transaction, svc, customPorts); err != nil {
if err := m.ensureL4Port(ctx, transaction, svc, customPorts, false); err != nil {
return err
}
@@ -651,12 +651,22 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St
m.preserveListenPort(service, existingService)
updateInfo.serviceEnabledChanged = existingService.Enabled != service.Enabled
if err := m.ensureL4Port(ctx, transaction, service, customPorts); err != nil {
// if the service is being updated, and we decide in the future to allow mode update,
// we should reconsider the currently assigned port if not 0 for clusters that don't support custom ports
if err := validateL4PortDiffOnClusterDiff(customPorts, service, existingService); err != nil {
return err
}
if err := m.ensureL4Port(ctx, transaction, service, customPorts, true); err != nil {
return err
}
// we can try carrying the previous service port into a new cluster, if this becomes a problem for multiple users,
// we should reconsider adding another check
if err := m.checkPortConflict(ctx, transaction, service); err != nil {
return err
}
if err := transaction.UpdateService(ctx, service); err != nil {
return fmt.Errorf("update service: %w", err)
}
@@ -664,6 +674,21 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St
return nil
}
// validateL4PortDiffOnClusterDiff checks if custom L4 ports are configured and validates port changes across clusters.
// It ensures no port changes if custom ports are unsupported for a given cluster and protocol mode.
// Returns an error if validation fails, otherwise returns nil.
func validateL4PortDiffOnClusterDiff(customPorts *bool, newSVC, oldSVC *service.Service) error {
if !service.IsPortBasedProtocol(newSVC.Mode) || (customPorts != nil && *customPorts) {
return nil
}
if newSVC.ListenPort != 0 && newSVC.ListenPort != oldSVC.ListenPort {
return status.Errorf(status.InvalidArgument, "custom ports not supported on target cluster %s", newSVC.ProxyCluster)
}
return nil
}
// handleDomainChange validates the new domain is free inside the transaction
// and applies the pre-resolved cluster (computed outside the tx by
// resolveEffectiveCluster). It must NOT call clusterDeriver here: that talks

View File

@@ -8,6 +8,8 @@ import (
"strings"
"time"
"github.com/hashicorp/go-version"
nbversion "github.com/netbirdio/netbird/version"
log "github.com/sirupsen/logrus"
goproto "google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
@@ -28,6 +30,23 @@ import (
"github.com/netbirdio/netbird/shared/sshauth"
)
const (
// deprecatedRemotePeersVersion is the version of Netbird that introduced the NetworkMap.RemotePeers field, deprecated in favor of RemotePeers.
deprecatedRemotePeersVersion = "0.29.3"
)
// precomputedDeprecatedRemotePeersConstraint is the parsed ">= 0.29.3" constraint,
// built once at init since the bound is a compile-time constant.
var precomputedDeprecatedRemotePeersConstraint version.Constraints
func init() {
constraint, err := version.NewConstraint(">= " + deprecatedRemotePeersVersion)
if err != nil {
panic("parse deprecated remote peers version constraint: " + err.Error())
}
precomputedDeprecatedRemotePeersConstraint = constraint
}
func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig {
if config == nil {
return nil
@@ -155,7 +174,11 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
remotePeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers))
remotePeers = appendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName, includeIPv6)
response.RemotePeers = remotePeers
if !shouldSkipSendingDeprecatedRemotePeers(peer.Meta.WtVersion) {
response.RemotePeers = remotePeers
}
response.NetworkMap.RemotePeers = remotePeers
response.RemotePeersIsEmpty = len(remotePeers) == 0
response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty
@@ -246,6 +269,19 @@ func buildAuthorizedUsersProto(ctx context.Context, authorizedUsers map[string]m
return hashedUsers, machineUsers
}
func shouldSkipSendingDeprecatedRemotePeers(peerVersion string) bool {
if nbversion.IsDevelopmentVersion(peerVersion) {
return true
}
peerNBVersion, err := version.NewVersion(peerVersion)
if err != nil {
return false
}
return precomputedDeprecatedRemotePeersConstraint.Check(peerNBVersion)
}
func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string, includeIPv6 bool) []*proto.RemotePeerConfig {
for _, rPeer := range peers {
allowedIPs := []string{rPeer.IP.String() + "/32"}
@@ -363,7 +399,6 @@ func toProtocolFirewallRules(rules []*types.FirewallRule, includeIPv6, useSource
return result
}
// populateSourcePrefixes sets SourcePrefixes on fwRule and returns any
// additional rules needed (e.g. a v6 wildcard clone when the peer IP is unspecified).
func populateSourcePrefixes(fwRule *proto.FirewallRule, rule *types.FirewallRule, includeIPv6 bool) []*proto.FirewallRule {

View File

@@ -202,6 +202,42 @@ func TestBuildJWTConfig_Audiences(t *testing.T) {
}
}
// TestShouldSkipSendingDeprecatedRemotePeers covers the version gate that
// stops populating the deprecated top-level SyncResponse.RemotePeers field for
// peers new enough to read RemotePeers off the NetworkMap. Development builds
// are treated as latest and skip the field. The gate otherwise fails safe: a
// release version older than the boundary, or one that can't be parsed (empty,
// garbage, prereleases of the boundary) still receives the deprecated field so
// older/unknown clients keep working.
func TestShouldSkipSendingDeprecatedRemotePeers(t *testing.T) {
tests := []struct {
name string
peerVersion string
wantSkip bool
}{
{"exact boundary skips", "0.29.3", true},
{"newer patch skips", "0.29.4", true},
{"newer minor skips", "0.30.0", true},
{"newer major skips", "1.0.0", true},
{"v-prefixed newer skips", "v0.30.0", true},
{"development build skips", "development", true},
{"development build with commit skips", "development-abc123def456-dirty", true},
{"older patch keeps field", "0.29.2", false},
{"older minor keeps field", "0.28.0", false},
{"prerelease of boundary keeps field", "0.29.3-SNAPSHOT", false},
{"tagged dev prerelease keeps field", "v0.31.1-dev", false},
{"empty version keeps field", "", false},
{"garbage version keeps field", "not-a-version", false},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := shouldSkipSendingDeprecatedRemotePeers(tc.peerVersion)
assert.Equal(t, tc.wantSkip, got, "skip decision for peer version %q", tc.peerVersion)
})
}
}
// TestEncodeSessionExpiresAt pins the wire encoding the client's
// applySessionDeadline depends on:
//

View File

@@ -666,8 +666,10 @@ func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error)
case resp := <-conn.sendChan:
if err := conn.sendResponse(resp); err != nil {
errChan <- err
log.WithContext(conn.ctx).Tracef("Failed to send response to proxy %s: %v", conn.proxyID, err)
return
}
log.WithContext(conn.ctx).Tracef("Send response to proxy %s", conn.proxyID)
case <-conn.ctx.Done():
return
}

View File

@@ -12,6 +12,7 @@ const (
RoleKey = nbcontext.RoleKey
UserIDKey = nbcontext.UserIDKey
PeerIDKey = nbcontext.PeerIDKey
UserAgentKey = nbcontext.UserAgentKey
)
// RoleFromContext returns the role stored in ctx, or empty string and false if absent.

View File

@@ -21,6 +21,8 @@ const (
httpRequestCounterPrefix = "management.http.request.counter"
httpResponseCounterPrefix = "management.http.response.counter"
httpRequestDurationPrefix = "management.http.request.duration.ms"
RequestIDHeader = "X-Request-Id"
)
// WrappedResponseWriter is a wrapper for http.ResponseWriter that allows the
@@ -172,6 +174,10 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler {
reqID := xid.New().String()
//nolint
ctx = context.WithValue(ctx, nbContext.RequestIDKey, reqID)
//nolint
ctx = context.WithValue(ctx, nbContext.UserAgentKey, r.UserAgent())
rw.Header().Set(RequestIDHeader, reqID)
log.WithContext(ctx).Tracef("HTTP request %v: %v %v", reqID, r.Method, r.URL)

View File

@@ -557,7 +557,6 @@ func (c *NetworkMapComponents) getRoutingPeerRoutes(peerID string) (enabledRoute
return enabledRoutes, disabledRoutes
}
func (c *NetworkMapComponents) filterRoutesByGroups(routes []*route.Route, groupListMap LookupMap) []*route.Route {
var filteredRoutes []*route.Route
for _, r := range routes {
@@ -628,9 +627,14 @@ func (c *NetworkMapComponents) getDefaultPermit(r *route.Route, includeIPv6 bool
rules := []*RouteFirewallRule{&rule}
if includeIPv6 && r.IsDynamic() {
isDefaultV4 := r.Network.Addr().Is4() && r.Network.Bits() == 0
if includeIPv6 && (r.IsDynamic() || isDefaultV4) {
ruleV6 := rule
ruleV6.SourceRanges = []string{"::/0"}
if isDefaultV4 {
ruleV6.Destination = "::/0"
ruleV6.RouteID = r.ID + "-v6-default"
}
rules = append(rules, &ruleV6)
}

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"net"
"net/netip"
"slices"
"testing"
"time"
@@ -1029,6 +1030,48 @@ func TestComponents_RouteDefaultPermit(t *testing.T) {
assert.True(t, hasDefaultPermit, "route without ACG should have default permit rule with 0.0.0.0/0 source")
}
// TestComponents_ExitNodeDefaultPermitIPv6 verifies that a default exit node route
// (0.0.0.0/0) without AccessControlGroups also emits an IPv6 default permit rule
// (::/0 source and destination) for peers that support IPv6, mirroring the route
// the client installs. Without it, IPv6 traffic is routed to the exit node but
// dropped at the forward chain.
func TestComponents_ExitNodeDefaultPermitIPv6(t *testing.T) {
account, validatedPeers := scalableTestAccount(20, 2)
routingPeerID := "peer-5"
routingPeer := account.Peers[routingPeerID]
routingPeer.IPv6 = netip.MustParseAddr("fd00::5")
routingPeer.Meta.Capabilities = append(routingPeer.Meta.Capabilities, nbpeer.PeerCapabilityIPv6Overlay)
account.Routes["route-exit"] = &route.Route{
ID: "route-exit", Network: netip.MustParsePrefix("0.0.0.0/0"),
PeerID: routingPeerID, Peer: routingPeer.Key,
Enabled: true, Groups: []string{"group-all"}, PeerGroups: []string{"group-0"},
AccessControlGroups: []string{},
AccountID: "test-account",
}
nm := componentsNetworkMap(account, routingPeerID, validatedPeers)
require.NotNil(t, nm)
hasV4 := false
hasV6 := false
for _, rfr := range nm.RoutesFirewallRules {
switch rfr.Destination {
case "0.0.0.0/0":
if slices.Contains(rfr.SourceRanges, "0.0.0.0/0") {
hasV4 = true
}
case "::/0":
if slices.Contains(rfr.SourceRanges, "::/0") {
hasV6 = true
}
}
}
assert.True(t, hasV4, "exit node route should have an IPv4 default permit rule (0.0.0.0/0)")
assert.True(t, hasV6, "exit node route should have an IPv6 default permit rule (::/0)")
}
// ──────────────────────────────────────────────────────────────────────────────
// 15. MULTIPLE ROUTERS PER NETWORK
// ──────────────────────────────────────────────────────────────────────────────

View File

@@ -249,6 +249,7 @@ func runServer(cmd *cobra.Command, args []string) error {
Private: private,
MaxDialTimeout: maxDialTimeout,
MaxSessionIdleTimeout: maxSessionIdleTimeout,
MappingBatchWatchdog: envDurationOrDefault("NB_PROXY_MAPPING_BATCH_WATCHDOG", 0),
GeoDataDir: geoDataDir,
CrowdSecAPIURL: crowdsecAPIURL,
CrowdSecAPIKey: crowdsecAPIKey,

View File

@@ -28,6 +28,10 @@ import (
const deviceNamePrefix = "ingress-proxy-"
const clientStopTimeout = 30 * time.Second
const createProxyPeerTimeout = 30 * time.Second
// backendKey identifies a backend by its host:port from the target URL.
type backendKey string
@@ -162,6 +166,7 @@ type NetBird struct {
clientsMux sync.RWMutex
clients map[types.AccountID]*clientEntry
lifecycleMu sync.Map
initLogOnce sync.Once
statusNotifier statusNotifier
// readyHandler runs after the embedded client for an account reports
@@ -177,6 +182,10 @@ type NetBird struct {
// (i.e. when a new client was actually created, not when an existing one
// was reused). The duration covers keygen + gRPC CreateProxyPeer + embed.New.
OnAddPeer func(d time.Duration, err error)
// startClient runs the post-create client startup. Nil uses runClientStartup;
// tests override it to avoid a real embed client.Start.
startClient func(accountID types.AccountID, client *embed.Client)
}
// ClientDebugInfo contains debug information about a client.
@@ -200,31 +209,20 @@ type skipTLSVerifyContextKey struct{}
func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key ServiceKey, authToken string, serviceID types.ServiceID) error {
si := serviceInfo{serviceID: serviceID}
n.clientsMux.Lock()
if n.registerExistingClient(accountID, key, si) {
return nil
}
entry, exists := n.clients[accountID]
if exists {
entry.services[key] = si
started := entry.started
n.clientsMux.Unlock()
n.logger.WithFields(log.Fields{
"account_id": accountID,
"service_key": key,
}).Debug("registered service with existing client")
if started && n.statusNotifier != nil {
// Use a background context, not the caller's: the management
// connection notification must land even if the request /
// stream that triggered this registration is cancelled.
// Mirrors the async runClientStartup path.
if err := n.statusNotifier.NotifyStatus(context.Background(), accountID, serviceID, true); err != nil {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"service_key": key,
}).WithError(err).Warn("failed to notify status for existing client")
}
lifecycle := n.accountLifecycle(accountID)
lifecycle.Lock()
transferred := false
defer func() {
if !transferred {
lifecycle.Unlock()
}
}()
if n.registerExistingClient(accountID, key, si) {
return nil
}
@@ -234,10 +232,10 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key Se
n.OnAddPeer(time.Since(createStart), err)
}
if err != nil {
n.clientsMux.Unlock()
return err
}
n.clientsMux.Lock()
n.clients[accountID] = entry
n.clientsMux.Unlock()
@@ -246,17 +244,64 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key Se
"service_key": key,
}).Info("created new client for account")
// Attempt to start the client in the background; if this fails we will
// retry on the first request via RoundTrip. runClientStartup uses its
// own background context so the caller's request-scoped ctx can't
// cancel the inbound bring-up.
go n.runClientStartup(accountID, entry.client)
transferred = true
go func() {
defer lifecycle.Unlock()
n.startClientStartup(accountID, entry.client)
}()
return nil
}
func (n *NetBird) startClientStartup(accountID types.AccountID, client *embed.Client) {
if n.startClient != nil {
n.startClient(accountID, client)
return
}
n.runClientStartup(accountID, client)
}
// registerExistingClient registers the service against an already-present
// client for the account and returns true when it did. It notifies management
// of the new service when the client is already started.
func (n *NetBird) registerExistingClient(accountID types.AccountID, key ServiceKey, si serviceInfo) bool {
n.clientsMux.Lock()
entry, exists := n.clients[accountID]
if !exists {
n.clientsMux.Unlock()
return false
}
entry.services[key] = si
started := entry.started
n.clientsMux.Unlock()
n.logger.WithFields(log.Fields{
"account_id": accountID,
"service_key": key,
}).Debug("registered service with existing client")
if started && n.statusNotifier != nil {
if err := n.statusNotifier.NotifyStatus(context.Background(), accountID, si.serviceID, true); err != nil {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"service_key": key,
}).WithError(err).Warn("failed to notify status for existing client")
}
}
return true
}
// accountLifecycle returns the per-account lifecycle mutex, serialising client
// creation against teardown so a slow client.Stop cannot race a new
// client.Start for the same account, without blocking clientsMux.
func (n *NetBird) accountLifecycle(accountID types.AccountID) *sync.Mutex {
mu, _ := n.lifecycleMu.LoadOrStore(accountID, &sync.Mutex{})
return mu.(*sync.Mutex)
}
// createClientEntry generates a WireGuard keypair, authenticates with management,
// and creates an embedded NetBird client. Must be called with clientsMux held.
// and creates an embedded NetBird client. Must be called with the account's
// lifecycle mutex held.
func (n *NetBird) createClientEntry(ctx context.Context, accountID types.AccountID, key ServiceKey, authToken string, si serviceInfo) (*clientEntry, error) {
serviceID := si.serviceID
n.logger.WithFields(log.Fields{
@@ -276,7 +321,9 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
"public_key": publicKey.String(),
}).Debug("authenticating new proxy peer with management")
resp, err := n.mgmtClient.CreateProxyPeer(ctx, &proto.CreateProxyPeerRequest{
createCtx, cancel := context.WithTimeout(ctx, createProxyPeerTimeout)
defer cancel()
resp, err := n.mgmtClient.CreateProxyPeer(createCtx, &proto.CreateProxyPeerRequest{
ServiceId: string(serviceID),
AccountId: string(accountID),
Token: authToken,
@@ -444,6 +491,15 @@ func (n *NetBird) notifyClientReady(accountID types.AccountID, client *embed.Cli
// RemovePeer unregisters a service from an account. The client is only stopped
// when no services are using it anymore.
func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, key ServiceKey) error {
lifecycle := n.accountLifecycle(accountID)
lifecycle.Lock()
transferred := false
defer func() {
if !transferred {
lifecycle.Unlock()
}
}()
n.clientsMux.Lock()
entry, exists := n.clients[accountID]
@@ -466,17 +522,8 @@ func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, key
delete(entry.services, key)
stopClient := len(entry.services) == 0
var client *embed.Client
var transport, insecureTransport *http.Transport
var inbound any
var stopHandler func(types.AccountID, any)
if stopClient {
n.logger.WithField("account_id", accountID).Info("stopping client, no more services")
client = entry.client
transport = entry.transport
insecureTransport = entry.insecureTransport
inbound = entry.inbound
stopHandler = n.stopHandler
delete(n.clients, accountID)
} else {
n.logger.WithFields(log.Fields{
@@ -490,19 +537,40 @@ func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, key
n.notifyDisconnect(ctx, accountID, key, si.serviceID)
if stopClient {
if inbound != nil && stopHandler != nil {
stopHandler(accountID, inbound)
}
transport.CloseIdleConnections()
insecureTransport.CloseIdleConnections()
if err := client.Stop(ctx); err != nil {
n.logger.WithField("account_id", accountID).WithError(err).Warn("failed to stop netbird client")
}
transferred = true
go n.stopClientLocked(accountID, lifecycle, entry)
}
return nil
}
// stopClientLocked releases a client's resources off the caller's goroutine so a
// slow client.Stop cannot wedge the mapping receive loop (which calls RemovePeer
// synchronously). It unlocks lifecycle when done so a new client.Start for the
// same account waits for this teardown.
func (n *NetBird) stopClientLocked(accountID types.AccountID, lifecycle *sync.Mutex, entry *clientEntry) {
defer lifecycle.Unlock()
if entry.inbound != nil && n.stopHandler != nil {
n.stopHandler(accountID, entry.inbound)
}
if entry.transport != nil {
entry.transport.CloseIdleConnections()
}
if entry.insecureTransport != nil {
entry.insecureTransport.CloseIdleConnections()
}
if entry.client == nil {
return
}
ctx, cancel := context.WithTimeout(context.Background(), clientStopTimeout)
defer cancel()
if err := entry.client.Stop(ctx); err != nil {
n.logger.WithField("account_id", accountID).WithError(err).Warn("failed to stop netbird client")
}
}
func (n *NetBird) notifyDisconnect(ctx context.Context, accountID types.AccountID, key ServiceKey, serviceID types.ServiceID) {
if n.statusNotifier == nil {
return

View File

@@ -6,6 +6,7 @@ import (
"net/netip"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -22,6 +23,18 @@ func (m *mockMgmtClient) CreateProxyPeer(_ context.Context, _ *proto.CreateProxy
return &proto.CreateProxyPeerResponse{Success: true}, nil
}
// signalMgmtClient closes entered the first time CreateProxyPeer is called, so
// tests can detect AddPeer reaching client creation.
type signalMgmtClient struct {
entered chan struct{}
once sync.Once
}
func (m *signalMgmtClient) CreateProxyPeer(_ context.Context, _ *proto.CreateProxyPeerRequest, _ ...grpc.CallOption) (*proto.CreateProxyPeerResponse, error) {
m.once.Do(func() { close(m.entered) })
return &proto.CreateProxyPeerResponse{Success: true}, nil
}
type mockStatusNotifier struct {
mu sync.Mutex
statuses []statusCall
@@ -52,11 +65,15 @@ func (m *mockStatusNotifier) calls() []statusCall {
// mockNetBird creates a NetBird instance for testing without actually connecting.
// It uses an invalid management URL to prevent real connections.
func mockNetBird() *NetBird {
return NewNetBird(context.Background(), "test-proxy", "invalid.test", ClientConfig{
nb := NewNetBird(context.Background(), "test-proxy", "invalid.test", ClientConfig{
MgmtAddr: "http://invalid.test:9999",
WGPort: 0,
PreSharedKey: "",
}, nil, nil, &mockMgmtClient{})
// Skip the real embed client.Start, which would hang against the unreachable
// mgmt URL and (now that the lifecycle lock spans startup) serialise removes.
nb.startClient = func(types.AccountID, *embed.Client) {}
return nb
}
func TestNetBird_AddPeer_CreatesClientForNewAccount(t *testing.T) {
@@ -288,6 +305,7 @@ func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) {
WGPort: 0,
PreSharedKey: "",
}, nil, notifier, &mockMgmtClient{})
nb.startClient = func(types.AccountID, *embed.Client) {}
accountID := types.AccountID("account-1")
// Add first service — creates a new client entry.
@@ -372,6 +390,117 @@ func TestNetBird_RemovePeer_NotifiesDisconnection(t *testing.T) {
assert.False(t, calls[0].connected)
}
// TestNetBird_RemovePeer_TeardownIsAsync proves the fix for the receive-loop
// stall: RemovePeer must return promptly even when the client teardown blocks,
// because teardown runs off the caller's goroutine. The receive loop calls
// RemovePeer synchronously, so a blocking teardown inline would wedge it.
func TestNetBird_RemovePeer_TeardownIsAsync(t *testing.T) {
nb := NewNetBird(context.Background(), "test-proxy", "invalid.test", ClientConfig{
MgmtAddr: "http://invalid.test:9999",
}, nil, &mockStatusNotifier{}, &mockMgmtClient{})
accountID := types.AccountID("acct-async-teardown")
key := DomainServiceKey("svc.example")
teardownEntered := make(chan struct{})
releaseTeardown := make(chan struct{})
nb.SetClientLifecycle(nil, func(types.AccountID, any) {
close(teardownEntered)
<-releaseTeardown
})
nb.clientsMux.Lock()
nb.clients[accountID] = &clientEntry{
services: map[ServiceKey]serviceInfo{key: {serviceID: types.ServiceID("svc-1")}},
started: true,
inbound: struct{}{},
}
nb.clientsMux.Unlock()
done := make(chan error, 1)
go func() { done <- nb.RemovePeer(context.Background(), accountID, key) }()
select {
case err := <-done:
require.NoError(t, err)
case <-time.After(2 * time.Second):
t.Fatal("RemovePeer did not return while teardown was blocked — teardown is not async")
}
select {
case <-teardownEntered:
case <-time.After(2 * time.Second):
t.Fatal("teardown never ran")
}
close(releaseTeardown)
}
// TestNetBird_AddPeer_WaitsForTeardown proves the lifecycle lock serialises a
// new client bringup behind an in-flight teardown for the same account, so a
// slow client.Stop can never race a new client.Start for that account.
//
// It targets the handoff race specifically: AddPeer is launched immediately
// after RemovePeer returns, WITHOUT waiting for the teardown goroutine to start.
// This only passes if RemovePeer acquires the lifecycle lock synchronously
// (before returning) and hands it to the teardown goroutine — if the goroutine
// acquired the lock itself, AddPeer could win the lock in this window and start
// a replacement client while the old teardown is still pending.
func TestNetBird_AddPeer_WaitsForTeardown(t *testing.T) {
nb := NewNetBird(context.Background(), "test-proxy", "invalid.test", ClientConfig{
MgmtAddr: "http://invalid.test:9999",
}, nil, &mockStatusNotifier{}, &mockMgmtClient{})
nb.startClient = func(types.AccountID, *embed.Client) {}
accountID := types.AccountID("acct-serialize")
key := DomainServiceKey("svc.example")
addEntered := make(chan struct{})
releaseTeardown := make(chan struct{})
nb.SetClientLifecycle(nil, func(types.AccountID, any) {
// Block teardown until released. If AddPeer ever reaches createClientEntry
// (signalled via the mgmt client below) while we hold the lock, the lock
// failed to serialise and the test fails before we release.
<-releaseTeardown
})
nb.clientsMux.Lock()
nb.clients[accountID] = &clientEntry{
services: map[ServiceKey]serviceInfo{key: {serviceID: types.ServiceID("svc-1")}},
started: true,
inbound: struct{}{},
}
nb.clientsMux.Unlock()
// createClientEntry calls CreateProxyPeer; closing addEntered there tells us
// AddPeer got past the lifecycle lock and into client creation.
nb.mgmtClient = &signalMgmtClient{entered: addEntered}
require.NoError(t, nb.RemovePeer(context.Background(), accountID, key))
// Launch AddPeer with NO synchronisation against the teardown goroutine.
addReturned := make(chan struct{})
go func() {
_ = nb.AddPeer(context.Background(), accountID, DomainServiceKey("svc2.example"), "key-2", types.ServiceID("svc-2"))
close(addReturned)
}()
select {
case <-addEntered:
t.Fatal("AddPeer entered client creation while teardown held the lifecycle lock — handoff race not closed")
case <-addReturned:
t.Fatal("AddPeer completed while teardown held the lifecycle lock — not serialised")
case <-time.After(300 * time.Millisecond):
}
close(releaseTeardown)
select {
case <-addReturned:
case <-time.After(2 * time.Second):
t.Fatal("AddPeer never completed after teardown released the lifecycle lock")
}
}
// TestNotifyClientReady_UsesBackgroundCtx pins the contract that the
// post-Start hooks (readyHandler + statusNotifier.NotifyStatus) run on
// a fresh context.Background() rather than inheriting the AddPeer

View File

@@ -114,6 +114,10 @@ type Config struct {
MaxDialTimeout time.Duration
// MaxSessionIdleTimeout caps the per-service session idle timeout.
MaxSessionIdleTimeout time.Duration
// MappingBatchWatchdog bounds how long a single mapping batch may spend
// being applied before the receive loop reconnects to resync. Zero falls
// back to the internal default.
MappingBatchWatchdog time.Duration
// GeoDataDir is the directory containing GeoLite2 MMDB files.
GeoDataDir string
@@ -164,6 +168,7 @@ func New(ctx context.Context, cfg Config) *Server {
Private: cfg.Private,
MaxDialTimeout: cfg.MaxDialTimeout,
MaxSessionIdleTimeout: cfg.MaxSessionIdleTimeout,
MappingBatchWatchdog: cfg.MappingBatchWatchdog,
GeoDataDir: cfg.GeoDataDir,
CrowdSecAPIURL: cfg.CrowdSecAPIURL,
CrowdSecAPIKey: cfg.CrowdSecAPIKey,

282
proxy/mapping_stall_test.go Normal file
View File

@@ -0,0 +1,282 @@
package proxy
import (
"context"
"sync"
"sync/atomic"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
// blockingMgmtClient implements roundtrip's managementClient interface.
// CreateProxyPeer parks until release is closed, signalling entry on entered.
// This reproduces the confirmed real-world stall: createClientEntry calls
// CreateProxyPeer synchronously while holding clientsMux, and the proxy's
// receive loop calls that path synchronously inside processMappings.
type blockingMgmtClient struct {
entered chan struct{}
once sync.Once
}
func (b *blockingMgmtClient) CreateProxyPeer(ctx context.Context, _ *proto.CreateProxyPeerRequest, _ ...grpc.CallOption) (*proto.CreateProxyPeerResponse, error) {
b.once.Do(func() { close(b.entered) })
// Park until the caller's context is cancelled. In production this ctx is
// the gRPC mapping-stream context with no per-call timeout, so a slow or
// unresponsive CreateProxyPeer parks the receive loop here indefinitely.
<-ctx.Done()
return nil, ctx.Err()
}
// gatedMappingStream is a mock GetMappingUpdate client stream that hands out a
// pre-seeded list of messages, then records how many times Recv advanced. It
// lets the test observe whether the single-threaded receive loop ever gets
// past the first (blocking) batch to pull the second message.
type gatedMappingStream struct {
grpc.ClientStream
messages []*proto.GetMappingUpdateResponse
idx int32
}
func (g *gatedMappingStream) Recv() (*proto.GetMappingUpdateResponse, error) {
i := int(atomic.LoadInt32(&g.idx))
if i >= len(g.messages) {
// Block instead of returning EOF so the loop doesn't exit; we only
// care whether the loop ever reaches this second Recv at all.
select {}
}
msg := g.messages[i]
atomic.AddInt32(&g.idx, 1)
return msg, nil
}
func (g *gatedMappingStream) deliveredCount() int32 { return atomic.LoadInt32(&g.idx) }
func (g *gatedMappingStream) Header() (metadata.MD, error) { return nil, nil } //nolint:nilnil
func (g *gatedMappingStream) Trailer() metadata.MD { return nil }
func (g *gatedMappingStream) CloseSend() error { return nil }
func (g *gatedMappingStream) Context() context.Context { return context.Background() }
func (g *gatedMappingStream) SendMsg(any) error { return nil }
func (g *gatedMappingStream) RecvMsg(any) error { return nil }
// noopNotifier satisfies roundtrip's statusNotifier interface.
type noopNotifier struct{}
func (noopNotifier) NotifyStatus(context.Context, types.AccountID, types.ServiceID, bool) error {
return nil
}
// noopProxyClient is a proto.ProxyServiceClient that no-ops the one method the
// teardown unwind reaches (SendStatusUpdate, via notifyError when the parked
// AddPeer is cancelled). The embedded nil interface satisfies the rest at
// compile time; none of those methods are called by this test.
type noopProxyClient struct {
proto.ProxyServiceClient
}
func (noopProxyClient) SendStatusUpdate(context.Context, *proto.SendStatusUpdateRequest, ...grpc.CallOption) (*proto.SendStatusUpdateResponse, error) {
return &proto.SendStatusUpdateResponse{}, nil
}
// TestMappingStream_StallsWhenApplyBlocks proves the deadlock: the proxy's
// mapping receive loop processes batches strictly serially, so when applying
// one batch blocks (here: createClientEntry parked on a synchronous
// CreateProxyPeer call, exactly as observed in production), the loop never
// advances to Recv the next batch. Management can keep sending updates onto
// the stream with no error and no channel overflow, yet the proxy applies
// nothing further — it is stuck.
func TestMappingStream_StallsWhenApplyBlocks(t *testing.T) {
logger := log.New()
logger.SetLevel(log.PanicLevel)
mgmt := &blockingMgmtClient{
entered: make(chan struct{}),
}
nb := roundtrip.NewNetBird(
context.Background(),
"proxy-test",
"proxy.example.com",
roundtrip.ClientConfig{},
logger,
noopNotifier{},
mgmt,
)
s := &Server{
Logger: logger,
netbird: nb,
mgmtClient: noopProxyClient{},
routerReady: closedChan(),
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
}
// First batch: a CREATED mapping for a brand-new account. addMapping ->
// netbird.AddPeer -> createClientEntry -> CreateProxyPeer, which blocks.
// Empty Path keeps setupHTTPMapping a no-op (it returns early), so the
// ONLY blocking point is the synchronous CreateProxyPeer in AddPeer —
// no routers/auth need wiring. The second batch exists only to detect
// whether the loop ever advances past the blocked first batch.
stream := &gatedMappingStream{
messages: []*proto.GetMappingUpdateResponse{
{
Mapping: []*proto.ProxyMapping{
{
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
Id: "svc-1",
AccountId: "acct-1",
AuthToken: "token-1",
},
},
},
{
Mapping: []*proto.ProxyMapping{
{
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
Id: "svc-2",
AccountId: "acct-2",
AuthToken: "token-2",
},
},
},
},
}
ctx, cancel := context.WithCancel(context.Background())
// Unblock the parked apply on teardown via ctx (CreateProxyPeer returns
// ctx.Err()), so the wedged loop goroutine unwinds before embed.New —
// avoiding any dependency on collaborators this test deliberately leaves
// nil. The deadlock is fully proven before this fires.
t.Cleanup(cancel)
loopDone := make(chan struct{})
syncDone := false
go func() {
defer close(loopDone)
_ = s.handleMappingStream(ctx, stream, &syncDone, time.Time{})
}()
// The loop must reach the blocking apply for the first batch.
select {
case <-mgmt.entered:
case <-time.After(2 * time.Second):
t.Fatal("receive loop never reached CreateProxyPeer for the first batch")
}
// THE DEADLOCK: while the first batch is parked in CreateProxyPeer, the
// single-threaded loop cannot advance. The second batch is never pulled,
// even though it is already available on the stream. Give it ample time.
// deliveredCount is atomic; syncDone is intentionally not read here because
// the loop goroutine owns it (reading it from the test would race).
time.Sleep(500 * time.Millisecond)
assert.Equal(t, int32(1), stream.deliveredCount(),
"loop must NOT consume the second batch while the first is blocked in apply — proxy is stuck")
select {
case <-loopDone:
t.Fatal("receive loop returned while it should be wedged in apply")
default:
// Still wedged, as expected.
}
}
// TestMappingStream_StallsWhenRemoveBlocks proves the deadlock for the REMOVE
// path observed in production: a mapping remove tears down the account's last
// embedded client via netbird.RemovePeer -> client.Stop -> Engine.Stop, whose
// jobExecutorWG.Wait() is unbounded. Because the receive loop is single-
// threaded, a blocked remove wedges the loop: no further mapping updates of any
// kind (create/modify/remove) are applied, while management keeps sending them
// successfully (no send error, no channel-full). Matches the reported symptom:
// the last log line is a remove that stops a client, then silence.
func TestMappingStream_StallsWhenRemoveBlocks(t *testing.T) {
logger := log.New()
logger.SetLevel(log.PanicLevel)
enteredRemove := make(chan struct{})
blockRemove := make(chan struct{})
var once sync.Once
s := &Server{
Logger: logger,
mgmtClient: noopProxyClient{},
routerReady: closedChan(),
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
// Stand in for netbird.RemovePeer -> client.Stop hanging on
// Engine.Stop's unbounded jobExecutorWG.Wait(). Only the first remove
// blocks; later removes return immediately so the recovery assertion
// can observe the loop advancing.
removePeer: func(ctx context.Context, _ types.AccountID, _ roundtrip.ServiceKey) error {
first := false
once.Do(func() {
first = true
close(enteredRemove)
})
if !first {
return nil
}
select {
case <-blockRemove:
case <-ctx.Done():
}
return nil
},
}
// Batch 1 removes a service (blocks in teardown). Batch 2 is a later update
// that must never be applied while the remove is wedged.
stream := &gatedMappingStream{
messages: []*proto.GetMappingUpdateResponse{
{
Mapping: []*proto.ProxyMapping{
{Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, Id: "svc-1", AccountId: "acct-1"},
},
},
{
Mapping: []*proto.ProxyMapping{
{Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, Id: "svc-2", AccountId: "acct-1"},
},
},
},
}
loopDone := make(chan struct{})
syncDone := false
go func() {
defer close(loopDone)
_ = s.handleMappingStream(context.Background(), stream, &syncDone, time.Time{})
}()
select {
case <-enteredRemove:
case <-time.After(2 * time.Second):
t.Fatal("receive loop never reached the blocking remove for the first batch")
}
// THE DEADLOCK: the loop is parked in the blocked remove and cannot advance.
// syncDone is owned by the loop goroutine, so it is not read here.
time.Sleep(500 * time.Millisecond)
assert.Equal(t, int32(1), stream.deliveredCount(),
"loop must NOT consume the second batch while the first remove is blocked — proxy is stuck")
select {
case <-loopDone:
t.Fatal("receive loop returned while it should be wedged on the remove")
default:
}
// Unblock and confirm the wedge was solely the blocked remove: the loop
// then advances and consumes the next batch.
close(blockRemove)
assert.Eventually(t, func() bool {
return stream.deliveredCount() >= 2
}, 2*time.Second, 5*time.Millisecond,
"once the remove unblocks, the loop must advance and consume the next batch")
}

View File

@@ -24,6 +24,7 @@ import (
"time"
"github.com/cenkalti/backoff/v4"
"github.com/google/uuid"
"github.com/pires/go-proxyproto"
prometheus2 "github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
@@ -75,29 +76,30 @@ type portRouter struct {
}
type Server struct {
ctx context.Context
mgmtClient proto.ProxyServiceClient
proxy *proxy.ReverseProxy
netbird *roundtrip.NetBird
acme *acme.Manager
auth *auth.Middleware
http *http.Server
https *http.Server
debug *http.Server
healthServer *health.Server
healthChecker *health.Checker
meter *proxymetrics.Metrics
accessLog *accesslog.Logger
mainRouter *nbtcp.Router
mainPort uint16
udpMu sync.Mutex
udpRelays map[types.ServiceID]*udprelay.Relay
udpRelayWg sync.WaitGroup
portMu sync.RWMutex
portRouters map[uint16]*portRouter
svcPorts map[types.ServiceID][]uint16
lastMappings map[types.ServiceID]*proto.ProxyMapping
portRouterWg sync.WaitGroup
ctx context.Context
mgmtClient proto.ProxyServiceClient
proxy *proxy.ReverseProxy
netbird *roundtrip.NetBird
acme *acme.Manager
staticCertWatcher *certwatch.Watcher
auth *auth.Middleware
http *http.Server
https *http.Server
debug *http.Server
healthServer *health.Server
healthChecker *health.Checker
meter *proxymetrics.Metrics
accessLog *accesslog.Logger
mainRouter *nbtcp.Router
mainPort uint16
udpMu sync.Mutex
udpRelays map[types.ServiceID]*udprelay.Relay
udpRelayWg sync.WaitGroup
portMu sync.RWMutex
portRouters map[uint16]*portRouter
svcPorts map[types.ServiceID][]uint16
lastMappings map[types.ServiceID]*proto.ProxyMapping
portRouterWg sync.WaitGroup
// hijackTracker tracks hijacked connections (e.g. WebSocket upgrades)
// so they can be closed during graceful shutdown, since http.Server.Shutdown
@@ -118,6 +120,9 @@ type Server struct {
// The mapping worker waits on this before processing updates.
routerReady chan struct{}
// removePeer defaults to netbird.RemovePeer; overridable in tests.
removePeer func(ctx context.Context, accountID types.AccountID, key roundtrip.ServiceKey) error
// inbound, when non-nil, manages per-account inbound listeners. Set by
// initPrivateInbound only when Private is true so the standalone
// proxy keeps its zero-overhead default path.
@@ -227,6 +232,10 @@ type Server struct {
// Zero means no cap (the proxy honors whatever management sends).
// Set via NB_PROXY_MAX_SESSION_IDLE_TIMEOUT for shared deployments.
MaxSessionIdleTimeout time.Duration
// MappingBatchWatchdog bounds how long a single mapping batch may spend
// in processMappings before the receive loop reconnects to resync.
// Zero uses defaultMappingBatchWatchdog.
MappingBatchWatchdog time.Duration
}
// clampIdleTimeout returns d capped to MaxSessionIdleTimeout when configured.
@@ -607,7 +616,7 @@ func (s *Server) initDefaults() {
// If no ID is set then one can be generated.
if s.ID == "" {
s.ID = "netbird-proxy-" + s.startTime.Format("20060102150405")
s.ID = fmt.Sprintf("netbird-proxy-%s", uuid.NewString())
}
// Fallback version option in case it is not set.
if s.Version == "" {
@@ -785,6 +794,7 @@ func (s *Server) configureTLS(ctx context.Context) (*tls.Config, error) {
return nil, fmt.Errorf("initialize certificate watcher: %w", err)
}
go certWatcher.Watch(ctx)
s.staticCertWatcher = certWatcher
tlsConfig.GetCertificate = certWatcher.GetCertificate
return tlsConfig, nil
}
@@ -1172,24 +1182,30 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr
s.healthChecker.SetManagementConnected(false)
}
connected := false
onConnected := func() { connected = true }
var streamErr error
if syncSupported {
streamErr = s.trySyncMappings(ctx, client, &initialSyncDone)
streamErr = s.trySyncMappings(ctx, client, &initialSyncDone, onConnected)
if isSyncUnimplemented(streamErr) {
syncSupported = false
s.Logger.Info("management does not support SyncMappings, falling back to GetMappingUpdate")
streamErr = s.tryGetMappingUpdate(ctx, client, &initialSyncDone)
streamErr = s.tryGetMappingUpdate(ctx, client, &initialSyncDone, onConnected)
}
} else {
streamErr = s.tryGetMappingUpdate(ctx, client, &initialSyncDone)
streamErr = s.tryGetMappingUpdate(ctx, client, &initialSyncDone, onConnected)
}
if s.healthChecker != nil {
s.healthChecker.SetManagementConnected(false)
}
// Stream established — reset backoff so the next failure retries quickly.
bo.Reset()
// Reset backoff only when a stream actually connected, so immediate
// connect failures still back off instead of spinning.
if connected {
bo.Reset()
}
if streamErr == nil {
return fmt.Errorf("stream closed by server")
@@ -1221,7 +1237,7 @@ func (s *Server) proxyCapabilities() *proto.ProxyCapabilities {
}
}
func (s *Server) tryGetMappingUpdate(ctx context.Context, client proto.ProxyServiceClient, initialSyncDone *bool) error {
func (s *Server) tryGetMappingUpdate(ctx context.Context, client proto.ProxyServiceClient, initialSyncDone *bool, onConnected func()) error {
connectTime := time.Now()
mappingClient, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
ProxyId: s.ID,
@@ -1234,6 +1250,7 @@ func (s *Server) tryGetMappingUpdate(ctx context.Context, client proto.ProxyServ
return fmt.Errorf("create mapping stream: %w", err)
}
onConnected()
if s.healthChecker != nil {
s.healthChecker.SetManagementConnected(true)
}
@@ -1242,7 +1259,7 @@ func (s *Server) tryGetMappingUpdate(ctx context.Context, client proto.ProxyServ
return s.handleMappingStream(ctx, mappingClient, initialSyncDone, connectTime)
}
func (s *Server) trySyncMappings(ctx context.Context, client proto.ProxyServiceClient, initialSyncDone *bool) error {
func (s *Server) trySyncMappings(ctx context.Context, client proto.ProxyServiceClient, initialSyncDone *bool, onConnected func()) error {
connectTime := time.Now()
stream, err := client.SyncMappings(ctx)
if err != nil {
@@ -1263,6 +1280,7 @@ func (s *Server) trySyncMappings(ctx context.Context, client proto.ProxyServiceC
return fmt.Errorf("send sync init: %w", err)
}
onConnected()
if s.healthChecker != nil {
s.healthChecker.SetManagementConnected(true)
}
@@ -1307,7 +1325,9 @@ func (s *Server) handleSyncMappingsStream(ctx context.Context, stream proto.Prox
batchStart := time.Now()
s.Logger.Debug("Received mapping update, starting processing")
s.processMappings(ctx, msg.GetMapping())
if err := s.processMappingsGuarded(ctx, msg.GetMapping()); err != nil {
return err
}
s.Logger.Debug("Processing mapping update completed")
tracker.recordBatch(ctx, s, msg.GetMapping(), msg.GetInitialSyncComplete(), batchStart)
@@ -1391,7 +1411,9 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr
batchStart := time.Now()
s.Logger.Debug("Received mapping update, starting processing")
s.processMappings(ctx, msg.GetMapping())
if err := s.processMappingsGuarded(ctx, msg.GetMapping()); err != nil {
return err
}
s.Logger.Debug("Processing mapping update completed")
tracker.recordBatch(ctx, s, msg.GetMapping(), msg.GetInitialSyncComplete(), batchStart)
}
@@ -1456,6 +1478,44 @@ func redactMappingForLog(m *proto.ProxyMapping) *proto.ProxyMapping {
return c
}
const defaultMappingBatchWatchdog = 2 * time.Minute
// mappingBatchWatchdog returns the configured batch watchdog or the default.
func (s *Server) mappingBatchWatchdog() time.Duration {
if s.MappingBatchWatchdog > 0 {
return s.MappingBatchWatchdog
}
return defaultMappingBatchWatchdog
}
// processMappingsGuarded applies a batch under a watchdog, returning an error
// if processing exceeds the watchdog so the caller reconnects and resyncs
// instead of wedging silently.
func (s *Server) processMappingsGuarded(ctx context.Context, mappings []*proto.ProxyMapping) error {
batchCtx, cancel := context.WithCancel(ctx)
defer cancel()
done := make(chan struct{})
go func() {
defer close(done)
s.processMappings(batchCtx, mappings)
}()
watchdog := s.mappingBatchWatchdog()
timer := time.NewTimer(watchdog)
defer timer.Stop()
select {
case <-done:
return nil
case <-ctx.Done():
return ctx.Err()
case <-timer.C:
s.Logger.Errorf("processing mapping batch exceeded %s, cancelling and reconnecting to resync", watchdog)
return fmt.Errorf("mapping batch processing stalled after %s", watchdog)
}
}
func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMapping) {
debug := s.Logger != nil && s.Logger.IsLevelEnabled(log.DebugLevel)
for _, mapping := range mappings {
@@ -1566,6 +1626,8 @@ func (s *Server) setupHTTPMapping(ctx context.Context, mapping *proto.ProxyMappi
var wildcardHit bool
if s.acme != nil {
wildcardHit = s.acme.AddDomain(d, accountID, svcID)
} else {
wildcardHit = s.staticCertCovers(d)
}
httpRoute := nbtcp.Route{
Type: nbtcp.RouteHTTP,
@@ -1590,6 +1652,26 @@ func (s *Server) setupHTTPMapping(ctx context.Context, mapping *proto.ProxyMappi
return nil
}
// staticCertCovers reports whether the static certificate loaded when ACME is
// disabled covers the given domain, making it certificate-ready immediately —
// the equivalent of a wildcard hit in the ACME path. Domains the certificate
// does not cover are logged: clients connecting to them will get TLS errors.
func (s *Server) staticCertCovers(d domain.Domain) bool {
if s.staticCertWatcher == nil {
return false
}
leaf := s.staticCertWatcher.Leaf()
if leaf == nil {
return false
}
name := d.PunycodeString()
if err := leaf.VerifyHostname(name); err != nil {
s.Logger.Warnf("static certificate (SANs %v) does not cover domain %q: %v", leaf.DNSNames, name, err)
return false
}
return true
}
// setupTCPMapping sets up a TCP port-forwarding fallback route on the listen port.
func (s *Server) setupTCPMapping(ctx context.Context, mapping *proto.ProxyMapping) error {
svcID := types.ServiceID(mapping.GetId())
@@ -1951,7 +2033,11 @@ func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping)
func (s *Server) removeMapping(ctx context.Context, mapping *proto.ProxyMapping) {
accountID := types.AccountID(mapping.GetAccountId())
svcKey := s.serviceKeyForMapping(mapping)
if err := s.netbird.RemovePeer(ctx, accountID, svcKey); err != nil {
removePeer := s.removePeer
if removePeer == nil {
removePeer = s.netbird.RemovePeer
}
if err := removePeer(ctx, accountID, svcKey); err != nil {
s.Logger.WithFields(log.Fields{
"account_id": accountID,
"service_id": mapping.GetId(),

89
proxy/static_cert_test.go Normal file
View File

@@ -0,0 +1,89 @@
package proxy
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/proxy/internal/certwatch"
"github.com/netbirdio/netbird/shared/management/domain"
)
func generateCertWithSANs(t *testing.T, dnsNames []string) (certPEM, keyPEM []byte) {
t.Helper()
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
template := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: dnsNames[0]},
DNSNames: dnsNames,
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(24 * time.Hour),
}
certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
require.NoError(t, err)
certPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
keyDER, err := x509.MarshalECPrivateKey(key)
require.NoError(t, err)
keyPEM = pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
return certPEM, keyPEM
}
func newStaticWatcher(t *testing.T, dnsNames []string) *certwatch.Watcher {
t.Helper()
dir := t.TempDir()
certPEM, keyPEM := generateCertWithSANs(t, dnsNames)
certPath := filepath.Join(dir, "tls.crt")
keyPath := filepath.Join(dir, "tls.key")
require.NoError(t, os.WriteFile(certPath, certPEM, 0o600))
require.NoError(t, os.WriteFile(keyPath, keyPEM, 0o600))
w, err := certwatch.NewWatcher(certPath, keyPath, quietLifecycleLogger())
require.NoError(t, err)
return w
}
func TestStaticCertCovers(t *testing.T) {
s := &Server{
Logger: quietLifecycleLogger(),
staticCertWatcher: newStaticWatcher(t, []string{"*.p.example.com", "exact.example.com"}),
}
cases := []struct {
domain string
covered bool
}{
{"svc.p.example.com", true},
{"exact.example.com", true},
{"a.b.p.example.com", false}, // wildcard does not span labels
{"p.example.com", false},
{"other.example.com", false},
}
for _, tc := range cases {
t.Run(tc.domain, func(t *testing.T) {
assert.Equal(t, tc.covered, s.staticCertCovers(domain.Domain(tc.domain)))
})
}
}
func TestStaticCertCoversNoWatcher(t *testing.T) {
s := &Server{Logger: quietLifecycleLogger()}
assert.False(t, s.staticCertCovers(domain.Domain("svc.p.example.com")))
}

View File

@@ -6,4 +6,5 @@ const (
RoleKey = "role"
UserIDKey = "userID"
PeerIDKey = "peerID"
UserAgentKey = "userAgent"
)

View File

@@ -322,15 +322,21 @@ func TestClient_Sync(t *testing.T) {
if resp.GetNetbirdConfig() == nil {
t.Error("expecting non nil NetbirdConfig got nil")
}
if len(resp.GetRemotePeers()) != 1 {
t.Errorf("expecting RemotePeers size %d got %d", 1, len(resp.GetRemotePeers()))
// we test network map peers from 0.29.3 and dev builds
if len(resp.GetRemotePeers()) != 0 {
t.Error("expecting top-level RemotePeers to be empty for v0.29.3+ clients")
}
networkMap := resp.GetNetworkMap()
if len(networkMap.GetRemotePeers()) != 1 {
t.Errorf("expecting RemotePeers size %d got %d", 1, len(networkMap.GetRemotePeers()))
return
}
if resp.GetRemotePeersIsEmpty() == true {
if networkMap.GetRemotePeersIsEmpty() {
t.Error("expecting RemotePeers property to be false, got true")
}
if resp.GetRemotePeers()[0].GetWgPubKey() != remoteKey.PublicKey().String() {
t.Errorf("expecting RemotePeer public key %s got %s", remoteKey.PublicKey().String(), resp.GetRemotePeers()[0].GetWgPubKey())
if networkMap.GetRemotePeers()[0].GetWgPubKey() != remoteKey.PublicKey().String() {
t.Errorf("expecting RemotePeer public key %s got %s", remoteKey.PublicKey().String(), networkMap.GetRemotePeers()[0].GetWgPubKey())
}
case <-time.After(3 * time.Second):
t.Error("timeout waiting for test to finish")

View File

@@ -5107,31 +5107,63 @@ components:
responses:
not_found:
description: Resource not found
headers:
X-Request-Id:
$ref: '#/components/headers/X-Request-Id'
content: { }
validation_failed_simple:
description: Validation failed
headers:
X-Request-Id:
$ref: '#/components/headers/X-Request-Id'
content: { }
bad_request:
description: Bad Request
headers:
X-Request-Id:
$ref: '#/components/headers/X-Request-Id'
content: { }
internal_error:
description: Internal Server Error
headers:
X-Request-Id:
$ref: '#/components/headers/X-Request-Id'
content: { }
validation_failed:
description: Validation failed
headers:
X-Request-Id:
$ref: '#/components/headers/X-Request-Id'
content: { }
forbidden:
description: Forbidden
headers:
X-Request-Id:
$ref: '#/components/headers/X-Request-Id'
content: { }
requires_authentication:
description: Requires authentication
headers:
X-Request-Id:
$ref: '#/components/headers/X-Request-Id'
content: { }
conflict:
description: Conflict
headers:
X-Request-Id:
$ref: '#/components/headers/X-Request-Id'
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
headers:
X-Request-Id:
description: |
Unique identifier assigned to the request by the server and set on every
response. Useful for correlating client requests with server-side logs.
schema:
type: string
example: cot7r4n3l3vh3qj4qveg
securitySchemes:
BearerAuth:
type: http

View File

@@ -9,12 +9,14 @@ import (
"net/url"
"strings"
"sync"
"sync/atomic"
"time"
log "github.com/sirupsen/logrus"
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
"github.com/netbirdio/netbird/shared/relay/client/dialer"
netErr "github.com/netbirdio/netbird/shared/relay/client/dialer/net"
"github.com/netbirdio/netbird/shared/relay/healthcheck"
"github.com/netbirdio/netbird/shared/relay/messages"
)
@@ -172,6 +174,19 @@ type Client struct {
stateSubscription *PeersStateSubscription
mtu uint16
// transportFallback, when set, records datagram-too-large failures so a
// datagram-sized transport is avoided on subsequent connects. Shared via
// the manager.
transportFallback *transportFallback
// datagramFallbackTriggered guards a single fallback per connection so a
// burst of oversized datagrams triggers one reconnect, not many.
datagramFallbackTriggered atomic.Bool
}
// SetTransportFallback wires the shared datagram-transport fallback tracker.
func (c *Client) SetTransportFallback(tf *transportFallback) {
c.transportFallback = tf
}
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
@@ -361,12 +376,13 @@ func (c *Client) Close() error {
}
func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
dialers := c.getDialers()
mode := transportModeFromEnv()
dialers := c.getDialers(mode)
var conn net.Conn
if c.serverIP.IsValid() {
var err error
conn, err = c.dialRaceDirect(ctx, dialers)
conn, err = c.dialRaceDirect(ctx, mode, dialers)
if err != nil {
c.log.Infof("dial via server IP %s failed, falling back to FQDN: %v", c.serverIP, err)
conn = nil
@@ -375,6 +391,9 @@ func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
if conn == nil {
rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, dialers...)
if mode.sequential() {
rd.WithSequential()
}
var err error
conn, err = rd.Dial(ctx)
if err != nil {
@@ -382,6 +401,7 @@ func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
}
}
c.relayConn = conn
c.datagramFallbackTriggered.Store(false)
instanceURL, err := c.handShake(ctx)
if err != nil {
@@ -396,7 +416,7 @@ func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
}
// dialRaceDirect dials c.serverIP, preserving the original FQDN as the TLS ServerName for SNI.
func (c *Client) dialRaceDirect(ctx context.Context, dialers []dialer.DialeFn) (net.Conn, error) {
func (c *Client) dialRaceDirect(ctx context.Context, mode TransportMode, dialers []dialer.DialeFn) (net.Conn, error) {
directURL, serverName, err := substituteHost(c.connectionURL, c.serverIP)
if err != nil {
return nil, fmt.Errorf("substitute host: %w", err)
@@ -406,6 +426,9 @@ func (c *Client) dialRaceDirect(ctx context.Context, dialers []dialer.DialeFn) (
rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, directURL, dialers...).
WithServerName(serverName)
if mode.sequential() {
rd.WithSequential()
}
return rd.Dial(ctx)
}
@@ -631,13 +654,53 @@ func (c *Client) writeTo(containerRef *connContainer, dstID messages.PeerID, pay
}
// the write always return with 0 length because the underling does not support the size feedback.
_, err = c.relayConn.Write(msg)
conn := c.relayConn
_, err = conn.Write(msg)
if err != nil {
c.log.Errorf("failed to write transport message: %s", err)
if errors.Is(err, netErr.ErrDatagramTooLarge) {
c.onDatagramTooLarge(conn, err)
} else {
c.log.Errorf("failed to write transport message: %s", err)
}
}
return len(payload), err
}
// onDatagramTooLarge reacts to a datagram rejected as too large for the path.
// When a non-datagram transport is available, it records a fallback for this
// server and closes the connection so the reconnect avoids datagram-sized
// transports. A single fallback is triggered per connection regardless of how
// many oversized datagrams arrive. cause carries the datagram size and budget.
func (c *Client) onDatagramTooLarge(conn net.Conn, cause error) {
// Handle one oversized datagram per connection; a burst triggers a single
// fallback (and a single log line), not many.
if !c.datagramFallbackTriggered.CompareAndSwap(false, true) {
return
}
// If the selected mode offers no non-datagram transport (e.g. pinned to a
// datagram-sized transport), reconnecting would just re-fail, so leave the
// connection up rather than loop.
if len(nonDatagramSized(c.baseDialers(transportModeFromEnv()))) == 0 {
c.log.Warnf("%s, but no non-datagram transport is available, not falling back", cause)
return
}
// Without the shared tracker a reconnect would just select the same
// transport again and re-fail, so leave the connection up rather than loop.
if c.transportFallback == nil {
c.log.Debugf("%s, but no transport fallback configured, leaving connection up", cause)
return
}
window := c.transportFallback.recordFailure(c.connectionURL)
c.log.Warnf("%s, avoiding datagram-sized transport for %s", cause, window)
if err := conn.Close(); err != nil {
c.log.Debugf("close relay connection for transport fallback: %s", err)
}
}
func (c *Client) listenForStopEvents(ctx context.Context, hc *healthcheck.Receiver, conn net.Conn, internalStopFlag *internalStopFlag) {
for {
select {

View File

@@ -0,0 +1,18 @@
package dialer
// DatagramSized is implemented by dialers whose connections carry each write in
// a single datagram, so a write can be rejected when it exceeds the path's
// datagram budget (e.g. QUIC). Transports without this capability (e.g.
// WebSocket over TCP) impose no per-write size limit, so the relay client can
// fall back to them when a datagram-sized transport rejects a write as too
// large. The capability is advertised per dialer rather than hardcoded, so a
// new transport only needs to declare whether it is datagram-sized.
type DatagramSized interface {
DatagramSized()
}
// IsDatagramSized reports whether d produces datagram-sized connections.
func IsDatagramSized(d DialeFn) bool {
_, ok := d.(DatagramSized)
return ok
}

View File

@@ -4,4 +4,9 @@ import "errors"
var (
ErrClosedByServer = errors.New("closed by server")
// ErrDatagramTooLarge is returned when a transport message exceeds the
// QUIC datagram size the path to the relay can carry. The relay client
// treats it as a signal to fall back to a non-datagram transport.
ErrDatagramTooLarge = errors.New("datagram frame too large")
)

View File

@@ -8,7 +8,6 @@ import (
"time"
"github.com/quic-go/quic-go"
log "github.com/sirupsen/logrus"
netErr "github.com/netbirdio/netbird/shared/relay/client/dialer/net"
)
@@ -52,11 +51,8 @@ func (c *Conn) Read(b []byte) (n int, err error) {
}
func (c *Conn) Write(b []byte) (int, error) {
err := c.session.SendDatagram(b)
if err != nil {
err = c.remoteCloseErrHandling(err)
log.Errorf("failed to write to QUIC stream: %v", err)
return 0, err
if err := c.session.SendDatagram(b); err != nil {
return 0, c.writeErrHandling(err, len(b))
}
return len(b), nil
}
@@ -95,3 +91,15 @@ func (c *Conn) remoteCloseErrHandling(err error) error {
}
return err
}
// writeErrHandling normalizes SendDatagram errors. A datagram that exceeds the
// path's QUIC packet budget is mapped to ErrDatagramTooLarge (annotated with the
// datagram size and path budget) so the relay client can fall back to a
// non-datagram transport.
func (c *Conn) writeErrHandling(err error, size int) error {
var tooLarge *quic.DatagramTooLargeError
if errors.As(err, &tooLarge) {
return fmt.Errorf("%w: %d byte datagram over path budget %d", netErr.ErrDatagramTooLarge, size, tooLarge.MaxDatagramPayloadSize)
}
return c.remoteCloseErrHandling(err)
}

View File

@@ -9,6 +9,7 @@ import (
"time"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/logging"
log "github.com/sirupsen/logrus"
nbnet "github.com/netbirdio/netbird/client/net"
@@ -23,6 +24,12 @@ func (d Dialer) Protocol() string {
return Network
}
// DatagramSized marks QUIC as a datagram-sized transport: relay traffic is
// carried in QUIC DATAGRAM frames, which must fit a single packet.
func (d Dialer) DatagramSized() {
// Intentional marker method; presence is the capability signal.
}
func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn, error) {
quicURL, err := prepareURL(address)
if err != nil {
@@ -47,6 +54,7 @@ func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn,
MaxIdleTimeout: 4 * time.Minute,
EnableDatagrams: true,
InitialPacketSize: nbRelay.QUICInitialPacketSize,
Tracer: connectionTracer(quicURL),
}
udpConn, err := nbnet.ListenUDP("udp", &net.UDPAddr{Port: 0})
@@ -74,6 +82,28 @@ func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn,
return conn, nil
}
// connectionTracer returns a QUIC tracer that logs the DPLPMTUD result and the
// reason a relay connection closed, so the path MTU settled on and teardown
// cause are visible in logs. Lines carry the relay address as a structured
// field, matching the rest of the relay client logging.
func connectionTracer(addr string) func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
relayLog := log.WithField("relay", addr)
return func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
return &logging.ConnectionTracer{
UpdatedMTU: func(mtu logging.ByteCount, done bool) {
if done {
relayLog.Infof("QUIC path MTU settled at %d", mtu)
return
}
relayLog.Debugf("QUIC path MTU probing at %d", mtu)
},
ClosedConnection: func(err error) {
relayLog.Debugf("QUIC connection closed: %v", err)
},
}
}
}
func prepareURL(address string) (string, error) {
var host string
var defaultPort string

View File

@@ -32,6 +32,7 @@ type RaceDial struct {
serverName string
dialerFns []DialeFn
connectionTimeout time.Duration
sequential bool
}
func NewRaceDial(log *log.Entry, connectionTimeout time.Duration, serverURL string, dialerFns ...DialeFn) *RaceDial {
@@ -53,7 +54,21 @@ func (r *RaceDial) WithServerName(serverName string) *RaceDial {
return r
}
// WithSequential makes Dial try the dialers in order, falling back to the next
// only when one fails to connect, instead of racing them concurrently.
//
// Mutates the receiver and is not safe for concurrent reconfiguration; a
// RaceDial is intended to be constructed per dial and discarded.
func (r *RaceDial) WithSequential() *RaceDial {
r.sequential = true
return r
}
func (r *RaceDial) Dial(ctx context.Context) (net.Conn, error) {
if r.sequential {
return r.dialSequential(ctx)
}
connChan := make(chan dialResult, len(r.dialerFns))
winnerConn := make(chan net.Conn, 1)
abortCtx, abort := context.WithCancel(ctx)
@@ -72,6 +87,30 @@ func (r *RaceDial) Dial(ctx context.Context) (net.Conn, error) {
return conn, nil
}
// dialSequential tries each dialer in order, returning the first connection and
// falling back to the next on failure.
func (r *RaceDial) dialSequential(ctx context.Context) (net.Conn, error) {
for _, dfn := range r.dialerFns {
if err := ctx.Err(); err != nil {
return nil, err
}
attemptCtx, cancel := context.WithTimeout(ctx, r.connectionTimeout)
r.log.Infof("dialing Relay server via %s", dfn.Protocol())
conn, err := dfn.Dial(attemptCtx, r.serverURL, r.serverName)
cancel()
if err != nil {
if errors.Is(err, context.Canceled) {
return nil, err
}
r.log.Errorf("failed to dial via %s: %s", dfn.Protocol(), err)
continue
}
r.log.Infof("successfully dialed via: %s", dfn.Protocol())
return conn, nil
}
return nil, errors.New("failed to dial to Relay server on any protocol")
}
func (r *RaceDial) dial(dfn DialeFn, abortCtx context.Context, connChan chan dialResult) {
ctx, cancel := context.WithTimeout(abortCtx, r.connectionTimeout)
defer cancel()

View File

@@ -250,3 +250,66 @@ func TestRaceDialFirstSuccessfulDialerWins(t *testing.T) {
}
}
}
func TestRaceDialSequentialFallback(t *testing.T) {
logger := logrus.NewEntry(logrus.New())
serverURL := "test.server.com"
var firstDialed, secondDialed bool
preferred := &MockDialer{
protocolStr: "quic",
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
firstDialed = true
return nil, errors.New("quic unreachable")
},
}
fallbackConn := &MockConn{remoteAddr: &MockAddr{network: "ws"}}
fallback := &MockDialer{
protocolStr: "ws",
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
secondDialed = true
return fallbackConn, nil
},
}
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, preferred, fallback).WithSequential()
conn, err := rd.Dial(context.Background())
if err != nil {
t.Fatalf("expected fallback to succeed, got %v", err)
}
if conn != fallbackConn {
t.Errorf("expected fallback connection, got %v", conn)
}
if !firstDialed || !secondDialed {
t.Errorf("expected both dialers attempted in order, first=%v second=%v", firstDialed, secondDialed)
}
}
func TestRaceDialSequentialPreferredWins(t *testing.T) {
logger := logrus.NewEntry(logrus.New())
serverURL := "test.server.com"
preferredConn := &MockConn{remoteAddr: &MockAddr{network: "quic"}}
preferred := &MockDialer{
protocolStr: "quic",
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
return preferredConn, nil
},
}
fallback := &MockDialer{
protocolStr: "ws",
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
t.Errorf("fallback dialer must not be tried when preferred succeeds")
return nil, errors.New("should not happen")
},
}
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, preferred, fallback).WithSequential()
conn, err := rd.Dial(context.Background())
if err != nil {
t.Fatalf("expected preferred to succeed, got %v", err)
}
if conn != preferredConn {
t.Errorf("expected preferred connection, got %v", conn)
}
}

View File

@@ -9,11 +9,42 @@ import (
"github.com/netbirdio/netbird/shared/relay/client/dialer/ws"
)
// getDialers returns the list of dialers to use for connecting to the relay server.
func (c *Client) getDialers() []dialer.DialeFn {
if c.mtu > 0 && c.mtu > iface.DefaultMTU {
c.log.Infof("MTU %d exceeds default (%d), forcing WebSocket transport to avoid DATAGRAM frame size issues", c.mtu, iface.DefaultMTU)
return []dialer.DialeFn{ws.Dialer{}}
// getDialers returns the ordered dialers for connecting to the relay server. It
// applies the datagram fallback generically: if this server recently rejected a
// datagram-sized transport, those dialers are dropped, leaving the rest.
func (c *Client) getDialers(mode TransportMode) []dialer.DialeFn {
dialers := c.baseDialers(mode)
if c.transportFallback != nil && c.transportFallback.avoidDatagramSized(c.connectionURL) {
if filtered := nonDatagramSized(dialers); len(filtered) > 0 {
c.log.Infof("relay recently rejected a datagram-sized transport, avoiding it")
return filtered
}
}
return []dialer.DialeFn{quic.Dialer{}, ws.Dialer{}}
return dialers
}
// baseDialers returns the ordered dialers for the mode, before any datagram
// fallback filtering. For racing modes (auto) the order is irrelevant; for
// prefer modes the first entry is tried before falling back to the second.
func (c *Client) baseDialers(mode TransportMode) []dialer.DialeFn {
switch mode {
case TransportModeWS:
c.log.Infof("%s=ws, using WebSocket transport", EnvRelayTransport)
return []dialer.DialeFn{ws.Dialer{}}
case TransportModeQUIC:
c.log.Infof("%s=quic, using QUIC transport", EnvRelayTransport)
return []dialer.DialeFn{quic.Dialer{}}
}
all := []dialer.DialeFn{quic.Dialer{}, ws.Dialer{}}
if mode == TransportModePreferWS {
all = []dialer.DialeFn{ws.Dialer{}, quic.Dialer{}}
}
if c.mtu > 0 && c.mtu > iface.DefaultMTU {
c.log.Infof("MTU %d exceeds default (%d), avoiding datagram-sized transports", c.mtu, iface.DefaultMTU)
return nonDatagramSized(all)
}
return all
}

View File

@@ -0,0 +1,101 @@
//go:build !js
package client
import (
"os"
"testing"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/shared/relay/client/dialer"
netErr "github.com/netbirdio/netbird/shared/relay/client/dialer/net"
"github.com/netbirdio/netbird/shared/relay/client/dialer/quic"
"github.com/netbirdio/netbird/shared/relay/client/dialer/ws"
)
// TestDatagramSizedCapability locks the capability the generic fallback relies
// on: QUIC is datagram-sized, WebSocket is not.
func TestDatagramSizedCapability(t *testing.T) {
assert.True(t, dialer.IsDatagramSized(quic.Dialer{}), "QUIC must advertise datagram-sized")
assert.False(t, dialer.IsDatagramSized(ws.Dialer{}), "WebSocket must not advertise datagram-sized")
}
func protocols(dialers []dialer.DialeFn) []string {
out := make([]string, len(dialers))
for i, d := range dialers {
out[i] = d.Protocol()
}
return out
}
func TestGetDialers(t *testing.T) {
const url = "rels://relay.example:443"
tests := []struct {
name string
mode string
mtu uint16
preferWS bool
want []string
}{
{name: "auto races quic and ws", mode: "auto", mtu: iface.DefaultMTU, want: []string{"quic", "WS"}},
{name: "ws pinned", mode: "ws", mtu: iface.DefaultMTU, want: []string{"WS"}},
{name: "quic pinned", mode: "quic", mtu: iface.DefaultMTU, want: []string{"quic"}},
{name: "prefer-quic orders quic first", mode: "prefer-quic", mtu: iface.DefaultMTU, want: []string{"quic", "WS"}},
{name: "prefer-ws orders ws first", mode: "prefer-ws", mtu: iface.DefaultMTU, want: []string{"WS", "quic"}},
{name: "mtu above default forces ws", mode: "auto", mtu: iface.DefaultMTU + 100, want: []string{"WS"}},
{name: "sticky fallback forces ws in auto", mode: "auto", mtu: iface.DefaultMTU, preferWS: true, want: []string{"WS"}},
{name: "sticky fallback forces ws in prefer-quic", mode: "prefer-quic", mtu: iface.DefaultMTU, preferWS: true, want: []string{"WS"}},
{name: "quic pin overrides sticky fallback", mode: "quic", mtu: iface.DefaultMTU, preferWS: true, want: []string{"quic"}},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Setenv(EnvRelayTransport, tc.mode)
if tc.mode == "" {
os.Unsetenv(EnvRelayTransport)
}
tf := newTransportFallback()
if tc.preferWS {
tf.recordFailure(url)
}
c := &Client{
log: log.WithField("test", t.Name()),
connectionURL: url,
mtu: tc.mtu,
transportFallback: tf,
}
assert.Equal(t, tc.want, protocols(c.getDialers(transportModeFromEnv())))
})
}
}
// TestStickyFallbackAfterDatagramTooLarge verifies the full chain: an oversized
// datagram records a fallback that makes the next dial pick WebSocket, the way a
// reconnect would after the connection is closed.
func TestStickyFallbackAfterDatagramTooLarge(t *testing.T) {
const url = "rels://relay.example:443"
t.Setenv(EnvRelayTransport, string(TransportModeAuto))
c := &Client{
log: log.WithField("test", t.Name()),
connectionURL: url,
mtu: iface.DefaultMTU,
transportFallback: newTransportFallback(),
}
// First dial races both transports.
assert.Equal(t, []string{"quic", "WS"}, protocols(c.getDialers(transportModeFromEnv())))
// An oversized datagram records the fallback for this server.
c.onDatagramTooLarge(&closeTrackingConn{}, netErr.ErrDatagramTooLarge)
// The reconnect now sticks to WebSocket.
assert.Equal(t, []string{"WS"}, protocols(c.getDialers(transportModeFromEnv())))
}

View File

@@ -7,7 +7,11 @@ import (
"github.com/netbirdio/netbird/shared/relay/client/dialer/ws"
)
func (c *Client) getDialers() []dialer.DialeFn {
func (c *Client) getDialers(_ TransportMode) []dialer.DialeFn {
// JS/WASM build only uses WebSocket transport
return []dialer.DialeFn{ws.Dialer{}}
}
func (c *Client) baseDialers(_ TransportMode) []dialer.DialeFn {
return []dialer.DialeFn{ws.Dialer{}}
}

View File

@@ -79,23 +79,30 @@ type Manager struct {
cleanupInterval time.Duration
keepUnusedServerTime time.Duration
// transportFallback is shared across home and foreign relay clients so a
// datagram-too-large failure makes that server avoid datagram-sized transports across reconnects.
transportFallback *transportFallback
}
// NewManager creates a new manager instance.
// The serverURL address can be empty. In this case, the manager will not serve.
func NewManager(ctx context.Context, serverURLs []string, peerID string, mtu uint16, opts ...ManagerOption) *Manager {
tokenStore := &relayAuth.TokenStore{}
tf := newTransportFallback()
m := &Manager{
ctx: ctx,
peerID: peerID,
tokenStore: tokenStore,
mtu: mtu,
ctx: ctx,
peerID: peerID,
tokenStore: tokenStore,
mtu: mtu,
transportFallback: tf,
serverPicker: &ServerPicker{
TokenStore: tokenStore,
PeerID: peerID,
MTU: mtu,
ConnectionTimeout: defaultConnectionTimeout,
TransportFallback: tf,
},
relayClients: make(map[string]*RelayTrack),
onDisconnectedListeners: make(map[string]*list.List),
@@ -287,6 +294,7 @@ func (m *Manager) openConnVia(ctx context.Context, serverAddress, peerKey string
m.relayClientsMutex.Unlock()
relayClient := NewClientWithServerIP(serverAddress, serverIP, m.tokenStore, m.peerID, m.mtu)
relayClient.SetTransportFallback(m.transportFallback)
err := relayClient.Connect(m.ctx)
if err != nil {
rt.err = err

View File

@@ -29,6 +29,7 @@ type ServerPicker struct {
PeerID string
MTU uint16
ConnectionTimeout time.Duration
TransportFallback *transportFallback
}
func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) {
@@ -70,6 +71,7 @@ func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) {
func (sp *ServerPicker) startConnection(ctx context.Context, resultChan chan connResult, url string) {
log.Infof("try to connecting to relay server: %s", url)
relayClient := NewClient(url, sp.TokenStore, sp.PeerID, sp.MTU)
relayClient.SetTransportFallback(sp.TransportFallback)
err := relayClient.Connect(ctx)
resultChan <- connResult{
RelayClient: relayClient,

View File

@@ -0,0 +1,129 @@
package client
import (
"os"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/shared/relay/client/dialer"
)
// EnvRelayTransport pins the relay transport. Valid values: "auto" (default,
// race QUIC and WebSocket), "quic" (QUIC only), "ws" (WebSocket only),
// "prefer-quic" / "prefer-ws" (try the preferred transport first, fall back to
// the other only if it fails to connect; no race). The prefer modes trade a
// slower connect when the preferred transport is blackholed for deterministic
// transport selection.
const EnvRelayTransport = "NB_RELAY_TRANSPORT"
const (
// transportFallbackBase is the initial window a relay server avoids
// datagram-sized transports after a datagram is rejected as too large.
transportFallbackBase = 10 * time.Minute
// transportFallbackMax caps the pinned window when failures repeat.
transportFallbackMax = 60 * time.Minute
)
// TransportMode selects which relay dialers are used.
type TransportMode string
const (
TransportModeAuto TransportMode = "auto"
TransportModeQUIC TransportMode = "quic"
TransportModeWS TransportMode = "ws"
TransportModePreferQUIC TransportMode = "prefer-quic"
TransportModePreferWS TransportMode = "prefer-ws"
)
// transportModeFromEnv reads EnvRelayTransport, defaulting to auto for an empty
// or unrecognized value.
func transportModeFromEnv() TransportMode {
switch TransportMode(strings.ToLower(strings.TrimSpace(os.Getenv(EnvRelayTransport)))) {
case "", TransportModeAuto:
return TransportModeAuto
case TransportModeQUIC:
return TransportModeQUIC
case TransportModeWS:
return TransportModeWS
case TransportModePreferQUIC:
return TransportModePreferQUIC
case TransportModePreferWS:
return TransportModePreferWS
default:
log.Warnf("invalid %s value %q, using %q", EnvRelayTransport, os.Getenv(EnvRelayTransport), TransportModeAuto)
return TransportModeAuto
}
}
// sequential reports whether the mode tries dialers in order with fallback
// instead of racing them concurrently.
func (m TransportMode) sequential() bool {
return m == TransportModePreferQUIC || m == TransportModePreferWS
}
// transportFallback tracks relay servers that have rejected a datagram-sized
// transport (a write too large for the path) and should temporarily avoid such
// transports. It is shared across the relay manager so the preference survives
// client recreation (foreign relay clients are evicted and rebuilt on
// disconnect). Entries are keyed by server URL and expire after a window that
// grows on repeated failures.
type transportFallback struct {
mu sync.Mutex
entries map[string]*fallbackEntry
}
type fallbackEntry struct {
until time.Time
duration time.Duration
}
func newTransportFallback() *transportFallback {
return &transportFallback{entries: make(map[string]*fallbackEntry)}
}
// avoidDatagramSized reports whether serverURL is currently within a window
// where datagram-sized transports should be avoided.
func (f *transportFallback) avoidDatagramSized(serverURL string) bool {
f.mu.Lock()
defer f.mu.Unlock()
e := f.entries[serverURL]
return e != nil && time.Now().Before(e.until)
}
// recordFailure makes serverURL avoid datagram-sized transports for a window:
// transportFallbackBase on the first failure, doubling up to transportFallbackMax
// when a datagram transport fails again after a previous window expired. It
// returns the active window duration.
func (f *transportFallback) recordFailure(serverURL string) time.Duration {
f.mu.Lock()
defer f.mu.Unlock()
now := time.Now()
e := f.entries[serverURL]
switch {
case e == nil:
e = &fallbackEntry{duration: transportFallbackBase}
f.entries[serverURL] = e
case now.Before(e.until):
return time.Until(e.until)
default:
e.duration = min(e.duration*2, transportFallbackMax)
}
e.until = now.Add(e.duration)
return e.duration
}
// nonDatagramSized returns the dialers from in that are not datagram-sized,
// preserving order.
func nonDatagramSized(in []dialer.DialeFn) []dialer.DialeFn {
out := make([]dialer.DialeFn, 0, len(in))
for _, d := range in {
if !dialer.IsDatagramSized(d) {
out = append(out, d)
}
}
return out
}

View File

@@ -0,0 +1,140 @@
package client
import (
"net"
"os"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
netErr "github.com/netbirdio/netbird/shared/relay/client/dialer/net"
)
// closeTrackingConn records whether Close was called; only Close is exercised.
type closeTrackingConn struct {
net.Conn
closed bool
}
func (c *closeTrackingConn) Close() error {
c.closed = true
return nil
}
func TestTransportModeFromEnv(t *testing.T) {
tests := []struct {
value string
want TransportMode
}{
{"", TransportModeAuto},
{"auto", TransportModeAuto},
{"quic", TransportModeQUIC},
{"QUIC", TransportModeQUIC},
{"ws", TransportModeWS},
{" Ws ", TransportModeWS},
{"prefer-quic", TransportModePreferQUIC},
{"prefer-ws", TransportModePreferWS},
{"garbage", TransportModeAuto},
}
for _, tc := range tests {
t.Run(tc.value, func(t *testing.T) {
t.Setenv(EnvRelayTransport, tc.value)
if tc.value == "" {
os.Unsetenv(EnvRelayTransport)
}
assert.Equal(t, tc.want, transportModeFromEnv())
})
}
}
func TestTransportFallbackRecordAndExpiry(t *testing.T) {
const url = "rels://relay.example:443"
f := newTransportFallback()
assert.False(t, f.avoidDatagramSized(url), "no fallback recorded yet")
d := f.recordFailure(url)
assert.Equal(t, transportFallbackBase, d, "first failure pins for the base window")
assert.True(t, f.avoidDatagramSized(url), "datagram-sized transport avoided within the window")
// A second failure while still inside the window must not grow the window.
d = f.recordFailure(url)
assert.LessOrEqual(t, d, transportFallbackBase, "still within the active window")
require.NotNil(t, f.entries[url])
assert.Equal(t, transportFallbackBase, f.entries[url].duration, "duration unchanged inside window")
// Expire the window: datagram-sized transport allowed again.
f.entries[url].until = time.Now().Add(-time.Second)
assert.False(t, f.avoidDatagramSized(url), "window expired, datagram-sized transport allowed")
}
func TestTransportFallbackGrowsOnRepeat(t *testing.T) {
const url = "rels://relay.example:443"
f := newTransportFallback()
want := transportFallbackBase
for i := range 6 {
d := f.recordFailure(url)
assert.Equal(t, want, d, "window after %d expiries", i)
// expire the window so the next failure is treated as a repeat
f.entries[url].until = time.Now().Add(-time.Second)
want = min(want*2, transportFallbackMax)
}
assert.Equal(t, transportFallbackMax, f.entries[url].duration, "window caps at the max")
}
func TestOnDatagramTooLargeAuto(t *testing.T) {
const url = "rels://relay.example:443"
t.Setenv(EnvRelayTransport, string(TransportModeAuto))
tf := newTransportFallback()
c := &Client{
log: log.WithField("test", t.Name()),
connectionURL: url,
transportFallback: tf,
}
conn := &closeTrackingConn{}
c.onDatagramTooLarge(conn, netErr.ErrDatagramTooLarge)
assert.True(t, conn.closed, "connection closed to force reconnect")
assert.True(t, tf.avoidDatagramSized(url), "fallback recorded for the server")
// A second oversized datagram on the same connection must not re-close.
conn.closed = false
c.onDatagramTooLarge(conn, netErr.ErrDatagramTooLarge)
assert.False(t, conn.closed, "single fallback per connection")
}
func TestOnDatagramTooLargeQUICPinned(t *testing.T) {
const url = "rels://relay.example:443"
t.Setenv(EnvRelayTransport, string(TransportModeQUIC))
tf := newTransportFallback()
c := &Client{
log: log.WithField("test", t.Name()),
connectionURL: url,
transportFallback: tf,
}
conn := &closeTrackingConn{}
c.onDatagramTooLarge(conn, netErr.ErrDatagramTooLarge)
assert.False(t, conn.closed, "QUIC pin keeps the connection, no fallback redial")
assert.False(t, tf.avoidDatagramSized(url), "QUIC pin records no fallback")
}
func TestTransportFallbackPerServer(t *testing.T) {
f := newTransportFallback()
f.recordFailure("rels://a.example:443")
assert.True(t, f.avoidDatagramSized("rels://a.example:443"))
assert.False(t, f.avoidDatagramSized("rels://b.example:443"), "fallback is scoped to one server")
}