Add relayServerIP field to signal for foreign-relay fallback dial

This commit is contained in:
Viktor Liu
2026-04-27 18:01:46 +02:00
parent d6f08e4840
commit e7bd62f58c
23 changed files with 786 additions and 122 deletions

View File

@@ -2,8 +2,12 @@ package client
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"net/url"
"strings"
"sync"
"time"
@@ -146,6 +150,7 @@ func (cc *connContainer) close() {
type Client struct {
log *log.Entry
connectionURL string
fallbackIP netip.Addr
authTokenStore *auth.TokenStore
hashedID messages.PeerID
@@ -170,13 +175,16 @@ type Client struct {
}
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
func NewClient(serverURL string, authTokenStore *auth.TokenStore, peerID string, mtu uint16) *Client {
// is called. fallbackIP, when valid, is used as a dial-time fallback if the FQDN-based dial fails. TLS
// verification still uses the FQDN from serverURL via SNI.
func NewClient(serverURL string, fallbackIP netip.Addr, authTokenStore *auth.TokenStore, peerID string, mtu uint16) *Client {
hashedID := messages.HashID(peerID)
relayLog := log.WithFields(log.Fields{"relay": serverURL})
c := &Client{
log: relayLog,
connectionURL: serverURL,
fallbackIP: fallbackIP,
authTokenStore: authTokenStore,
hashedID: hashedID,
mtu: mtu,
@@ -304,6 +312,41 @@ func (c *Client) ServerInstanceURL() (string, error) {
return c.instanceURL.String(), nil
}
// ConnectedIP returns the IP address of the live relay-server connection,
// extracted from the underlying socket's RemoteAddr. Zero value if not
// connected or if the address is not an IP literal.
func (c *Client) ConnectedIP() netip.Addr {
c.mu.Lock()
conn := c.relayConn
c.mu.Unlock()
if conn == nil {
return netip.Addr{}
}
addr := conn.RemoteAddr()
if addr == nil {
return netip.Addr{}
}
return extractIPLiteral(addr.String())
}
// extractIPLiteral returns the IP from address forms produced by the relay
// dialers (URL or host:port). Zero value if the host is not an IP.
func extractIPLiteral(s string) netip.Addr {
if u, err := url.Parse(s); err == nil && u.Host != "" {
s = u.Host
}
host, _, err := net.SplitHostPort(s)
if err != nil {
host = s
}
host = strings.Trim(host, "[]")
ip, err := netip.ParseAddr(host)
if err != nil {
return netip.Addr{}
}
return ip.Unmap()
}
// SetOnDisconnectListener sets a function that will be called when the connection to the relay server is closed.
func (c *Client) SetOnDisconnectListener(fn func(string)) {
c.listenerMutex.Lock()
@@ -335,7 +378,11 @@ func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, dialers...)
conn, err := rd.Dial(ctx)
if err != nil {
return nil, err
fallbackConn, fbErr := c.dialFallback(ctx, dialers)
if fbErr != nil {
return nil, fmt.Errorf("primary dial: %w; fallback dial: %w", err, fbErr)
}
conn = fallbackConn
}
c.relayConn = conn
@@ -351,6 +398,58 @@ func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
return instanceURL, nil
}
// dialFallback retries the dial against c.fallbackIP, preserving the
// original FQDN as the TLS ServerName for SNI. Returns an error if no
// fallback IP is configured or if the substituted URL is malformed.
func (c *Client) dialFallback(ctx context.Context, dialers []dialer.DialeFn) (net.Conn, error) {
if !c.fallbackIP.IsValid() || c.fallbackIP.IsUnspecified() {
return nil, errors.New("no usable fallback IP configured")
}
fallbackURL, serverName, err := substituteHost(c.connectionURL, c.fallbackIP)
if err != nil {
return nil, fmt.Errorf("substitute host: %w", err)
}
c.log.Infof("primary dial failed, retrying via fallback IP %s (SNI=%s)", c.fallbackIP, serverName)
rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, fallbackURL, dialers...).
WithServerName(serverName)
return rd.Dial(ctx)
}
// substituteHost replaces the host portion of a rel/rels URL with ip,
// preserving the scheme and port. Returns the rewritten URL and the
// original host to use as the TLS ServerName, or empty if the original
// host is itself an IP literal (SNI requires a DNS name).
func substituteHost(serverURL string, ip netip.Addr) (string, string, error) {
u, err := url.Parse(serverURL)
if err != nil {
return "", "", fmt.Errorf("parse %q: %w", serverURL, err)
}
if u.Scheme == "" || u.Host == "" {
return "", "", fmt.Errorf("invalid relay URL %q", serverURL)
}
if !ip.IsValid() {
return "", "", errors.New("invalid fallback IP")
}
origHost := u.Hostname()
if _, err := netip.ParseAddr(origHost); err == nil {
origHost = ""
}
ip = ip.Unmap()
newHost := ip.String()
if ip.Is6() {
newHost = "[" + newHost + "]"
}
if port := u.Port(); port != "" {
u.Host = newHost + ":" + port
} else {
u.Host = newHost
}
return u.String(), origHost, nil
}
func (c *Client) handShake(ctx context.Context) (*RelayAddr, error) {
msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary())
if err != nil {

View File

@@ -0,0 +1,280 @@
package client
import (
"context"
"fmt"
"net"
"net/netip"
"testing"
"time"
"go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/relay/server"
"github.com/netbirdio/netbird/shared/relay/auth/allow"
)
// TestClient_FallbackIPRecoversFromUnresolvableFQDN verifies that when the
// primary FQDN-based dial fails (unresolvable .invalid host), Connect
// recovers via the fallback IP and SNI still uses the FQDN.
func TestClient_FallbackIPRecoversFromUnresolvableFQDN(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
listenAddr, port := freeAddr(t)
srvCfg := server.Config{
Meter: otel.Meter(""),
ExposedAddress: fmt.Sprintf("rel://test-unresolvable-host.invalid:%d", port),
TLSSupport: false,
AuthValidator: &allow.Auth{},
}
srv, err := server.NewServer(srvCfg)
if err != nil {
t.Fatalf("create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
if err := srv.Listen(server.ListenerConfig{Address: listenAddr}); err != nil {
errChan <- err
}
}()
t.Cleanup(func() {
if err := srv.Shutdown(context.Background()); err != nil {
t.Errorf("shutdown server: %s", err)
}
})
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("server failed to start: %s", err)
}
t.Run("no fallback IP, primary fails", func(t *testing.T) {
c := NewClient(srvCfg.ExposedAddress, netip.Addr{}, hmacTokenStore, "alice-nofallback", iface.DefaultMTU)
err := c.Connect(ctx)
if err == nil {
_ = c.Close()
t.Fatalf("expected connect to fail without fallback IP, got nil")
}
})
t.Run("fallback IP recovers", func(t *testing.T) {
c := NewClient(srvCfg.ExposedAddress, netip.MustParseAddr("127.0.0.1"), hmacTokenStore, "alice-fallback", iface.DefaultMTU)
if err := c.Connect(ctx); err != nil {
t.Fatalf("connect with fallback IP: %s", err)
}
t.Cleanup(func() { _ = c.Close() })
if !c.Ready() {
t.Fatalf("client not ready after connect")
}
if got := c.ConnectedIP(); got.String() != "127.0.0.1" {
t.Fatalf("ConnectedIP = %q, want 127.0.0.1", got)
}
})
}
// TestClient_ConnectedIPAfterFQDNDial verifies ConnectedIP returns the
// resolved IP after a successful FQDN-based dial. The underlying socket's
// RemoteAddr must be exposed through the dialer wrappers; if it returns
// the dial-time URL instead, ConnectedIP returns empty and the fallback
// IP we advertise to peers is empty too.
func TestClient_ConnectedIPAfterFQDNDial(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
listenAddr, port := freeAddr(t)
srvCfg := server.Config{
Meter: otel.Meter(""),
ExposedAddress: fmt.Sprintf("rel://localhost:%d", port),
TLSSupport: false,
AuthValidator: &allow.Auth{},
}
srv, err := server.NewServer(srvCfg)
if err != nil {
t.Fatalf("create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
if err := srv.Listen(server.ListenerConfig{Address: listenAddr}); err != nil {
errChan <- err
}
}()
t.Cleanup(func() { _ = srv.Shutdown(context.Background()) })
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("server failed to start: %s", err)
}
c := NewClient(srvCfg.ExposedAddress, netip.Addr{}, hmacTokenStore, "alice-fqdn", iface.DefaultMTU)
if err := c.Connect(ctx); err != nil {
t.Fatalf("connect: %s", err)
}
t.Cleanup(func() { _ = c.Close() })
got := c.ConnectedIP().String()
if got != "127.0.0.1" && got != "::1" {
t.Fatalf("ConnectedIP after FQDN dial = %q, want 127.0.0.1 or ::1", got)
}
}
func TestSubstituteHost(t *testing.T) {
tests := []struct {
name string
serverURL string
ip string
wantURL string
wantServerName string
wantErr bool
}{
{
name: "rels with port",
serverURL: "rels://relay.netbird.io:443",
ip: "10.0.0.5",
wantURL: "rels://10.0.0.5:443",
wantServerName: "relay.netbird.io",
},
{
name: "rel with port",
serverURL: "rel://relay.example.com:80",
ip: "192.0.2.1",
wantURL: "rel://192.0.2.1:80",
wantServerName: "relay.example.com",
},
{
name: "ipv6 fallback bracketed",
serverURL: "rels://relay.example.com:443",
ip: "2001:db8::1",
wantURL: "rels://[2001:db8::1]:443",
wantServerName: "relay.example.com",
},
{
name: "no port",
serverURL: "rels://relay.example.com",
ip: "10.0.0.5",
wantURL: "rels://10.0.0.5",
wantServerName: "relay.example.com",
},
{
name: "ipv6 server with port returns empty SNI",
serverURL: "rels://[2001:db8::5]:443",
ip: "10.0.0.5",
wantURL: "rels://10.0.0.5:443",
wantServerName: "",
},
{
name: "ipv4 server with port returns empty SNI",
serverURL: "rels://10.0.0.5:443",
ip: "10.0.0.6",
wantURL: "rels://10.0.0.6:443",
wantServerName: "",
},
{
name: "ipv6 fallback no port",
serverURL: "rels://relay.example.com",
ip: "2001:db8::1",
wantURL: "rels://[2001:db8::1]",
wantServerName: "relay.example.com",
},
{
name: "missing scheme",
serverURL: "relay.example.com:443",
ip: "10.0.0.5",
wantErr: true,
},
{
name: "empty",
serverURL: "",
ip: "10.0.0.5",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var ip netip.Addr
if tt.ip != "" {
ip = netip.MustParseAddr(tt.ip)
}
gotURL, gotName, err := substituteHost(tt.serverURL, ip)
if tt.wantErr {
if err == nil {
t.Fatalf("expected error, got nil")
}
return
}
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if gotURL != tt.wantURL {
t.Errorf("URL = %q, want %q", gotURL, tt.wantURL)
}
if gotName != tt.wantServerName {
t.Errorf("ServerName = %q, want %q", gotName, tt.wantServerName)
}
})
}
}
func TestClient_ConnectedIPEmptyWhenNotConnected(t *testing.T) {
c := NewClient("rel://example.invalid:80", netip.Addr{}, hmacTokenStore, "x", iface.DefaultMTU)
if got := c.ConnectedIP(); got.IsValid() {
t.Fatalf("ConnectedIP on disconnected client = %q, want zero", got)
}
}
// staticAddr is a net.Addr that returns a fixed string. Used to verify
// ConnectedIP parses RemoteAddr correctly.
type staticAddr struct{ s string }
func (a staticAddr) Network() string { return "tcp" }
func (a staticAddr) String() string { return a.s }
type stubConn struct {
net.Conn
remote net.Addr
}
func (s stubConn) RemoteAddr() net.Addr { return s.remote }
func TestClient_ConnectedIPParsesRemoteAddr(t *testing.T) {
tests := []struct {
name string
s string
want string
}{
{"hostport ipv4", "127.0.0.1:50301", "127.0.0.1"},
{"hostport ipv6 bracketed", "[::1]:50301", "::1"},
{"url with ipv4", "rel://127.0.0.1:50301", "127.0.0.1"},
{"url with ipv6", "rels://[2001:db8::1]:443", "2001:db8::1"},
{"fqdn url returns empty", "rel://relay.example.com:50301", ""},
{"fqdn hostport returns empty", "relay.example.com:50301", ""},
{"plain ipv4 no port", "10.0.0.1", "10.0.0.1"},
{"empty", "", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Client{relayConn: stubConn{remote: staticAddr{s: tt.s}}}
got := c.ConnectedIP()
var gotStr string
if got.IsValid() {
gotStr = got.String()
}
if gotStr != tt.want {
t.Errorf("ConnectedIP(%q) = %q, want %q", tt.s, gotStr, tt.want)
}
})
}
}
// freeAddr returns a 127.0.0.1 address with an OS-assigned port. The
// listener is closed before returning, so the port is briefly free for
// the caller to bind. Avoids hardcoded ports that can collide.
func freeAddr(t *testing.T) (string, int) {
t.Helper()
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("get free port: %s", err)
}
addr := l.Addr().(*net.TCPAddr)
_ = l.Close()
return addr.String(), addr.Port
}

View File

@@ -3,6 +3,7 @@ package client
import (
"context"
"net"
"net/netip"
"os"
"testing"
"time"
@@ -68,7 +69,7 @@ func TestClient(t *testing.T) {
t.Fatalf("failed to start server: %s", err)
}
t.Log("alice connecting to server")
clientAlice := NewClient(serverCfg.ExposedAddress, hmacTokenStore, "alice", iface.DefaultMTU)
clientAlice := NewClient(serverCfg.ExposedAddress, netip.Addr{}, hmacTokenStore, "alice", iface.DefaultMTU)
err = clientAlice.Connect(ctx)
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
@@ -76,7 +77,7 @@ func TestClient(t *testing.T) {
defer clientAlice.Close()
t.Log("placeholder connecting to server")
clientPlaceHolder := NewClient(serverCfg.ExposedAddress, hmacTokenStore, "clientPlaceHolder", iface.DefaultMTU)
clientPlaceHolder := NewClient(serverCfg.ExposedAddress, netip.Addr{}, hmacTokenStore, "clientPlaceHolder", iface.DefaultMTU)
err = clientPlaceHolder.Connect(ctx)
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
@@ -84,7 +85,7 @@ func TestClient(t *testing.T) {
defer clientPlaceHolder.Close()
t.Log("Bob connecting to server")
clientBob := NewClient(serverCfg.ExposedAddress, hmacTokenStore, "bob", iface.DefaultMTU)
clientBob := NewClient(serverCfg.ExposedAddress, netip.Addr{}, hmacTokenStore, "bob", iface.DefaultMTU)
err = clientBob.Connect(ctx)
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
@@ -144,7 +145,7 @@ func TestRegistration(t *testing.T) {
t.Fatalf("failed to start server: %s", err)
}
clientAlice := NewClient(serverCfg.ExposedAddress, hmacTokenStore, "alice", iface.DefaultMTU)
clientAlice := NewClient(serverCfg.ExposedAddress, netip.Addr{}, hmacTokenStore, "alice", iface.DefaultMTU)
err = clientAlice.Connect(ctx)
if err != nil {
_ = srv.Shutdown(ctx)
@@ -184,7 +185,7 @@ func TestRegistrationTimeout(t *testing.T) {
_ = fakeTCPListener.Close()
}(fakeTCPListener)
clientAlice := NewClient("127.0.0.1:50201", hmacTokenStore, "alice", iface.DefaultMTU)
clientAlice := NewClient("127.0.0.1:50201", netip.Addr{}, hmacTokenStore, "alice", iface.DefaultMTU)
err = clientAlice.Connect(ctx)
if err == nil {
t.Errorf("failed to connect to server: %s", err)
@@ -227,7 +228,7 @@ func TestEcho(t *testing.T) {
t.Fatalf("failed to start server: %s", err)
}
clientAlice := NewClient(serverCfg.ExposedAddress, hmacTokenStore, idAlice, iface.DefaultMTU)
clientAlice := NewClient(serverCfg.ExposedAddress, netip.Addr{}, hmacTokenStore, idAlice, iface.DefaultMTU)
err = clientAlice.Connect(ctx)
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
@@ -239,7 +240,7 @@ func TestEcho(t *testing.T) {
}
}()
clientBob := NewClient(serverCfg.ExposedAddress, hmacTokenStore, idBob, iface.DefaultMTU)
clientBob := NewClient(serverCfg.ExposedAddress, netip.Addr{}, hmacTokenStore, idBob, iface.DefaultMTU)
err = clientBob.Connect(ctx)
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
@@ -319,7 +320,7 @@ func TestBindToUnavailabePeer(t *testing.T) {
t.Fatalf("failed to start server: %s", err)
}
clientAlice := NewClient(serverCfg.ExposedAddress, hmacTokenStore, "alice", iface.DefaultMTU)
clientAlice := NewClient(serverCfg.ExposedAddress, netip.Addr{}, hmacTokenStore, "alice", iface.DefaultMTU)
err = clientAlice.Connect(ctx)
if err != nil {
t.Errorf("failed to connect to server: %s", err)
@@ -367,13 +368,13 @@ func TestBindReconnect(t *testing.T) {
t.Fatalf("failed to start server: %s", err)
}
clientAlice := NewClient(serverCfg.ExposedAddress, hmacTokenStore, "alice", iface.DefaultMTU)
clientAlice := NewClient(serverCfg.ExposedAddress, netip.Addr{}, hmacTokenStore, "alice", iface.DefaultMTU)
err = clientAlice.Connect(ctx)
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
clientBob := NewClient(serverCfg.ExposedAddress, hmacTokenStore, "bob", iface.DefaultMTU)
clientBob := NewClient(serverCfg.ExposedAddress, netip.Addr{}, hmacTokenStore, "bob", iface.DefaultMTU)
err = clientBob.Connect(ctx)
if err != nil {
t.Errorf("failed to connect to server: %s", err)
@@ -395,7 +396,7 @@ func TestBindReconnect(t *testing.T) {
t.Errorf("failed to close client: %s", err)
}
clientAlice = NewClient(serverCfg.ExposedAddress, hmacTokenStore, "alice", iface.DefaultMTU)
clientAlice = NewClient(serverCfg.ExposedAddress, netip.Addr{}, hmacTokenStore, "alice", iface.DefaultMTU)
err = clientAlice.Connect(ctx)
if err != nil {
t.Errorf("failed to connect to server: %s", err)
@@ -470,13 +471,13 @@ func TestCloseConn(t *testing.T) {
t.Fatalf("failed to start server: %s", err)
}
bob := NewClient(serverCfg.ExposedAddress, hmacTokenStore, "bob", iface.DefaultMTU)
bob := NewClient(serverCfg.ExposedAddress, netip.Addr{}, hmacTokenStore, "bob", iface.DefaultMTU)
err = bob.Connect(ctx)
if err != nil {
t.Errorf("failed to connect to server: %s", err)
}
clientAlice := NewClient(serverCfg.ExposedAddress, hmacTokenStore, "alice", iface.DefaultMTU)
clientAlice := NewClient(serverCfg.ExposedAddress, netip.Addr{}, hmacTokenStore, "alice", iface.DefaultMTU)
err = clientAlice.Connect(ctx)
if err != nil {
t.Errorf("failed to connect to server: %s", err)
@@ -534,13 +535,13 @@ func TestCloseRelayConn(t *testing.T) {
t.Fatalf("failed to start server: %s", err)
}
bob := NewClient(serverCfg.ExposedAddress, hmacTokenStore, "bob", iface.DefaultMTU)
bob := NewClient(serverCfg.ExposedAddress, netip.Addr{}, hmacTokenStore, "bob", iface.DefaultMTU)
err = bob.Connect(ctx)
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
clientAlice := NewClient(serverCfg.ExposedAddress, hmacTokenStore, "alice", iface.DefaultMTU)
clientAlice := NewClient(serverCfg.ExposedAddress, netip.Addr{}, hmacTokenStore, "alice", iface.DefaultMTU)
err = clientAlice.Connect(ctx)
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
@@ -590,7 +591,7 @@ func TestCloseByServer(t *testing.T) {
idAlice := "alice"
log.Debugf("connect by alice")
relayClient := NewClient(serverCfg.ExposedAddress, hmacTokenStore, idAlice, iface.DefaultMTU)
relayClient := NewClient(serverCfg.ExposedAddress, netip.Addr{}, hmacTokenStore, idAlice, iface.DefaultMTU)
if err = relayClient.Connect(ctx); err != nil {
log.Fatalf("failed to connect to server: %s", err)
}
@@ -648,7 +649,7 @@ func TestCloseByClient(t *testing.T) {
idAlice := "alice"
log.Debugf("connect by alice")
relayClient := NewClient(serverCfg.ExposedAddress, hmacTokenStore, idAlice, iface.DefaultMTU)
relayClient := NewClient(serverCfg.ExposedAddress, netip.Addr{}, hmacTokenStore, idAlice, iface.DefaultMTU)
err = relayClient.Connect(ctx)
if err != nil {
log.Fatalf("failed to connect to server: %s", err)
@@ -701,7 +702,7 @@ func TestCloseNotDrainedChannel(t *testing.T) {
t.Fatalf("failed to start server: %s", err)
}
clientAlice := NewClient(serverCfg.ExposedAddress, hmacTokenStore, idAlice, iface.DefaultMTU)
clientAlice := NewClient(serverCfg.ExposedAddress, netip.Addr{}, hmacTokenStore, idAlice, iface.DefaultMTU)
err = clientAlice.Connect(ctx)
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
@@ -713,7 +714,7 @@ func TestCloseNotDrainedChannel(t *testing.T) {
}
}()
clientBob := NewClient(serverCfg.ExposedAddress, hmacTokenStore, idBob, iface.DefaultMTU)
clientBob := NewClient(serverCfg.ExposedAddress, netip.Addr{}, hmacTokenStore, idBob, iface.DefaultMTU)
err = clientBob.Connect(ctx)
if err != nil {
t.Fatalf("failed to connect to server: %s", err)

View File

@@ -23,7 +23,7 @@ func (d Dialer) Protocol() string {
return Network
}
func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) {
func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn, error) {
quicURL, err := prepareURL(address)
if err != nil {
return nil, err
@@ -32,11 +32,14 @@ func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) {
// Get the base TLS config
tlsClientConfig := quictls.ClientQUICTLSConfig()
// Set ServerName to hostname if not an IP address
host, _, splitErr := net.SplitHostPort(quicURL)
if splitErr == nil && net.ParseIP(host) == nil {
// It's a hostname, not an IP - modify directly
tlsClientConfig.ServerName = host
switch {
case serverName != "" && net.ParseIP(serverName) == nil:
tlsClientConfig.ServerName = serverName
case serverName == "":
host, _, splitErr := net.SplitHostPort(quicURL)
if splitErr == nil && net.ParseIP(host) == nil {
tlsClientConfig.ServerName = host
}
}
quicConfig := &quic.Config{

View File

@@ -14,7 +14,9 @@ const (
)
type DialeFn interface {
Dial(ctx context.Context, address string) (net.Conn, error)
// Dial connects to address. serverName, when non-empty, overrides the TLS
// ServerName used for SNI/cert validation. Empty means derive from address.
Dial(ctx context.Context, address, serverName string) (net.Conn, error)
Protocol() string
}
@@ -27,6 +29,7 @@ type dialResult struct {
type RaceDial struct {
log *log.Entry
serverURL string
serverName string
dialerFns []DialeFn
connectionTimeout time.Duration
}
@@ -40,6 +43,16 @@ func NewRaceDial(log *log.Entry, connectionTimeout time.Duration, serverURL stri
}
}
// WithServerName sets a TLS SNI/cert validation override. Used when serverURL
// contains an IP literal but the cert is issued for a different hostname.
//
// Mutates the receiver and is not safe for concurrent reconfiguration; a
// RaceDial is intended to be constructed per dial and discarded.
func (r *RaceDial) WithServerName(serverName string) *RaceDial {
r.serverName = serverName
return r
}
func (r *RaceDial) Dial(ctx context.Context) (net.Conn, error) {
connChan := make(chan dialResult, len(r.dialerFns))
winnerConn := make(chan net.Conn, 1)
@@ -64,7 +77,7 @@ func (r *RaceDial) dial(dfn DialeFn, abortCtx context.Context, connChan chan dia
defer cancel()
r.log.Infof("dialing Relay server via %s", dfn.Protocol())
conn, err := dfn.Dial(ctx, r.serverURL)
conn, err := dfn.Dial(ctx, r.serverURL, r.serverName)
connChan <- dialResult{Conn: conn, Protocol: dfn.Protocol(), Err: err}
}

View File

@@ -28,7 +28,7 @@ type MockDialer struct {
protocolStr string
}
func (m *MockDialer) Dial(ctx context.Context, address string) (net.Conn, error) {
func (m *MockDialer) Dial(ctx context.Context, address, _ string) (net.Conn, error) {
return m.dialFunc(ctx, address)
}

View File

@@ -12,14 +12,24 @@ import (
type Conn struct {
ctx context.Context
*websocket.Conn
remoteAddr WebsocketAddr
remoteAddr net.Addr
}
func NewConn(wsConn *websocket.Conn, serverAddress string) net.Conn {
// NewConn builds a relay ws.Conn. underlying is the raw TCP/TLS conn captured
// from the http transport's DialContext; when set, RemoteAddr returns its
// peer address (an IP literal). When nil (e.g. wasm), RemoteAddr falls back
// to the dial-time URL.
func NewConn(wsConn *websocket.Conn, serverAddress string, underlying net.Conn) net.Conn {
var addr net.Addr = WebsocketAddr{serverAddress}
if underlying != nil {
if ra := underlying.RemoteAddr(); ra != nil {
addr = ra
}
}
return &Conn{
ctx: context.Background(),
Conn: wsConn,
remoteAddr: WebsocketAddr{serverAddress},
remoteAddr: addr,
}
}

View File

@@ -2,10 +2,14 @@
package ws
import "github.com/coder/websocket"
import (
"net"
func createDialOptions() *websocket.DialOptions {
"github.com/coder/websocket"
)
func createDialOptions(serverName string, underlyingOut *net.Conn) *websocket.DialOptions {
return &websocket.DialOptions{
HTTPClient: httpClientNbDialer(),
HTTPClient: httpClientNbDialer(serverName, underlyingOut),
}
}

View File

@@ -2,9 +2,13 @@
package ws
import "github.com/coder/websocket"
import (
"net"
func createDialOptions() *websocket.DialOptions {
// WASM version doesn't support HTTPClient
"github.com/coder/websocket"
)
func createDialOptions(_ string, _ *net.Conn) *websocket.DialOptions {
// WASM version doesn't support HTTPClient or custom TLS config.
return &websocket.DialOptions{}
}

View File

@@ -26,13 +26,14 @@ func (d Dialer) Protocol() string {
return "WS"
}
func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) {
func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn, error) {
wsURL, err := prepareURL(address)
if err != nil {
return nil, err
}
opts := createDialOptions()
var underlying net.Conn
opts := createDialOptions(serverName, &underlying)
parsedURL, err := url.Parse(wsURL)
if err != nil {
@@ -52,7 +53,7 @@ func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) {
_ = resp.Body.Close()
}
conn := NewConn(wsConn, address)
conn := NewConn(wsConn, address, underlying)
return conn, nil
}
@@ -64,7 +65,10 @@ func prepareURL(address string) (string, error) {
return strings.Replace(address, "rel", "ws", 1), nil
}
func httpClientNbDialer() *http.Client {
// httpClientNbDialer builds the http client used by the websocket library.
// underlyingOut, when non-nil, is populated with the raw conn from the
// transport's DialContext so the caller can read its RemoteAddr.
func httpClientNbDialer(serverName string, underlyingOut *net.Conn) *http.Client {
customDialer := nbnet.NewDialer()
certPool, err := x509.SystemCertPool()
@@ -75,10 +79,15 @@ func httpClientNbDialer() *http.Client {
customTransport := &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return customDialer.DialContext(ctx, network, addr)
c, err := customDialer.DialContext(ctx, network, addr)
if err == nil && underlyingOut != nil {
*underlyingOut = c
}
return c, err
},
TLSClientConfig: &tls.Config{
RootCAs: certPool,
RootCAs: certPool,
ServerName: serverName,
},
}

View File

@@ -5,6 +5,7 @@ import (
"context"
"fmt"
"net"
"net/netip"
"reflect"
"sync"
"time"
@@ -117,7 +118,10 @@ func (m *Manager) Serve() error {
// OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be
// established via the relay server. If the peer is on a different relay server, the manager will establish a new
// connection to the relay server. It returns back with a net.Conn what represent the remote peer connection.
func (m *Manager) OpenConn(ctx context.Context, serverAddress, peerKey string) (net.Conn, error) {
//
// fallbackIP, when valid and serverAddress is foreign, is used as a dial-time fallback if the FQDN-based
// dial fails. Ignored for the local home-server path. TLS verification still uses the FQDN via SNI.
func (m *Manager) OpenConn(ctx context.Context, serverAddress, peerKey string, fallbackIP netip.Addr) (net.Conn, error) {
m.relayClientMu.RLock()
defer m.relayClientMu.RUnlock()
@@ -138,7 +142,7 @@ func (m *Manager) OpenConn(ctx context.Context, serverAddress, peerKey string) (
netConn, err = m.relayClient.OpenConn(ctx, peerKey)
} else {
log.Debugf("open peer connection via foreign server: %s", serverAddress)
netConn, err = m.openConnVia(ctx, serverAddress, peerKey)
netConn, err = m.openConnVia(ctx, serverAddress, peerKey, fallbackIP)
}
if err != nil {
return nil, err
@@ -202,6 +206,19 @@ func (m *Manager) RelayInstanceAddress() (string, error) {
return m.relayClient.ServerInstanceURL()
}
// RelayInstanceIP returns the IP address of the live home relay connection.
// Zero value if not connected. Sent alongside RelayInstanceAddress so remote
// peers can dial directly without their own DNS lookup.
func (m *Manager) RelayInstanceIP() netip.Addr {
m.relayClientMu.RLock()
defer m.relayClientMu.RUnlock()
if m.relayClient == nil {
return netip.Addr{}
}
return m.relayClient.ConnectedIP()
}
// ServerURLs returns the addresses of the relay servers.
func (m *Manager) ServerURLs() []string {
return m.serverPicker.ServerURLs.Load().([]string)
@@ -223,7 +240,7 @@ func (m *Manager) UpdateToken(token *relayAuth.Token) error {
return m.tokenStore.UpdateToken(token)
}
func (m *Manager) openConnVia(ctx context.Context, serverAddress, peerKey string) (net.Conn, error) {
func (m *Manager) openConnVia(ctx context.Context, serverAddress, peerKey string, fallbackIP netip.Addr) (net.Conn, error) {
// check if already has a connection to the desired relay server
m.relayClientsMutex.RLock()
rt, ok := m.relayClients[serverAddress]
@@ -258,7 +275,7 @@ func (m *Manager) openConnVia(ctx context.Context, serverAddress, peerKey string
m.relayClients[serverAddress] = rt
m.relayClientsMutex.Unlock()
relayClient := NewClient(serverAddress, m.tokenStore, m.peerID, m.mtu)
relayClient := NewClient(serverAddress, fallbackIP, m.tokenStore, m.peerID, m.mtu)
err := relayClient.Connect(m.ctx)
if err != nil {
rt.err = err

View File

@@ -0,0 +1,146 @@
package client
import (
"context"
"net/netip"
"testing"
"time"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/relay/server"
)
// TestManager_ForeignRelayFallbackIP exercises the foreign-relay path
// end-to-end through Manager.OpenConn. Alice and Bob register on different
// relay servers; Alice dials Bob's foreign relay using an unresolvable
// FQDN. Without a fallback IP the dial fails; with Bob's advertised IP it
// recovers and a payload round-trips between the peers.
func TestManager_ForeignRelayFallbackIP(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
// Alice's home relay
homeCfg := server.ListenerConfig{Address: "127.0.0.1:52401"}
homeSrv, err := server.NewServer(newManagerTestServerConfig(homeCfg.Address))
if err != nil {
t.Fatalf("create home server: %s", err)
}
homeErr := make(chan error, 1)
go func() {
if err := homeSrv.Listen(homeCfg); err != nil {
homeErr <- err
}
}()
t.Cleanup(func() { _ = homeSrv.Shutdown(context.Background()) })
if err := waitForServerToStart(homeErr); err != nil {
t.Fatalf("home server: %s", err)
}
// Bob's foreign relay
foreignCfg := server.ListenerConfig{Address: "127.0.0.1:52402"}
foreignSrv, err := server.NewServer(newManagerTestServerConfig(foreignCfg.Address))
if err != nil {
t.Fatalf("create foreign server: %s", err)
}
foreignErr := make(chan error, 1)
go func() {
if err := foreignSrv.Listen(foreignCfg); err != nil {
foreignErr <- err
}
}()
t.Cleanup(func() { _ = foreignSrv.Shutdown(context.Background()) })
if err := waitForServerToStart(foreignErr); err != nil {
t.Fatalf("foreign server: %s", err)
}
mCtx, mCancel := context.WithCancel(ctx)
t.Cleanup(mCancel)
mgrAlice := NewManager(mCtx, toURL(homeCfg), "alice", iface.DefaultMTU)
if err := mgrAlice.Serve(); err != nil {
t.Fatalf("alice manager serve: %s", err)
}
mgrBob := NewManager(mCtx, toURL(foreignCfg), "bob", iface.DefaultMTU)
if err := mgrBob.Serve(); err != nil {
t.Fatalf("bob manager serve: %s", err)
}
// Bob's real relay URL (what mgrBob.RelayInstanceAddress returns).
bobRealAddr, err := mgrBob.RelayInstanceAddress()
if err != nil {
t.Fatalf("bob relay address: %s", err)
}
// What Bob's RelayInstanceIP() reports — this is the field that
// would ride along in signal as relayServerIP.
bobAdvertisedIP := mgrBob.RelayInstanceIP()
if !bobAdvertisedIP.IsValid() {
t.Fatalf("expected valid RelayInstanceIP for bob, got zero")
}
// .invalid is reserved (RFC 2606), so DNS resolution always fails.
const brokenFQDN = "rel://relay-bob-instance.invalid:52402"
if brokenFQDN == bobRealAddr {
t.Fatalf("broken FQDN must differ from bob's real address (%s)", bobRealAddr)
}
t.Run("no fallback IP, dial fails", func(t *testing.T) {
dialCtx, dialCancel := context.WithTimeout(ctx, 5*time.Second)
defer dialCancel()
_, err := mgrAlice.OpenConn(dialCtx, brokenFQDN, "bob", netip.Addr{})
if err == nil {
t.Fatalf("expected OpenConn to fail without fallback IP, got success")
}
})
t.Run("fallback IP recovers", func(t *testing.T) {
// Bob waits for Alice's incoming peer connection on his side.
bobSideCh := make(chan error, 1)
go func() {
conn, err := mgrBob.OpenConn(ctx, bobRealAddr, "alice", netip.Addr{})
if err != nil {
bobSideCh <- err
return
}
buf := make([]byte, 1024)
n, err := conn.Read(buf)
if err != nil {
bobSideCh <- err
return
}
if _, err := conn.Write(buf[:n]); err != nil {
bobSideCh <- err
return
}
bobSideCh <- nil
}()
aliceConn, err := mgrAlice.OpenConn(ctx, brokenFQDN, "bob", bobAdvertisedIP)
if err != nil {
t.Fatalf("alice OpenConn with fallback IP: %s", err)
}
t.Cleanup(func() { _ = aliceConn.Close() })
payload := []byte("alice-to-bob")
if _, err := aliceConn.Write(payload); err != nil {
t.Fatalf("alice write: %s", err)
}
buf := make([]byte, len(payload))
if _, err := aliceConn.Read(buf); err != nil {
t.Fatalf("alice read echo: %s", err)
}
if string(buf) != string(payload) {
t.Fatalf("echo mismatch: got %q want %q", buf, payload)
}
select {
case err := <-bobSideCh:
if err != nil {
t.Fatalf("bob side: %s", err)
}
case <-time.After(5 * time.Second):
t.Fatalf("timed out waiting for bob side")
}
})
}

View File

@@ -2,6 +2,7 @@ package client
import (
"context"
"net/netip"
"testing"
"time"
@@ -104,11 +105,11 @@ func TestForeignConn(t *testing.T) {
if err != nil {
t.Fatalf("failed to get relay address: %s", err)
}
connAliceToBob, err := clientAlice.OpenConn(ctx, bobsSrvAddr, "bob")
connAliceToBob, err := clientAlice.OpenConn(ctx, bobsSrvAddr, "bob", netip.Addr{})
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
connBobToAlice, err := clientBob.OpenConn(ctx, bobsSrvAddr, "alice")
connBobToAlice, err := clientBob.OpenConn(ctx, bobsSrvAddr, "alice", netip.Addr{})
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
@@ -208,7 +209,7 @@ func TestForeginConnClose(t *testing.T) {
if err != nil {
t.Fatalf("failed to serve manager: %s", err)
}
conn, err := mgr.OpenConn(ctx, toURL(srvCfg2)[0], "bob")
conn, err := mgr.OpenConn(ctx, toURL(srvCfg2)[0], "bob", netip.Addr{})
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
@@ -300,7 +301,7 @@ func TestForeignAutoClose(t *testing.T) {
}
t.Log("open connection to another peer")
if _, err = mgr.OpenConn(ctx, foreignServerURL, "anotherpeer"); err == nil {
if _, err = mgr.OpenConn(ctx, foreignServerURL, "anotherpeer", netip.Addr{}); err == nil {
t.Fatalf("should have failed to open connection to another peer")
}
@@ -369,7 +370,7 @@ func TestAutoReconnect(t *testing.T) {
if err != nil {
t.Errorf("failed to get relay address: %s", err)
}
conn, err := clientAlice.OpenConn(ctx, ra, "bob")
conn, err := clientAlice.OpenConn(ctx, ra, "bob", netip.Addr{})
if err != nil {
t.Errorf("failed to bind channel: %s", err)
}
@@ -387,7 +388,7 @@ func TestAutoReconnect(t *testing.T) {
time.Sleep(reconnectingTimeout + 1*time.Second)
log.Infof("reopent the connection")
_, err = clientAlice.OpenConn(ctx, ra, "bob")
_, err = clientAlice.OpenConn(ctx, ra, "bob", netip.Addr{})
if err != nil {
t.Errorf("failed to open channel: %s", err)
}
@@ -434,7 +435,7 @@ func TestNotifierDoubleAdd(t *testing.T) {
t.Fatalf("failed to serve manager: %s", err)
}
conn1, err := clientAlice.OpenConn(ctx, clientAlice.ServerURLs()[0], "bob")
conn1, err := clientAlice.OpenConn(ctx, clientAlice.ServerURLs()[0], "bob", netip.Addr{})
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"net/netip"
"sync/atomic"
"time"
@@ -69,7 +70,7 @@ func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) {
func (sp *ServerPicker) startConnection(ctx context.Context, resultChan chan connResult, url string) {
log.Infof("try to connecting to relay server: %s", url)
relayClient := NewClient(url, sp.TokenStore, sp.PeerID, sp.MTU)
relayClient := NewClient(url, netip.Addr{}, sp.TokenStore, sp.PeerID, sp.MTU)
err := relayClient.Connect(ctx)
resultChan <- connResult{
RelayClient: relayClient,