From 50b58a682868851a3666a99662aeb00d7fbb3846 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 4 May 2026 18:40:25 +0900 Subject: [PATCH 01/15] [client, relay] Advertise relay server IP via signal for foreign-relay fallback dial (#6004) --- client/internal/engine.go | 18 ++ client/internal/peer/handshaker.go | 8 +- client/internal/peer/signaler.go | 20 +- client/internal/peer/status.go | 2 +- client/internal/peer/worker_relay.go | 11 +- shared/relay/client/client.go | 116 +++++++- shared/relay/client/client_serverip_test.go | 280 ++++++++++++++++++ shared/relay/client/dialer/quic/quic.go | 15 +- shared/relay/client/dialer/race_dialer.go | 17 +- .../relay/client/dialer/race_dialer_test.go | 2 +- shared/relay/client/dialer/ws/conn.go | 16 +- .../client/dialer/ws/dialopts_generic.go | 10 +- shared/relay/client/dialer/ws/dialopts_js.go | 10 +- shared/relay/client/dialer/ws/ws.go | 21 +- shared/relay/client/manager.go | 37 ++- shared/relay/client/manager_serverip_test.go | 144 +++++++++ shared/relay/client/manager_test.go | 19 +- shared/signal/client/client.go | 69 +++-- shared/signal/proto/signalexchange.pb.go | 88 +++--- shared/signal/proto/signalexchange.proto | 10 +- 20 files changed, 789 insertions(+), 124 deletions(-) create mode 100644 shared/relay/client/client_serverip_test.go create mode 100644 shared/relay/client/manager_serverip_test.go diff --git a/client/internal/engine.go b/client/internal/engine.go index 8c9553e52..7f19e2d28 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -2454,6 +2454,8 @@ func convertToOfferAnswer(msg *sProto.Message) (*peer.OfferAnswer, error) { } } + relayIP := decodeRelayIP(msg.GetBody().GetRelayServerIP()) + offerAnswer := peer.OfferAnswer{ IceCredentials: peer.IceCredentials{ UFrag: remoteCred.UFrag, @@ -2464,7 +2466,23 @@ func convertToOfferAnswer(msg *sProto.Message) (*peer.OfferAnswer, error) { RosenpassPubKey: rosenpassPubKey, RosenpassAddr: rosenpassAddr, RelaySrvAddress: msg.GetBody().GetRelayServerAddress(), + RelaySrvIP: relayIP, SessionID: sessionID, } return &offerAnswer, nil } + +// decodeRelayIP decodes the proto relayServerIP bytes (4 or 16) into a +// netip.Addr. Returns the zero value for empty input and logs a warning +// for malformed payloads. +func decodeRelayIP(b []byte) netip.Addr { + if len(b) == 0 { + return netip.Addr{} + } + ip, ok := netip.AddrFromSlice(b) + if !ok { + log.Warnf("invalid relayServerIP in signal message (%d bytes), ignoring", len(b)) + return netip.Addr{} + } + return ip.Unmap() +} diff --git a/client/internal/peer/handshaker.go b/client/internal/peer/handshaker.go index 741dfce60..1d44096b6 100644 --- a/client/internal/peer/handshaker.go +++ b/client/internal/peer/handshaker.go @@ -3,6 +3,7 @@ package peer import ( "context" "errors" + "net/netip" "sync" "sync/atomic" @@ -40,6 +41,10 @@ type OfferAnswer struct { // relay server address RelaySrvAddress string + // RelaySrvIP is the IP the remote peer is connected to on its + // relay server. Used as a dial target if DNS for RelaySrvAddress + // fails. Zero value if the peer did not advertise an IP. + RelaySrvIP netip.Addr // SessionID is the unique identifier of the session, used to discard old messages SessionID *ICESessionID } @@ -217,8 +222,9 @@ func (h *Handshaker) buildOfferAnswer() OfferAnswer { answer.SessionID = &sid } - if addr, err := h.relay.RelayInstanceAddress(); err == nil { + if addr, ip, err := h.relay.RelayInstanceAddress(); err == nil { answer.RelaySrvAddress = addr + answer.RelaySrvIP = ip } return answer diff --git a/client/internal/peer/signaler.go b/client/internal/peer/signaler.go index f6eb87cca..5e437d96b 100644 --- a/client/internal/peer/signaler.go +++ b/client/internal/peer/signaler.go @@ -54,19 +54,19 @@ func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string, log.Warnf("failed to get session ID bytes: %v", err) } } - msg, err := signal.MarshalCredential( - s.wgPrivateKey, - offerAnswer.WgListenPort, - remoteKey, - &signal.Credential{ + msg, err := signal.MarshalCredential(s.wgPrivateKey, remoteKey, signal.CredentialPayload{ + Type: bodyType, + WgListenPort: offerAnswer.WgListenPort, + Credential: &signal.Credential{ UFrag: offerAnswer.IceCredentials.UFrag, Pwd: offerAnswer.IceCredentials.Pwd, }, - bodyType, - offerAnswer.RosenpassPubKey, - offerAnswer.RosenpassAddr, - offerAnswer.RelaySrvAddress, - sessionIDBytes) + RosenpassPubKey: offerAnswer.RosenpassPubKey, + RosenpassAddr: offerAnswer.RosenpassAddr, + RelaySrvAddress: offerAnswer.RelaySrvAddress, + RelaySrvIP: offerAnswer.RelaySrvIP, + SessionID: sessionIDBytes, + }) if err != nil { return err } diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index abedc208e..7bd19b0e1 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -919,7 +919,7 @@ func (d *Status) GetRelayStates() []relay.ProbeResult { // if the server connection is not established then we will use the general address // in case of connection we will use the instance specific address - instanceAddr, err := d.relayMgr.RelayInstanceAddress() + instanceAddr, _, err := d.relayMgr.RelayInstanceAddress() if err != nil { // TODO add their status for _, r := range d.relayMgr.ServerURLs() { diff --git a/client/internal/peer/worker_relay.go b/client/internal/peer/worker_relay.go index 06309fbaf..0402992c9 100644 --- a/client/internal/peer/worker_relay.go +++ b/client/internal/peer/worker_relay.go @@ -4,6 +4,7 @@ import ( "context" "errors" "net" + "net/netip" "sync" "sync/atomic" @@ -53,15 +54,19 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) { w.relaySupportedOnRemotePeer.Store(true) // the relayManager will return with error in case if the connection has lost with relay server - currentRelayAddress, err := w.relayManager.RelayInstanceAddress() + currentRelayAddress, _, err := w.relayManager.RelayInstanceAddress() if err != nil { w.log.Errorf("failed to handle new offer: %s", err) return } srv := w.preferredRelayServer(currentRelayAddress, remoteOfferAnswer.RelaySrvAddress) + var serverIP netip.Addr + if srv == remoteOfferAnswer.RelaySrvAddress { + serverIP = remoteOfferAnswer.RelaySrvIP + } - relayedConn, err := w.relayManager.OpenConn(w.peerCtx, srv, w.config.Key) + relayedConn, err := w.relayManager.OpenConn(w.peerCtx, srv, w.config.Key, serverIP) if err != nil { if errors.Is(err, relayClient.ErrConnAlreadyExists) { w.log.Debugf("handled offer by reusing existing relay connection") @@ -90,7 +95,7 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) { }) } -func (w *WorkerRelay) RelayInstanceAddress() (string, error) { +func (w *WorkerRelay) RelayInstanceAddress() (string, netip.Addr, error) { return w.relayManager.RelayInstanceAddress() } diff --git a/shared/relay/client/client.go b/shared/relay/client/client.go index b10b05617..1800bddb2 100644 --- a/shared/relay/client/client.go +++ b/shared/relay/client/client.go @@ -2,8 +2,12 @@ package client import ( "context" + "errors" "fmt" "net" + "net/netip" + "net/url" + "strings" "sync" "time" @@ -146,6 +150,7 @@ func (cc *connContainer) close() { type Client struct { log *log.Entry connectionURL string + serverIP netip.Addr authTokenStore *auth.TokenStore hashedID messages.PeerID @@ -170,13 +175,22 @@ type Client struct { } // NewClient creates a new client for the relay server. The client is not connected to the server until the Connect +// is called. func NewClient(serverURL string, authTokenStore *auth.TokenStore, peerID string, mtu uint16) *Client { + return NewClientWithServerIP(serverURL, netip.Addr{}, authTokenStore, peerID, mtu) +} + +// NewClientWithServerIP creates a new client for the relay server with a known server IP. serverIP, when valid, is +// dialed directly first; the FQDN is only attempted if the IP-based dial fails. TLS verification still uses the +// FQDN from serverURL via SNI. +func NewClientWithServerIP(serverURL string, serverIP netip.Addr, authTokenStore *auth.TokenStore, peerID string, mtu uint16) *Client { hashedID := messages.HashID(peerID) relayLog := log.WithFields(log.Fields{"relay": serverURL}) c := &Client{ log: relayLog, connectionURL: serverURL, + serverIP: serverIP, authTokenStore: authTokenStore, hashedID: hashedID, mtu: mtu, @@ -304,6 +318,23 @@ func (c *Client) ServerInstanceURL() (string, error) { return c.instanceURL.String(), nil } +// ConnectedIP returns the IP address of the live relay-server connection, +// extracted from the underlying socket's RemoteAddr. Zero value if not +// connected or if the address is not an IP literal. +func (c *Client) ConnectedIP() netip.Addr { + c.mu.Lock() + conn := c.relayConn + c.mu.Unlock() + if conn == nil { + return netip.Addr{} + } + addr := conn.RemoteAddr() + if addr == nil { + return netip.Addr{} + } + return extractIPLiteral(addr.String()) +} + // SetOnDisconnectListener sets a function that will be called when the connection to the relay server is closed. func (c *Client) SetOnDisconnectListener(fn func(string)) { c.listenerMutex.Lock() @@ -332,10 +363,23 @@ func (c *Client) Close() error { func (c *Client) connect(ctx context.Context) (*RelayAddr, error) { dialers := c.getDialers() - rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, dialers...) - conn, err := rd.Dial(ctx) - if err != nil { - return nil, err + var conn net.Conn + if c.serverIP.IsValid() { + var err error + conn, err = c.dialRaceDirect(ctx, dialers) + if err != nil { + c.log.Infof("dial via server IP %s failed, falling back to FQDN: %v", c.serverIP, err) + conn = nil + } + } + + if conn == nil { + rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, dialers...) + var err error + conn, err = rd.Dial(ctx) + if err != nil { + return nil, fmt.Errorf("dial via FQDN: %w", err) + } } c.relayConn = conn @@ -351,6 +395,52 @@ func (c *Client) connect(ctx context.Context) (*RelayAddr, error) { return instanceURL, nil } +// dialRaceDirect dials c.serverIP, preserving the original FQDN as the TLS ServerName for SNI. +func (c *Client) dialRaceDirect(ctx context.Context, dialers []dialer.DialeFn) (net.Conn, error) { + directURL, serverName, err := substituteHost(c.connectionURL, c.serverIP) + if err != nil { + return nil, fmt.Errorf("substitute host: %w", err) + } + + c.log.Debugf("dialing via server IP %s (SNI=%s)", c.serverIP, serverName) + + rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, directURL, dialers...). + WithServerName(serverName) + return rd.Dial(ctx) +} + +// substituteHost replaces the host portion of a rel/rels URL with ip, +// preserving the scheme and port. Returns the rewritten URL and the +// original host to use as the TLS ServerName, or empty if the original +// host is itself an IP literal (SNI requires a DNS name). +func substituteHost(serverURL string, ip netip.Addr) (string, string, error) { + u, err := url.Parse(serverURL) + if err != nil { + return "", "", fmt.Errorf("parse %q: %w", serverURL, err) + } + if u.Scheme == "" || u.Host == "" { + return "", "", fmt.Errorf("invalid relay URL %q", serverURL) + } + if !ip.IsValid() { + return "", "", errors.New("invalid server IP") + } + origHost := u.Hostname() + if _, err := netip.ParseAddr(origHost); err == nil { + origHost = "" + } + ip = ip.Unmap() + newHost := ip.String() + if ip.Is6() { + newHost = "[" + newHost + "]" + } + if port := u.Port(); port != "" { + u.Host = newHost + ":" + port + } else { + u.Host = newHost + } + return u.String(), origHost, nil +} + func (c *Client) handShake(ctx context.Context) (*RelayAddr, error) { msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary()) if err != nil { @@ -716,3 +806,21 @@ func (c *Client) handlePeersWentOfflineMsg(buf []byte) { } c.stateSubscription.OnPeersWentOffline(peersID) } + +// extractIPLiteral returns the IP from address forms produced by the relay +// dialers (URL or host:port). Zero value if the host is not an IP. +func extractIPLiteral(s string) netip.Addr { + if u, err := url.Parse(s); err == nil && u.Host != "" { + s = u.Host + } + host, _, err := net.SplitHostPort(s) + if err != nil { + host = s + } + host = strings.Trim(host, "[]") + ip, err := netip.ParseAddr(host) + if err != nil { + return netip.Addr{} + } + return ip.Unmap() +} diff --git a/shared/relay/client/client_serverip_test.go b/shared/relay/client/client_serverip_test.go new file mode 100644 index 000000000..7e699e37d --- /dev/null +++ b/shared/relay/client/client_serverip_test.go @@ -0,0 +1,280 @@ +package client + +import ( + "context" + "fmt" + "net" + "net/netip" + "testing" + "time" + + "go.opentelemetry.io/otel" + + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/relay/server" + "github.com/netbirdio/netbird/shared/relay/auth/allow" +) + +// TestClient_ServerIPRecoversFromUnresolvableFQDN verifies that when the +// primary FQDN-based dial fails (unresolvable .invalid host), Connect +// recovers via the server IP and SNI still uses the FQDN. +func TestClient_ServerIPRecoversFromUnresolvableFQDN(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + listenAddr, port := freeAddr(t) + srvCfg := server.Config{ + Meter: otel.Meter(""), + ExposedAddress: fmt.Sprintf("rel://test-unresolvable-host.invalid:%d", port), + TLSSupport: false, + AuthValidator: &allow.Auth{}, + } + srv, err := server.NewServer(srvCfg) + if err != nil { + t.Fatalf("create server: %s", err) + } + + errChan := make(chan error, 1) + go func() { + if err := srv.Listen(server.ListenerConfig{Address: listenAddr}); err != nil { + errChan <- err + } + }() + t.Cleanup(func() { + if err := srv.Shutdown(context.Background()); err != nil { + t.Errorf("shutdown server: %s", err) + } + }) + if err := waitForServerToStart(errChan); err != nil { + t.Fatalf("server failed to start: %s", err) + } + + t.Run("no server IP, primary fails", func(t *testing.T) { + c := NewClient(srvCfg.ExposedAddress, hmacTokenStore, "alice-noip", iface.DefaultMTU) + err := c.Connect(ctx) + if err == nil { + _ = c.Close() + t.Fatalf("expected connect to fail without server IP, got nil") + } + }) + + t.Run("server IP recovers", func(t *testing.T) { + c := NewClientWithServerIP(srvCfg.ExposedAddress, netip.MustParseAddr("127.0.0.1"), hmacTokenStore, "alice-with-ip", iface.DefaultMTU) + if err := c.Connect(ctx); err != nil { + t.Fatalf("connect with server IP: %s", err) + } + t.Cleanup(func() { _ = c.Close() }) + + if !c.Ready() { + t.Fatalf("client not ready after connect") + } + if got := c.ConnectedIP(); got.String() != "127.0.0.1" { + t.Fatalf("ConnectedIP = %q, want 127.0.0.1", got) + } + }) +} + +// TestClient_ConnectedIPAfterFQDNDial verifies ConnectedIP returns the +// resolved IP after a successful FQDN-based dial. The underlying socket's +// RemoteAddr must be exposed through the dialer wrappers; if it returns +// the dial-time URL instead, ConnectedIP returns empty and the dial +// IP we advertise to peers is empty too. +func TestClient_ConnectedIPAfterFQDNDial(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + listenAddr, port := freeAddr(t) + srvCfg := server.Config{ + Meter: otel.Meter(""), + ExposedAddress: fmt.Sprintf("rel://localhost:%d", port), + TLSSupport: false, + AuthValidator: &allow.Auth{}, + } + srv, err := server.NewServer(srvCfg) + if err != nil { + t.Fatalf("create server: %s", err) + } + errChan := make(chan error, 1) + go func() { + if err := srv.Listen(server.ListenerConfig{Address: listenAddr}); err != nil { + errChan <- err + } + }() + t.Cleanup(func() { _ = srv.Shutdown(context.Background()) }) + if err := waitForServerToStart(errChan); err != nil { + t.Fatalf("server failed to start: %s", err) + } + + c := NewClient(srvCfg.ExposedAddress, hmacTokenStore, "alice-fqdn", iface.DefaultMTU) + if err := c.Connect(ctx); err != nil { + t.Fatalf("connect: %s", err) + } + t.Cleanup(func() { _ = c.Close() }) + + got := c.ConnectedIP().String() + if got != "127.0.0.1" && got != "::1" { + t.Fatalf("ConnectedIP after FQDN dial = %q, want 127.0.0.1 or ::1", got) + } +} + +func TestSubstituteHost(t *testing.T) { + tests := []struct { + name string + serverURL string + ip string + wantURL string + wantServerName string + wantErr bool + }{ + { + name: "rels with port", + serverURL: "rels://relay.netbird.io:443", + ip: "10.0.0.5", + wantURL: "rels://10.0.0.5:443", + wantServerName: "relay.netbird.io", + }, + { + name: "rel with port", + serverURL: "rel://relay.example.com:80", + ip: "192.0.2.1", + wantURL: "rel://192.0.2.1:80", + wantServerName: "relay.example.com", + }, + { + name: "ipv6 server IP bracketed", + serverURL: "rels://relay.example.com:443", + ip: "2001:db8::1", + wantURL: "rels://[2001:db8::1]:443", + wantServerName: "relay.example.com", + }, + { + name: "no port", + serverURL: "rels://relay.example.com", + ip: "10.0.0.5", + wantURL: "rels://10.0.0.5", + wantServerName: "relay.example.com", + }, + { + name: "ipv6 server with port returns empty SNI", + serverURL: "rels://[2001:db8::5]:443", + ip: "10.0.0.5", + wantURL: "rels://10.0.0.5:443", + wantServerName: "", + }, + { + name: "ipv4 server with port returns empty SNI", + serverURL: "rels://10.0.0.5:443", + ip: "10.0.0.6", + wantURL: "rels://10.0.0.6:443", + wantServerName: "", + }, + { + name: "ipv6 server IP no port", + serverURL: "rels://relay.example.com", + ip: "2001:db8::1", + wantURL: "rels://[2001:db8::1]", + wantServerName: "relay.example.com", + }, + { + name: "missing scheme", + serverURL: "relay.example.com:443", + ip: "10.0.0.5", + wantErr: true, + }, + { + name: "empty", + serverURL: "", + ip: "10.0.0.5", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var ip netip.Addr + if tt.ip != "" { + ip = netip.MustParseAddr(tt.ip) + } + gotURL, gotName, err := substituteHost(tt.serverURL, ip) + if tt.wantErr { + if err == nil { + t.Fatalf("expected error, got nil") + } + return + } + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + if gotURL != tt.wantURL { + t.Errorf("URL = %q, want %q", gotURL, tt.wantURL) + } + if gotName != tt.wantServerName { + t.Errorf("ServerName = %q, want %q", gotName, tt.wantServerName) + } + }) + } +} + +func TestClient_ConnectedIPEmptyWhenNotConnected(t *testing.T) { + c := NewClient("rel://example.invalid:80", hmacTokenStore, "x", iface.DefaultMTU) + if got := c.ConnectedIP(); got.IsValid() { + t.Fatalf("ConnectedIP on disconnected client = %q, want zero", got) + } +} + +// staticAddr is a net.Addr that returns a fixed string. Used to verify +// ConnectedIP parses RemoteAddr correctly. +type staticAddr struct{ s string } + +func (a staticAddr) Network() string { return "tcp" } +func (a staticAddr) String() string { return a.s } + +type stubConn struct { + net.Conn + remote net.Addr +} + +func (s stubConn) RemoteAddr() net.Addr { return s.remote } + +func TestClient_ConnectedIPParsesRemoteAddr(t *testing.T) { + tests := []struct { + name string + s string + want string + }{ + {"hostport ipv4", "127.0.0.1:50301", "127.0.0.1"}, + {"hostport ipv6 bracketed", "[::1]:50301", "::1"}, + {"url with ipv4", "rel://127.0.0.1:50301", "127.0.0.1"}, + {"url with ipv6", "rels://[2001:db8::1]:443", "2001:db8::1"}, + {"fqdn url returns empty", "rel://relay.example.com:50301", ""}, + {"fqdn hostport returns empty", "relay.example.com:50301", ""}, + {"plain ipv4 no port", "10.0.0.1", "10.0.0.1"}, + {"empty", "", ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Client{relayConn: stubConn{remote: staticAddr{s: tt.s}}} + got := c.ConnectedIP() + var gotStr string + if got.IsValid() { + gotStr = got.String() + } + if gotStr != tt.want { + t.Errorf("ConnectedIP(%q) = %q, want %q", tt.s, gotStr, tt.want) + } + }) + } +} + +// freeAddr returns a 127.0.0.1 address with an OS-assigned port. The +// listener is closed before returning, so the port is briefly free for +// the caller to bind. Avoids hardcoded ports that can collide. +func freeAddr(t *testing.T) (string, int) { + t.Helper() + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("get free port: %s", err) + } + addr := l.Addr().(*net.TCPAddr) + _ = l.Close() + return addr.String(), addr.Port +} diff --git a/shared/relay/client/dialer/quic/quic.go b/shared/relay/client/dialer/quic/quic.go index 2d7b00a80..602803b19 100644 --- a/shared/relay/client/dialer/quic/quic.go +++ b/shared/relay/client/dialer/quic/quic.go @@ -23,7 +23,7 @@ func (d Dialer) Protocol() string { return Network } -func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) { +func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn, error) { quicURL, err := prepareURL(address) if err != nil { return nil, err @@ -32,11 +32,14 @@ func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) { // Get the base TLS config tlsClientConfig := quictls.ClientQUICTLSConfig() - // Set ServerName to hostname if not an IP address - host, _, splitErr := net.SplitHostPort(quicURL) - if splitErr == nil && net.ParseIP(host) == nil { - // It's a hostname, not an IP - modify directly - tlsClientConfig.ServerName = host + switch { + case serverName != "" && net.ParseIP(serverName) == nil: + tlsClientConfig.ServerName = serverName + default: + host, _, splitErr := net.SplitHostPort(quicURL) + if splitErr == nil && net.ParseIP(host) == nil { + tlsClientConfig.ServerName = host + } } quicConfig := &quic.Config{ diff --git a/shared/relay/client/dialer/race_dialer.go b/shared/relay/client/dialer/race_dialer.go index 34359d17e..15208b858 100644 --- a/shared/relay/client/dialer/race_dialer.go +++ b/shared/relay/client/dialer/race_dialer.go @@ -14,7 +14,9 @@ const ( ) type DialeFn interface { - Dial(ctx context.Context, address string) (net.Conn, error) + // Dial connects to address. serverName, when non-empty, overrides the TLS + // ServerName used for SNI/cert validation. Empty means derive from address. + Dial(ctx context.Context, address, serverName string) (net.Conn, error) Protocol() string } @@ -27,6 +29,7 @@ type dialResult struct { type RaceDial struct { log *log.Entry serverURL string + serverName string dialerFns []DialeFn connectionTimeout time.Duration } @@ -40,6 +43,16 @@ func NewRaceDial(log *log.Entry, connectionTimeout time.Duration, serverURL stri } } +// WithServerName sets a TLS SNI/cert validation override. Used when serverURL +// contains an IP literal but the cert is issued for a different hostname. +// +// Mutates the receiver and is not safe for concurrent reconfiguration; a +// RaceDial is intended to be constructed per dial and discarded. +func (r *RaceDial) WithServerName(serverName string) *RaceDial { + r.serverName = serverName + return r +} + func (r *RaceDial) Dial(ctx context.Context) (net.Conn, error) { connChan := make(chan dialResult, len(r.dialerFns)) winnerConn := make(chan net.Conn, 1) @@ -64,7 +77,7 @@ func (r *RaceDial) dial(dfn DialeFn, abortCtx context.Context, connChan chan dia defer cancel() r.log.Infof("dialing Relay server via %s", dfn.Protocol()) - conn, err := dfn.Dial(ctx, r.serverURL) + conn, err := dfn.Dial(ctx, r.serverURL, r.serverName) connChan <- dialResult{Conn: conn, Protocol: dfn.Protocol(), Err: err} } diff --git a/shared/relay/client/dialer/race_dialer_test.go b/shared/relay/client/dialer/race_dialer_test.go index aa18df578..a53edc00e 100644 --- a/shared/relay/client/dialer/race_dialer_test.go +++ b/shared/relay/client/dialer/race_dialer_test.go @@ -28,7 +28,7 @@ type MockDialer struct { protocolStr string } -func (m *MockDialer) Dial(ctx context.Context, address string) (net.Conn, error) { +func (m *MockDialer) Dial(ctx context.Context, address, _ string) (net.Conn, error) { return m.dialFunc(ctx, address) } diff --git a/shared/relay/client/dialer/ws/conn.go b/shared/relay/client/dialer/ws/conn.go index d5b719f51..9497fab89 100644 --- a/shared/relay/client/dialer/ws/conn.go +++ b/shared/relay/client/dialer/ws/conn.go @@ -12,14 +12,24 @@ import ( type Conn struct { ctx context.Context *websocket.Conn - remoteAddr WebsocketAddr + remoteAddr net.Addr } -func NewConn(wsConn *websocket.Conn, serverAddress string) net.Conn { +// NewConn builds a relay ws.Conn. underlying is the raw TCP/TLS conn captured +// from the http transport's DialContext; when set, RemoteAddr returns its +// peer address (an IP literal). When nil (e.g. wasm), RemoteAddr falls back +// to the dial-time URL. +func NewConn(wsConn *websocket.Conn, serverAddress string, underlying net.Conn) net.Conn { + var addr net.Addr = WebsocketAddr{serverAddress} + if underlying != nil { + if ra := underlying.RemoteAddr(); ra != nil { + addr = ra + } + } return &Conn{ ctx: context.Background(), Conn: wsConn, - remoteAddr: WebsocketAddr{serverAddress}, + remoteAddr: addr, } } diff --git a/shared/relay/client/dialer/ws/dialopts_generic.go b/shared/relay/client/dialer/ws/dialopts_generic.go index 9dfe698d0..8008d89d3 100644 --- a/shared/relay/client/dialer/ws/dialopts_generic.go +++ b/shared/relay/client/dialer/ws/dialopts_generic.go @@ -2,10 +2,14 @@ package ws -import "github.com/coder/websocket" +import ( + "net" -func createDialOptions() *websocket.DialOptions { + "github.com/coder/websocket" +) + +func createDialOptions(serverName string, underlyingOut *net.Conn) *websocket.DialOptions { return &websocket.DialOptions{ - HTTPClient: httpClientNbDialer(), + HTTPClient: httpClientNbDialer(serverName, underlyingOut), } } diff --git a/shared/relay/client/dialer/ws/dialopts_js.go b/shared/relay/client/dialer/ws/dialopts_js.go index 7eac27531..5b11fe765 100644 --- a/shared/relay/client/dialer/ws/dialopts_js.go +++ b/shared/relay/client/dialer/ws/dialopts_js.go @@ -2,9 +2,13 @@ package ws -import "github.com/coder/websocket" +import ( + "net" -func createDialOptions() *websocket.DialOptions { - // WASM version doesn't support HTTPClient + "github.com/coder/websocket" +) + +func createDialOptions(_ string, _ *net.Conn) *websocket.DialOptions { + // WASM version doesn't support HTTPClient or custom TLS config. return &websocket.DialOptions{} } diff --git a/shared/relay/client/dialer/ws/ws.go b/shared/relay/client/dialer/ws/ws.go index 37b189e05..301486514 100644 --- a/shared/relay/client/dialer/ws/ws.go +++ b/shared/relay/client/dialer/ws/ws.go @@ -26,13 +26,14 @@ func (d Dialer) Protocol() string { return "WS" } -func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) { +func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn, error) { wsURL, err := prepareURL(address) if err != nil { return nil, err } - opts := createDialOptions() + var underlying net.Conn + opts := createDialOptions(serverName, &underlying) parsedURL, err := url.Parse(wsURL) if err != nil { @@ -52,7 +53,7 @@ func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) { _ = resp.Body.Close() } - conn := NewConn(wsConn, address) + conn := NewConn(wsConn, address, underlying) return conn, nil } @@ -64,7 +65,10 @@ func prepareURL(address string) (string, error) { return strings.Replace(address, "rel", "ws", 1), nil } -func httpClientNbDialer() *http.Client { +// httpClientNbDialer builds the http client used by the websocket library. +// underlyingOut, when non-nil, is populated with the raw conn from the +// transport's DialContext so the caller can read its RemoteAddr. +func httpClientNbDialer(serverName string, underlyingOut *net.Conn) *http.Client { customDialer := nbnet.NewDialer() certPool, err := x509.SystemCertPool() @@ -75,10 +79,15 @@ func httpClientNbDialer() *http.Client { customTransport := &http.Transport{ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return customDialer.DialContext(ctx, network, addr) + c, err := customDialer.DialContext(ctx, network, addr) + if err == nil && underlyingOut != nil { + *underlyingOut = c + } + return c, err }, TLSClientConfig: &tls.Config{ - RootCAs: certPool, + RootCAs: certPool, + ServerName: serverName, }, } diff --git a/shared/relay/client/manager.go b/shared/relay/client/manager.go index 37104bfe7..3858b3c83 100644 --- a/shared/relay/client/manager.go +++ b/shared/relay/client/manager.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "net" + "net/netip" "reflect" "sync" "time" @@ -75,6 +76,9 @@ type Manager struct { mtu uint16 maxBackoffInterval time.Duration + + cleanupInterval time.Duration + keepUnusedServerTime time.Duration } // NewManager creates a new manager instance. @@ -95,6 +99,8 @@ func NewManager(ctx context.Context, serverURLs []string, peerID string, mtu uin }, relayClients: make(map[string]*RelayTrack), onDisconnectedListeners: make(map[string]*list.List), + cleanupInterval: relayCleanupInterval, + keepUnusedServerTime: keepUnusedServerTime, } for _, opt := range opts { opt(m) @@ -130,7 +136,10 @@ func (m *Manager) Serve() error { // OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be // established via the relay server. If the peer is on a different relay server, the manager will establish a new // connection to the relay server. It returns back with a net.Conn what represent the remote peer connection. -func (m *Manager) OpenConn(ctx context.Context, serverAddress, peerKey string) (net.Conn, error) { +// +// serverIP, when valid and serverAddress is foreign, is used as a dial target if the FQDN-based dial fails. +// Ignored for the local home-server path. TLS verification still uses the FQDN via SNI. +func (m *Manager) OpenConn(ctx context.Context, serverAddress, peerKey string, serverIP netip.Addr) (net.Conn, error) { m.relayClientMu.RLock() defer m.relayClientMu.RUnlock() @@ -151,7 +160,7 @@ func (m *Manager) OpenConn(ctx context.Context, serverAddress, peerKey string) ( netConn, err = m.relayClient.OpenConn(ctx, peerKey) } else { log.Debugf("open peer connection via foreign server: %s", serverAddress) - netConn, err = m.openConnVia(ctx, serverAddress, peerKey) + netConn, err = m.openConnVia(ctx, serverAddress, peerKey, serverIP) } if err != nil { return nil, err @@ -203,16 +212,22 @@ func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServ return nil } -// RelayInstanceAddress returns the address of the permanent relay server. It could change if the network connection is -// lost. This address will be sent to the target peer to choose the common relay server for the communication. -func (m *Manager) RelayInstanceAddress() (string, error) { +// RelayInstanceAddress returns the address and resolved IP of the permanent relay server. It could change if the +// network connection is lost. The address is sent to the target peer to choose the common relay server for the +// communication; the IP is sent alongside so remote peers can dial directly without their own DNS lookup. Both +// values are read under the same lock so they cannot diverge across a reconnection. +func (m *Manager) RelayInstanceAddress() (string, netip.Addr, error) { m.relayClientMu.RLock() defer m.relayClientMu.RUnlock() if m.relayClient == nil { - return "", ErrRelayClientNotConnected + return "", netip.Addr{}, ErrRelayClientNotConnected } - return m.relayClient.ServerInstanceURL() + addr, err := m.relayClient.ServerInstanceURL() + if err != nil { + return "", netip.Addr{}, err + } + return addr, m.relayClient.ConnectedIP(), nil } // ServerURLs returns the addresses of the relay servers. @@ -236,7 +251,7 @@ func (m *Manager) UpdateToken(token *relayAuth.Token) error { return m.tokenStore.UpdateToken(token) } -func (m *Manager) openConnVia(ctx context.Context, serverAddress, peerKey string) (net.Conn, error) { +func (m *Manager) openConnVia(ctx context.Context, serverAddress, peerKey string, serverIP netip.Addr) (net.Conn, error) { // check if already has a connection to the desired relay server m.relayClientsMutex.RLock() rt, ok := m.relayClients[serverAddress] @@ -271,7 +286,7 @@ func (m *Manager) openConnVia(ctx context.Context, serverAddress, peerKey string m.relayClients[serverAddress] = rt m.relayClientsMutex.Unlock() - relayClient := NewClient(serverAddress, m.tokenStore, m.peerID, m.mtu) + relayClient := NewClientWithServerIP(serverAddress, serverIP, m.tokenStore, m.peerID, m.mtu) err := relayClient.Connect(m.ctx) if err != nil { rt.err = err @@ -364,7 +379,7 @@ func (m *Manager) isForeignServer(address string) (bool, error) { } func (m *Manager) startCleanupLoop() { - ticker := time.NewTicker(relayCleanupInterval) + ticker := time.NewTicker(m.cleanupInterval) defer ticker.Stop() for { select { @@ -389,7 +404,7 @@ func (m *Manager) cleanUpUnusedRelays() { continue } - if time.Since(rt.created) <= keepUnusedServerTime { + if time.Since(rt.created) <= m.keepUnusedServerTime { rt.Unlock() continue } diff --git a/shared/relay/client/manager_serverip_test.go b/shared/relay/client/manager_serverip_test.go new file mode 100644 index 000000000..a354beade --- /dev/null +++ b/shared/relay/client/manager_serverip_test.go @@ -0,0 +1,144 @@ +package client + +import ( + "context" + "io" + "net/netip" + "testing" + "time" + + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/relay/server" +) + +// TestManager_ForeignRelayServerIP exercises the foreign-relay path +// end-to-end through Manager.OpenConn. Alice and Bob register on different +// relay servers; Alice dials Bob's foreign relay using an unresolvable +// FQDN. Without a server IP the dial fails; with Bob's advertised IP it +// recovers and a payload round-trips between the peers. +func TestManager_ForeignRelayServerIP(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + + // Alice's home relay + homeCfg := server.ListenerConfig{Address: "127.0.0.1:52401"} + homeSrv, err := server.NewServer(newManagerTestServerConfig(homeCfg.Address)) + if err != nil { + t.Fatalf("create home server: %s", err) + } + homeErr := make(chan error, 1) + go func() { + if err := homeSrv.Listen(homeCfg); err != nil { + homeErr <- err + } + }() + t.Cleanup(func() { _ = homeSrv.Shutdown(context.Background()) }) + if err := waitForServerToStart(homeErr); err != nil { + t.Fatalf("home server: %s", err) + } + + // Bob's foreign relay + foreignCfg := server.ListenerConfig{Address: "127.0.0.1:52402"} + foreignSrv, err := server.NewServer(newManagerTestServerConfig(foreignCfg.Address)) + if err != nil { + t.Fatalf("create foreign server: %s", err) + } + foreignErr := make(chan error, 1) + go func() { + if err := foreignSrv.Listen(foreignCfg); err != nil { + foreignErr <- err + } + }() + t.Cleanup(func() { _ = foreignSrv.Shutdown(context.Background()) }) + if err := waitForServerToStart(foreignErr); err != nil { + t.Fatalf("foreign server: %s", err) + } + + mCtx, mCancel := context.WithCancel(ctx) + t.Cleanup(mCancel) + + mgrAlice := NewManager(mCtx, toURL(homeCfg), "alice", iface.DefaultMTU) + if err := mgrAlice.Serve(); err != nil { + t.Fatalf("alice manager serve: %s", err) + } + + mgrBob := NewManager(mCtx, toURL(foreignCfg), "bob", iface.DefaultMTU) + if err := mgrBob.Serve(); err != nil { + t.Fatalf("bob manager serve: %s", err) + } + + // Bob's real relay URL and the IP that would ride along in signal as relayServerIP. + bobRealAddr, bobAdvertisedIP, err := mgrBob.RelayInstanceAddress() + if err != nil { + t.Fatalf("bob relay address: %s", err) + } + if !bobAdvertisedIP.IsValid() { + t.Fatalf("expected valid RelayInstanceIP for bob, got zero") + } + + // .invalid is reserved (RFC 2606), so DNS resolution always fails. + const brokenFQDN = "rel://relay-bob-instance.invalid:52402" + if brokenFQDN == bobRealAddr { + t.Fatalf("broken FQDN must differ from bob's real address (%s)", bobRealAddr) + } + + t.Run("no server IP, dial fails", func(t *testing.T) { + dialCtx, dialCancel := context.WithTimeout(ctx, 5*time.Second) + defer dialCancel() + _, err := mgrAlice.OpenConn(dialCtx, brokenFQDN, "bob", netip.Addr{}) + if err == nil { + t.Fatalf("expected OpenConn to fail without server IP, got success") + } + }) + + t.Run("server IP recovers", func(t *testing.T) { + // Bob waits for Alice's incoming peer connection on his side. + bobSideCh := make(chan error, 1) + go func() { + conn, err := mgrBob.OpenConn(ctx, bobRealAddr, "alice", netip.Addr{}) + if err != nil { + bobSideCh <- err + return + } + buf := make([]byte, 1024) + n, err := conn.Read(buf) + if err != nil { + bobSideCh <- err + return + } + if _, err := conn.Write(buf[:n]); err != nil { + bobSideCh <- err + return + } + bobSideCh <- nil + }() + + aliceConn, err := mgrAlice.OpenConn(ctx, brokenFQDN, "bob", bobAdvertisedIP) + if err != nil { + t.Fatalf("alice OpenConn with server IP: %s", err) + } + t.Cleanup(func() { _ = aliceConn.Close() }) + + payload := []byte("alice-to-bob") + if _, err := aliceConn.Write(payload); err != nil { + t.Fatalf("alice write: %s", err) + } + + buf := make([]byte, len(payload)) + if _, err := io.ReadFull(aliceConn, buf); err != nil { + t.Fatalf("alice read echo: %s", err) + } + if string(buf) != string(payload) { + t.Fatalf("echo mismatch: got %q want %q", buf, payload) + } + + select { + case err := <-bobSideCh: + if err != nil { + t.Fatalf("bob side: %s", err) + } + case <-time.After(5 * time.Second): + t.Fatalf("timed out waiting for bob side") + } + }) +} diff --git a/shared/relay/client/manager_test.go b/shared/relay/client/manager_test.go index 5bbcad886..9e964f688 100644 --- a/shared/relay/client/manager_test.go +++ b/shared/relay/client/manager_test.go @@ -3,6 +3,7 @@ package client import ( "context" "fmt" + "net/netip" "testing" "time" @@ -101,15 +102,15 @@ func TestForeignConn(t *testing.T) { if err := clientBob.Serve(); err != nil { t.Fatalf("failed to serve manager: %s", err) } - bobsSrvAddr, err := clientBob.RelayInstanceAddress() + bobsSrvAddr, _, err := clientBob.RelayInstanceAddress() if err != nil { t.Fatalf("failed to get relay address: %s", err) } - connAliceToBob, err := clientAlice.OpenConn(ctx, bobsSrvAddr, "bob") + connAliceToBob, err := clientAlice.OpenConn(ctx, bobsSrvAddr, "bob", netip.Addr{}) if err != nil { t.Fatalf("failed to bind channel: %s", err) } - connBobToAlice, err := clientBob.OpenConn(ctx, bobsSrvAddr, "alice") + connBobToAlice, err := clientBob.OpenConn(ctx, bobsSrvAddr, "alice", netip.Addr{}) if err != nil { t.Fatalf("failed to bind channel: %s", err) } @@ -209,7 +210,7 @@ func TestForeginConnClose(t *testing.T) { if err != nil { t.Fatalf("failed to serve manager: %s", err) } - conn, err := mgr.OpenConn(ctx, toURL(srvCfg2)[0], "bob") + conn, err := mgr.OpenConn(ctx, toURL(srvCfg2)[0], "bob", netip.Addr{}) if err != nil { t.Fatalf("failed to bind channel: %s", err) } @@ -301,7 +302,7 @@ func TestForeignAutoClose(t *testing.T) { } t.Log("open connection to another peer") - if _, err = mgr.OpenConn(ctx, foreignServerURL, "anotherpeer"); err == nil { + if _, err = mgr.OpenConn(ctx, foreignServerURL, "anotherpeer", netip.Addr{}); err == nil { t.Fatalf("should have failed to open connection to another peer") } @@ -367,11 +368,11 @@ func TestAutoReconnect(t *testing.T) { if err != nil { t.Fatalf("failed to serve manager: %s", err) } - ra, err := clientAlice.RelayInstanceAddress() + ra, _, err := clientAlice.RelayInstanceAddress() if err != nil { t.Errorf("failed to get relay address: %s", err) } - conn, err := clientAlice.OpenConn(ctx, ra, "bob") + conn, err := clientAlice.OpenConn(ctx, ra, "bob", netip.Addr{}) if err != nil { t.Errorf("failed to bind channel: %s", err) } @@ -391,7 +392,7 @@ func TestAutoReconnect(t *testing.T) { } log.Infof("reopent the connection") - _, err = clientAlice.OpenConn(ctx, ra, "bob") + _, err = clientAlice.OpenConn(ctx, ra, "bob", netip.Addr{}) if err != nil { t.Errorf("failed to open channel: %s", err) } @@ -453,7 +454,7 @@ func TestNotifierDoubleAdd(t *testing.T) { t.Fatalf("failed to serve manager: %s", err) } - conn1, err := clientAlice.OpenConn(ctx, clientAlice.ServerURLs()[0], "bob") + conn1, err := clientAlice.OpenConn(ctx, clientAlice.ServerURLs()[0], "bob", netip.Addr{}) if err != nil { t.Fatalf("failed to bind channel: %s", err) } diff --git a/shared/signal/client/client.go b/shared/signal/client/client.go index 5347c80e9..9dc6ccd37 100644 --- a/shared/signal/client/client.go +++ b/shared/signal/client/client.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "io" + "net/netip" "strings" "github.com/netbirdio/netbird/shared/signal/proto" @@ -14,17 +15,17 @@ import ( // A set of tools to exchange connection details (Wireguard endpoints) with the remote peer. -// Status is the status of the client -type Status string - -const StreamConnected Status = "Connected" -const StreamDisconnected Status = "Disconnected" - const ( + StreamConnected Status = "Connected" + StreamDisconnected Status = "Disconnected" + // DirectCheck indicates support to direct mode checks DirectCheck uint32 = 1 ) +// Status is the status of the client +type Status string + type Client interface { io.Closer StreamConnected() bool @@ -38,6 +39,24 @@ type Client interface { SetOnReconnectedListener(func()) } +// Credential is an instance of a GrpcClient's Credential +type Credential struct { + UFrag string + Pwd string +} + +// CredentialPayload bundles the fields of a signal Body for MarshalCredential. +type CredentialPayload struct { + Type proto.Body_Type + WgListenPort int + Credential *Credential + RosenpassPubKey []byte + RosenpassAddr string + RelaySrvAddress string + RelaySrvIP netip.Addr + SessionID []byte +} + // UnMarshalCredential parses the credentials from the message and returns a Credential instance func UnMarshalCredential(msg *proto.Message) (*Credential, error) { @@ -52,27 +71,27 @@ func UnMarshalCredential(msg *proto.Message) (*Credential, error) { } // MarshalCredential marshal a Credential instance and returns a Message object -func MarshalCredential(myKey wgtypes.Key, myPort int, remoteKey string, credential *Credential, t proto.Body_Type, rosenpassPubKey []byte, rosenpassAddr string, relaySrvAddress string, sessionID []byte) (*proto.Message, error) { +func MarshalCredential(myKey wgtypes.Key, remoteKey string, p CredentialPayload) (*proto.Message, error) { + body := &proto.Body{ + Type: p.Type, + Payload: fmt.Sprintf("%s:%s", p.Credential.UFrag, p.Credential.Pwd), + WgListenPort: uint32(p.WgListenPort), + NetBirdVersion: version.NetbirdVersion(), + RosenpassConfig: &proto.RosenpassConfig{ + RosenpassPubKey: p.RosenpassPubKey, + RosenpassServerAddr: p.RosenpassAddr, + }, + SessionId: p.SessionID, + } + if p.RelaySrvAddress != "" { + body.RelayServerAddress = &p.RelaySrvAddress + } + if p.RelaySrvIP.IsValid() { + body.RelayServerIP = p.RelaySrvIP.Unmap().AsSlice() + } return &proto.Message{ Key: myKey.PublicKey().String(), RemoteKey: remoteKey, - Body: &proto.Body{ - Type: t, - Payload: fmt.Sprintf("%s:%s", credential.UFrag, credential.Pwd), - WgListenPort: uint32(myPort), - NetBirdVersion: version.NetbirdVersion(), - RosenpassConfig: &proto.RosenpassConfig{ - RosenpassPubKey: rosenpassPubKey, - RosenpassServerAddr: rosenpassAddr, - }, - RelayServerAddress: relaySrvAddress, - SessionId: sessionID, - }, + Body: body, }, nil } - -// Credential is an instance of a GrpcClient's Credential -type Credential struct { - UFrag string - Pwd string -} diff --git a/shared/signal/proto/signalexchange.pb.go b/shared/signal/proto/signalexchange.pb.go index d9c61a846..0c80fb489 100644 --- a/shared/signal/proto/signalexchange.pb.go +++ b/shared/signal/proto/signalexchange.pb.go @@ -229,8 +229,13 @@ type Body struct { // RosenpassConfig is a Rosenpass config of the remote peer our peer tries to connect to RosenpassConfig *RosenpassConfig `protobuf:"bytes,7,opt,name=rosenpassConfig,proto3" json:"rosenpassConfig,omitempty"` // relayServerAddress is url of the relay server - RelayServerAddress string `protobuf:"bytes,8,opt,name=relayServerAddress,proto3" json:"relayServerAddress,omitempty"` - SessionId []byte `protobuf:"bytes,10,opt,name=sessionId,proto3,oneof" json:"sessionId,omitempty"` + RelayServerAddress *string `protobuf:"bytes,8,opt,name=relayServerAddress,proto3,oneof" json:"relayServerAddress,omitempty"` + SessionId []byte `protobuf:"bytes,10,opt,name=sessionId,proto3,oneof" json:"sessionId,omitempty"` + // relayServerIP is the IP the sender is connected to on its relay server, + // encoded as 4 bytes (IPv4) or 16 bytes (IPv6). Receivers may use it as a + // fallback dial target when DNS resolution of relayServerAddress fails. + // SNI/TLS verification still uses relayServerAddress. + RelayServerIP []byte `protobuf:"bytes,11,opt,name=relayServerIP,proto3,oneof" json:"relayServerIP,omitempty"` } func (x *Body) Reset() { @@ -315,8 +320,8 @@ func (x *Body) GetRosenpassConfig() *RosenpassConfig { } func (x *Body) GetRelayServerAddress() string { - if x != nil { - return x.RelayServerAddress + if x != nil && x.RelayServerAddress != nil { + return *x.RelayServerAddress } return "" } @@ -328,6 +333,13 @@ func (x *Body) GetSessionId() []byte { return nil } +func (x *Body) GetRelayServerIP() []byte { + if x != nil { + return x.RelayServerIP + } + return nil +} + // Mode indicates a connection mode type Mode struct { state protoimpl.MessageState @@ -451,7 +463,7 @@ var file_signalexchange_proto_rawDesc = []byte{ 0x52, 0x09, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x4b, 0x65, 0x79, 0x12, 0x28, 0x0a, 0x04, 0x62, 0x6f, 0x64, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x42, 0x6f, 0x64, 0x79, 0x52, - 0x04, 0x62, 0x6f, 0x64, 0x79, 0x22, 0xe4, 0x03, 0x0a, 0x04, 0x42, 0x6f, 0x64, 0x79, 0x12, 0x2d, + 0x04, 0x62, 0x6f, 0x64, 0x79, 0x22, 0xc3, 0x04, 0x0a, 0x04, 0x42, 0x6f, 0x64, 0x79, 0x12, 0x2d, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x42, 0x6f, 0x64, 0x79, 0x2e, 0x54, 0x79, 0x70, 0x65, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x18, 0x0a, @@ -471,40 +483,46 @@ var file_signalexchange_proto_rawDesc = []byte{ 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1f, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x52, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, - 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x2e, 0x0a, 0x12, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53, + 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x33, 0x0a, 0x12, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x08, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x12, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, - 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x21, 0x0a, 0x09, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, - 0x6e, 0x49, 0x64, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x0c, 0x48, 0x00, 0x52, 0x09, 0x73, 0x65, 0x73, - 0x73, 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x88, 0x01, 0x01, 0x22, 0x43, 0x0a, 0x04, 0x54, 0x79, 0x70, + 0x28, 0x09, 0x48, 0x00, 0x52, 0x12, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x65, 0x72, 0x76, 0x65, + 0x72, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x88, 0x01, 0x01, 0x12, 0x21, 0x0a, 0x09, 0x73, + 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x0c, 0x48, 0x01, + 0x52, 0x09, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x88, 0x01, 0x01, 0x12, 0x29, + 0x0a, 0x0d, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x49, 0x50, 0x18, + 0x0b, 0x20, 0x01, 0x28, 0x0c, 0x48, 0x02, 0x52, 0x0d, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x65, + 0x72, 0x76, 0x65, 0x72, 0x49, 0x50, 0x88, 0x01, 0x01, 0x22, 0x43, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x09, 0x0a, 0x05, 0x4f, 0x46, 0x46, 0x45, 0x52, 0x10, 0x00, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x4e, 0x53, 0x57, 0x45, 0x52, 0x10, 0x01, 0x12, 0x0d, 0x0a, 0x09, 0x43, 0x41, 0x4e, 0x44, 0x49, 0x44, 0x41, 0x54, 0x45, 0x10, 0x02, 0x12, 0x08, 0x0a, 0x04, 0x4d, 0x4f, 0x44, 0x45, 0x10, - 0x04, 0x12, 0x0b, 0x0a, 0x07, 0x47, 0x4f, 0x5f, 0x49, 0x44, 0x4c, 0x45, 0x10, 0x05, 0x42, 0x0c, - 0x0a, 0x0a, 0x5f, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x22, 0x2e, 0x0a, 0x04, - 0x4d, 0x6f, 0x64, 0x65, 0x12, 0x1b, 0x0a, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x08, 0x48, 0x00, 0x52, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x88, 0x01, - 0x01, 0x42, 0x09, 0x0a, 0x07, 0x5f, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x22, 0x6d, 0x0a, 0x0f, - 0x52, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, - 0x28, 0x0a, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x75, 0x62, 0x4b, - 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, - 0x61, 0x73, 0x73, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73, - 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, - 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x32, 0xb9, 0x01, 0x0a, 0x0e, - 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x45, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x4c, - 0x0a, 0x04, 0x53, 0x65, 0x6e, 0x64, 0x12, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, - 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, - 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, - 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, - 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x59, 0x0a, 0x0d, - 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, 0x20, 0x2e, - 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, - 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, - 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, - 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, - 0x65, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x04, 0x12, 0x0b, 0x0a, 0x07, 0x47, 0x4f, 0x5f, 0x49, 0x44, 0x4c, 0x45, 0x10, 0x05, 0x42, 0x15, + 0x0a, 0x13, 0x5f, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, + 0x64, 0x72, 0x65, 0x73, 0x73, 0x42, 0x0c, 0x0a, 0x0a, 0x5f, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, + 0x6e, 0x49, 0x64, 0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x65, 0x72, + 0x76, 0x65, 0x72, 0x49, 0x50, 0x4a, 0x04, 0x08, 0x09, 0x10, 0x0a, 0x22, 0x2e, 0x0a, 0x04, 0x4d, + 0x6f, 0x64, 0x65, 0x12, 0x1b, 0x0a, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x08, 0x48, 0x00, 0x52, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x88, 0x01, 0x01, + 0x42, 0x09, 0x0a, 0x07, 0x5f, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x22, 0x6d, 0x0a, 0x0f, 0x52, + 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x28, + 0x0a, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x75, 0x62, 0x4b, 0x65, + 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, + 0x73, 0x73, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73, 0x65, + 0x6e, 0x70, 0x61, 0x73, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, + 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x32, 0xb9, 0x01, 0x0a, 0x0e, 0x53, + 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x45, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x4c, 0x0a, + 0x04, 0x53, 0x65, 0x6e, 0x64, 0x12, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, + 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, + 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, + 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, + 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x59, 0x0a, 0x0d, 0x43, + 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, 0x20, 0x2e, 0x73, + 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, + 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x20, + 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, + 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, + 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( diff --git a/shared/signal/proto/signalexchange.proto b/shared/signal/proto/signalexchange.proto index 0a33ad78b..96a4001e3 100644 --- a/shared/signal/proto/signalexchange.proto +++ b/shared/signal/proto/signalexchange.proto @@ -63,9 +63,17 @@ message Body { RosenpassConfig rosenpassConfig = 7; // relayServerAddress is url of the relay server - string relayServerAddress = 8; + optional string relayServerAddress = 8; + + reserved 9; optional bytes sessionId = 10; + + // relayServerIP is the IP the sender is connected to on its relay server, + // encoded as 4 bytes (IPv4) or 16 bytes (IPv6). Receivers may use it as a + // fallback dial target when DNS resolution of relayServerAddress fails. + // SNI/TLS verification still uses relayServerAddress. + optional bytes relayServerIP = 11; } // Mode indicates a connection mode From 6262b0d841a5a4c1bd758d45332a6dba51cb09dd Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Mon, 4 May 2026 12:47:13 +0300 Subject: [PATCH 02/15] [management] Track pending approval in peer event metadata (#6040) --- management/server/peer.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/management/server/peer.go b/management/server/peer.go index d1c52002e..25c6ecd8c 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -818,6 +818,9 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe if !addedByUser { opEvent.Meta["setup_key_name"] = peerAddConfig.SetupKeyName } + if newPeer.Status != nil && newPeer.Status.RequiresApproval { + opEvent.Meta["pending_approval"] = true + } if !temporary { am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta) From a21f6ecb0a5d7ba45b4bf570a7af62ba1f66447d Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Mon, 4 May 2026 11:59:01 +0200 Subject: [PATCH 03/15] [client] release Status.mux before invoking notifier callbacks (#6039) The Status recorder used to fire notifier callbacks while holding d.mux: - notifyPeerListChanged / notifyPeerStateChangeListeners ran from inside the locked section of every Update*/AddPeerStateRoute/etc. - notifyAddressChanged ran from UpdateLocalPeerState and CleanLocalPeerState while d.mux was held. - onConnectionChanged was registered with a defer above defer d.mux.Unlock, so it executed before the mutex was released in the Mark*Connected/ Disconnected helpers. - notifyPeerStateChangeListeners did a blocking channel send under d.mux, so a slow subscriber stalled every other d.mux holder. A listener that re-enters the recorder (e.g. calls GetFullStatus from within a callback) deadlocks against d.mux, and any callback that takes longer than expected stalls every other state query for its duration. Capture the values needed for notification under the lock, release d.mux, then call the notifier. Build per-peer router-state snapshots inside the lock and dispatch them via dispatchRouterPeers afterwards. The router-peer channel send stays blocking, but now happens outside d.mux so a slow consumer cannot stall any other d.mux holder, and no peer state transitions are silently dropped. The notifier itself is unchanged: its internal state was already protected by its own locks, and the field d.notifier is set once in NewRecorder and never reassigned, so reading it without d.mux is safe. Also fix a pre-existing race in Test_notifier_RemoveListener / Test_notifier_SetListener: setListener spawns a goroutine that writes listener.peers, but the tests read listener.peers without waiting for it. --- client/internal/peer/notifier_test.go | 17 ++ client/internal/peer/status.go | 229 +++++++++++++++++--------- 2 files changed, 170 insertions(+), 76 deletions(-) diff --git a/client/internal/peer/notifier_test.go b/client/internal/peer/notifier_test.go index bbdc00e13..0b7722b0c 100644 --- a/client/internal/peer/notifier_test.go +++ b/client/internal/peer/notifier_test.go @@ -8,6 +8,7 @@ import ( type mocListener struct { lastState int wg sync.WaitGroup + peersWg sync.WaitGroup peers int } @@ -33,6 +34,7 @@ func (l *mocListener) OnAddressChanged(host, addr string) { } func (l *mocListener) OnPeersListChanged(size int) { l.peers = size + l.peersWg.Done() } func (l *mocListener) setWaiter() { @@ -43,6 +45,14 @@ func (l *mocListener) wait() { l.wg.Wait() } +func (l *mocListener) setPeersWaiter() { + l.peersWg.Add(1) +} + +func (l *mocListener) waitPeers() { + l.peersWg.Wait() +} + func Test_notifier_serverState(t *testing.T) { type scenario struct { @@ -72,11 +82,13 @@ func Test_notifier_serverState(t *testing.T) { func Test_notifier_SetListener(t *testing.T) { listener := &mocListener{} listener.setWaiter() + listener.setPeersWaiter() n := newNotifier() n.lastNotification = stateConnecting n.setListener(listener) listener.wait() + listener.waitPeers() if listener.lastState != n.lastNotification { t.Errorf("invalid state: %d, expected: %d", listener.lastState, n.lastNotification) } @@ -85,9 +97,14 @@ func Test_notifier_SetListener(t *testing.T) { func Test_notifier_RemoveListener(t *testing.T) { listener := &mocListener{} listener.setWaiter() + listener.setPeersWaiter() n := newNotifier() n.lastNotification = stateConnecting n.setListener(listener) + // setListener replays cached state on a goroutine; wait for both the state + // and peers callbacks to finish so we don't race on listener.peers. + listener.wait() + listener.waitPeers() n.removeListener() n.peerListChanged(1) diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index 7bd19b0e1..e8e61f660 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -320,10 +320,10 @@ func (d *Status) RemovePeer(peerPubKey string) error { // UpdatePeerState updates peer status func (d *Status) UpdatePeerState(receivedState State) error { d.mux.Lock() - defer d.mux.Unlock() peerState, ok := d.peers[receivedState.PubKey] if !ok { + d.mux.Unlock() return errors.New("peer doesn't exist") } @@ -343,23 +343,29 @@ func (d *Status) UpdatePeerState(receivedState State) error { d.peers[receivedState.PubKey] = peerState - if hasConnStatusChanged(oldState, receivedState.ConnStatus) { - d.notifyPeerListChanged() - } - + notifyList := hasConnStatusChanged(oldState, receivedState.ConnStatus) // when we close the connection we will not notify the router manager - if receivedState.ConnStatus == StatusIdle { - d.notifyPeerStateChangeListeners(receivedState.PubKey) + notifyRouter := receivedState.ConnStatus == StatusIdle + routerSnapshot := d.snapshotRouterPeersLocked(receivedState.PubKey, notifyRouter) + numPeers := d.numOfPeers() + + d.mux.Unlock() + + if notifyList { + d.notifier.peerListChanged(numPeers) + } + if notifyRouter { + d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot) } return nil } func (d *Status) AddPeerStateRoute(peer string, route string, resourceId route.ResID) error { d.mux.Lock() - defer d.mux.Unlock() peerState, ok := d.peers[peer] if !ok { + d.mux.Unlock() return errors.New("peer doesn't exist") } @@ -371,17 +377,20 @@ func (d *Status) AddPeerStateRoute(peer string, route string, resourceId route.R d.routeIDLookup.AddRemoteRouteID(resourceId, pref) } + numPeers := d.numOfPeers() + d.mux.Unlock() + // todo: consider to make sense of this notification or not - d.notifyPeerListChanged() + d.notifier.peerListChanged(numPeers) return nil } func (d *Status) RemovePeerStateRoute(peer string, route string) error { d.mux.Lock() - defer d.mux.Unlock() peerState, ok := d.peers[peer] if !ok { + d.mux.Unlock() return errors.New("peer doesn't exist") } @@ -393,8 +402,11 @@ func (d *Status) RemovePeerStateRoute(peer string, route string) error { d.routeIDLookup.RemoveRemoteRouteID(pref) } + numPeers := d.numOfPeers() + d.mux.Unlock() + // todo: consider to make sense of this notification or not - d.notifyPeerListChanged() + d.notifier.peerListChanged(numPeers) return nil } @@ -410,10 +422,10 @@ func (d *Status) CheckRoutes(ip netip.Addr) ([]byte, bool) { func (d *Status) UpdatePeerICEState(receivedState State) error { d.mux.Lock() - defer d.mux.Unlock() peerState, ok := d.peers[receivedState.PubKey] if !ok { + d.mux.Unlock() return errors.New("peer doesn't exist") } @@ -431,22 +443,28 @@ func (d *Status) UpdatePeerICEState(receivedState State) error { d.peers[receivedState.PubKey] = peerState - if hasConnStatusChanged(oldState, receivedState.ConnStatus) { - d.notifyPeerListChanged() - } + notifyList := hasConnStatusChanged(oldState, receivedState.ConnStatus) + notifyRouter := hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) + routerSnapshot := d.snapshotRouterPeersLocked(receivedState.PubKey, notifyRouter) + numPeers := d.numOfPeers() - if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) { - d.notifyPeerStateChangeListeners(receivedState.PubKey) + d.mux.Unlock() + + if notifyList { + d.notifier.peerListChanged(numPeers) + } + if notifyRouter { + d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot) } return nil } func (d *Status) UpdatePeerRelayedState(receivedState State) error { d.mux.Lock() - defer d.mux.Unlock() peerState, ok := d.peers[receivedState.PubKey] if !ok { + d.mux.Unlock() return errors.New("peer doesn't exist") } @@ -461,22 +479,28 @@ func (d *Status) UpdatePeerRelayedState(receivedState State) error { d.peers[receivedState.PubKey] = peerState - if hasConnStatusChanged(oldState, receivedState.ConnStatus) { - d.notifyPeerListChanged() - } + notifyList := hasConnStatusChanged(oldState, receivedState.ConnStatus) + notifyRouter := hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) + routerSnapshot := d.snapshotRouterPeersLocked(receivedState.PubKey, notifyRouter) + numPeers := d.numOfPeers() - if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) { - d.notifyPeerStateChangeListeners(receivedState.PubKey) + d.mux.Unlock() + + if notifyList { + d.notifier.peerListChanged(numPeers) + } + if notifyRouter { + d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot) } return nil } func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error { d.mux.Lock() - defer d.mux.Unlock() peerState, ok := d.peers[receivedState.PubKey] if !ok { + d.mux.Unlock() return errors.New("peer doesn't exist") } @@ -490,22 +514,28 @@ func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error d.peers[receivedState.PubKey] = peerState - if hasConnStatusChanged(oldState, receivedState.ConnStatus) { - d.notifyPeerListChanged() - } + notifyList := hasConnStatusChanged(oldState, receivedState.ConnStatus) + notifyRouter := hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) + routerSnapshot := d.snapshotRouterPeersLocked(receivedState.PubKey, notifyRouter) + numPeers := d.numOfPeers() - if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) { - d.notifyPeerStateChangeListeners(receivedState.PubKey) + d.mux.Unlock() + + if notifyList { + d.notifier.peerListChanged(numPeers) + } + if notifyRouter { + d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot) } return nil } func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error { d.mux.Lock() - defer d.mux.Unlock() peerState, ok := d.peers[receivedState.PubKey] if !ok { + d.mux.Unlock() return errors.New("peer doesn't exist") } @@ -522,12 +552,18 @@ func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error { d.peers[receivedState.PubKey] = peerState - if hasConnStatusChanged(oldState, receivedState.ConnStatus) { - d.notifyPeerListChanged() - } + notifyList := hasConnStatusChanged(oldState, receivedState.ConnStatus) + notifyRouter := hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) + routerSnapshot := d.snapshotRouterPeersLocked(receivedState.PubKey, notifyRouter) + numPeers := d.numOfPeers() - if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) { - d.notifyPeerStateChangeListeners(receivedState.PubKey) + d.mux.Unlock() + + if notifyList { + d.notifier.peerListChanged(numPeers) + } + if notifyRouter { + d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot) } return nil } @@ -594,17 +630,33 @@ func (d *Status) UpdatePeerSSHHostKey(peerPubKey string, sshHostKey []byte) erro // FinishPeerListModifications this event invoke the notification func (d *Status) FinishPeerListModifications() { d.mux.Lock() - defer d.mux.Unlock() if !d.peerListChangedForNotification { + d.mux.Unlock() return } d.peerListChangedForNotification = false - d.notifyPeerListChanged() + numPeers := d.numOfPeers() + // snapshot per-peer router state to deliver after the lock is released + type routerDispatch struct { + peerID string + snapshot map[string]RouterState + } + dispatches := make([]routerDispatch, 0, len(d.peers)) for key := range d.peers { - d.notifyPeerStateChangeListeners(key) + snapshot := d.snapshotRouterPeersLocked(key, true) + if snapshot != nil { + dispatches = append(dispatches, routerDispatch{peerID: key, snapshot: snapshot}) + } + } + + d.mux.Unlock() + + d.notifier.peerListChanged(numPeers) + for _, rd := range dispatches { + d.dispatchRouterPeers(rd.peerID, rd.snapshot) } } @@ -655,10 +707,12 @@ func (d *Status) GetLocalPeerState() LocalPeerState { // UpdateLocalPeerState updates local peer status func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) { d.mux.Lock() - defer d.mux.Unlock() - d.localPeer = localPeerState - d.notifyAddressChanged() + fqdn := d.localPeer.FQDN + ip := d.localPeer.IP + d.mux.Unlock() + + d.notifier.localAddressChanged(fqdn, ip) } // AddLocalPeerStateRoute adds a route to the local peer state @@ -721,30 +775,36 @@ func (d *Status) CleanLocalPeerStateRoutes() { // CleanLocalPeerState cleans local peer status func (d *Status) CleanLocalPeerState() { d.mux.Lock() - defer d.mux.Unlock() - d.localPeer = LocalPeerState{} - d.notifyAddressChanged() + fqdn := d.localPeer.FQDN + ip := d.localPeer.IP + d.mux.Unlock() + + d.notifier.localAddressChanged(fqdn, ip) } // MarkManagementDisconnected sets ManagementState to disconnected func (d *Status) MarkManagementDisconnected(err error) { d.mux.Lock() - defer d.mux.Unlock() - defer d.onConnectionChanged() - d.managementState = false d.managementError = err + mgm := d.managementState + sig := d.signalState + d.mux.Unlock() + + d.notifier.updateServerStates(mgm, sig) } // MarkManagementConnected sets ManagementState to connected func (d *Status) MarkManagementConnected() { d.mux.Lock() - defer d.mux.Unlock() - defer d.onConnectionChanged() - d.managementState = true d.managementError = nil + mgm := d.managementState + sig := d.signalState + d.mux.Unlock() + + d.notifier.updateServerStates(mgm, sig) } // UpdateSignalAddress update the address of the signal server @@ -778,21 +838,25 @@ func (d *Status) UpdateLazyConnection(enabled bool) { // MarkSignalDisconnected sets SignalState to disconnected func (d *Status) MarkSignalDisconnected(err error) { d.mux.Lock() - defer d.mux.Unlock() - defer d.onConnectionChanged() - d.signalState = false d.signalError = err + mgm := d.managementState + sig := d.signalState + d.mux.Unlock() + + d.notifier.updateServerStates(mgm, sig) } // MarkSignalConnected sets SignalState to connected func (d *Status) MarkSignalConnected() { d.mux.Lock() - defer d.mux.Unlock() - defer d.onConnectionChanged() - d.signalState = true d.signalError = nil + mgm := d.managementState + sig := d.signalState + d.mux.Unlock() + + d.notifier.updateServerStates(mgm, sig) } func (d *Status) UpdateRelayStates(relayResults []relay.ProbeResult) { @@ -1012,18 +1076,17 @@ func (d *Status) RemoveConnectionListener() { d.notifier.removeListener() } -func (d *Status) onConnectionChanged() { - d.notifier.updateServerStates(d.managementState, d.signalState) -} - -// notifyPeerStateChangeListeners notifies route manager about the change in peer state -func (d *Status) notifyPeerStateChangeListeners(peerID string) { - subs, ok := d.changeNotify[peerID] - if !ok { - return +// snapshotRouterPeersLocked builds the RouterState map for a peer's subscribers. +// Caller MUST hold d.mux. Returns nil when there are no subscribers for peerID +// or when notify is false. The snapshot is consumed later by dispatchRouterPeers +// outside the lock so the channel send cannot stall any d.mux holder. +func (d *Status) snapshotRouterPeersLocked(peerID string, notify bool) map[string]RouterState { + if !notify { + return nil + } + if _, ok := d.changeNotify[peerID]; !ok { + return nil } - - // collect the relevant data for router peers routerPeers := make(map[string]RouterState, len(d.changeNotify)) for pid := range d.changeNotify { s, ok := d.peers[pid] @@ -1031,13 +1094,35 @@ func (d *Status) notifyPeerStateChangeListeners(peerID string) { log.Warnf("router peer not found in peers list: %s", pid) continue } - routerPeers[pid] = RouterState{ Status: s.ConnStatus, Relayed: s.Relayed, Latency: s.Latency, } } + return routerPeers +} + +// dispatchRouterPeers delivers a previously snapshotted router-state map to +// the peer's subscribers. Caller MUST NOT hold d.mux. The method takes a +// fresh, short read of d.changeNotify under the lock to grab subscriber +// channels, then sends outside the lock so a slow consumer cannot block other +// d.mux holders. The send itself stays blocking (only short-circuited by the +// subscriber's context) so peer state transitions are not silently dropped. +func (d *Status) dispatchRouterPeers(peerID string, routerPeers map[string]RouterState) { + if routerPeers == nil { + return + } + + d.mux.Lock() + subsMap, ok := d.changeNotify[peerID] + subs := make([]*StatusChangeSubscription, 0, len(subsMap)) + if ok { + for _, sub := range subsMap { + subs = append(subs, sub) + } + } + d.mux.Unlock() for _, sub := range subs { select { @@ -1047,14 +1132,6 @@ func (d *Status) notifyPeerStateChangeListeners(peerID string) { } } -func (d *Status) notifyPeerListChanged() { - d.notifier.peerListChanged(d.numOfPeers()) -} - -func (d *Status) notifyAddressChanged() { - d.notifier.localAddressChanged(d.localPeer.FQDN, d.localPeer.IP) -} - func (d *Status) numOfPeers() int { return len(d.peers) + len(d.offlinePeers) } From a547fc74edd71268767d258a6ebe8513fa65f467 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Mon, 4 May 2026 11:59:25 +0200 Subject: [PATCH 04/15] [client] Use ctx.Err() instead of gRPC codes.Canceled to detect shutdown (#6019) Detecting shutdown by inspecting the gRPC status code conflates a local context cancellation with a server- or proxy-sent codes.Canceled. When the latter occurs (e.g. an intermediary proxy resets the stream), the retry loop silently terminates and the client never reconnects. Switch to ctx.Err() in the signal Receive loop and management Sync/Job handlers, and stop matching gRPC Canceled/DeadlineExceeded in the flow client's isContextDone helper. With this change, a server-sent Canceled is treated as a transient error and the backoff retry loop continues. --- flow/client/client.go | 15 +++++------- shared/management/client/grpc.go | 39 ++++++++++++-------------------- shared/signal/client/grpc.go | 2 +- 3 files changed, 21 insertions(+), 35 deletions(-) diff --git a/flow/client/client.go b/flow/client/client.go index 8ad637974..180a4b441 100644 --- a/flow/client/client.go +++ b/flow/client/client.go @@ -13,11 +13,9 @@ import ( "github.com/cenkalti/backoff/v4" log "github.com/sirupsen/logrus" "google.golang.org/grpc" - "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/keepalive" - "google.golang.org/grpc/status" nbgrpc "github.com/netbirdio/netbird/client/grpc" "github.com/netbirdio/netbird/flow/proto" @@ -301,12 +299,11 @@ func defaultBackoff(ctx context.Context, interval time.Duration) backoff.BackOff }, ctx) } +// isContextDone reports whether the local context has been canceled or has +// exceeded its deadline. It deliberately does not inspect gRPC status codes: +// a server- or proxy-sent codes.Canceled / codes.DeadlineExceeded must not +// short-circuit our retry loop, since retrying is the correct response when +// the local context is still alive. func isContextDone(err error) bool { - if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { - return true - } - if s, ok := status.FromError(err); ok { - return s.Code() == codes.Canceled || s.Code() == codes.DeadlineExceeded - } - return false + return errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) } diff --git a/shared/management/client/grpc.go b/shared/management/client/grpc.go index 2a51a777d..80625fe06 100644 --- a/shared/management/client/grpc.go +++ b/shared/management/client/grpc.go @@ -246,27 +246,23 @@ func (c *GrpcClient) handleJobStream( for { jobReq, err := c.receiveJobRequest(ctx, stream, serverPubKey) if err != nil { + if ctx.Err() != nil { + log.Debugf("job stream context has been canceled, this usually indicates shutdown") + return nil + } if s, ok := gstatus.FromError(err); ok { switch s.Code() { case codes.PermissionDenied: c.notifyDisconnected(err) return backoff.Permanent(err) // unrecoverable error, propagate to the upper layer - case codes.Canceled: - log.Debugf("job stream context has been canceled, this usually indicates shutdown") - return err case codes.Unimplemented: log.Warn("Job feature is not supported by the current management server version. " + "Please update the management service to use this feature.") return nil - default: - log.Warnf("job stream disconnected, will retry silently. Reason: %v", err) - return err } - } else { - // non-gRPC error - log.Warnf("job stream disconnected, will retry silently. Reason: %v", err) - return err } + log.Warnf("job stream disconnected, will retry silently. Reason: %v", err) + return err } if jobReq == nil || len(jobReq.ID) == 0 { @@ -381,22 +377,15 @@ func (c *GrpcClient) handleSyncStream(ctx context.Context, serverPubKey wgtypes. err = c.receiveUpdatesEvents(stream, serverPubKey, msgHandler) if err != nil { c.notifyDisconnected(err) - if s, ok := gstatus.FromError(err); ok { - switch s.Code() { - case codes.PermissionDenied: - return backoff.Permanent(err) // unrecoverable error, propagate to the upper layer - case codes.Canceled: - log.Debugf("management connection context has been canceled, this usually indicates shutdown") - return nil - default: - log.Warnf("disconnected from the Management service but will retry silently. Reason: %v", err) - return err - } - } else { - // non-gRPC error - log.Warnf("disconnected from the Management service but will retry silently. Reason: %v", err) - return err + if ctx.Err() != nil { + log.Debugf("management connection context has been canceled, this usually indicates shutdown") + return nil } + if s, ok := gstatus.FromError(err); ok && s.Code() == codes.PermissionDenied { + return backoff.Permanent(err) // unrecoverable error, propagate to the upper layer + } + log.Warnf("disconnected from the Management service but will retry silently. Reason: %v", err) + return err } return nil diff --git a/shared/signal/client/grpc.go b/shared/signal/client/grpc.go index d0f598dd7..b245b2296 100644 --- a/shared/signal/client/grpc.go +++ b/shared/signal/client/grpc.go @@ -167,7 +167,7 @@ func (c *GrpcClient) Receive(ctx context.Context, msgHandler func(msg *proto.Mes // start receiving messages from the Signal stream (from other peers through signal) err = c.receive(stream) if err != nil { - if s, ok := status.FromError(err); ok && s.Code() == codes.Canceled { + if ctx.Err() != nil { log.Debugf("signal connection context has been canceled, this usually indicates shutdown") return nil } From 4268a5cfb7046bf713feac4b862539d20769d944 Mon Sep 17 00:00:00 2001 From: Lauri Tirkkonen Date: Tue, 5 May 2026 01:24:52 +0900 Subject: [PATCH 05/15] [client] Use atomic write/rename pattern for ssh config --- client/ssh/config/manager.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/client/ssh/config/manager.go b/client/ssh/config/manager.go index 6e584b2c3..5d69fd35c 100644 --- a/client/ssh/config/manager.go +++ b/client/ssh/config/manager.go @@ -224,15 +224,20 @@ func (m *Manager) buildHostPatterns(peer PeerSSHInfo) []string { func (m *Manager) writeSSHConfig(sshConfig string) error { sshConfigPath := filepath.Join(m.sshConfigDir, m.sshConfigFile) + sshConfigPathTmp := sshConfigPath + ".tmp" if err := os.MkdirAll(m.sshConfigDir, 0755); err != nil { return fmt.Errorf("create SSH config directory %s: %w", m.sshConfigDir, err) } - if err := writeFileWithTimeout(sshConfigPath, []byte(sshConfig), 0644); err != nil { + if err := writeFileWithTimeout(sshConfigPathTmp, []byte(sshConfig), 0644); err != nil { return fmt.Errorf("write SSH config file %s: %w", sshConfigPath, err) } + if err := os.Rename(sshConfigPathTmp, sshConfigPath); err != nil { + return fmt.Errorf("rename ssh config %s -> %s: %w", sshConfigPathTmp, sshConfigPath, err) + } + log.Infof("Created NetBird SSH client config: %s", sshConfigPath) return nil } From bde632c3b2f4fbf1252f0b37209f000661f466d9 Mon Sep 17 00:00:00 2001 From: alexsavio Date: Mon, 4 May 2026 18:49:39 +0200 Subject: [PATCH 06/15] [client] Replace WG interface monitor polling with netlink subscription on Linux (#5857) --- client/internal/wg_iface_monitor.go | 31 +---- client/internal/wg_iface_monitor_linux.go | 134 ++++++++++++++++++++++ client/internal/wg_iface_monitor_other.go | 56 +++++++++ 3 files changed, 195 insertions(+), 26 deletions(-) create mode 100644 client/internal/wg_iface_monitor_linux.go create mode 100644 client/internal/wg_iface_monitor_other.go diff --git a/client/internal/wg_iface_monitor.go b/client/internal/wg_iface_monitor.go index a870c1145..2a2fa2366 100644 --- a/client/internal/wg_iface_monitor.go +++ b/client/internal/wg_iface_monitor.go @@ -6,7 +6,6 @@ import ( "fmt" "net" "runtime" - "time" log "github.com/sirupsen/logrus" @@ -28,6 +27,10 @@ func NewWGIfaceMonitor() *WGIfaceMonitor { // Start begins monitoring the WireGuard interface. // It relies on the provided context cancellation to stop. +// +// On Linux the watcher is event-driven (RTNLGRP_LINK netlink subscription) +// to avoid the allocation churn of repeatedly dumping the kernel link +// table; on other platforms it falls back to a low-frequency poll. func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRestart bool, err error) { defer close(m.done) @@ -56,31 +59,7 @@ func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRes log.Infof("Interface monitor: watching %s (index: %d)", ifaceName, expectedIndex) - ticker := time.NewTicker(2 * time.Second) - defer ticker.Stop() - - for { - select { - case <-ctx.Done(): - log.Infof("Interface monitor: stopped for %s", ifaceName) - return false, fmt.Errorf("wg interface monitor stopped: %v", ctx.Err()) - case <-ticker.C: - currentIndex, err := getInterfaceIndex(ifaceName) - if err != nil { - // Interface was deleted - log.Infof("Interface monitor: %s deleted", ifaceName) - return true, fmt.Errorf("interface %s deleted: %w", ifaceName, err) - } - - // Check if interface index changed (interface was recreated) - if currentIndex != expectedIndex { - log.Infof("Interface monitor: %s recreated (index changed from %d to %d), restarting engine", - ifaceName, expectedIndex, currentIndex) - return true, nil - } - } - } - + return watchInterface(ctx, ifaceName, expectedIndex) } // getInterfaceIndex returns the index of a network interface by name. diff --git a/client/internal/wg_iface_monitor_linux.go b/client/internal/wg_iface_monitor_linux.go new file mode 100644 index 000000000..2662b99d6 --- /dev/null +++ b/client/internal/wg_iface_monitor_linux.go @@ -0,0 +1,134 @@ +//go:build linux + +package internal + +import ( + "context" + "fmt" + "syscall" + + log "github.com/sirupsen/logrus" + "github.com/vishvananda/netlink" +) + +// watchInterface uses an RTNLGRP_LINK netlink subscription to detect +// deletion or recreation of the WireGuard interface. +// +// The previous implementation polled net.InterfaceByName every 2 s, which +// on Linux issues syscall.NetlinkRIB(RTM_GETLINK, ...) and dumps the +// entire kernel link table on every call. On hosts with many veth +// interfaces (containers, bridges) the resulting allocation churn was on +// the order of ~1 GB/day from this single ticker, which on small ARM +// hosts manifested as a slow RSS climb (see netbirdio/netbird#3678). +// +// The event-driven version below allocates only when the kernel actually +// publishes a link event for the tracked interface — typically zero +// allocations between events. +func watchInterface(ctx context.Context, ifaceName string, expectedIndex int) (bool, error) { + done := make(chan struct{}) + defer close(done) + + // Buffer the channel to absorb event bursts (e.g. when many veth + // pairs are created/destroyed at once by container runtimes). + linkChan := make(chan netlink.LinkUpdate, 32) + if err := netlink.LinkSubscribe(linkChan, done); err != nil { + // Return shouldRestart=true so the engine recovers monitoring + // via triggerClientRestart instead of silently losing it for + // the rest of the process lifetime. + return true, fmt.Errorf("subscribe to link updates: %w", err) + } + + // Race window: the interface could have been deleted (or recreated) + // between the initial getInterfaceIndex() in Start and LinkSubscribe + // completing its handshake with the kernel. Re-check explicitly so we + // do not block forever waiting for an event that already fired. + if currentIndex, err := getInterfaceIndex(ifaceName); err != nil { + log.Infof("Interface monitor: %s deleted before subscription completed", ifaceName) + return true, fmt.Errorf("interface %s deleted: %w", ifaceName, err) + } else if currentIndex != expectedIndex { + log.Infof("Interface monitor: %s recreated (index changed from %d to %d) before subscription completed", + ifaceName, expectedIndex, currentIndex) + return true, nil + } + + for { + select { + case <-ctx.Done(): + log.Infof("Interface monitor: stopped for %s", ifaceName) + return false, fmt.Errorf("wg interface monitor stopped: %w", ctx.Err()) + + case update, ok := <-linkChan: + if !ok { + // The vishvananda/netlink subscription goroutine closes + // the channel on receive errors. Signal the engine to + // restart so monitoring is re-established instead of + // silently ending. + log.Warnf("Interface monitor: link subscription channel closed unexpectedly for %s", ifaceName) + return true, fmt.Errorf("link subscription channel closed unexpectedly") + } + if restart, err := inspectLinkEvent(update, ifaceName, expectedIndex); restart { + return true, err + } + } + } +} + +// inspectLinkEvent classifies a single netlink link update against the +// tracked WireGuard interface. It returns (true, err) when the engine +// should restart monitoring; (false, nil) means the event is unrelated +// and the caller should keep waiting. +// +// The error component, when non-nil, describes the kernel-side reason +// (deletion or rename); the recreation case returns (true, nil) since +// no error condition is reported. +func inspectLinkEvent(update netlink.LinkUpdate, ifaceName string, expectedIndex int) (bool, error) { + eventIndex := int(update.Index) + eventName := "" + if attrs := update.Attrs(); attrs != nil { + eventName = attrs.Name + } + + switch update.Header.Type { + case syscall.RTM_DELLINK: + return inspectDelLink(eventIndex, ifaceName, expectedIndex) + case syscall.RTM_NEWLINK: + return inspectNewLink(eventIndex, eventName, ifaceName, expectedIndex) + } + return false, nil +} + +// inspectDelLink reports a restart when an RTM_DELLINK arrives for the +// tracked interface index. +func inspectDelLink(eventIndex int, ifaceName string, expectedIndex int) (bool, error) { + if eventIndex != expectedIndex { + return false, nil + } + log.Infof("Interface monitor: %s deleted", ifaceName) + return true, fmt.Errorf("interface %s deleted", ifaceName) +} + +// inspectNewLink reports a restart when an RTM_NEWLINK either: +// +// 1. Introduces a link with our name at a different index (recreation +// after a delete), or +// +// 2. Reports a link still at our index but with a different name +// (in-place rename). The previous polling implementation caught +// this implicitly because net.InterfaceByName(ifaceName) would +// start failing; the event-driven version has to test it. +// +// Same name + same index is just a flag/state change on the existing +// interface and is ignored. +func inspectNewLink(eventIndex int, eventName, ifaceName string, expectedIndex int) (bool, error) { + if eventName == ifaceName && eventIndex != expectedIndex { + log.Infof("Interface monitor: %s recreated (index changed from %d to %d), restarting engine", + ifaceName, expectedIndex, eventIndex) + return true, nil + } + if eventIndex == expectedIndex && eventName != "" && eventName != ifaceName { + log.Infof("Interface monitor: %s renamed to %s (index %d), restarting engine", + ifaceName, eventName, expectedIndex) + return true, fmt.Errorf("interface %s renamed to %s", ifaceName, eventName) + } + return false, nil +} diff --git a/client/internal/wg_iface_monitor_other.go b/client/internal/wg_iface_monitor_other.go new file mode 100644 index 000000000..afebbf4df --- /dev/null +++ b/client/internal/wg_iface_monitor_other.go @@ -0,0 +1,56 @@ +//go:build !linux + +package internal + +import ( + "context" + "fmt" + "time" + + log "github.com/sirupsen/logrus" +) + +// watchInterface polls net.InterfaceByName at a fixed interval to detect +// deletion or recreation of the WireGuard interface. +// +// This is the fallback used on non-Linux desktop and server platforms +// (darwin, windows, freebsd). It is also compiled on android and ios so +// the package builds on every supported GOOS, but it is never reached +// at runtime there because Start() in wg_iface_monitor.go exits early +// on mobile platforms. +// +// The Linux build (see wg_iface_monitor_linux.go) uses an event-driven +// RTNLGRP_LINK netlink subscription instead, because on Linux +// net.InterfaceByName issues syscall.NetlinkRIB(RTM_GETLINK, ...) which +// dumps the entire kernel link table on every call and produces +// significant allocation churn (netbirdio/netbird#3678). +// +// Windows is also reported in #3678 as affected by RSS climb. A future +// follow-up could implement an event-driven watcher there using +// NotifyIpInterfaceChange from iphlpapi. +func watchInterface(ctx context.Context, ifaceName string, expectedIndex int) (bool, error) { + ticker := time.NewTicker(2 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + log.Infof("Interface monitor: stopped for %s", ifaceName) + return false, fmt.Errorf("wg interface monitor stopped: %w", ctx.Err()) + case <-ticker.C: + currentIndex, err := getInterfaceIndex(ifaceName) + if err != nil { + // Interface was deleted + log.Infof("Interface monitor: %s deleted", ifaceName) + return true, fmt.Errorf("interface %s deleted: %w", ifaceName, err) + } + + // Check if interface index changed (interface was recreated) + if currentIndex != expectedIndex { + log.Infof("Interface monitor: %s recreated (index changed from %d to %d), restarting engine", + ifaceName, expectedIndex, currentIndex) + return true, nil + } + } + } +} From 104990dfdd5d9eae1760da5f57ba47a2df7052a6 Mon Sep 17 00:00:00 2001 From: JungwooShin <166088609+typhoon1217@users.noreply.github.com> Date: Tue, 5 May 2026 01:59:29 +0900 Subject: [PATCH 07/15] [client] Display QR code for device auth login URL (#5415) --- client/cmd/login.go | 14 +++++++++++--- client/cmd/qr.go | 25 +++++++++++++++++++++++++ client/cmd/qr_test.go | 26 ++++++++++++++++++++++++++ client/cmd/up.go | 5 +++++ go.mod | 2 ++ go.sum | 4 ++++ 6 files changed, 73 insertions(+), 3 deletions(-) create mode 100644 client/cmd/qr.go create mode 100644 client/cmd/qr_test.go diff --git a/client/cmd/login.go b/client/cmd/login.go index 4521a67c9..bd37e30f1 100644 --- a/client/cmd/login.go +++ b/client/cmd/login.go @@ -10,6 +10,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/spf13/cobra" + "golang.org/x/term" "google.golang.org/grpc/codes" gstatus "google.golang.org/grpc/status" @@ -23,6 +24,7 @@ import ( func init() { loginCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc) + loginCmd.PersistentFlags().BoolVar(&showQR, showQRFlag, false, showQRDesc) loginCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc) loginCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) Netbird config file location") } @@ -256,7 +258,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string, } func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.LoginResponse, client proto.DaemonServiceClient, pm *profilemanager.ProfileManager) error { - openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser) + openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser, showQR) resp, err := client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName}) if err != nil { @@ -324,7 +326,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err) } - openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser) + openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser, showQR) tokenInfo, err := oAuthFlow.WaitToken(context.TODO(), flowInfo) if err != nil { @@ -334,7 +336,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro return &tokenInfo, nil } -func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBrowser bool) { +func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBrowser, showQR bool) { var codeMsg string if userCode != "" && !strings.Contains(verificationURIComplete, userCode) { codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode) @@ -348,6 +350,12 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro verificationURIComplete + " " + codeMsg) } + if showQR { + if f, ok := cmd.OutOrStdout().(*os.File); ok && term.IsTerminal(int(f.Fd())) { + printQRCode(f, verificationURIComplete) + } + } + cmd.Println("") if !noBrowser { diff --git a/client/cmd/qr.go b/client/cmd/qr.go new file mode 100644 index 000000000..8b2c489ff --- /dev/null +++ b/client/cmd/qr.go @@ -0,0 +1,25 @@ +package cmd + +import ( + "io" + + "github.com/mdp/qrterminal/v3" +) + +// printQRCode prints a QR code for the given URL to the writer. +// Called only when the user explicitly requests QR output via --qr. +func printQRCode(w io.Writer, url string) { + if url == "" { + return + } + qrterminal.GenerateWithConfig(url, qrterminal.Config{ + Level: qrterminal.M, + Writer: w, + HalfBlocks: true, + BlackChar: qrterminal.BLACK_BLACK, + WhiteChar: qrterminal.WHITE_WHITE, + BlackWhiteChar: qrterminal.BLACK_WHITE, + WhiteBlackChar: qrterminal.WHITE_BLACK, + QuietZone: qrterminal.QUIET_ZONE, + }) +} diff --git a/client/cmd/qr_test.go b/client/cmd/qr_test.go new file mode 100644 index 000000000..d12705b9e --- /dev/null +++ b/client/cmd/qr_test.go @@ -0,0 +1,26 @@ +package cmd + +import ( + "bytes" + "testing" +) + +func TestPrintQRCode_EmptyURL(t *testing.T) { + var buf bytes.Buffer + + printQRCode(&buf, "") + + if buf.Len() != 0 { + t.Error("expected no output for empty URL") + } +} + +func TestPrintQRCode_WritesOutput(t *testing.T) { + var buf bytes.Buffer + + printQRCode(&buf, "https://example.com/auth") + + if buf.Len() == 0 { + t.Error("expected QR code output for non-empty URL") + } +} diff --git a/client/cmd/up.go b/client/cmd/up.go index f5766522a..f4136cb23 100644 --- a/client/cmd/up.go +++ b/client/cmd/up.go @@ -39,6 +39,9 @@ const ( noBrowserFlag = "no-browser" noBrowserDesc = "do not open the browser for SSO login" + showQRFlag = "qr" + showQRDesc = "show QR code for the SSO login URL (useful for headless machines without browser access)" + profileNameFlag = "profile" profileNameDesc = "profile name to use for the login. If not specified, the last used profile will be used." ) @@ -48,6 +51,7 @@ var ( dnsLabels []string dnsLabelsValidated domain.List noBrowser bool + showQR bool profileName string configPath string @@ -80,6 +84,7 @@ func init() { ) upCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc) + upCmd.PersistentFlags().BoolVar(&showQR, showQRFlag, false, showQRDesc) upCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc) upCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) NetBird config file location. ") diff --git a/go.mod b/go.mod index 8e6a481d2..e82e6b10d 100644 --- a/go.mod +++ b/go.mod @@ -71,6 +71,7 @@ require ( github.com/libp2p/go-netroute v0.2.1 github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81 github.com/mdlayher/socket v0.5.1 + github.com/mdp/qrterminal/v3 v3.2.1 github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42 @@ -308,6 +309,7 @@ require ( google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 // indirect gopkg.in/square/go-jose.v2 v2.6.0 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect + rsc.io/qr v0.2.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect ) diff --git a/go.sum b/go.sum index 2abf55142..a71f47d8d 100644 --- a/go.sum +++ b/go.sum @@ -415,6 +415,8 @@ github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 h1:A1Cq6Ysb0GM0 github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42/go.mod h1:BB4YCPDOzfy7FniQ/lxuYQ3dgmM2cZumHbK8RpTjN2o= github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos= github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ= +github.com/mdp/qrterminal/v3 v3.2.1 h1:6+yQjiiOsSuXT5n9/m60E54vdgFsw0zhADHhHLrFet4= +github.com/mdp/qrterminal/v3 v3.2.1/go.mod h1:jOTmXvnBsMy5xqLniO0R++Jmjs2sTm9dFSuQ5kpz/SU= github.com/mholt/acmez/v2 v2.0.1 h1:3/3N0u1pLjMK4sNEAFSI+bcvzbPhRpY383sy1kLHJ6k= github.com/mholt/acmez/v2 v2.0.1/go.mod h1:fX4c9r5jYwMyMsC+7tkYRxHibkOTgta5DIFGoe67e1U= github.com/miekg/dns v1.1.59 h1:C9EXc/UToRwKLhK5wKU/I4QVsBUc8kE6MkHBkeypWZs= @@ -915,3 +917,5 @@ gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU= gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU= gvisor.dev/gvisor v0.0.0-20260219192049-0f2374377e89 h1:mGJaeA61P8dEHTqdvAgc70ZIV3QoUoJcXCRyyjO26OA= gvisor.dev/gvisor v0.0.0-20260219192049-0f2374377e89/go.mod h1:QkHjoMIBaYtpVufgwv3keYAbln78mBoCuShZrPrer1Q= +rsc.io/qr v0.2.0 h1:6vBLea5/NRMVTz8V66gipeLycZMl/+UlFmk8DvqQ6WY= +rsc.io/qr v0.2.0/go.mod h1:IF+uZjkb9fqyeF/4tlBoynqmQxUoPfWEKh921coOuXs= From 77a0992dc21644b2b041b7ff45000e232b928158 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 5 May 2026 02:59:41 +0900 Subject: [PATCH 08/15] [misc] Disable govet inline analyzer and tidy go.mod (#6066) --- .golangci.yaml | 5 +++++ go.mod | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/.golangci.yaml b/.golangci.yaml index d81ad1377..900af4ac0 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -58,6 +58,11 @@ linters: govet: enable: - nilness + disable: + # The inline analyzer flags x/exp/maps Clone/Clear with //go:fix inline + # directives but cannot perform the rewrite due to generic type + # parameter inference limitations in the Go inliner. + - inline enable-all: false revive: rules: diff --git a/go.mod b/go.mod index e82e6b10d..e24312a1a 100644 --- a/go.mod +++ b/go.mod @@ -309,8 +309,8 @@ require ( google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 // indirect gopkg.in/square/go-jose.v2 v2.6.0 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect - rsc.io/qr v0.2.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect + rsc.io/qr v0.2.0 // indirect ) replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 From 97db82492946bce8db6a4ef820ef2d59365f1512 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Mon, 4 May 2026 20:43:25 +0200 Subject: [PATCH 09/15] [management] fix proxy reconnect (#6063) --- .../modules/reverseproxy/proxy/manager.go | 6 +- .../reverseproxy/proxy/manager/manager.go | 40 +++++---- .../reverseproxy/proxy/manager_mock.go | 29 +++---- .../modules/reverseproxy/proxy/proxy.go | 3 +- management/internals/shared/grpc/proxy.go | 47 ++++++++--- management/server/store/sql_store.go | 44 ++++++---- management/server/store/store.go | 3 +- management/server/store/store_mock.go | 23 +++++- proxy/management_integration_test.go | 81 +++++++++++++++++-- 9 files changed, 199 insertions(+), 77 deletions(-) diff --git a/management/internals/modules/reverseproxy/proxy/manager.go b/management/internals/modules/reverseproxy/proxy/manager.go index aa7cd8630..53c52b3aa 100644 --- a/management/internals/modules/reverseproxy/proxy/manager.go +++ b/management/internals/modules/reverseproxy/proxy/manager.go @@ -11,9 +11,9 @@ import ( // Manager defines the interface for proxy operations type Manager interface { - Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *Capabilities) error - Disconnect(ctx context.Context, proxyID string) error - Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error + Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *Capabilities) (*Proxy, error) + Disconnect(ctx context.Context, proxyID, sessionID string) error + Heartbeat(ctx context.Context, p *Proxy) error GetActiveClusterAddresses(ctx context.Context) ([]string, error) GetActiveClusters(ctx context.Context) ([]Cluster, error) ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool diff --git a/management/internals/modules/reverseproxy/proxy/manager/manager.go b/management/internals/modules/reverseproxy/proxy/manager/manager.go index d13334e83..341e8c943 100644 --- a/management/internals/modules/reverseproxy/proxy/manager/manager.go +++ b/management/internals/modules/reverseproxy/proxy/manager/manager.go @@ -13,7 +13,8 @@ import ( // store defines the interface for proxy persistence operations type store interface { SaveProxy(ctx context.Context, p *proxy.Proxy) error - UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error + DisconnectProxy(ctx context.Context, proxyID, sessionID string) error + UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool @@ -43,7 +44,7 @@ func NewManager(store store, meter metric.Meter) (*Manager, error) { // Connect registers a new proxy connection in the database. // capabilities may be nil for old proxies that do not report them. -func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *proxy.Capabilities) error { +func (m Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *proxy.Capabilities) (*proxy.Proxy, error) { now := time.Now() var caps proxy.Capabilities if capabilities != nil { @@ -51,6 +52,7 @@ func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress } p := &proxy.Proxy{ ID: proxyID, + SessionID: sessionID, ClusterAddress: clusterAddress, IPAddress: ipAddress, LastSeen: now, @@ -61,48 +63,42 @@ func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress if err := m.store.SaveProxy(ctx, p); err != nil { log.WithContext(ctx).Errorf("failed to register proxy %s: %v", proxyID, err) - return err + return nil, err } log.WithContext(ctx).WithFields(log.Fields{ "proxyID": proxyID, + "sessionID": sessionID, "clusterAddress": clusterAddress, "ipAddress": ipAddress, }).Info("proxy connected") - return nil + return p, nil } -// Disconnect marks a proxy as disconnected in the database -func (m Manager) Disconnect(ctx context.Context, proxyID string) error { - now := time.Now() - p := &proxy.Proxy{ - ID: proxyID, - Status: "disconnected", - DisconnectedAt: &now, - LastSeen: now, - } - - if err := m.store.SaveProxy(ctx, p); err != nil { - log.WithContext(ctx).Errorf("failed to disconnect proxy %s: %v", proxyID, err) +// Disconnect marks a proxy as disconnected in the database. +func (m Manager) Disconnect(ctx context.Context, proxyID, sessionID string) error { + if err := m.store.DisconnectProxy(ctx, proxyID, sessionID); err != nil { + log.WithContext(ctx).Errorf("failed to disconnect proxy %s session %s: %v", proxyID, sessionID, err) return err } log.WithContext(ctx).WithFields(log.Fields{ - "proxyID": proxyID, + "proxyID": proxyID, + "sessionID": sessionID, }).Info("proxy disconnected") return nil } -// Heartbeat updates the proxy's last seen timestamp -func (m Manager) Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error { - if err := m.store.UpdateProxyHeartbeat(ctx, proxyID, clusterAddress, ipAddress); err != nil { - log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", proxyID, err) +// Heartbeat updates the proxy's last seen timestamp. +func (m Manager) Heartbeat(ctx context.Context, p *proxy.Proxy) error { + if err := m.store.UpdateProxyHeartbeat(ctx, p); err != nil { + log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", p.ID, err) return err } - log.WithContext(ctx).Tracef("updated heartbeat for proxy %s", proxyID) + log.WithContext(ctx).Tracef("updated heartbeat for proxy %s session %s", p.ID, p.SessionID) m.metrics.IncrementProxyHeartbeatCount() return nil } diff --git a/management/internals/modules/reverseproxy/proxy/manager_mock.go b/management/internals/modules/reverseproxy/proxy/manager_mock.go index 282ca0ba5..98d97b3c6 100644 --- a/management/internals/modules/reverseproxy/proxy/manager_mock.go +++ b/management/internals/modules/reverseproxy/proxy/manager_mock.go @@ -93,31 +93,32 @@ func (mr *MockManagerMockRecorder) ClusterSupportsCrowdSec(ctx, clusterAddr inte } // Connect mocks base method. -func (m *MockManager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *Capabilities) error { +func (m *MockManager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *Capabilities) (*Proxy, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Connect", ctx, proxyID, clusterAddress, ipAddress, capabilities) - ret0, _ := ret[0].(error) - return ret0 + ret := m.ctrl.Call(m, "Connect", ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities) + ret0, _ := ret[0].(*Proxy) + ret1, _ := ret[1].(error) + return ret0, ret1 } // Connect indicates an expected call of Connect. -func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, clusterAddress, ipAddress, capabilities interface{}) *gomock.Call { +func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, clusterAddress, ipAddress, capabilities) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities) } // Disconnect mocks base method. -func (m *MockManager) Disconnect(ctx context.Context, proxyID string) error { +func (m *MockManager) Disconnect(ctx context.Context, proxyID, sessionID string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Disconnect", ctx, proxyID) + ret := m.ctrl.Call(m, "Disconnect", ctx, proxyID, sessionID) ret0, _ := ret[0].(error) return ret0 } // Disconnect indicates an expected call of Disconnect. -func (mr *MockManagerMockRecorder) Disconnect(ctx, proxyID interface{}) *gomock.Call { +func (mr *MockManagerMockRecorder) Disconnect(ctx, proxyID, sessionID interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Disconnect", reflect.TypeOf((*MockManager)(nil).Disconnect), ctx, proxyID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Disconnect", reflect.TypeOf((*MockManager)(nil).Disconnect), ctx, proxyID, sessionID) } // GetActiveClusterAddresses mocks base method. @@ -151,17 +152,17 @@ func (mr *MockManagerMockRecorder) GetActiveClusters(ctx interface{}) *gomock.Ca } // Heartbeat mocks base method. -func (m *MockManager) Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error { +func (m *MockManager) Heartbeat(ctx context.Context, p *Proxy) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Heartbeat", ctx, proxyID, clusterAddress, ipAddress) + ret := m.ctrl.Call(m, "Heartbeat", ctx, p) ret0, _ := ret[0].(error) return ret0 } // Heartbeat indicates an expected call of Heartbeat. -func (mr *MockManagerMockRecorder) Heartbeat(ctx, proxyID, clusterAddress, ipAddress interface{}) *gomock.Call { +func (mr *MockManagerMockRecorder) Heartbeat(ctx, p interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, proxyID, clusterAddress, ipAddress) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, p) } // MockController is a mock of Controller interface. diff --git a/management/internals/modules/reverseproxy/proxy/proxy.go b/management/internals/modules/reverseproxy/proxy/proxy.go index 339c82446..dcedb8811 100644 --- a/management/internals/modules/reverseproxy/proxy/proxy.go +++ b/management/internals/modules/reverseproxy/proxy/proxy.go @@ -18,12 +18,13 @@ type Capabilities struct { // Proxy represents a reverse proxy instance type Proxy struct { ID string `gorm:"primaryKey;type:varchar(255)"` + SessionID string `gorm:"type:varchar(36)"` ClusterAddress string `gorm:"type:varchar(255);not null;index:idx_proxy_cluster_status"` IPAddress string `gorm:"type:varchar(45)"` LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"` ConnectedAt *time.Time DisconnectedAt *time.Time - Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"` + Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"` Capabilities Capabilities `gorm:"embedded"` CreatedAt time.Time UpdatedAt time.Time diff --git a/management/internals/shared/grpc/proxy.go b/management/internals/shared/grpc/proxy.go index a5e352e75..d811a0f69 100644 --- a/management/internals/shared/grpc/proxy.go +++ b/management/internals/shared/grpc/proxy.go @@ -16,6 +16,7 @@ import ( "time" "github.com/coreos/go-oidc/v3/oidc" + "github.com/google/uuid" log "github.com/sirupsen/logrus" "golang.org/x/oauth2" "google.golang.org/grpc/codes" @@ -89,6 +90,7 @@ const pkceVerifierTTL = 10 * time.Minute // proxyConnection represents a connected proxy type proxyConnection struct { proxyID string + sessionID string address string capabilities *proto.ProxyCapabilities stream proto.ProxyService_GetMappingUpdateServer @@ -166,9 +168,22 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest return status.Errorf(codes.InvalidArgument, "proxy address is invalid") } + sessionID := uuid.NewString() + + if old, loaded := s.connectedProxies.Load(proxyID); loaded { + oldConn := old.(*proxyConnection) + log.WithFields(log.Fields{ + "proxy_id": proxyID, + "old_session_id": oldConn.sessionID, + "new_session_id": sessionID, + }).Info("Superseding existing proxy connection") + oldConn.cancel() + } + connCtx, cancel := context.WithCancel(ctx) conn := &proxyConnection{ proxyID: proxyID, + sessionID: sessionID, address: proxyAddress, capabilities: req.GetCapabilities(), stream: stream, @@ -188,12 +203,13 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest caps = &proxy.Capabilities{ SupportsCustomPorts: c.SupportsCustomPorts, RequireSubdomain: c.RequireSubdomain, - SupportsCrowdsec: c.SupportsCrowdsec, + SupportsCrowdsec: c.SupportsCrowdsec, } } - if err := s.proxyManager.Connect(ctx, proxyID, proxyAddress, peerInfo, caps); err != nil { + proxyRecord, err := s.proxyManager.Connect(ctx, proxyID, sessionID, proxyAddress, peerInfo, caps) + if err != nil { log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err) - s.connectedProxies.Delete(proxyID) + s.connectedProxies.CompareAndDelete(proxyID, conn) if unregErr := s.proxyController.UnregisterProxyFromCluster(ctx, conn.address, proxyID); unregErr != nil { log.WithContext(ctx).Debugf("cleanup after Connect failure for proxy %s: %v", proxyID, unregErr) } @@ -202,22 +218,27 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest log.WithFields(log.Fields{ "proxy_id": proxyID, + "session_id": sessionID, "address": proxyAddress, "cluster_addr": proxyAddress, "total_proxies": len(s.GetConnectedProxies()), }).Info("Proxy registered in cluster") defer func() { - if err := s.proxyManager.Disconnect(context.Background(), proxyID); err != nil { - log.Warnf("Failed to mark proxy %s as disconnected: %v", proxyID, err) + if !s.connectedProxies.CompareAndDelete(proxyID, conn) { + log.Infof("Proxy %s session %s: skipping cleanup, superseded by new connection", proxyID, sessionID) + cancel() + return } - s.connectedProxies.Delete(proxyID) if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); err != nil { log.Warnf("Failed to unregister proxy %s from cluster: %v", proxyID, err) } + if err := s.proxyManager.Disconnect(context.Background(), proxyID, sessionID); err != nil { + log.Warnf("Failed to mark proxy %s as disconnected: %v", proxyID, err) + } cancel() - log.Infof("Proxy %s disconnected", proxyID) + log.Infof("Proxy %s session %s disconnected", proxyID, sessionID) }() if err := s.sendSnapshot(ctx, conn); err != nil { @@ -227,29 +248,31 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest errChan := make(chan error, 2) go s.sender(conn, errChan) - // Start heartbeat goroutine - go s.heartbeat(connCtx, proxyID, proxyAddress, peerInfo) + go s.heartbeat(connCtx, proxyRecord) select { case err := <-errChan: + log.WithContext(ctx).Warnf("Failed to send update: %v", err) return fmt.Errorf("send update to proxy %s: %w", proxyID, err) case <-connCtx.Done(): + log.WithContext(ctx).Infof("Proxy %s context canceled", proxyID) return connCtx.Err() } } // heartbeat updates the proxy's last_seen timestamp every minute -func (s *ProxyServiceServer) heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) { +func (s *ProxyServiceServer) heartbeat(ctx context.Context, p *proxy.Proxy) { ticker := time.NewTicker(1 * time.Minute) defer ticker.Stop() for { select { case <-ticker.C: - if err := s.proxyManager.Heartbeat(ctx, proxyID, clusterAddress, ipAddress); err != nil { - log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", proxyID, err) + if err := s.proxyManager.Heartbeat(ctx, p); err != nil { + log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", p.ID, err) } case <-ctx.Done(): + log.WithContext(ctx).Infof("proxy %s heartbeat stopped: context canceled", p.ID) return } } diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 0a716d08d..1fa3d08ee 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -5437,13 +5437,35 @@ func (s *SqlStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error { return nil } -// UpdateProxyHeartbeat updates the last_seen timestamp for a proxy or creates a new entry if it doesn't exist -func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error { +// DisconnectProxy marks a proxy as disconnected only if the session ID matches. +// This prevents a slow-to-close old session from overwriting a newer reconnection. +func (s *SqlStore) DisconnectProxy(ctx context.Context, proxyID, sessionID string) error { + now := time.Now() + result := s.db. + Model(&proxy.Proxy{}). + Where("id = ? AND session_id = ?", proxyID, sessionID). + Updates(map[string]any{ + "status": "disconnected", + "disconnected_at": now, + "last_seen": now, + }) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to disconnect proxy %s session %s: %v", proxyID, sessionID, result.Error) + return status.Errorf(status.Internal, "failed to disconnect proxy") + } + if result.RowsAffected == 0 { + log.WithContext(ctx).Debugf("proxy %s session %s: no row updated (superseded by newer session)", proxyID, sessionID) + } + return nil +} + +// UpdateProxyHeartbeat updates the last_seen timestamp for the proxy's current session. +func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error { now := time.Now() result := s.db. Model(&proxy.Proxy{}). - Where("id = ? AND status = ?", proxyID, "connected"). + Where("id = ? AND session_id = ?", p.ID, p.SessionID). Update("last_seen", now) if result.Error != nil { @@ -5452,17 +5474,11 @@ func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAdd } if result.RowsAffected == 0 { - p := &proxy.Proxy{ - ID: proxyID, - ClusterAddress: clusterAddress, - IPAddress: ipAddress, - LastSeen: now, - ConnectedAt: &now, - Status: "connected", - } - if err := s.db.Save(p).Error; err != nil { - log.WithContext(ctx).Errorf("failed to create proxy on heartbeat: %v", err) - return status.Errorf(status.Internal, "failed to create proxy on heartbeat") + p.LastSeen = now + p.ConnectedAt = &now + p.Status = "connected" + if err := s.db.Create(p).Error; err != nil { + log.WithContext(ctx).Debugf("proxy %s session %s: heartbeat fallback insert skipped: %v", p.ID, p.SessionID, err) } } diff --git a/management/server/store/store.go b/management/server/store/store.go index 0d8b0678a..447c85547 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -284,7 +284,8 @@ type Store interface { DeleteServiceTargets(ctx context.Context, accountID string, serviceID string) error SaveProxy(ctx context.Context, proxy *proxy.Proxy) error - UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error + DisconnectProxy(ctx context.Context, proxyID, sessionID string) error + UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool diff --git a/management/server/store/store_mock.go b/management/server/store/store_mock.go index beee13d96..d8bd826a8 100644 --- a/management/server/store/store_mock.go +++ b/management/server/store/store_mock.go @@ -178,6 +178,7 @@ func (mr *MockStoreMockRecorder) GetClusterSupportsCrowdSec(ctx, clusterAddr int mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCrowdSec", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCrowdSec), ctx, clusterAddr) } + // Close mocks base method. func (m *MockStore) Close(ctx context.Context) error { m.ctrl.T.Helper() @@ -2799,6 +2800,20 @@ func (mr *MockStoreMockRecorder) SaveProxy(ctx, proxy interface{}) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveProxy", reflect.TypeOf((*MockStore)(nil).SaveProxy), ctx, proxy) } +// DisconnectProxy mocks base method. +func (m *MockStore) DisconnectProxy(ctx context.Context, proxyID, sessionID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DisconnectProxy", ctx, proxyID, sessionID) + ret0, _ := ret[0].(error) + return ret0 +} + +// DisconnectProxy indicates an expected call of DisconnectProxy. +func (mr *MockStoreMockRecorder) DisconnectProxy(ctx, proxyID, sessionID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectProxy", reflect.TypeOf((*MockStore)(nil).DisconnectProxy), ctx, proxyID, sessionID) +} + // SaveProxyAccessToken mocks base method. func (m *MockStore) SaveProxyAccessToken(ctx context.Context, token *types2.ProxyAccessToken) error { m.ctrl.T.Helper() @@ -2995,17 +3010,17 @@ func (mr *MockStoreMockRecorder) UpdateGroups(ctx, accountID, groups interface{} } // UpdateProxyHeartbeat mocks base method. -func (m *MockStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error { +func (m *MockStore) UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateProxyHeartbeat", ctx, proxyID, clusterAddress, ipAddress) + ret := m.ctrl.Call(m, "UpdateProxyHeartbeat", ctx, p) ret0, _ := ret[0].(error) return ret0 } // UpdateProxyHeartbeat indicates an expected call of UpdateProxyHeartbeat. -func (mr *MockStoreMockRecorder) UpdateProxyHeartbeat(ctx, proxyID, clusterAddress, ipAddress interface{}) *gomock.Call { +func (mr *MockStoreMockRecorder) UpdateProxyHeartbeat(ctx, p interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateProxyHeartbeat", reflect.TypeOf((*MockStore)(nil).UpdateProxyHeartbeat), ctx, proxyID, clusterAddress, ipAddress) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateProxyHeartbeat", reflect.TypeOf((*MockStore)(nil).UpdateProxyHeartbeat), ctx, p) } // UpdateService mocks base method. diff --git a/proxy/management_integration_test.go b/proxy/management_integration_test.go index 4b1ecf922..e9eae3210 100644 --- a/proxy/management_integration_test.go +++ b/proxy/management_integration_test.go @@ -201,15 +201,15 @@ func (m *testAccessLogManager) GetAllAccessLogs(_ context.Context, _, _ string, // testProxyManager is a mock implementation of proxy.Manager for testing. type testProxyManager struct{} -func (m *testProxyManager) Connect(_ context.Context, _, _, _ string, _ *nbproxy.Capabilities) error { +func (m *testProxyManager) Connect(_ context.Context, proxyID, sessionID, _, _ string, _ *nbproxy.Capabilities) (*nbproxy.Proxy, error) { + return &nbproxy.Proxy{ID: proxyID, SessionID: sessionID, Status: "connected"}, nil +} + +func (m *testProxyManager) Disconnect(_ context.Context, _, _ string) error { return nil } -func (m *testProxyManager) Disconnect(_ context.Context, _ string) error { - return nil -} - -func (m *testProxyManager) Heartbeat(_ context.Context, _, _, _ string) error { +func (m *testProxyManager) Heartbeat(_ context.Context, _ *nbproxy.Proxy) error { return nil } @@ -656,3 +656,72 @@ func TestIntegration_ProxyConnection_MultipleProxiesReceiveUpdates(t *testing.T) assert.Equal(t, 2, count, "Proxy %s should receive 2 mappings", proxyID) } } + +// TestIntegration_ProxyConnection_FastReconnectDoesNotLoseState verifies that +// when a proxy reconnects before the old stream's cleanup runs, the new +// connection is NOT removed by the stale defer. +func TestIntegration_ProxyConnection_FastReconnectDoesNotLoseState(t *testing.T) { + setup := setupIntegrationTest(t) + defer setup.cleanup() + + clusterAddress := "test.proxy.io" + proxyID := "test-proxy-race" + + conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer conn.Close() + + client := proto.NewProxyServiceClient(conn) + + ctx1, cancel1 := context.WithCancel(context.Background()) + stream1, err := client.GetMappingUpdate(ctx1, &proto.GetMappingUpdateRequest{ + ProxyId: proxyID, + Version: "test-v1", + Address: clusterAddress, + }) + require.NoError(t, err) + + for i := 0; i < 2; i++ { + _, err := stream1.Recv() + require.NoError(t, err) + } + + require.Contains(t, setup.proxyService.GetConnectedProxies(), proxyID, + "proxy should be registered after first connection") + + ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel2() + + stream2, err := client.GetMappingUpdate(ctx2, &proto.GetMappingUpdateRequest{ + ProxyId: proxyID, + Version: "test-v1", + Address: clusterAddress, + }) + require.NoError(t, err) + + for i := 0; i < 2; i++ { + _, err := stream2.Recv() + require.NoError(t, err) + } + + cancel1() + + time.Sleep(200 * time.Millisecond) + + assert.Contains(t, setup.proxyService.GetConnectedProxies(), proxyID, + "proxy should still be registered after old connection cleanup — old defer must not remove new connection") + + setup.proxyService.SendServiceUpdate(&proto.GetMappingUpdateResponse{ + Mapping: []*proto.ProxyMapping{{ + Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, + Id: "rp-1", + AccountId: "test-account-1", + Domain: "app1.test.proxy.io", + }}, + }) + + msg, err := stream2.Recv() + require.NoError(t, err, "new stream should still receive updates") + require.NotEmpty(t, msg.GetMapping(), "update should contain the mapping") + assert.Equal(t, "rp-1", msg.GetMapping()[0].GetId()) +} From cd8e71002fe6c8bf1197ca30abe5d2f5e813adc1 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 5 May 2026 22:26:27 +0900 Subject: [PATCH 10/15] [client] Bump go-netroute to v0.4.0 and drop fork (#6062) --- client/internal/portforward/pcp/nat.go | 15 ++++++- .../systemops/systemops_generic.go | 16 ++++++++ .../systemops/systemops_generic_test.go | 6 ++- .../systemops/v6route_bsd_test.go | 30 ++++++++++++++ .../systemops/v6route_linux_test.go | 41 +++++++++++++++++++ .../systemops/v6route_windows_test.go | 34 +++++++++++++++ go.mod | 18 ++++---- go.sum | 32 +++++++-------- 8 files changed, 163 insertions(+), 29 deletions(-) create mode 100644 client/internal/routemanager/systemops/v6route_bsd_test.go create mode 100644 client/internal/routemanager/systemops/v6route_linux_test.go create mode 100644 client/internal/routemanager/systemops/v6route_windows_test.go diff --git a/client/internal/portforward/pcp/nat.go b/client/internal/portforward/pcp/nat.go index 1dc24274b..6491e7367 100644 --- a/client/internal/portforward/pcp/nat.go +++ b/client/internal/portforward/pcp/nat.go @@ -5,6 +5,7 @@ import ( "fmt" "net" "net/netip" + "runtime" "sync" "time" @@ -177,7 +178,12 @@ func getDefaultGateway() (gateway net.IP, localIP net.IP, err error) { return nil, nil, err } - _, gateway, localIP, err = router.Route(net.IPv4zero) + dst := net.IPv4zero + if runtime.GOOS == "linux" { + // go-netroute v0.4.0 rejects unspecified destinations client-side on Linux. + dst = net.IPv4(0, 0, 0, 1) + } + _, gateway, localIP, err = router.Route(dst) if err != nil { return nil, nil, err } @@ -196,7 +202,12 @@ func getDefaultGateway6() (gateway net.IP, localIP net.IP, err error) { return nil, nil, err } - _, gateway, localIP, err = router.Route(net.IPv6zero) + dst := net.IPv6zero + if runtime.GOOS == "linux" { + // ::2 + dst = net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2} + } + _, gateway, localIP, err = router.Route(dst) if err != nil { return nil, nil, err } diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go index 4211eb057..bf7b95a28 100644 --- a/client/internal/routemanager/systemops/systemops_generic.go +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -342,6 +342,22 @@ func GetNextHop(ip netip.Addr) (Nexthop, error) { if err != nil { return Nexthop{}, fmt.Errorf("new netroute: %w", err) } + + // go-netroute v0.4.0 rejects unspecified destinations on Linux with a hard + // client-side check. Substitute the lowest non-loopback address so the + // lookup falls through to the default route (::1 / 127.0.0.1 would match + // loopback, ::/0.0.0.0 are unspec). BSD/Windows pass the query straight to + // the kernel and need no substitution. + if runtime.GOOS == "linux" && ip.IsUnspecified() { + if ip.Is6() { + // ::2 + ip = netip.AddrFrom16([16]byte{15: 2}) + } else { + // 0.0.0.1 + ip = netip.AddrFrom4([4]byte{0, 0, 0, 1}) + } + } + intf, gateway, preferredSrc, err := r.Route(ip.AsSlice()) if err != nil { log.Debugf("Failed to get route for %s: %v", ip, err) diff --git a/client/internal/routemanager/systemops/systemops_generic_test.go b/client/internal/routemanager/systemops/systemops_generic_test.go index 01916fbe3..08e354a78 100644 --- a/client/internal/routemanager/systemops/systemops_generic_test.go +++ b/client/internal/routemanager/systemops/systemops_generic_test.go @@ -354,9 +354,13 @@ func TestAddRouteToNonVPNIntf(t *testing.T) { require.NoError(t, err, "Should be able to get IPv4 default route") t.Logf("Initial IPv4 next hop: %s", initialNextHopV4) + if testCase.prefix.Addr().Is6() && !testCase.expectError { + ensureIPv6DefaultRoute(t) + } + initialNextHopV6, err := GetNextHop(netip.IPv6Unspecified()) if testCase.prefix.Addr().Is6() && - (errors.Is(err, vars.ErrRouteNotFound) || initialNextHopV6.Intf != nil && strings.HasPrefix(initialNextHopV6.Intf.Name, "utun")) { + initialNextHopV6.Intf != nil && strings.HasPrefix(initialNextHopV6.Intf.Name, "utun") { t.Skip("Skipping test as no ipv6 default route is available") } if err != nil && !errors.Is(err, vars.ErrRouteNotFound) { diff --git a/client/internal/routemanager/systemops/v6route_bsd_test.go b/client/internal/routemanager/systemops/v6route_bsd_test.go new file mode 100644 index 000000000..98ce29c6d --- /dev/null +++ b/client/internal/routemanager/systemops/v6route_bsd_test.go @@ -0,0 +1,30 @@ +//go:build darwin || dragonfly || freebsd || netbsd || openbsd + +package systemops + +import ( + "bytes" + "os/exec" + "testing" +) + +// ensureIPv6DefaultRoute installs an IPv6 default route via the loopback +// interface so route lookups for global IPv6 prefixes resolve in environments +// without v6 connectivity. If a default already exists it is left alone. +func ensureIPv6DefaultRoute(t *testing.T) { + t.Helper() + + out, err := exec.Command("route", "-6", "add", "default", "-iface", "lo0").CombinedOutput() + if err != nil { + // Existing default; nothing to install or clean up. + if bytes.Contains(out, []byte("route already in table")) { + return + } + t.Skipf("install IPv6 fallback default route: %v: %s", err, out) + } + t.Cleanup(func() { + if out, err := exec.Command("route", "-6", "delete", "default").CombinedOutput(); err != nil { + t.Logf("delete IPv6 fallback default route: %v: %s", err, out) + } + }) +} diff --git a/client/internal/routemanager/systemops/v6route_linux_test.go b/client/internal/routemanager/systemops/v6route_linux_test.go new file mode 100644 index 000000000..0b17cefff --- /dev/null +++ b/client/internal/routemanager/systemops/v6route_linux_test.go @@ -0,0 +1,41 @@ +//go:build linux && !android + +package systemops + +import ( + "errors" + "net" + "syscall" + "testing" + + "github.com/stretchr/testify/require" + "github.com/vishvananda/netlink" +) + +// ensureIPv6DefaultRoute installs a low-preference IPv6 default route via the +// loopback interface so route lookups for global IPv6 prefixes resolve in +// environments without v6 connectivity. Any pre-existing default route wins +// because of its lower metric. +func ensureIPv6DefaultRoute(t *testing.T) { + t.Helper() + + lo, err := netlink.LinkByName("lo") + require.NoError(t, err, "find loopback interface") + + route := &netlink.Route{ + Dst: &net.IPNet{IP: net.IPv6zero, Mask: net.CIDRMask(0, 128)}, + LinkIndex: lo.Attrs().Index, + Priority: 1 << 20, + } + if err := netlink.RouteAdd(route); err != nil { + if errors.Is(err, syscall.EEXIST) { + return + } + t.Skipf("install IPv6 fallback default route: %v", err) + } + t.Cleanup(func() { + if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) { + t.Logf("delete IPv6 fallback default route: %v", err) + } + }) +} diff --git a/client/internal/routemanager/systemops/v6route_windows_test.go b/client/internal/routemanager/systemops/v6route_windows_test.go new file mode 100644 index 000000000..f79277b87 --- /dev/null +++ b/client/internal/routemanager/systemops/v6route_windows_test.go @@ -0,0 +1,34 @@ +//go:build windows + +package systemops + +import ( + "bytes" + "os/exec" + "testing" +) + +const loopbackIfaceWindows = "Loopback Pseudo-Interface 1" + +// ensureIPv6DefaultRoute installs an IPv6 default route via the loopback +// interface so route lookups for global IPv6 prefixes resolve in environments +// without v6 connectivity. If a default already exists it is left alone. +func ensureIPv6DefaultRoute(t *testing.T) { + t.Helper() + + script := `New-NetRoute -DestinationPrefix "::/0" -InterfaceAlias "` + loopbackIfaceWindows + `" -RouteMetric 9999 -PolicyStore ActiveStore -ErrorAction Stop` + out, err := exec.Command("powershell", "-Command", script).CombinedOutput() + if err != nil { + // Existing default; nothing to install or clean up. + if bytes.Contains(out, []byte("already exists")) { + return + } + t.Skipf("install IPv6 fallback default route: %v: %s", err, out) + } + t.Cleanup(func() { + script := `Remove-NetRoute -DestinationPrefix "::/0" -InterfaceAlias "` + loopbackIfaceWindows + `" -Confirm:$false -ErrorAction Stop` + if out, err := exec.Command("powershell", "-Command", script).CombinedOutput(); err != nil { + t.Logf("delete IPv6 fallback default route: %v: %s", err, out) + } + }) +} diff --git a/go.mod b/go.mod index e24312a1a..bc4e8af15 100644 --- a/go.mod +++ b/go.mod @@ -17,8 +17,8 @@ require ( github.com/spf13/cobra v1.10.1 github.com/spf13/pflag v1.0.9 github.com/vishvananda/netlink v1.3.1 - golang.org/x/crypto v0.49.0 - golang.org/x/sys v0.42.0 + golang.org/x/crypto v0.50.0 + golang.org/x/sys v0.43.0 golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/windows v0.5.3 @@ -68,7 +68,7 @@ require ( github.com/jackc/pgx/v5 v5.5.5 github.com/libdns/route53 v1.5.0 github.com/libp2p/go-nat v0.2.0 - github.com/libp2p/go-netroute v0.2.1 + github.com/libp2p/go-netroute v0.4.0 github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81 github.com/mdlayher/socket v0.5.1 github.com/mdp/qrterminal/v3 v3.2.1 @@ -118,11 +118,11 @@ require ( goauthentik.io/api/v3 v3.2023051.3 golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b golang.org/x/mobile v0.0.0-20251113184115-a159579294ab - golang.org/x/mod v0.33.0 - golang.org/x/net v0.52.0 + golang.org/x/mod v0.34.0 + golang.org/x/net v0.53.0 golang.org/x/oauth2 v0.36.0 golang.org/x/sync v0.20.0 - golang.org/x/term v0.41.0 + golang.org/x/term v0.42.0 golang.org/x/time v0.15.0 google.golang.org/api v0.276.0 gopkg.in/yaml.v3 v3.0.1 @@ -303,8 +303,8 @@ require ( go.uber.org/multierr v1.11.0 // indirect go.yaml.in/yaml/v2 v2.4.3 // indirect golang.org/x/image v0.33.0 // indirect - golang.org/x/text v0.35.0 // indirect - golang.org/x/tools v0.42.0 // indirect + golang.org/x/text v0.36.0 // indirect + golang.org/x/tools v0.43.0 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 // indirect gopkg.in/square/go-jose.v2 v2.6.0 // indirect @@ -323,8 +323,6 @@ replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-2023080111 replace github.com/pion/ice/v4 => github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 -replace github.com/libp2p/go-netroute => github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 - replace github.com/dexidp/dex => github.com/netbirdio/dex v0.244.0 replace github.com/mailru/easyjson => github.com/netbirdio/easyjson v0.9.0 diff --git a/go.sum b/go.sum index a71f47d8d..d54dc01e6 100644 --- a/go.sum +++ b/go.sum @@ -395,6 +395,8 @@ github.com/libdns/route53 v1.5.0 h1:2SKdpPFl/qgWsXQvsLNJJAoX7rSxlk7zgoL4jnWdXVA= github.com/libdns/route53 v1.5.0/go.mod h1:joT4hKmaTNKHEwb7GmZ65eoDz1whTu7KKYPS8ZqIh6Q= github.com/libp2p/go-nat v0.2.0 h1:Tyz+bUFAYqGyJ/ppPPymMGbIgNRH+WqC5QrT5fKrrGk= github.com/libp2p/go-nat v0.2.0/go.mod h1:3MJr+GRpRkyT65EpVPBstXLvOlAPzUVlG6Pwg9ohLJk= +github.com/libp2p/go-netroute v0.4.0 h1:sZZx9hyANYUx9PZyqcgE/E1GUG3iEtTZHUEvdtXT7/Q= +github.com/libp2p/go-netroute v0.4.0/go.mod h1:Nkd5ShYgSMS5MUKy/MU2T57xFoOKvvLR92Lic48LEyA= github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81 h1:J56rFEfUTFT9j9CiRXhi1r8lUJ4W5idG3CiaBZGojNU= github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81/go.mod h1:RD8ML/YdXctQ7qbcizZkw5mZ6l8Ogrl1dodBzVJduwI= github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I= @@ -453,8 +455,6 @@ github.com/netbirdio/dex v0.244.0 h1:1GOvi8wnXYassnKGildzNqRHq0RbcfEUw7LKYpKIN7U github.com/netbirdio/dex v0.244.0/go.mod h1:STGInJhPcAflrHmDO7vyit2kSq03PdL+8zQPoGALtcU= github.com/netbirdio/easyjson v0.9.0 h1:6Nw2lghSVuy8RSkAYDhDv1thBVEmfVbKZnV7T7Z6Aus= github.com/netbirdio/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= -github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6Sf8uYFx/dMeqNOL90KUoRscdfpFZ3Im89uk= -github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8= github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42 h1:F3zS5fT9xzD1OFLfcdAE+3FfyiwjGukF1hvj0jErgs8= @@ -711,8 +711,8 @@ golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1m golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= -golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4= -golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA= +golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI= +golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q= golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b h1:M2rDM6z3Fhozi9O7NWsxAkg/yqS/lQJ6PmkyIV3YP+o= golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8= golang.org/x/image v0.33.0 h1:LXRZRnv1+zGd5XBUVRFmYEphyyKJjQjCRiOuAP3sZfQ= @@ -729,8 +729,8 @@ golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= -golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8= -golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w= +golang.org/x/mod v0.34.0 h1:xIHgNUUnW6sYkcM5Jleh05DvLOtwc6RitGHbDk4akRI= +golang.org/x/mod v0.34.0/go.mod h1:ykgH52iCZe79kzLLMhyCUzhMci+nQj+0XkbXpNYtVjY= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= @@ -749,8 +749,8 @@ golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= -golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= -golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= +golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA= +golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs= golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE= golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs= golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q= @@ -801,8 +801,8 @@ golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= -golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= +golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= @@ -815,8 +815,8 @@ golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM= -golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU= -golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A= +golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY= +golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= @@ -828,8 +828,8 @@ golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= -golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= -golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= +golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg= +golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164= golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U= golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -843,8 +843,8 @@ golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= -golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k= -golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0= +golang.org/x/tools v0.43.0 h1:12BdW9CeB3Z+J/I/wj34VMl8X+fEXBxVR90JeMX5E7s= +golang.org/x/tools v0.43.0/go.mod h1:uHkMso649BX2cZK6+RpuIPXS3ho2hZo4FVwfoy1vIk0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= From 31395f8bd2b8e2200077217c755500f2d7e9618c Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 5 May 2026 23:18:22 +0900 Subject: [PATCH 11/15] [client] Use fwmark-aware route lookup for raw socket UDP checksum source (#6070) * Use fwmark-aware route lookup for raw socket UDP checksum source * Guard nil raw socket in sharedsock WriteTo --- sharedsock/sock_linux.go | 56 +++++++++++++++------------------------- 1 file changed, 21 insertions(+), 35 deletions(-) diff --git a/sharedsock/sock_linux.go b/sharedsock/sock_linux.go index 523beb32b..4855e1aed 100644 --- a/sharedsock/sock_linux.go +++ b/sharedsock/sock_linux.go @@ -10,15 +10,13 @@ import ( "context" "fmt" "net" - "sync" "time" "github.com/google/gopacket" "github.com/google/gopacket/layers" - "github.com/google/gopacket/routing" - "github.com/libp2p/go-netroute" "github.com/mdlayher/socket" log "github.com/sirupsen/logrus" + "github.com/vishvananda/netlink" "golang.org/x/sync/errgroup" "golang.org/x/sys/unix" @@ -37,8 +35,6 @@ type SharedSocket struct { conn6 *socket.Conn port int mtu uint16 - routerMux sync.RWMutex - router routing.Router packetDemux chan rcvdPacket cancel context.CancelFunc } @@ -82,11 +78,6 @@ func Listen(port int, filter BPFFilter, mtu uint16) (_ net.PacketConn, err error } }() - rawSock.router, err = netroute.New() - if err != nil { - return nil, fmt.Errorf("failed to create raw socket router: %w", err) - } - rawSock.conn4, err = socket.Socket(unix.AF_INET, unix.SOCK_RAW, unix.IPPROTO_UDP, "raw_udp4", nil) if err != nil { return nil, fmt.Errorf("failed to create ipv4 raw socket: %w", err) @@ -127,31 +118,26 @@ func Listen(port int, filter BPFFilter, mtu uint16) (_ net.PacketConn, err error go rawSock.read(rawSock.conn6.Recvfrom) } - go rawSock.updateRouter() - return rawSock, nil } -// updateRouter updates the listener routing table client -// this is needed to avoid outdated information across different client networks -func (s *SharedSocket) updateRouter() { - ticker := time.NewTicker(15 * time.Second) - defer ticker.Stop() - for { - select { - case <-s.ctx.Done(): - return - case <-ticker.C: - router, err := netroute.New() - if err != nil { - log.Errorf("Failed to create and update packet router for stunListener: %s", err) - continue - } - s.routerMux.Lock() - s.router = router - s.routerMux.Unlock() +// resolveSrc returns the source IP the kernel will pick for a packet sent to +// dst by these raw sockets, mirroring the fwmark the kernel will see on send. +func (s *SharedSocket) resolveSrc(dst net.IP) (net.IP, error) { + opts := &netlink.RouteGetOptions{} + if nbnet.AdvancedRouting() { + opts.Mark = nbnet.ControlPlaneMark + } + routes, err := netlink.RouteGetWithOptions(dst, opts) + if err != nil { + return nil, fmt.Errorf("route get %s: %w", dst, err) + } + for _, r := range routes { + if r.Src != nil { + return r.Src, nil } } + return nil, fmt.Errorf("no source IP for %s", dst) } // LocalAddr returns the local address, preferring IPv4 for backward compatibility. @@ -310,15 +296,15 @@ func (s *SharedSocket) WriteTo(buf []byte, rAddr net.Addr) (n int, err error) { DstPort: layers.UDPPort(rUDPAddr.Port), } - s.routerMux.RLock() - defer s.routerMux.RUnlock() - - _, _, src, err := s.router.Route(rUDPAddr.IP) + src, err := s.resolveSrc(rUDPAddr.IP) if err != nil { - return 0, fmt.Errorf("got an error while checking route, err: %w", err) + return 0, fmt.Errorf("resolve source for %s: %w", rUDPAddr.IP, err) } rSockAddr, conn, nwLayer := s.getWriterObjects(src, rUDPAddr.IP) + if conn == nil { + return 0, fmt.Errorf("no raw socket for %s", rUDPAddr.IP) + } if err := udp.SetNetworkLayerForChecksum(nwLayer); err != nil { return -1, fmt.Errorf("failed to set network layer for checksum: %w", err) From 1795bc801d3832af13846a1030ea824eac9b0a5c Mon Sep 17 00:00:00 2001 From: Nicolas Frati Date: Tue, 5 May 2026 16:53:01 +0200 Subject: [PATCH 12/15] chores: updated discussions and issues templates (#6073) --- .../ideas-feature-requests.yml | 130 ++++++++++ .github/DISCUSSION_TEMPLATE/issue-triage.yml | 237 ++++++++++++++++++ .github/DISCUSSION_TEMPLATE/q-a-support.yml | 146 +++++++++++ .github/ISSUE_TEMPLATE/bug-issue-report.md | 71 ------ .github/ISSUE_TEMPLATE/config.yml | 26 +- .github/ISSUE_TEMPLATE/feature_request.md | 20 -- .github/ISSUE_TEMPLATE/validated_issue.yml | 128 ++++++++++ 7 files changed, 660 insertions(+), 98 deletions(-) create mode 100644 .github/DISCUSSION_TEMPLATE/ideas-feature-requests.yml create mode 100644 .github/DISCUSSION_TEMPLATE/issue-triage.yml create mode 100644 .github/DISCUSSION_TEMPLATE/q-a-support.yml delete mode 100644 .github/ISSUE_TEMPLATE/bug-issue-report.md delete mode 100644 .github/ISSUE_TEMPLATE/feature_request.md create mode 100644 .github/ISSUE_TEMPLATE/validated_issue.yml diff --git a/.github/DISCUSSION_TEMPLATE/ideas-feature-requests.yml b/.github/DISCUSSION_TEMPLATE/ideas-feature-requests.yml new file mode 100644 index 000000000..f57a62107 --- /dev/null +++ b/.github/DISCUSSION_TEMPLATE/ideas-feature-requests.yml @@ -0,0 +1,130 @@ +body: + - type: markdown + attributes: + value: | + ## Ideas & Feature Requests + + Use this category for feature requests, enhancements, integrations, and product ideas. + + NetBird uses community traction in discussions — upvotes, replies, affected users, and use-case detail — as an input when deciding what should become a maintainer-curated issue or roadmap item. A clear problem statement is more useful than a solution-only request. + + Please search first and add your use case to an existing discussion when one already exists. + + - type: checkboxes + id: preflight + attributes: + label: Before posting + options: + - label: I searched existing discussions and issues for similar requests. + required: true + - label: I checked the documentation to confirm this is not already supported. + required: true + - label: This is a product idea or enhancement request, not a support question. + required: true + - label: I removed or anonymized sensitive details from examples and screenshots. + required: true + + - type: dropdown + id: area + attributes: + label: Product area + description: Select every area this request touches. + multiple: true + options: + - Client / Agent + - CLI + - Desktop UI + - Mobile app + - Dashboard / Admin UI + - Management service / API + - Signal service + - Relay + - DNS + - Routes / Exit nodes + - NetBird SSH + - Access control policies + - Posture checks + - Identity provider / SSO + - Self-hosting / Deployment + - Kubernetes / Operator + - Terraform / Automation + - Documentation + - Other / not sure + validations: + required: true + + - type: textarea + id: problem + attributes: + label: Problem or use case + description: What are you trying to accomplish, and what is difficult or impossible today? + placeholder: | + As a ... + I want to ... + Because ... + validations: + required: true + + - type: textarea + id: proposal + attributes: + label: Proposed solution + description: Describe the behavior, workflow, API, UI, or integration you would like to see. + validations: + required: true + + - type: textarea + id: alternatives + attributes: + label: Alternatives or workarounds considered + description: What have you tried today? Why is the current workaround not enough? + + - type: textarea + id: impact + attributes: + label: Community impact and priority + description: Help us understand who benefits and how urgent this is. + placeholder: | + - Number of users/teams/peers affected: + - Deployment type: Cloud / self-hosted / both + - Frequency: daily / weekly / occasional + - Blocking production adoption? yes/no + - Related comments, discussions, or customer requests: + validations: + required: true + + - type: textarea + id: examples + attributes: + label: Examples from other tools or products + description: If another tool solves this well, link or describe the behavior. + + - type: textarea + id: security + attributes: + label: Security, privacy, and compatibility considerations + description: Note any access-control, audit, data retention, network, platform, or backward-compatibility concerns. + + - type: textarea + id: implementation + attributes: + label: Implementation ideas + description: Optional. If you are familiar with the codebase or API, share possible implementation notes. + + - type: dropdown + id: contribution + attributes: + label: Are you willing to help? + options: + - Yes, I can submit a PR if the approach is accepted. + - Yes, I can test or validate a proposed implementation. + - Yes, I can provide more use-case details. + - Not at this time. + validations: + required: true + + - type: textarea + id: additional-context + attributes: + label: Additional context + description: Add screenshots, diagrams, links, or anything else that helps explain the request. diff --git a/.github/DISCUSSION_TEMPLATE/issue-triage.yml b/.github/DISCUSSION_TEMPLATE/issue-triage.yml new file mode 100644 index 000000000..b13ec066e --- /dev/null +++ b/.github/DISCUSSION_TEMPLATE/issue-triage.yml @@ -0,0 +1,237 @@ +body: + - type: markdown + attributes: + value: | + ## Issue Triage + + Use this category for reproducible bugs and regressions in NetBird. + + The more context you include, the faster we can validate and act on your report. If you're not sure whether something is a bug, **Q&A / Support** is a good starting point — we can always move the conversation here once we've confirmed it's a product issue. + + Intermittent issues are useful too. Include the trigger, frequency, timing, and any logs or debug evidence you have, and we'll work from there. + + Please don't include secrets, tokens, private keys, internal hostnames, or public IPs. Security vulnerabilities should be reported through the repository security policy rather than a public discussion. + + - type: checkboxes + id: preflight + attributes: + label: Before posting + options: + - label: I searched existing discussions and issues, including closed ones, and checked the relevant docs. + required: true + - label: I believe this is a product bug rather than a configuration or setup question. + required: true + - label: I can reproduce this issue, or for intermittent issues I've included trigger, frequency, and timing details below. + required: true + - label: I removed or anonymized sensitive data from logs, screenshots, and configuration. + required: true + + - type: dropdown + id: area + attributes: + label: Affected area + description: Select every area this report touches. + multiple: true + options: + - Client / Agent + - Reverse Proxy + - CLI + - Desktop UI + - Mobile app + - Peer connectivity + - DNS + - Routes / Exit nodes + - NetBird SSH + - Relay / Signal / NAT traversal + - Login / Authentication / IdP + - Dashboard / Admin UI + - Management service / API + - Access control policies / Posture checks + - Self-hosting / Deployment + - Kubernetes / Operator + - Documentation + - Other / not sure + validations: + required: true + + - type: dropdown + id: deployment + attributes: + label: Deployment type + options: + - NetBird Cloud + - Self-hosted - quickstart script + - Self-hosted - advanced/custom deployment + - Local development build + - Not sure / environment I do not fully control + validations: + required: true + + - type: dropdown + id: platform + attributes: + label: Operating system or environment + description: Select every environment involved in the reproduction. + multiple: true + options: + - Linux + - macOS + - Windows + - Android + - iOS + - FreeBSD + - OpenWRT + - Docker + - Kubernetes + - Synology + - Browser + - Other / not sure + validations: + required: true + + - type: textarea + id: version + attributes: + label: NetBird version and upgrade status + description: Run `netbird version` where applicable. For self-hosted deployments, include management, signal, relay, and dashboard versions if available. If you cannot test on a current/supported version, explain why. + placeholder: | + Example: + - Client: 0.30.2 + - Management: 0.30.2 + - Signal: 0.30.2 + - Relay: 0.30.2 + - Dashboard: 0.30.2 + - Upgrade status: reproduced on current version / cannot upgrade because ... + validations: + required: true + + - type: dropdown + id: regression + attributes: + label: Did this work before? + options: + - Yes, this worked before + - No, this never worked + - Not sure + validations: + required: true + + - type: textarea + id: regression-details + attributes: + label: Regression details + description: If this worked before, include the last known working version, first known broken version, and any recent upgrade, configuration, network, or IdP changes. + placeholder: | + - Last known working version: + - First known broken version: + - Recent changes: + + - type: textarea + id: summary + attributes: + label: Summary + description: Briefly describe the reproducible bug. + placeholder: What is broken? + validations: + required: true + + - type: textarea + id: current-behavior + attributes: + label: Current behavior + description: What happens now? Include exact errors, timeouts, UI messages, or failed commands when possible. + validations: + required: true + + - type: textarea + id: expected-behavior + attributes: + label: Expected behavior + description: What did you expect to happen instead? + validations: + required: true + + - type: textarea + id: reproduction + attributes: + label: Steps to reproduce + description: Provide the smallest set of steps that reliably reproduces the bug. If the issue is intermittent, include the trigger, frequency, timing, and relevant timestamps. + placeholder: | + 1. Configure ... + 2. Run ... + 3. Observe ... + + For intermittent issues: + - Trigger: + - Frequency: + - Timing/timestamps: + validations: + required: true + + - type: textarea + id: environment + attributes: + label: Environment and topology + description: Include the relevant topology and software involved in the reproduction. For UI/docs-only reports, write `N/A` if this does not apply. Use `None`, `Unknown`, or `N/A` where appropriate. + placeholder: | + - Peer A: + - Peer B: + - Same LAN or different networks: + - NAT/CGNAT/corporate firewall/mobile network: + - Other VPN software: + - Firewall, DNS, or endpoint security software: + - Routes, DNS, policies, posture checks, or SSH rules involved: + - IdP, reverse proxy, or browser involved: + validations: + required: true + + - type: textarea + id: self-hosted-details + attributes: + label: Self-hosted details, if available + description: Optional. If you use self-hosting and have access to these details, include them. If you do not administer the environment, provide what you know and say what you cannot access. + placeholder: | + - Deployment method: quickstart / Docker Compose / Helm / operator / custom + - Management/signal/relay/dashboard versions: + - Reverse proxy: + - IdP/provider: + - STUN/TURN/coturn/relay details: + - Relevant component logs: + + - type: textarea + id: logs + attributes: + label: Logs, status output, or debug evidence + description: | + For client, connectivity, DNS, route, relay/signal, or self-hosted reports, logs are essential — please include anonymized output from `netbird status -dA`, or a debug bundle via `netbird debug for 1m -AS -U`. Debug bundles are automatically deleted after 30 days. + + For UI, dashboard, or documentation reports, leave the pre-filled `N/A`. + value: "N/A" + render: shell + validations: + required: true + + - type: textarea + id: related-reports + attributes: + label: Related issues or discussions + description: Optional. Link similar reports you found while searching, if any. + placeholder: | + - Related issue/discussion: + - Why this may be the same or different: + + - type: textarea + id: impact + attributes: + label: Impact + description: Optional. Help us understand priority. How many users, peers, environments, or workflows are affected? Is there a workaround? + placeholder: | + - Affected users/peers: + - Business or production impact: + - Workaround available: + + - type: textarea + id: additional-context + attributes: + label: Additional context + description: Add links to related discussions, issues, docs, screenshots, recordings, or anything else that may help validation. diff --git a/.github/DISCUSSION_TEMPLATE/q-a-support.yml b/.github/DISCUSSION_TEMPLATE/q-a-support.yml new file mode 100644 index 000000000..725f8737c --- /dev/null +++ b/.github/DISCUSSION_TEMPLATE/q-a-support.yml @@ -0,0 +1,146 @@ +body: + - type: markdown + attributes: + value: | + ## Q&A / Support + + Use this category for questions about configuration, setup, self-hosted deployments, troubleshooting, and general NetBird usage. + + This is community support and does not provide an SLA. For NetBird Cloud support, use the official support channel linked from the issue creation page. Please do not post secrets, tokens, private keys, internal hostnames, or public IPs unless you intentionally want them public. + + If your question turns into a reproducible product defect, DevRel or a maintainer may ask you to open or move the conversation to Issue Triage. + + - type: checkboxes + id: preflight + attributes: + label: Before posting + options: + - label: I searched existing discussions and issues for similar questions. + required: true + - label: I reviewed the relevant NetBird documentation or troubleshooting guide. + required: true + - label: I removed or anonymized sensitive data from logs, screenshots, and configuration. + required: true + + - type: dropdown + id: topic + attributes: + label: Topic + multiple: true + options: + - Getting started + - Self-hosting + - Client / Agent + - CLI + - Desktop UI + - Mobile app + - Dashboard / Admin UI + - DNS + - Routes / Exit nodes + - NetBird SSH + - Relay + - Access control policies + - Posture checks + - Identity provider / SSO + - API + - Kubernetes / Operator + - Terraform / Automation + - Documentation + - Other / not sure + validations: + required: true + + - type: dropdown + id: deployment + attributes: + label: Deployment type + options: + - NetBird Cloud + - Self-hosted - quickstart script + - Self-hosted - advanced/custom deployment + - Local development build + - Not sure + validations: + required: true + + - type: dropdown + id: platform + attributes: + label: Operating system or environment + multiple: true + options: + - Linux + - macOS + - Windows + - Android + - iOS + - FreeBSD + - OpenWRT + - Docker + - Kubernetes + - Synology + - Browser + - Other / not sure + validations: + required: true + + - type: input + id: version + attributes: + label: NetBird version + description: Run `netbird version` where applicable. For self-hosted deployments, include component versions if relevant. + placeholder: "Example: client 0.30.2, management 0.30.2" + + - type: textarea + id: question + attributes: + label: Question + description: What are you trying to understand or accomplish? + placeholder: Describe your question clearly. + validations: + required: true + + - type: textarea + id: goal + attributes: + label: Desired outcome + description: What would a successful answer help you do? + placeholder: | + I want to configure ... + I expected ... + I need help deciding ... + + - type: textarea + id: attempted + attributes: + label: What have you tried? + description: Include commands, documentation links, configuration attempts, or troubleshooting steps already tried. + placeholder: | + - Read ... + - Ran ... + - Changed ... + - Observed ... + + - type: textarea + id: environment + attributes: + label: Relevant environment details + description: Include redacted topology, IdP/provider, reverse proxy, firewall, DNS, route, policy, or self-hosted setup details that may affect the answer. + placeholder: | + - Deployment: + - Components involved: + - Network/topology: + - Related config: + + - type: textarea + id: logs + attributes: + label: Logs or output + description: Optional. Include anonymized logs, command output, screenshots, or `netbird status -dA` if relevant. + render: shell + + - type: textarea + id: additional-context + attributes: + label: Additional context + description: Add links, diagrams, screenshots, or other details that may help the community answer. diff --git a/.github/ISSUE_TEMPLATE/bug-issue-report.md b/.github/ISSUE_TEMPLATE/bug-issue-report.md deleted file mode 100644 index df670db06..000000000 --- a/.github/ISSUE_TEMPLATE/bug-issue-report.md +++ /dev/null @@ -1,71 +0,0 @@ ---- -name: Bug/Issue report -about: Create a report to help us improve -title: '' -labels: ['triage-needed'] -assignees: '' - ---- - -**Describe the problem** - -A clear and concise description of what the problem is. - -**To Reproduce** - -Steps to reproduce the behavior: -1. Go to '...' -2. Click on '....' -3. Scroll down to '....' -4. See error - -**Expected behavior** - -A clear and concise description of what you expected to happen. - -**Are you using NetBird Cloud?** - -Please specify whether you use NetBird Cloud or self-host NetBird's control plane. - -**NetBird version** - -`netbird version` - -**Is any other VPN software installed?** - -If yes, which one? - -**Debug output** - -To help us resolve the problem, please attach the following anonymized status output - - netbird status -dA - -Create and upload a debug bundle, and share the returned file key: - - netbird debug for 1m -AS -U - -*Uploaded files are automatically deleted after 30 days.* - - -Alternatively, create the file only and attach it here manually: - - netbird debug for 1m -AS - - -**Screenshots** - -If applicable, add screenshots to help explain your problem. - -**Additional context** - -Add any other context about the problem here. - -**Have you tried these troubleshooting steps?** -- [ ] Reviewed [client troubleshooting](https://docs.netbird.io/how-to/troubleshooting-client) (if applicable) -- [ ] Checked for newer NetBird versions -- [ ] Searched for similar issues on GitHub (including closed ones) -- [ ] Restarted the NetBird client -- [ ] Disabled other VPN software -- [ ] Checked firewall settings - diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml index e9ffaf8a3..ee3e84df6 100644 --- a/.github/ISSUE_TEMPLATE/config.yml +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -1,14 +1,26 @@ -blank_issues_enabled: true +blank_issues_enabled: false contact_links: - - name: Community Support + - name: Start an Issue Triage discussion + url: https://github.com/netbirdio/netbird/discussions/new?category=issue-triage + about: Report a bug, regression, or unexpected behavior so DevRel can validate it before it becomes an issue. + - name: Propose an idea or feature request + url: https://github.com/netbirdio/netbird/discussions/new?category=ideas-feature-requests + about: Share feature requests, enhancements, and integration ideas for community feedback and prioritization. + - name: Ask a Q&A / Support question + url: https://github.com/netbirdio/netbird/discussions/new?category=q-a-support + about: Get help with setup, configuration, self-hosting, troubleshooting, and general usage. + - name: Security vulnerability disclosure + url: https://github.com/netbirdio/netbird/security/policy + about: Please do not report security vulnerabilities in public issues or discussions. + - name: Community Support Forum url: https://forum.netbird.io/ - about: Community support forum + about: Community support forum. - name: Cloud Support url: https://docs.netbird.io/help/report-bug-issues - about: Contact us for support - - name: Client/Connection Troubleshooting + about: Contact NetBird for Cloud support. + - name: Client / Connection Troubleshooting url: https://docs.netbird.io/help/troubleshooting-client - about: See our client troubleshooting guide for help addressing common issues + about: See the client troubleshooting guide for common connectivity issues. - name: Self-host Troubleshooting url: https://docs.netbird.io/selfhosted/troubleshooting - about: See our self-host troubleshooting guide for help addressing common issues + about: See the self-host troubleshooting guide for common deployment issues. diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md deleted file mode 100644 index 4a3e5782c..000000000 --- a/.github/ISSUE_TEMPLATE/feature_request.md +++ /dev/null @@ -1,20 +0,0 @@ ---- -name: Feature request -about: Suggest an idea for this project -title: '' -labels: ['feature-request'] -assignees: '' - ---- - -**Is your feature request related to a problem? Please describe.** -A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] - -**Describe the solution you'd like** -A clear and concise description of what you want to happen. - -**Describe alternatives you've considered** -A clear and concise description of any alternative solutions or features you've considered. - -**Additional context** -Add any other context or screenshots about the feature request here. diff --git a/.github/ISSUE_TEMPLATE/validated_issue.yml b/.github/ISSUE_TEMPLATE/validated_issue.yml new file mode 100644 index 000000000..2a21b73b2 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/validated_issue.yml @@ -0,0 +1,128 @@ +name: Validated issue +description: Maintainer/DevRel only. Create an issue after a discussion has been validated or for internally validated work. +title: "[Validated]: " +body: + - type: markdown + attributes: + value: | + ## Discussion-first issue policy + + Issues are maintainer-curated work items. Community reports and feature requests should start in [Discussions](https://github.com/netbirdio/netbird/discussions) so DevRel can validate, reproduce, and route them before engineering time is committed. + + Use this form when: + - A discussion has been validated and should become actionable work. + - A maintainer is opening internally validated work that can bypass the discussion-first flow. + + Issues opened without a relevant validated discussion or maintainer context may be closed and redirected to Discussions. + + - type: checkboxes + id: validation-checks + attributes: + label: Validation checklist + options: + - label: This issue is linked to a validated discussion, or it is being opened directly by a maintainer. + required: true + - label: The report has enough context for engineering to act on it without re-triaging from scratch. + required: true + - label: Sensitive data, secrets, private keys, internal hostnames, and public IPs have been removed or intentionally disclosed. + required: true + + - type: dropdown + id: issue-type + attributes: + label: Issue type + options: + - Bug / Regression + - Feature / Enhancement + - Documentation + - Maintenance / Refactor + - Cross-repository coordination + - Other + validations: + required: true + + - type: input + id: source-discussion + attributes: + label: Source discussion + description: Link the GitHub Discussion that was validated. Maintainers bypassing the flow can write "Maintainer-created" and explain why below. + placeholder: https://github.com/netbirdio/netbird/discussions/1234 + validations: + required: true + + - type: input + id: validation-owner + attributes: + label: Validation owner + description: GitHub handle of the DevRel team member or maintainer who validated this work. + placeholder: "@username" + validations: + required: true + + - type: dropdown + id: target-repository + attributes: + label: Target repository + description: Where should the implementation work happen? + options: + - netbirdio/netbird + - netbirdio/dashboard + - netbirdio/kubernetes-operator + - netbirdio/docs + - Multiple repositories + - Unknown / needs routing + validations: + required: true + + - type: textarea + id: summary + attributes: + label: Summary + description: Concise description of the validated work. + placeholder: What needs to be fixed, changed, documented, or built? + validations: + required: true + + - type: textarea + id: evidence + attributes: + label: Validation evidence + description: For bugs, include reproduction status, affected versions, logs, and environment. For features, include community traction, affected users, and alignment notes. + placeholder: | + - Reproduced by: + - Affected versions / platforms: + - Community signal: + - Related logs or screenshots: + validations: + required: true + + - type: textarea + id: scope + attributes: + label: Proposed scope + description: Describe what is in scope and, if helpful, what is explicitly out of scope. + placeholder: | + In scope: + - ... + + Out of scope: + - ... + validations: + required: true + + - type: textarea + id: acceptance-criteria + attributes: + label: Acceptance criteria + description: What must be true for this issue to be closed? + placeholder: | + - [ ] ... + - [ ] ... + validations: + required: true + + - type: textarea + id: additional-context + attributes: + label: Additional context + description: Links to related PRs, docs, issues in other repositories, roadmap items, or implementation notes. From 3c28d297252e59ebbbdc00eaa6cb0f880807ba68 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Tue, 5 May 2026 18:12:18 +0300 Subject: [PATCH 13/15] [management] Map Entra oid claim as Dex user ID (#6067) --- idp/dex/connector.go | 62 ++++++++---- idp/dex/connector_test.go | 205 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 246 insertions(+), 21 deletions(-) create mode 100644 idp/dex/connector_test.go diff --git a/idp/dex/connector.go b/idp/dex/connector.go index 8aba92999..fb20fdcc3 100644 --- a/idp/dex/connector.go +++ b/idp/dex/connector.go @@ -89,21 +89,33 @@ func (p *Provider) ListConnectors(ctx context.Context) ([]*ConnectorConfig, erro } // UpdateConnector updates an existing connector in Dex storage. -// It merges incoming updates with existing values to prevent data loss on partial updates. +// It overlays user-mutable config fields (issuer, clientID, clientSecret, +// redirectURI) onto the stored connector config, and updates the connector name +// when cfg.Name is set. Empty fields on cfg leave stored values unchanged, so +// partial updates preserve create-time defaults such as scopes, claimMapping, +// and userIDKey. func (p *Provider) UpdateConnector(ctx context.Context, cfg *ConnectorConfig) error { if err := p.storage.UpdateConnector(ctx, cfg.ID, func(old storage.Connector) (storage.Connector, error) { - oldCfg, err := p.parseStorageConnector(old) - if err != nil { - return storage.Connector{}, fmt.Errorf("failed to parse existing connector: %w", err) + if cfg.Type != "" && cfg.Type != inferIdentityProviderType(old.Type, cfg.ID, nil) { + return storage.Connector{}, errors.New("connector type change not allowed") } - mergeConnectorConfig(cfg, oldCfg) - - storageConn, err := p.buildStorageConnector(cfg) + configData, err := overlayConnectorConfig(old.Config, cfg) if err != nil { - return storage.Connector{}, fmt.Errorf("failed to build connector: %w", err) + return storage.Connector{}, fmt.Errorf("failed to overlay connector config: %w", err) } - return storageConn, nil + + name := cfg.Name + if name == "" { + name = old.Name + } + + return storage.Connector{ + ID: cfg.ID, + Type: old.Type, + Name: name, + Config: configData, + }, nil }); err != nil { return fmt.Errorf("failed to update connector: %w", err) } @@ -112,23 +124,27 @@ func (p *Provider) UpdateConnector(ctx context.Context, cfg *ConnectorConfig) er return nil } -// mergeConnectorConfig preserves existing values for empty fields in the update. -func mergeConnectorConfig(cfg, oldCfg *ConnectorConfig) { - if cfg.ClientSecret == "" { - cfg.ClientSecret = oldCfg.ClientSecret +// overlayConnectorConfig writes only the user-mutable fields onto the existing +// stored config, preserving every other field (scopes, claimMapping, userIDKey, +// insecure flags, etc.). Empty fields on cfg leave the existing value alone. +func overlayConnectorConfig(oldConfig []byte, cfg *ConnectorConfig) ([]byte, error) { + var m map[string]any + if err := decodeConnectorConfig(oldConfig, &m); err != nil { + return nil, err } - if cfg.RedirectURI == "" { - cfg.RedirectURI = oldCfg.RedirectURI + if cfg.Issuer != "" { + m["issuer"] = cfg.Issuer } - if cfg.Issuer == "" && cfg.Type == oldCfg.Type { - cfg.Issuer = oldCfg.Issuer + if cfg.ClientID != "" { + m["clientID"] = cfg.ClientID } - if cfg.ClientID == "" { - cfg.ClientID = oldCfg.ClientID + if cfg.ClientSecret != "" { + m["clientSecret"] = cfg.ClientSecret } - if cfg.Name == "" { - cfg.Name = oldCfg.Name + if cfg.RedirectURI != "" { + m["redirectURI"] = cfg.RedirectURI } + return encodeConnectorConfig(m) } // DeleteConnector removes a connector from Dex storage. @@ -216,6 +232,10 @@ func buildOIDCConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte, oidcConfig["getUserInfo"] = true case "entra": oidcConfig["claimMapping"] = map[string]string{"email": "preferred_username"} + // Use the Entra Object ID (oid) instead of the default OIDC sub claim. + // Entra issues sub as a per-app pairwise identifier that does not match + // the stable Object ID. + oidcConfig["userIDKey"] = "oid" case "okta": oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"} case "pocketid": diff --git a/idp/dex/connector_test.go b/idp/dex/connector_test.go new file mode 100644 index 000000000..4253e02b7 --- /dev/null +++ b/idp/dex/connector_test.go @@ -0,0 +1,205 @@ +package dex + +import ( + "context" + "encoding/json" + "log/slog" + "os" + "path/filepath" + "testing" + + "github.com/dexidp/dex/storage" + "github.com/dexidp/dex/storage/sql" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func newTestProvider(t *testing.T) (*Provider, func()) { + t.Helper() + tmpDir, err := os.MkdirTemp("", "dex-connector-test-*") + require.NoError(t, err) + + logger := slog.New(slog.NewTextHandler(os.Stderr, nil)) + s, err := (&sql.SQLite3{File: filepath.Join(tmpDir, "dex.db")}).Open(logger) + require.NoError(t, err) + + return &Provider{storage: s, logger: logger}, func() { + _ = s.Close() + _ = os.RemoveAll(tmpDir) + } +} + +func TestBuildOIDCConnectorConfig_EntraSetsUserIDKey(t *testing.T) { + cfg := &ConnectorConfig{ + ID: "entra-test", + Name: "Entra", + Type: "entra", + Issuer: "https://login.microsoftonline.com/tid/v2.0", + ClientID: "client-id", + ClientSecret: "client-secret", + } + data, err := buildOIDCConnectorConfig(cfg, "https://example.com/oauth2/callback") + require.NoError(t, err) + + var m map[string]any + require.NoError(t, json.Unmarshal(data, &m)) + + assert.Equal(t, "oid", m["userIDKey"], "entra connectors must default userIDKey to oid") + assert.Equal(t, map[string]any{"email": "preferred_username"}, m["claimMapping"]) +} + +func TestBuildOIDCConnectorConfig_NonEntraDoesNotSetUserIDKey(t *testing.T) { + // ensures the Entra userIDKey override does not leak into other OIDC providers, + // which already use a stable sub claim. + for _, typ := range []string{"oidc", "zitadel", "okta", "pocketid", "authentik", "keycloak", "adfs"} { + t.Run(typ, func(t *testing.T) { + data, err := buildOIDCConnectorConfig(&ConnectorConfig{Type: typ}, "https://example.com/oauth2/callback") + require.NoError(t, err) + var m map[string]any + require.NoError(t, json.Unmarshal(data, &m)) + _, ok := m["userIDKey"] + assert.False(t, ok, "%s connectors must not have userIDKey set", typ) + }) + } +} + +func TestUpdateConnector_PreservesCreateTimeDefaults(t *testing.T) { + ctx := context.Background() + p, cleanup := newTestProvider(t) + defer cleanup() + + created, err := p.CreateConnector(ctx, &ConnectorConfig{ + ID: "entra-test", + Name: "Entra", + Type: "entra", + Issuer: "https://login.microsoftonline.com/tid/v2.0", + ClientID: "client-id", + ClientSecret: "old-secret", + RedirectURI: "https://example.com/oauth2/callback", + }) + require.NoError(t, err) + require.Equal(t, "entra-test", created.ID) + + // Rotate only the client secret. + err = p.UpdateConnector(ctx, &ConnectorConfig{ + ID: "entra-test", + Type: "entra", + ClientSecret: "new-secret", + }) + require.NoError(t, err) + + conn, err := p.storage.GetConnector(ctx, "entra-test") + require.NoError(t, err) + var m map[string]any + require.NoError(t, json.Unmarshal(conn.Config, &m)) + + assert.Equal(t, "new-secret", m["clientSecret"], "clientSecret should be rotated") + assert.Equal(t, "client-id", m["clientID"], "clientID must survive (overlay should leave it alone)") + assert.Equal(t, "https://login.microsoftonline.com/tid/v2.0", m["issuer"]) + assert.Equal(t, "oid", m["userIDKey"], "userIDKey must survive update") + assert.Equal(t, map[string]any{"email": "preferred_username"}, m["claimMapping"], "claimMapping must survive update") +} + +func TestUpdateConnector_DoesNotAddUserIDKeyToExistingConnector(t *testing.T) { + ctx := context.Background() + p, cleanup := newTestProvider(t) + defer cleanup() + + // Seed a connector directly into storage without userIDKey + preFixConfig, err := json.Marshal(map[string]any{ + "issuer": "https://login.microsoftonline.com/tid/v2.0", + "clientID": "client-id", + "clientSecret": "old-secret", + "redirectURI": "https://example.com/oauth2/callback", + "scopes": []string{"openid", "profile", "email"}, + "claimMapping": map[string]string{"email": "preferred_username"}, + }) + require.NoError(t, err) + + require.NoError(t, p.storage.CreateConnector(ctx, storage.Connector{ + ID: "entra-prefix", + Type: "oidc", + Name: "Entra", + Config: preFixConfig, + })) + + // Rotate client secret via UpdateConnector. + err = p.UpdateConnector(ctx, &ConnectorConfig{ + ID: "entra-prefix", + Type: "entra", + ClientSecret: "new-secret", + }) + require.NoError(t, err) + + conn, err := p.storage.GetConnector(ctx, "entra-prefix") + require.NoError(t, err) + var m map[string]any + require.NoError(t, json.Unmarshal(conn.Config, &m)) + + assert.Equal(t, "new-secret", m["clientSecret"]) + _, has := m["userIDKey"] + assert.False(t, has, "userIDKey must not be auto-added to a connector that did not have it before") +} + +func TestUpdateConnector_RejectsTypeChange(t *testing.T) { + ctx := context.Background() + p, cleanup := newTestProvider(t) + defer cleanup() + + _, err := p.CreateConnector(ctx, &ConnectorConfig{ + ID: "entra-test", + Name: "Entra", + Type: "entra", + Issuer: "https://login.microsoftonline.com/tid/v2.0", + ClientID: "client-id", + ClientSecret: "secret", + RedirectURI: "https://example.com/oauth2/callback", + }) + require.NoError(t, err) + + // Attempt to switch the connector to okta. + err = p.UpdateConnector(ctx, &ConnectorConfig{ + ID: "entra-test", + Type: "okta", + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "connector type change not allowed") + + // stored connector type/config unchanged after the rejected update. + conn, err := p.storage.GetConnector(ctx, "entra-test") + require.NoError(t, err) + assert.Equal(t, "oidc", conn.Type) + var m map[string]any + require.NoError(t, json.Unmarshal(conn.Config, &m)) + assert.Equal(t, "oid", m["userIDKey"]) +} + +func TestUpdateConnector_AllowsSameTypeUpdate(t *testing.T) { + ctx := context.Background() + p, cleanup := newTestProvider(t) + defer cleanup() + + _, err := p.CreateConnector(ctx, &ConnectorConfig{ + ID: "entra-test", + Name: "Entra", + Type: "entra", + Issuer: "https://login.microsoftonline.com/old/v2.0", + ClientID: "client-id", + ClientSecret: "secret", + RedirectURI: "https://example.com/oauth2/callback", + }) + require.NoError(t, err) + + err = p.UpdateConnector(ctx, &ConnectorConfig{ + ID: "entra-test", + Type: "entra", + Issuer: "https://login.microsoftonline.com/new/v2.0", + }) + require.NoError(t, err) + + conn, err := p.storage.GetConnector(ctx, "entra-test") + require.NoError(t, err) + var m map[string]any + require.NoError(t, json.Unmarshal(conn.Config, &m)) + assert.Equal(t, "https://login.microsoftonline.com/new/v2.0", m["issuer"]) +} From cfb1b3fe31c37db79a67434ab620ddb0eca41faf Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Tue, 5 May 2026 18:40:42 +0200 Subject: [PATCH 14/15] [proxy] consolidate mapping update (#6072) --- management/internals/shared/grpc/proxy.go | 118 ++++++--- .../shared/grpc/proxy_snapshot_test.go | 174 ++++++++++++++ .../internals/shared/grpc/proxy_test.go | 3 + proxy/management_integration_test.go | 50 ++-- proxy/server.go | 45 +++- proxy/snapshot_reconcile_test.go | 227 ++++++++++++++++++ 6 files changed, 559 insertions(+), 58 deletions(-) create mode 100644 management/internals/shared/grpc/proxy_snapshot_test.go create mode 100644 proxy/snapshot_reconcile_test.go diff --git a/management/internals/shared/grpc/proxy.go b/management/internals/shared/grpc/proxy.go index d811a0f69..6763a3ba3 100644 --- a/management/internals/shared/grpc/proxy.go +++ b/management/internals/shared/grpc/proxy.go @@ -11,6 +11,8 @@ import ( "fmt" "net/http" "net/url" + "os" + "strconv" "strings" "sync" "time" @@ -82,11 +84,40 @@ type ProxyServiceServer struct { // Store for PKCE verifiers pkceVerifierStore *PKCEVerifierStore + // tokenTTL is the lifetime of one-time tokens generated for proxy + // authentication. Defaults to defaultProxyTokenTTL when zero. + tokenTTL time.Duration + + // snapshotBatchSize is the number of mappings per gRPC message during + // initial snapshot delivery. Configurable via NB_PROXY_SNAPSHOT_BATCH_SIZE. + snapshotBatchSize int + cancel context.CancelFunc } const pkceVerifierTTL = 10 * time.Minute +const defaultProxyTokenTTL = 5 * time.Minute + +const defaultSnapshotBatchSize = 500 + +func snapshotBatchSizeFromEnv() int { + if v := os.Getenv("NB_PROXY_SNAPSHOT_BATCH_SIZE"); v != "" { + if n, err := strconv.Atoi(v); err == nil && n > 0 { + return n + } + } + return defaultSnapshotBatchSize +} + +// proxyTokenTTL returns the configured token TTL or the default when unset. +func (s *ProxyServiceServer) proxyTokenTTL() time.Duration { + if s.tokenTTL > 0 { + return s.tokenTTL + } + return defaultProxyTokenTTL +} + // proxyConnection represents a connected proxy type proxyConnection struct { proxyID string @@ -110,6 +141,7 @@ func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeT peersManager: peersManager, usersManager: usersManager, proxyManager: proxyMgr, + snapshotBatchSize: snapshotBatchSizeFromEnv(), cancel: cancel, } go s.cleanupStaleProxies(ctx) @@ -192,11 +224,6 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest cancel: cancel, } - s.connectedProxies.Store(proxyID, conn) - if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil { - log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err) - } - // Register proxy in database with capabilities var caps *proxy.Capabilities if c := req.GetCapabilities(); c != nil { @@ -209,13 +236,31 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest proxyRecord, err := s.proxyManager.Connect(ctx, proxyID, sessionID, proxyAddress, peerInfo, caps) if err != nil { log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err) - s.connectedProxies.CompareAndDelete(proxyID, conn) - if unregErr := s.proxyController.UnregisterProxyFromCluster(ctx, conn.address, proxyID); unregErr != nil { - log.WithContext(ctx).Debugf("cleanup after Connect failure for proxy %s: %v", proxyID, unregErr) - } + cancel() return status.Errorf(codes.Internal, "register proxy in database: %v", err) } + s.connectedProxies.Store(proxyID, conn) + if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil { + log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err) + } + + if err := s.sendSnapshot(ctx, conn); err != nil { + if s.connectedProxies.CompareAndDelete(proxyID, conn) { + if unregErr := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); unregErr != nil { + log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", proxyID, unregErr) + } + } + cancel() + if disconnErr := s.proxyManager.Disconnect(context.Background(), proxyID, sessionID); disconnErr != nil { + log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", proxyID, disconnErr) + } + return fmt.Errorf("send snapshot to proxy %s: %w", proxyID, err) + } + + errChan := make(chan error, 2) + go s.sender(conn, errChan) + log.WithFields(log.Fields{ "proxy_id": proxyID, "session_id": sessionID, @@ -241,13 +286,6 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest log.Infof("Proxy %s session %s disconnected", proxyID, sessionID) }() - if err := s.sendSnapshot(ctx, conn); err != nil { - return fmt.Errorf("send snapshot to proxy %s: %w", proxyID, err) - } - - errChan := make(chan error, 2) - go s.sender(conn, errChan) - go s.heartbeat(connCtx, proxyRecord) select { @@ -290,22 +328,27 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec return err } + // Send mappings in batches to reduce per-message gRPC overhead while + // staying well within the default 4 MB message size limit. + for i := 0; i < len(mappings); i += s.snapshotBatchSize { + end := i + s.snapshotBatchSize + if end > len(mappings) { + end = len(mappings) + } + if err := conn.stream.Send(&proto.GetMappingUpdateResponse{ + Mapping: mappings[i:end], + InitialSyncComplete: end == len(mappings), + }); err != nil { + return fmt.Errorf("send snapshot batch: %w", err) + } + } + if len(mappings) == 0 { if err := conn.stream.Send(&proto.GetMappingUpdateResponse{ InitialSyncComplete: true, }); err != nil { return fmt.Errorf("send snapshot completion: %w", err) } - return nil - } - - for i, m := range mappings { - if err := conn.stream.Send(&proto.GetMappingUpdateResponse{ - Mapping: []*proto.ProxyMapping{m}, - InitialSyncComplete: i == len(mappings)-1, - }); err != nil { - return fmt.Errorf("send proxy mapping: %w", err) - } } return nil @@ -323,13 +366,9 @@ func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn * continue } - token, err := s.tokenStore.GenerateToken(service.AccountID, service.ID, 5*time.Minute) + token, err := s.tokenStore.GenerateToken(service.AccountID, service.ID, s.proxyTokenTTL()) if err != nil { - log.WithFields(log.Fields{ - "service": service.Name, - "account": service.AccountID, - }).WithError(err).Error("failed to generate auth token for snapshot") - continue + return nil, fmt.Errorf("generate auth token for service %s: %w", service.ID, err) } m := service.ToProtoMapping(rpservice.Create, token, s.GetOIDCValidationConfig()) @@ -409,13 +448,16 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateRes conn := value.(*proxyConnection) resp := s.perProxyMessage(update, conn.proxyID) if resp == nil { + log.Warnf("Token generation failed for proxy %s, disconnecting to force resync", conn.proxyID) + conn.cancel() return true } select { case conn.sendChan <- resp: log.Debugf("Sent service update to proxy server %s", conn.proxyID) default: - log.Warnf("Failed to send service update to proxy server %s (channel full)", conn.proxyID) + log.Warnf("Send channel full for proxy %s, disconnecting to force resync", conn.proxyID) + conn.cancel() } return true }) @@ -495,13 +537,16 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd } msg := s.perProxyMessage(updateResponse, proxyID) if msg == nil { + log.WithContext(ctx).Warnf("Token generation failed for proxy %s in cluster %s, disconnecting to force resync", proxyID, clusterAddr) + conn.cancel() continue } select { case conn.sendChan <- msg: log.WithContext(ctx).Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr) default: - log.WithContext(ctx).Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr) + log.WithContext(ctx).Warnf("Send channel full for proxy %s in cluster %s, disconnecting to force resync", proxyID, clusterAddr) + conn.cancel() } } } @@ -527,7 +572,8 @@ func proxyAcceptsMapping(conn *proxyConnection, mapping *proto.ProxyMapping) boo // perProxyMessage returns a copy of update with a fresh one-time token for // create/update operations. For delete operations the original mapping is // used unchanged because proxies do not need to authenticate for removal. -// Returns nil if token generation fails (the proxy should be skipped). +// Returns nil if token generation fails; the caller must disconnect the +// proxy so it can resync via a fresh snapshot on reconnect. func (s *ProxyServiceServer) perProxyMessage(update *proto.GetMappingUpdateResponse, proxyID string) *proto.GetMappingUpdateResponse { resp := make([]*proto.ProxyMapping, 0, len(update.Mapping)) for _, mapping := range update.Mapping { @@ -536,7 +582,7 @@ func (s *ProxyServiceServer) perProxyMessage(update *proto.GetMappingUpdateRespo continue } - token, err := s.tokenStore.GenerateToken(mapping.AccountId, mapping.Id, 5*time.Minute) + token, err := s.tokenStore.GenerateToken(mapping.AccountId, mapping.Id, s.proxyTokenTTL()) if err != nil { log.Warnf("Failed to generate token for proxy %s: %v", proxyID, err) return nil diff --git a/management/internals/shared/grpc/proxy_snapshot_test.go b/management/internals/shared/grpc/proxy_snapshot_test.go new file mode 100644 index 000000000..e0c7425c5 --- /dev/null +++ b/management/internals/shared/grpc/proxy_snapshot_test.go @@ -0,0 +1,174 @@ +package grpc + +import ( + "context" + "fmt" + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" + + rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" + "github.com/netbirdio/netbird/shared/management/proto" +) + +// recordingStream captures all messages sent via Send so tests can inspect +// batching behaviour without a real gRPC transport. +type recordingStream struct { + grpc.ServerStream + messages []*proto.GetMappingUpdateResponse +} + +func (s *recordingStream) Send(m *proto.GetMappingUpdateResponse) error { + s.messages = append(s.messages, m) + return nil +} + +func (s *recordingStream) Context() context.Context { return context.Background() } +func (s *recordingStream) SetHeader(metadata.MD) error { return nil } +func (s *recordingStream) SendHeader(metadata.MD) error { return nil } +func (s *recordingStream) SetTrailer(metadata.MD) {} +func (s *recordingStream) SendMsg(any) error { return nil } +func (s *recordingStream) RecvMsg(any) error { return nil } + +// makeServices creates n enabled services assigned to the given cluster. +func makeServices(n int, cluster string) []*rpservice.Service { + services := make([]*rpservice.Service, n) + for i := range n { + services[i] = &rpservice.Service{ + ID: fmt.Sprintf("svc-%d", i), + AccountID: "acct-1", + Name: fmt.Sprintf("svc-%d", i), + Domain: fmt.Sprintf("svc-%d.example.com", i), + ProxyCluster: cluster, + Enabled: true, + Targets: []*rpservice.Target{ + {TargetType: rpservice.TargetTypeHost, TargetId: "host-1"}, + }, + } + } + return services +} + +func newSnapshotTestServer(t *testing.T, batchSize int) *ProxyServiceServer { + t.Helper() + s := &ProxyServiceServer{ + tokenStore: NewOneTimeTokenStore(context.Background(), testCacheStore(t)), + snapshotBatchSize: batchSize, + } + s.SetProxyController(newTestProxyController()) + return s +} + +func TestSendSnapshot_BatchesMappings(t *testing.T) { + const cluster = "cluster.example.com" + const batchSize = 3 + const totalServices = 7 // 3 + 3 + 1 + + ctrl := gomock.NewController(t) + mgr := rpservice.NewMockManager(ctrl) + mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil) + + s := newSnapshotTestServer(t, batchSize) + s.serviceManager = mgr + + stream := &recordingStream{} + conn := &proxyConnection{ + proxyID: "proxy-a", + address: cluster, + stream: stream, + } + + err := s.sendSnapshot(context.Background(), conn) + require.NoError(t, err) + + // Expect ceil(7/3) = 3 messages + require.Len(t, stream.messages, 3, "should send ceil(totalServices/batchSize) messages") + + assert.Len(t, stream.messages[0].Mapping, 3) + assert.False(t, stream.messages[0].InitialSyncComplete, "first batch should not be sync-complete") + + assert.Len(t, stream.messages[1].Mapping, 3) + assert.False(t, stream.messages[1].InitialSyncComplete, "middle batch should not be sync-complete") + + assert.Len(t, stream.messages[2].Mapping, 1) + assert.True(t, stream.messages[2].InitialSyncComplete, "last batch must be sync-complete") + + // Verify all service IDs are present exactly once + seen := make(map[string]bool) + for _, msg := range stream.messages { + for _, m := range msg.Mapping { + assert.False(t, seen[m.Id], "duplicate service ID %s", m.Id) + seen[m.Id] = true + } + } + assert.Len(t, seen, totalServices) +} + +func TestSendSnapshot_ExactBatchMultiple(t *testing.T) { + const cluster = "cluster.example.com" + const batchSize = 3 + const totalServices = 6 // exactly 2 batches + + ctrl := gomock.NewController(t) + mgr := rpservice.NewMockManager(ctrl) + mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil) + + s := newSnapshotTestServer(t, batchSize) + s.serviceManager = mgr + + stream := &recordingStream{} + conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream} + + require.NoError(t, s.sendSnapshot(context.Background(), conn)) + require.Len(t, stream.messages, 2) + + assert.Len(t, stream.messages[0].Mapping, 3) + assert.False(t, stream.messages[0].InitialSyncComplete) + + assert.Len(t, stream.messages[1].Mapping, 3) + assert.True(t, stream.messages[1].InitialSyncComplete) +} + +func TestSendSnapshot_SingleBatch(t *testing.T) { + const cluster = "cluster.example.com" + const batchSize = 100 + const totalServices = 5 + + ctrl := gomock.NewController(t) + mgr := rpservice.NewMockManager(ctrl) + mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil) + + s := newSnapshotTestServer(t, batchSize) + s.serviceManager = mgr + + stream := &recordingStream{} + conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream} + + require.NoError(t, s.sendSnapshot(context.Background(), conn)) + require.Len(t, stream.messages, 1, "all mappings should fit in one batch") + assert.Len(t, stream.messages[0].Mapping, totalServices) + assert.True(t, stream.messages[0].InitialSyncComplete) +} + +func TestSendSnapshot_EmptySnapshot(t *testing.T) { + const cluster = "cluster.example.com" + + ctrl := gomock.NewController(t) + mgr := rpservice.NewMockManager(ctrl) + mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(nil, nil) + + s := newSnapshotTestServer(t, 500) + s.serviceManager = mgr + + stream := &recordingStream{} + conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream} + + require.NoError(t, s.sendSnapshot(context.Background(), conn)) + require.Len(t, stream.messages, 1, "empty snapshot must still send sync-complete") + assert.Empty(t, stream.messages[0].Mapping) + assert.True(t, stream.messages[0].InitialSyncComplete) +} diff --git a/management/internals/shared/grpc/proxy_test.go b/management/internals/shared/grpc/proxy_test.go index de4e96d93..5a7a457df 100644 --- a/management/internals/shared/grpc/proxy_test.go +++ b/management/internals/shared/grpc/proxy_test.go @@ -85,11 +85,14 @@ func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan // registerFakeProxyWithCaps adds a fake proxy connection with explicit capabilities. func registerFakeProxyWithCaps(s *ProxyServiceServer, proxyID, clusterAddr string, caps *proto.ProxyCapabilities) chan *proto.GetMappingUpdateResponse { ch := make(chan *proto.GetMappingUpdateResponse, 10) + ctx, cancel := context.WithCancel(context.Background()) conn := &proxyConnection{ proxyID: proxyID, address: clusterAddr, capabilities: caps, sendChan: ch, + ctx: ctx, + cancel: cancel, } s.connectedProxies.Store(proxyID, conn) diff --git a/proxy/management_integration_test.go b/proxy/management_integration_test.go index e9eae3210..99bbdad0c 100644 --- a/proxy/management_integration_test.go +++ b/proxy/management_integration_test.go @@ -364,14 +364,16 @@ func TestIntegration_ProxyConnection_HappyPath(t *testing.T) { }) require.NoError(t, err) - // Receive all mappings from the snapshot - server sends each mapping individually mappingsByID := make(map[string]*proto.ProxyMapping) - for i := 0; i < 2; i++ { + for { msg, err := stream.Recv() require.NoError(t, err) for _, m := range msg.GetMapping() { mappingsByID[m.GetId()] = m } + if msg.GetInitialSyncComplete() { + break + } } // Should receive 2 mappings total @@ -411,12 +413,14 @@ func TestIntegration_ProxyConnection_SendsClusterAddress(t *testing.T) { }) require.NoError(t, err) - // Receive all mappings - server sends each mapping individually mappings := make([]*proto.ProxyMapping, 0) - for i := 0; i < 2; i++ { + for { msg, err := stream.Recv() require.NoError(t, err) mappings = append(mappings, msg.GetMapping()...) + if msg.GetInitialSyncComplete() { + break + } } // Should receive the 2 mappings matching the cluster @@ -440,13 +444,15 @@ func TestIntegration_ProxyConnection_Reconnect_ReceivesSameConfig(t *testing.T) clusterAddress := "test.proxy.io" proxyID := "test-proxy-reconnect" - // Helper to receive all mappings from a stream - receiveMappings := func(stream proto.ProxyService_GetMappingUpdateClient, count int) []*proto.ProxyMapping { + receiveMappings := func(stream proto.ProxyService_GetMappingUpdateClient) []*proto.ProxyMapping { var mappings []*proto.ProxyMapping - for i := 0; i < count; i++ { + for { msg, err := stream.Recv() require.NoError(t, err) mappings = append(mappings, msg.GetMapping()...) + if msg.GetInitialSyncComplete() { + break + } } return mappings } @@ -460,7 +466,7 @@ func TestIntegration_ProxyConnection_Reconnect_ReceivesSameConfig(t *testing.T) }) require.NoError(t, err) - firstMappings := receiveMappings(stream1, 2) + firstMappings := receiveMappings(stream1) cancel1() time.Sleep(100 * time.Millisecond) @@ -476,7 +482,7 @@ func TestIntegration_ProxyConnection_Reconnect_ReceivesSameConfig(t *testing.T) }) require.NoError(t, err) - secondMappings := receiveMappings(stream2, 2) + secondMappings := receiveMappings(stream2) // Should receive the same mappings assert.Equal(t, len(firstMappings), len(secondMappings), @@ -542,12 +548,14 @@ func TestIntegration_ProxyConnection_ReconnectDoesNotDuplicateState(t *testing.T } } - // Helper to receive and apply all mappings receiveAndApply := func(stream proto.ProxyService_GetMappingUpdateClient) { - for i := 0; i < 2; i++ { + for { msg, err := stream.Recv() require.NoError(t, err) applyMappings(msg.GetMapping()) + if msg.GetInitialSyncComplete() { + break + } } } @@ -636,12 +644,14 @@ func TestIntegration_ProxyConnection_MultipleProxiesReceiveUpdates(t *testing.T) }) require.NoError(t, err) - // Receive all mappings - server sends each mapping individually count := 0 - for i := 0; i < 2; i++ { + for { msg, err := stream.Recv() require.NoError(t, err) count += len(msg.GetMapping()) + if msg.GetInitialSyncComplete() { + break + } } mu.Lock() @@ -681,9 +691,12 @@ func TestIntegration_ProxyConnection_FastReconnectDoesNotLoseState(t *testing.T) }) require.NoError(t, err) - for i := 0; i < 2; i++ { - _, err := stream1.Recv() + for { + msg, err := stream1.Recv() require.NoError(t, err) + if msg.GetInitialSyncComplete() { + break + } } require.Contains(t, setup.proxyService.GetConnectedProxies(), proxyID, @@ -699,9 +712,12 @@ func TestIntegration_ProxyConnection_FastReconnectDoesNotLoseState(t *testing.T) }) require.NoError(t, err) - for i := 0; i < 2; i++ { - _, err := stream2.Recv() + for { + msg, err := stream2.Recv() require.NoError(t, err) + if msg.GetInitialSyncComplete() { + break + } } cancel1() diff --git a/proxy/server.go b/proxy/server.go index fbd0d058e..6980e1df1 100644 --- a/proxy/server.go +++ b/proxy/server.go @@ -943,6 +943,8 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr operation := func() error { s.Logger.Debug("connecting to management mapping stream") + initialSyncDone = false + if s.healthChecker != nil { s.healthChecker.SetManagementConnected(false) } @@ -1000,6 +1002,11 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr return ctx.Err() } + var snapshotIDs map[types.ServiceID]struct{} + if !*initialSyncDone { + snapshotIDs = make(map[types.ServiceID]struct{}) + } + for { // Check for context completion to gracefully shutdown. select { @@ -1020,17 +1027,45 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr s.processMappings(ctx, msg.GetMapping()) s.Logger.Debug("Processing mapping update completed") - if !*initialSyncDone && msg.GetInitialSyncComplete() { - if s.healthChecker != nil { - s.healthChecker.SetInitialSyncComplete() + if !*initialSyncDone { + for _, m := range msg.GetMapping() { + snapshotIDs[types.ServiceID(m.GetId())] = struct{}{} + } + if msg.GetInitialSyncComplete() { + s.reconcileSnapshot(ctx, snapshotIDs) + snapshotIDs = nil + if s.healthChecker != nil { + s.healthChecker.SetInitialSyncComplete() + } + *initialSyncDone = true + s.Logger.Info("Initial mapping sync complete") } - *initialSyncDone = true - s.Logger.Info("Initial mapping sync complete") } } } } +// reconcileSnapshot removes local mappings that are absent from the snapshot. +// This ensures services deleted while the proxy was disconnected get cleaned up. +func (s *Server) reconcileSnapshot(ctx context.Context, snapshotIDs map[types.ServiceID]struct{}) { + s.portMu.RLock() + var stale []*proto.ProxyMapping + for svcID, mapping := range s.lastMappings { + if _, ok := snapshotIDs[svcID]; !ok { + stale = append(stale, mapping) + } + } + s.portMu.RUnlock() + + for _, mapping := range stale { + s.Logger.WithFields(log.Fields{ + "service_id": mapping.GetId(), + "domain": mapping.GetDomain(), + }).Info("Removing stale mapping absent from snapshot") + s.removeMapping(ctx, mapping) + } +} + func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMapping) { for _, mapping := range mappings { s.Logger.WithFields(log.Fields{ diff --git a/proxy/snapshot_reconcile_test.go b/proxy/snapshot_reconcile_test.go new file mode 100644 index 000000000..042d8df77 --- /dev/null +++ b/proxy/snapshot_reconcile_test.go @@ -0,0 +1,227 @@ +package proxy + +import ( + "context" + "io" + "testing" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/proxy/internal/health" + "github.com/netbirdio/netbird/proxy/internal/types" + "github.com/netbirdio/netbird/shared/management/proto" +) + +// collectStaleIDs mirrors the stale-detection logic in reconcileSnapshot +// so we can verify it without triggering removeMapping (which requires full +// server wiring). This keeps the test focused on the detection algorithm. +func collectStaleIDs(lastMappings map[types.ServiceID]*proto.ProxyMapping, snapshotIDs map[types.ServiceID]struct{}) []types.ServiceID { + var stale []types.ServiceID + for svcID := range lastMappings { + if _, ok := snapshotIDs[svcID]; !ok { + stale = append(stale, svcID) + } + } + return stale +} + +// TestStaleDetection_PartialOverlap verifies that only services absent from +// the snapshot are flagged as stale. +func TestStaleDetection_PartialOverlap(t *testing.T) { + local := map[types.ServiceID]*proto.ProxyMapping{ + "svc-1": {Id: "svc-1"}, + "svc-2": {Id: "svc-2"}, + "svc-stale-a": {Id: "svc-stale-a"}, + "svc-stale-b": {Id: "svc-stale-b"}, + } + snapshot := map[types.ServiceID]struct{}{ + "svc-1": {}, + "svc-2": {}, + "svc-3": {}, // new service, not in local + } + + stale := collectStaleIDs(local, snapshot) + assert.Len(t, stale, 2) + staleSet := make(map[types.ServiceID]struct{}) + for _, id := range stale { + staleSet[id] = struct{}{} + } + assert.Contains(t, staleSet, types.ServiceID("svc-stale-a")) + assert.Contains(t, staleSet, types.ServiceID("svc-stale-b")) +} + +// TestStaleDetection_AllStale verifies an empty snapshot flags everything. +func TestStaleDetection_AllStale(t *testing.T) { + local := map[types.ServiceID]*proto.ProxyMapping{ + "svc-1": {Id: "svc-1"}, + "svc-2": {Id: "svc-2"}, + } + stale := collectStaleIDs(local, map[types.ServiceID]struct{}{}) + assert.Len(t, stale, 2) +} + +// TestStaleDetection_NoneStale verifies full overlap produces no stale entries. +func TestStaleDetection_NoneStale(t *testing.T) { + local := map[types.ServiceID]*proto.ProxyMapping{ + "svc-1": {Id: "svc-1"}, + "svc-2": {Id: "svc-2"}, + } + snapshot := map[types.ServiceID]struct{}{ + "svc-1": {}, + "svc-2": {}, + } + stale := collectStaleIDs(local, snapshot) + assert.Empty(t, stale) +} + +// TestStaleDetection_EmptyLocal verifies no stale entries when local is empty. +func TestStaleDetection_EmptyLocal(t *testing.T) { + stale := collectStaleIDs( + map[types.ServiceID]*proto.ProxyMapping{}, + map[types.ServiceID]struct{}{"svc-1": {}}, + ) + assert.Empty(t, stale) +} + +// TestReconcileSnapshot_NoStale verifies reconciliation is a no-op when all +// local mappings are present in the snapshot (removeMapping is never called). +func TestReconcileSnapshot_NoStale(t *testing.T) { + s := &Server{ + Logger: log.StandardLogger(), + lastMappings: make(map[types.ServiceID]*proto.ProxyMapping), + } + s.lastMappings["svc-1"] = &proto.ProxyMapping{Id: "svc-1"} + s.lastMappings["svc-2"] = &proto.ProxyMapping{Id: "svc-2"} + + snapshotIDs := map[types.ServiceID]struct{}{ + "svc-1": {}, + "svc-2": {}, + } + // This should not panic — no stale entries means removeMapping is never called. + s.reconcileSnapshot(context.Background(), snapshotIDs) + + assert.Len(t, s.lastMappings, 2, "no mappings should be removed when all are in snapshot") +} + +// TestReconcileSnapshot_EmptyLocal verifies reconciliation is a no-op with +// no local mappings. +func TestReconcileSnapshot_EmptyLocal(t *testing.T) { + s := &Server{ + Logger: log.StandardLogger(), + lastMappings: make(map[types.ServiceID]*proto.ProxyMapping), + } + s.reconcileSnapshot(context.Background(), map[types.ServiceID]struct{}{"svc-1": {}}) + assert.Empty(t, s.lastMappings) +} + +// --- handleMappingStream tests for batched snapshot ID accumulation --- + +// TestHandleMappingStream_BatchedSnapshotSyncComplete verifies that sync is +// marked done only after the final InitialSyncComplete message, even when +// the snapshot arrives in multiple batches. +func TestHandleMappingStream_BatchedSnapshotSyncComplete(t *testing.T) { + checker := health.NewChecker(nil, nil) + s := &Server{ + Logger: log.StandardLogger(), + healthChecker: checker, + routerReady: closedChan(), + lastMappings: make(map[types.ServiceID]*proto.ProxyMapping), + } + + stream := &mockMappingStream{ + messages: []*proto.GetMappingUpdateResponse{ + {}, // batch 1: no sync-complete + {}, // batch 2: no sync-complete + {InitialSyncComplete: true}, // batch 3: sync done + }, + } + + syncDone := false + err := s.handleMappingStream(context.Background(), stream, &syncDone) + assert.NoError(t, err) + assert.True(t, syncDone, "sync should be marked done after final batch") +} + +// TestHandleMappingStream_PostSyncDoesNotReconcile verifies that messages +// arriving after InitialSyncComplete do not trigger a second reconciliation. +func TestHandleMappingStream_PostSyncDoesNotReconcile(t *testing.T) { + s := &Server{ + Logger: log.StandardLogger(), + routerReady: closedChan(), + lastMappings: make(map[types.ServiceID]*proto.ProxyMapping), + } + + // Simulate state left over from a previous sync. + s.lastMappings["svc-1"] = &proto.ProxyMapping{Id: "svc-1", AccountId: "acct-1"} + s.lastMappings["svc-2"] = &proto.ProxyMapping{Id: "svc-2", AccountId: "acct-1"} + + stream := &mockMappingStream{ + messages: []*proto.GetMappingUpdateResponse{ + {}, // post-sync empty message — must not reconcile + }, + } + + syncDone := true // sync already completed in a previous stream + err := s.handleMappingStream(context.Background(), stream, &syncDone) + require.NoError(t, err) + + assert.Len(t, s.lastMappings, 2, + "post-sync messages must not trigger reconciliation — all entries should survive") +} + +// TestHandleMappingStream_ImmediateEOF_NoReconciliation verifies that if the +// stream closes before sync completes, no reconciliation occurs. +func TestHandleMappingStream_ImmediateEOF_NoReconciliation(t *testing.T) { + s := &Server{ + Logger: log.StandardLogger(), + routerReady: closedChan(), + lastMappings: make(map[types.ServiceID]*proto.ProxyMapping), + } + + s.lastMappings["svc-stale"] = &proto.ProxyMapping{Id: "svc-stale", AccountId: "acct-1"} + + stream := &mockMappingStream{} // no messages → immediate EOF + + syncDone := false + err := s.handleMappingStream(context.Background(), stream, &syncDone) + assert.NoError(t, err) + assert.False(t, syncDone, "sync should not be marked done on immediate EOF") + + _, hasStale := s.lastMappings["svc-stale"] + assert.True(t, hasStale, "stale mapping should remain when sync never completed") +} + +// mockErrRecvStream returns an error on the second Recv to verify +// handleMappingStream returns without completing sync. +type mockErrRecvStream struct { + mockMappingStream + calls int +} + +func (m *mockErrRecvStream) Recv() (*proto.GetMappingUpdateResponse, error) { + m.calls++ + if m.calls == 1 { + return &proto.GetMappingUpdateResponse{}, nil + } + return nil, io.ErrUnexpectedEOF +} + +func TestHandleMappingStream_ErrorMidSync_NoReconciliation(t *testing.T) { + s := &Server{ + Logger: log.StandardLogger(), + routerReady: closedChan(), + lastMappings: make(map[types.ServiceID]*proto.ProxyMapping), + } + + s.lastMappings["svc-stale"] = &proto.ProxyMapping{Id: "svc-stale", AccountId: "acct-1"} + + syncDone := false + err := s.handleMappingStream(context.Background(), &mockErrRecvStream{}, &syncDone) + assert.Error(t, err) + assert.False(t, syncDone) + + _, hasStale := s.lastMappings["svc-stale"] + assert.True(t, hasStale, "stale mapping should remain when sync was interrupted by error") +} From b19b7464eac5c58bb6a6780a033398a27f3d772f Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Tue, 5 May 2026 18:48:51 +0200 Subject: [PATCH 15/15] [management] fix flaky invite token test (#6077) --- management/server/types/user_invite_test.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/management/server/types/user_invite_test.go b/management/server/types/user_invite_test.go index 09dae3800..c77fb89e2 100644 --- a/management/server/types/user_invite_test.go +++ b/management/server/types/user_invite_test.go @@ -144,8 +144,11 @@ func TestValidateInviteToken_ModifiedToken(t *testing.T) { _, plainToken, err := GenerateInviteToken() require.NoError(t, err) - // Modify one character in the secret part - modifiedToken := plainToken[:5] + "X" + plainToken[6:] + replacement := "X" + if plainToken[5] == 'X' { + replacement = "Y" + } + modifiedToken := plainToken[:5] + replacement + plainToken[6:] err = ValidateInviteToken(modifiedToken) require.Error(t, err) }