mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-30 14:16:38 +00:00
Add relayServerIP field to signal for foreign-relay fallback dial
This commit is contained in:
@@ -2,8 +2,12 @@ package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -146,6 +150,7 @@ func (cc *connContainer) close() {
|
||||
type Client struct {
|
||||
log *log.Entry
|
||||
connectionURL string
|
||||
fallbackIP netip.Addr
|
||||
authTokenStore *auth.TokenStore
|
||||
hashedID messages.PeerID
|
||||
|
||||
@@ -170,13 +175,16 @@ type Client struct {
|
||||
}
|
||||
|
||||
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
|
||||
func NewClient(serverURL string, authTokenStore *auth.TokenStore, peerID string, mtu uint16) *Client {
|
||||
// is called. fallbackIP, when valid, is used as a dial-time fallback if the FQDN-based dial fails. TLS
|
||||
// verification still uses the FQDN from serverURL via SNI.
|
||||
func NewClient(serverURL string, fallbackIP netip.Addr, authTokenStore *auth.TokenStore, peerID string, mtu uint16) *Client {
|
||||
hashedID := messages.HashID(peerID)
|
||||
relayLog := log.WithFields(log.Fields{"relay": serverURL})
|
||||
|
||||
c := &Client{
|
||||
log: relayLog,
|
||||
connectionURL: serverURL,
|
||||
fallbackIP: fallbackIP,
|
||||
authTokenStore: authTokenStore,
|
||||
hashedID: hashedID,
|
||||
mtu: mtu,
|
||||
@@ -304,6 +312,41 @@ func (c *Client) ServerInstanceURL() (string, error) {
|
||||
return c.instanceURL.String(), nil
|
||||
}
|
||||
|
||||
// ConnectedIP returns the IP address of the live relay-server connection,
|
||||
// extracted from the underlying socket's RemoteAddr. Zero value if not
|
||||
// connected or if the address is not an IP literal.
|
||||
func (c *Client) ConnectedIP() netip.Addr {
|
||||
c.mu.Lock()
|
||||
conn := c.relayConn
|
||||
c.mu.Unlock()
|
||||
if conn == nil {
|
||||
return netip.Addr{}
|
||||
}
|
||||
addr := conn.RemoteAddr()
|
||||
if addr == nil {
|
||||
return netip.Addr{}
|
||||
}
|
||||
return extractIPLiteral(addr.String())
|
||||
}
|
||||
|
||||
// extractIPLiteral returns the IP from address forms produced by the relay
|
||||
// dialers (URL or host:port). Zero value if the host is not an IP.
|
||||
func extractIPLiteral(s string) netip.Addr {
|
||||
if u, err := url.Parse(s); err == nil && u.Host != "" {
|
||||
s = u.Host
|
||||
}
|
||||
host, _, err := net.SplitHostPort(s)
|
||||
if err != nil {
|
||||
host = s
|
||||
}
|
||||
host = strings.Trim(host, "[]")
|
||||
ip, err := netip.ParseAddr(host)
|
||||
if err != nil {
|
||||
return netip.Addr{}
|
||||
}
|
||||
return ip.Unmap()
|
||||
}
|
||||
|
||||
// SetOnDisconnectListener sets a function that will be called when the connection to the relay server is closed.
|
||||
func (c *Client) SetOnDisconnectListener(fn func(string)) {
|
||||
c.listenerMutex.Lock()
|
||||
@@ -335,7 +378,11 @@ func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
|
||||
rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, dialers...)
|
||||
conn, err := rd.Dial(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
fallbackConn, fbErr := c.dialFallback(ctx, dialers)
|
||||
if fbErr != nil {
|
||||
return nil, fmt.Errorf("primary dial: %w; fallback dial: %w", err, fbErr)
|
||||
}
|
||||
conn = fallbackConn
|
||||
}
|
||||
c.relayConn = conn
|
||||
|
||||
@@ -351,6 +398,58 @@ func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
|
||||
return instanceURL, nil
|
||||
}
|
||||
|
||||
// dialFallback retries the dial against c.fallbackIP, preserving the
|
||||
// original FQDN as the TLS ServerName for SNI. Returns an error if no
|
||||
// fallback IP is configured or if the substituted URL is malformed.
|
||||
func (c *Client) dialFallback(ctx context.Context, dialers []dialer.DialeFn) (net.Conn, error) {
|
||||
if !c.fallbackIP.IsValid() || c.fallbackIP.IsUnspecified() {
|
||||
return nil, errors.New("no usable fallback IP configured")
|
||||
}
|
||||
|
||||
fallbackURL, serverName, err := substituteHost(c.connectionURL, c.fallbackIP)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("substitute host: %w", err)
|
||||
}
|
||||
|
||||
c.log.Infof("primary dial failed, retrying via fallback IP %s (SNI=%s)", c.fallbackIP, serverName)
|
||||
|
||||
rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, fallbackURL, dialers...).
|
||||
WithServerName(serverName)
|
||||
return rd.Dial(ctx)
|
||||
}
|
||||
|
||||
// substituteHost replaces the host portion of a rel/rels URL with ip,
|
||||
// preserving the scheme and port. Returns the rewritten URL and the
|
||||
// original host to use as the TLS ServerName, or empty if the original
|
||||
// host is itself an IP literal (SNI requires a DNS name).
|
||||
func substituteHost(serverURL string, ip netip.Addr) (string, string, error) {
|
||||
u, err := url.Parse(serverURL)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("parse %q: %w", serverURL, err)
|
||||
}
|
||||
if u.Scheme == "" || u.Host == "" {
|
||||
return "", "", fmt.Errorf("invalid relay URL %q", serverURL)
|
||||
}
|
||||
if !ip.IsValid() {
|
||||
return "", "", errors.New("invalid fallback IP")
|
||||
}
|
||||
origHost := u.Hostname()
|
||||
if _, err := netip.ParseAddr(origHost); err == nil {
|
||||
origHost = ""
|
||||
}
|
||||
ip = ip.Unmap()
|
||||
newHost := ip.String()
|
||||
if ip.Is6() {
|
||||
newHost = "[" + newHost + "]"
|
||||
}
|
||||
if port := u.Port(); port != "" {
|
||||
u.Host = newHost + ":" + port
|
||||
} else {
|
||||
u.Host = newHost
|
||||
}
|
||||
return u.String(), origHost, nil
|
||||
}
|
||||
|
||||
func (c *Client) handShake(ctx context.Context) (*RelayAddr, error) {
|
||||
msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary())
|
||||
if err != nil {
|
||||
|
||||
280
shared/relay/client/client_fallback_test.go
Normal file
280
shared/relay/client/client_fallback_test.go
Normal file
@@ -0,0 +1,280 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.opentelemetry.io/otel"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/relay/server"
|
||||
"github.com/netbirdio/netbird/shared/relay/auth/allow"
|
||||
)
|
||||
|
||||
// TestClient_FallbackIPRecoversFromUnresolvableFQDN verifies that when the
|
||||
// primary FQDN-based dial fails (unresolvable .invalid host), Connect
|
||||
// recovers via the fallback IP and SNI still uses the FQDN.
|
||||
func TestClient_FallbackIPRecoversFromUnresolvableFQDN(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
listenAddr, port := freeAddr(t)
|
||||
srvCfg := server.Config{
|
||||
Meter: otel.Meter(""),
|
||||
ExposedAddress: fmt.Sprintf("rel://test-unresolvable-host.invalid:%d", port),
|
||||
TLSSupport: false,
|
||||
AuthValidator: &allow.Auth{},
|
||||
}
|
||||
srv, err := server.NewServer(srvCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("create server: %s", err)
|
||||
}
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
if err := srv.Listen(server.ListenerConfig{Address: listenAddr}); err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
t.Cleanup(func() {
|
||||
if err := srv.Shutdown(context.Background()); err != nil {
|
||||
t.Errorf("shutdown server: %s", err)
|
||||
}
|
||||
})
|
||||
if err := waitForServerToStart(errChan); err != nil {
|
||||
t.Fatalf("server failed to start: %s", err)
|
||||
}
|
||||
|
||||
t.Run("no fallback IP, primary fails", func(t *testing.T) {
|
||||
c := NewClient(srvCfg.ExposedAddress, netip.Addr{}, hmacTokenStore, "alice-nofallback", iface.DefaultMTU)
|
||||
err := c.Connect(ctx)
|
||||
if err == nil {
|
||||
_ = c.Close()
|
||||
t.Fatalf("expected connect to fail without fallback IP, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("fallback IP recovers", func(t *testing.T) {
|
||||
c := NewClient(srvCfg.ExposedAddress, netip.MustParseAddr("127.0.0.1"), hmacTokenStore, "alice-fallback", iface.DefaultMTU)
|
||||
if err := c.Connect(ctx); err != nil {
|
||||
t.Fatalf("connect with fallback IP: %s", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = c.Close() })
|
||||
|
||||
if !c.Ready() {
|
||||
t.Fatalf("client not ready after connect")
|
||||
}
|
||||
if got := c.ConnectedIP(); got.String() != "127.0.0.1" {
|
||||
t.Fatalf("ConnectedIP = %q, want 127.0.0.1", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestClient_ConnectedIPAfterFQDNDial verifies ConnectedIP returns the
|
||||
// resolved IP after a successful FQDN-based dial. The underlying socket's
|
||||
// RemoteAddr must be exposed through the dialer wrappers; if it returns
|
||||
// the dial-time URL instead, ConnectedIP returns empty and the fallback
|
||||
// IP we advertise to peers is empty too.
|
||||
func TestClient_ConnectedIPAfterFQDNDial(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
listenAddr, port := freeAddr(t)
|
||||
srvCfg := server.Config{
|
||||
Meter: otel.Meter(""),
|
||||
ExposedAddress: fmt.Sprintf("rel://localhost:%d", port),
|
||||
TLSSupport: false,
|
||||
AuthValidator: &allow.Auth{},
|
||||
}
|
||||
srv, err := server.NewServer(srvCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("create server: %s", err)
|
||||
}
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
if err := srv.Listen(server.ListenerConfig{Address: listenAddr}); err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
t.Cleanup(func() { _ = srv.Shutdown(context.Background()) })
|
||||
if err := waitForServerToStart(errChan); err != nil {
|
||||
t.Fatalf("server failed to start: %s", err)
|
||||
}
|
||||
|
||||
c := NewClient(srvCfg.ExposedAddress, netip.Addr{}, hmacTokenStore, "alice-fqdn", iface.DefaultMTU)
|
||||
if err := c.Connect(ctx); err != nil {
|
||||
t.Fatalf("connect: %s", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = c.Close() })
|
||||
|
||||
got := c.ConnectedIP().String()
|
||||
if got != "127.0.0.1" && got != "::1" {
|
||||
t.Fatalf("ConnectedIP after FQDN dial = %q, want 127.0.0.1 or ::1", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubstituteHost(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
serverURL string
|
||||
ip string
|
||||
wantURL string
|
||||
wantServerName string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "rels with port",
|
||||
serverURL: "rels://relay.netbird.io:443",
|
||||
ip: "10.0.0.5",
|
||||
wantURL: "rels://10.0.0.5:443",
|
||||
wantServerName: "relay.netbird.io",
|
||||
},
|
||||
{
|
||||
name: "rel with port",
|
||||
serverURL: "rel://relay.example.com:80",
|
||||
ip: "192.0.2.1",
|
||||
wantURL: "rel://192.0.2.1:80",
|
||||
wantServerName: "relay.example.com",
|
||||
},
|
||||
{
|
||||
name: "ipv6 fallback bracketed",
|
||||
serverURL: "rels://relay.example.com:443",
|
||||
ip: "2001:db8::1",
|
||||
wantURL: "rels://[2001:db8::1]:443",
|
||||
wantServerName: "relay.example.com",
|
||||
},
|
||||
{
|
||||
name: "no port",
|
||||
serverURL: "rels://relay.example.com",
|
||||
ip: "10.0.0.5",
|
||||
wantURL: "rels://10.0.0.5",
|
||||
wantServerName: "relay.example.com",
|
||||
},
|
||||
{
|
||||
name: "ipv6 server with port returns empty SNI",
|
||||
serverURL: "rels://[2001:db8::5]:443",
|
||||
ip: "10.0.0.5",
|
||||
wantURL: "rels://10.0.0.5:443",
|
||||
wantServerName: "",
|
||||
},
|
||||
{
|
||||
name: "ipv4 server with port returns empty SNI",
|
||||
serverURL: "rels://10.0.0.5:443",
|
||||
ip: "10.0.0.6",
|
||||
wantURL: "rels://10.0.0.6:443",
|
||||
wantServerName: "",
|
||||
},
|
||||
{
|
||||
name: "ipv6 fallback no port",
|
||||
serverURL: "rels://relay.example.com",
|
||||
ip: "2001:db8::1",
|
||||
wantURL: "rels://[2001:db8::1]",
|
||||
wantServerName: "relay.example.com",
|
||||
},
|
||||
{
|
||||
name: "missing scheme",
|
||||
serverURL: "relay.example.com:443",
|
||||
ip: "10.0.0.5",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty",
|
||||
serverURL: "",
|
||||
ip: "10.0.0.5",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var ip netip.Addr
|
||||
if tt.ip != "" {
|
||||
ip = netip.MustParseAddr(tt.ip)
|
||||
}
|
||||
gotURL, gotName, err := substituteHost(tt.serverURL, ip)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Fatalf("expected error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
if gotURL != tt.wantURL {
|
||||
t.Errorf("URL = %q, want %q", gotURL, tt.wantURL)
|
||||
}
|
||||
if gotName != tt.wantServerName {
|
||||
t.Errorf("ServerName = %q, want %q", gotName, tt.wantServerName)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_ConnectedIPEmptyWhenNotConnected(t *testing.T) {
|
||||
c := NewClient("rel://example.invalid:80", netip.Addr{}, hmacTokenStore, "x", iface.DefaultMTU)
|
||||
if got := c.ConnectedIP(); got.IsValid() {
|
||||
t.Fatalf("ConnectedIP on disconnected client = %q, want zero", got)
|
||||
}
|
||||
}
|
||||
|
||||
// staticAddr is a net.Addr that returns a fixed string. Used to verify
|
||||
// ConnectedIP parses RemoteAddr correctly.
|
||||
type staticAddr struct{ s string }
|
||||
|
||||
func (a staticAddr) Network() string { return "tcp" }
|
||||
func (a staticAddr) String() string { return a.s }
|
||||
|
||||
type stubConn struct {
|
||||
net.Conn
|
||||
remote net.Addr
|
||||
}
|
||||
|
||||
func (s stubConn) RemoteAddr() net.Addr { return s.remote }
|
||||
|
||||
func TestClient_ConnectedIPParsesRemoteAddr(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
s string
|
||||
want string
|
||||
}{
|
||||
{"hostport ipv4", "127.0.0.1:50301", "127.0.0.1"},
|
||||
{"hostport ipv6 bracketed", "[::1]:50301", "::1"},
|
||||
{"url with ipv4", "rel://127.0.0.1:50301", "127.0.0.1"},
|
||||
{"url with ipv6", "rels://[2001:db8::1]:443", "2001:db8::1"},
|
||||
{"fqdn url returns empty", "rel://relay.example.com:50301", ""},
|
||||
{"fqdn hostport returns empty", "relay.example.com:50301", ""},
|
||||
{"plain ipv4 no port", "10.0.0.1", "10.0.0.1"},
|
||||
{"empty", "", ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Client{relayConn: stubConn{remote: staticAddr{s: tt.s}}}
|
||||
got := c.ConnectedIP()
|
||||
var gotStr string
|
||||
if got.IsValid() {
|
||||
gotStr = got.String()
|
||||
}
|
||||
if gotStr != tt.want {
|
||||
t.Errorf("ConnectedIP(%q) = %q, want %q", tt.s, gotStr, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// freeAddr returns a 127.0.0.1 address with an OS-assigned port. The
|
||||
// listener is closed before returning, so the port is briefly free for
|
||||
// the caller to bind. Avoids hardcoded ports that can collide.
|
||||
func freeAddr(t *testing.T) (string, int) {
|
||||
t.Helper()
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("get free port: %s", err)
|
||||
}
|
||||
addr := l.Addr().(*net.TCPAddr)
|
||||
_ = l.Close()
|
||||
return addr.String(), addr.Port
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package client
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -68,7 +69,7 @@ func TestClient(t *testing.T) {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
t.Log("alice connecting to server")
|
||||
clientAlice := NewClient(serverCfg.ExposedAddress, hmacTokenStore, "alice", iface.DefaultMTU)
|
||||
clientAlice := NewClient(serverCfg.ExposedAddress, netip.Addr{}, hmacTokenStore, "alice", iface.DefaultMTU)
|
||||
err = clientAlice.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
@@ -76,7 +77,7 @@ func TestClient(t *testing.T) {
|
||||
defer clientAlice.Close()
|
||||
|
||||
t.Log("placeholder connecting to server")
|
||||
clientPlaceHolder := NewClient(serverCfg.ExposedAddress, hmacTokenStore, "clientPlaceHolder", iface.DefaultMTU)
|
||||
clientPlaceHolder := NewClient(serverCfg.ExposedAddress, netip.Addr{}, hmacTokenStore, "clientPlaceHolder", iface.DefaultMTU)
|
||||
err = clientPlaceHolder.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
@@ -84,7 +85,7 @@ func TestClient(t *testing.T) {
|
||||
defer clientPlaceHolder.Close()
|
||||
|
||||
t.Log("Bob connecting to server")
|
||||
clientBob := NewClient(serverCfg.ExposedAddress, hmacTokenStore, "bob", iface.DefaultMTU)
|
||||
clientBob := NewClient(serverCfg.ExposedAddress, netip.Addr{}, hmacTokenStore, "bob", iface.DefaultMTU)
|
||||
err = clientBob.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
@@ -144,7 +145,7 @@ func TestRegistration(t *testing.T) {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
clientAlice := NewClient(serverCfg.ExposedAddress, hmacTokenStore, "alice", iface.DefaultMTU)
|
||||
clientAlice := NewClient(serverCfg.ExposedAddress, netip.Addr{}, hmacTokenStore, "alice", iface.DefaultMTU)
|
||||
err = clientAlice.Connect(ctx)
|
||||
if err != nil {
|
||||
_ = srv.Shutdown(ctx)
|
||||
@@ -184,7 +185,7 @@ func TestRegistrationTimeout(t *testing.T) {
|
||||
_ = fakeTCPListener.Close()
|
||||
}(fakeTCPListener)
|
||||
|
||||
clientAlice := NewClient("127.0.0.1:50201", hmacTokenStore, "alice", iface.DefaultMTU)
|
||||
clientAlice := NewClient("127.0.0.1:50201", netip.Addr{}, hmacTokenStore, "alice", iface.DefaultMTU)
|
||||
err = clientAlice.Connect(ctx)
|
||||
if err == nil {
|
||||
t.Errorf("failed to connect to server: %s", err)
|
||||
@@ -227,7 +228,7 @@ func TestEcho(t *testing.T) {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
clientAlice := NewClient(serverCfg.ExposedAddress, hmacTokenStore, idAlice, iface.DefaultMTU)
|
||||
clientAlice := NewClient(serverCfg.ExposedAddress, netip.Addr{}, hmacTokenStore, idAlice, iface.DefaultMTU)
|
||||
err = clientAlice.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
@@ -239,7 +240,7 @@ func TestEcho(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
clientBob := NewClient(serverCfg.ExposedAddress, hmacTokenStore, idBob, iface.DefaultMTU)
|
||||
clientBob := NewClient(serverCfg.ExposedAddress, netip.Addr{}, hmacTokenStore, idBob, iface.DefaultMTU)
|
||||
err = clientBob.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
@@ -319,7 +320,7 @@ func TestBindToUnavailabePeer(t *testing.T) {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
clientAlice := NewClient(serverCfg.ExposedAddress, hmacTokenStore, "alice", iface.DefaultMTU)
|
||||
clientAlice := NewClient(serverCfg.ExposedAddress, netip.Addr{}, hmacTokenStore, "alice", iface.DefaultMTU)
|
||||
err = clientAlice.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to connect to server: %s", err)
|
||||
@@ -367,13 +368,13 @@ func TestBindReconnect(t *testing.T) {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
clientAlice := NewClient(serverCfg.ExposedAddress, hmacTokenStore, "alice", iface.DefaultMTU)
|
||||
clientAlice := NewClient(serverCfg.ExposedAddress, netip.Addr{}, hmacTokenStore, "alice", iface.DefaultMTU)
|
||||
err = clientAlice.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
|
||||
clientBob := NewClient(serverCfg.ExposedAddress, hmacTokenStore, "bob", iface.DefaultMTU)
|
||||
clientBob := NewClient(serverCfg.ExposedAddress, netip.Addr{}, hmacTokenStore, "bob", iface.DefaultMTU)
|
||||
err = clientBob.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to connect to server: %s", err)
|
||||
@@ -395,7 +396,7 @@ func TestBindReconnect(t *testing.T) {
|
||||
t.Errorf("failed to close client: %s", err)
|
||||
}
|
||||
|
||||
clientAlice = NewClient(serverCfg.ExposedAddress, hmacTokenStore, "alice", iface.DefaultMTU)
|
||||
clientAlice = NewClient(serverCfg.ExposedAddress, netip.Addr{}, hmacTokenStore, "alice", iface.DefaultMTU)
|
||||
err = clientAlice.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to connect to server: %s", err)
|
||||
@@ -470,13 +471,13 @@ func TestCloseConn(t *testing.T) {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
bob := NewClient(serverCfg.ExposedAddress, hmacTokenStore, "bob", iface.DefaultMTU)
|
||||
bob := NewClient(serverCfg.ExposedAddress, netip.Addr{}, hmacTokenStore, "bob", iface.DefaultMTU)
|
||||
err = bob.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to connect to server: %s", err)
|
||||
}
|
||||
|
||||
clientAlice := NewClient(serverCfg.ExposedAddress, hmacTokenStore, "alice", iface.DefaultMTU)
|
||||
clientAlice := NewClient(serverCfg.ExposedAddress, netip.Addr{}, hmacTokenStore, "alice", iface.DefaultMTU)
|
||||
err = clientAlice.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to connect to server: %s", err)
|
||||
@@ -534,13 +535,13 @@ func TestCloseRelayConn(t *testing.T) {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
bob := NewClient(serverCfg.ExposedAddress, hmacTokenStore, "bob", iface.DefaultMTU)
|
||||
bob := NewClient(serverCfg.ExposedAddress, netip.Addr{}, hmacTokenStore, "bob", iface.DefaultMTU)
|
||||
err = bob.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
|
||||
clientAlice := NewClient(serverCfg.ExposedAddress, hmacTokenStore, "alice", iface.DefaultMTU)
|
||||
clientAlice := NewClient(serverCfg.ExposedAddress, netip.Addr{}, hmacTokenStore, "alice", iface.DefaultMTU)
|
||||
err = clientAlice.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
@@ -590,7 +591,7 @@ func TestCloseByServer(t *testing.T) {
|
||||
|
||||
idAlice := "alice"
|
||||
log.Debugf("connect by alice")
|
||||
relayClient := NewClient(serverCfg.ExposedAddress, hmacTokenStore, idAlice, iface.DefaultMTU)
|
||||
relayClient := NewClient(serverCfg.ExposedAddress, netip.Addr{}, hmacTokenStore, idAlice, iface.DefaultMTU)
|
||||
if err = relayClient.Connect(ctx); err != nil {
|
||||
log.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
@@ -648,7 +649,7 @@ func TestCloseByClient(t *testing.T) {
|
||||
|
||||
idAlice := "alice"
|
||||
log.Debugf("connect by alice")
|
||||
relayClient := NewClient(serverCfg.ExposedAddress, hmacTokenStore, idAlice, iface.DefaultMTU)
|
||||
relayClient := NewClient(serverCfg.ExposedAddress, netip.Addr{}, hmacTokenStore, idAlice, iface.DefaultMTU)
|
||||
err = relayClient.Connect(ctx)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to connect to server: %s", err)
|
||||
@@ -701,7 +702,7 @@ func TestCloseNotDrainedChannel(t *testing.T) {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
clientAlice := NewClient(serverCfg.ExposedAddress, hmacTokenStore, idAlice, iface.DefaultMTU)
|
||||
clientAlice := NewClient(serverCfg.ExposedAddress, netip.Addr{}, hmacTokenStore, idAlice, iface.DefaultMTU)
|
||||
err = clientAlice.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
@@ -713,7 +714,7 @@ func TestCloseNotDrainedChannel(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
clientBob := NewClient(serverCfg.ExposedAddress, hmacTokenStore, idBob, iface.DefaultMTU)
|
||||
clientBob := NewClient(serverCfg.ExposedAddress, netip.Addr{}, hmacTokenStore, idBob, iface.DefaultMTU)
|
||||
err = clientBob.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
|
||||
@@ -23,7 +23,7 @@ func (d Dialer) Protocol() string {
|
||||
return Network
|
||||
}
|
||||
|
||||
func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) {
|
||||
func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn, error) {
|
||||
quicURL, err := prepareURL(address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -32,11 +32,14 @@ func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) {
|
||||
// Get the base TLS config
|
||||
tlsClientConfig := quictls.ClientQUICTLSConfig()
|
||||
|
||||
// Set ServerName to hostname if not an IP address
|
||||
host, _, splitErr := net.SplitHostPort(quicURL)
|
||||
if splitErr == nil && net.ParseIP(host) == nil {
|
||||
// It's a hostname, not an IP - modify directly
|
||||
tlsClientConfig.ServerName = host
|
||||
switch {
|
||||
case serverName != "" && net.ParseIP(serverName) == nil:
|
||||
tlsClientConfig.ServerName = serverName
|
||||
case serverName == "":
|
||||
host, _, splitErr := net.SplitHostPort(quicURL)
|
||||
if splitErr == nil && net.ParseIP(host) == nil {
|
||||
tlsClientConfig.ServerName = host
|
||||
}
|
||||
}
|
||||
|
||||
quicConfig := &quic.Config{
|
||||
|
||||
@@ -14,7 +14,9 @@ const (
|
||||
)
|
||||
|
||||
type DialeFn interface {
|
||||
Dial(ctx context.Context, address string) (net.Conn, error)
|
||||
// Dial connects to address. serverName, when non-empty, overrides the TLS
|
||||
// ServerName used for SNI/cert validation. Empty means derive from address.
|
||||
Dial(ctx context.Context, address, serverName string) (net.Conn, error)
|
||||
Protocol() string
|
||||
}
|
||||
|
||||
@@ -27,6 +29,7 @@ type dialResult struct {
|
||||
type RaceDial struct {
|
||||
log *log.Entry
|
||||
serverURL string
|
||||
serverName string
|
||||
dialerFns []DialeFn
|
||||
connectionTimeout time.Duration
|
||||
}
|
||||
@@ -40,6 +43,16 @@ func NewRaceDial(log *log.Entry, connectionTimeout time.Duration, serverURL stri
|
||||
}
|
||||
}
|
||||
|
||||
// WithServerName sets a TLS SNI/cert validation override. Used when serverURL
|
||||
// contains an IP literal but the cert is issued for a different hostname.
|
||||
//
|
||||
// Mutates the receiver and is not safe for concurrent reconfiguration; a
|
||||
// RaceDial is intended to be constructed per dial and discarded.
|
||||
func (r *RaceDial) WithServerName(serverName string) *RaceDial {
|
||||
r.serverName = serverName
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *RaceDial) Dial(ctx context.Context) (net.Conn, error) {
|
||||
connChan := make(chan dialResult, len(r.dialerFns))
|
||||
winnerConn := make(chan net.Conn, 1)
|
||||
@@ -64,7 +77,7 @@ func (r *RaceDial) dial(dfn DialeFn, abortCtx context.Context, connChan chan dia
|
||||
defer cancel()
|
||||
|
||||
r.log.Infof("dialing Relay server via %s", dfn.Protocol())
|
||||
conn, err := dfn.Dial(ctx, r.serverURL)
|
||||
conn, err := dfn.Dial(ctx, r.serverURL, r.serverName)
|
||||
connChan <- dialResult{Conn: conn, Protocol: dfn.Protocol(), Err: err}
|
||||
}
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ type MockDialer struct {
|
||||
protocolStr string
|
||||
}
|
||||
|
||||
func (m *MockDialer) Dial(ctx context.Context, address string) (net.Conn, error) {
|
||||
func (m *MockDialer) Dial(ctx context.Context, address, _ string) (net.Conn, error) {
|
||||
return m.dialFunc(ctx, address)
|
||||
}
|
||||
|
||||
|
||||
@@ -12,14 +12,24 @@ import (
|
||||
type Conn struct {
|
||||
ctx context.Context
|
||||
*websocket.Conn
|
||||
remoteAddr WebsocketAddr
|
||||
remoteAddr net.Addr
|
||||
}
|
||||
|
||||
func NewConn(wsConn *websocket.Conn, serverAddress string) net.Conn {
|
||||
// NewConn builds a relay ws.Conn. underlying is the raw TCP/TLS conn captured
|
||||
// from the http transport's DialContext; when set, RemoteAddr returns its
|
||||
// peer address (an IP literal). When nil (e.g. wasm), RemoteAddr falls back
|
||||
// to the dial-time URL.
|
||||
func NewConn(wsConn *websocket.Conn, serverAddress string, underlying net.Conn) net.Conn {
|
||||
var addr net.Addr = WebsocketAddr{serverAddress}
|
||||
if underlying != nil {
|
||||
if ra := underlying.RemoteAddr(); ra != nil {
|
||||
addr = ra
|
||||
}
|
||||
}
|
||||
return &Conn{
|
||||
ctx: context.Background(),
|
||||
Conn: wsConn,
|
||||
remoteAddr: WebsocketAddr{serverAddress},
|
||||
remoteAddr: addr,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,10 +2,14 @@
|
||||
|
||||
package ws
|
||||
|
||||
import "github.com/coder/websocket"
|
||||
import (
|
||||
"net"
|
||||
|
||||
func createDialOptions() *websocket.DialOptions {
|
||||
"github.com/coder/websocket"
|
||||
)
|
||||
|
||||
func createDialOptions(serverName string, underlyingOut *net.Conn) *websocket.DialOptions {
|
||||
return &websocket.DialOptions{
|
||||
HTTPClient: httpClientNbDialer(),
|
||||
HTTPClient: httpClientNbDialer(serverName, underlyingOut),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,9 +2,13 @@
|
||||
|
||||
package ws
|
||||
|
||||
import "github.com/coder/websocket"
|
||||
import (
|
||||
"net"
|
||||
|
||||
func createDialOptions() *websocket.DialOptions {
|
||||
// WASM version doesn't support HTTPClient
|
||||
"github.com/coder/websocket"
|
||||
)
|
||||
|
||||
func createDialOptions(_ string, _ *net.Conn) *websocket.DialOptions {
|
||||
// WASM version doesn't support HTTPClient or custom TLS config.
|
||||
return &websocket.DialOptions{}
|
||||
}
|
||||
|
||||
@@ -26,13 +26,14 @@ func (d Dialer) Protocol() string {
|
||||
return "WS"
|
||||
}
|
||||
|
||||
func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) {
|
||||
func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn, error) {
|
||||
wsURL, err := prepareURL(address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
opts := createDialOptions()
|
||||
var underlying net.Conn
|
||||
opts := createDialOptions(serverName, &underlying)
|
||||
|
||||
parsedURL, err := url.Parse(wsURL)
|
||||
if err != nil {
|
||||
@@ -52,7 +53,7 @@ func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) {
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
|
||||
conn := NewConn(wsConn, address)
|
||||
conn := NewConn(wsConn, address, underlying)
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
@@ -64,7 +65,10 @@ func prepareURL(address string) (string, error) {
|
||||
return strings.Replace(address, "rel", "ws", 1), nil
|
||||
}
|
||||
|
||||
func httpClientNbDialer() *http.Client {
|
||||
// httpClientNbDialer builds the http client used by the websocket library.
|
||||
// underlyingOut, when non-nil, is populated with the raw conn from the
|
||||
// transport's DialContext so the caller can read its RemoteAddr.
|
||||
func httpClientNbDialer(serverName string, underlyingOut *net.Conn) *http.Client {
|
||||
customDialer := nbnet.NewDialer()
|
||||
|
||||
certPool, err := x509.SystemCertPool()
|
||||
@@ -75,10 +79,15 @@ func httpClientNbDialer() *http.Client {
|
||||
|
||||
customTransport := &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return customDialer.DialContext(ctx, network, addr)
|
||||
c, err := customDialer.DialContext(ctx, network, addr)
|
||||
if err == nil && underlyingOut != nil {
|
||||
*underlyingOut = c
|
||||
}
|
||||
return c, err
|
||||
},
|
||||
TLSClientConfig: &tls.Config{
|
||||
RootCAs: certPool,
|
||||
RootCAs: certPool,
|
||||
ServerName: serverName,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -117,7 +118,10 @@ func (m *Manager) Serve() error {
|
||||
// OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be
|
||||
// established via the relay server. If the peer is on a different relay server, the manager will establish a new
|
||||
// connection to the relay server. It returns back with a net.Conn what represent the remote peer connection.
|
||||
func (m *Manager) OpenConn(ctx context.Context, serverAddress, peerKey string) (net.Conn, error) {
|
||||
//
|
||||
// fallbackIP, when valid and serverAddress is foreign, is used as a dial-time fallback if the FQDN-based
|
||||
// dial fails. Ignored for the local home-server path. TLS verification still uses the FQDN via SNI.
|
||||
func (m *Manager) OpenConn(ctx context.Context, serverAddress, peerKey string, fallbackIP netip.Addr) (net.Conn, error) {
|
||||
m.relayClientMu.RLock()
|
||||
defer m.relayClientMu.RUnlock()
|
||||
|
||||
@@ -138,7 +142,7 @@ func (m *Manager) OpenConn(ctx context.Context, serverAddress, peerKey string) (
|
||||
netConn, err = m.relayClient.OpenConn(ctx, peerKey)
|
||||
} else {
|
||||
log.Debugf("open peer connection via foreign server: %s", serverAddress)
|
||||
netConn, err = m.openConnVia(ctx, serverAddress, peerKey)
|
||||
netConn, err = m.openConnVia(ctx, serverAddress, peerKey, fallbackIP)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -202,6 +206,19 @@ func (m *Manager) RelayInstanceAddress() (string, error) {
|
||||
return m.relayClient.ServerInstanceURL()
|
||||
}
|
||||
|
||||
// RelayInstanceIP returns the IP address of the live home relay connection.
|
||||
// Zero value if not connected. Sent alongside RelayInstanceAddress so remote
|
||||
// peers can dial directly without their own DNS lookup.
|
||||
func (m *Manager) RelayInstanceIP() netip.Addr {
|
||||
m.relayClientMu.RLock()
|
||||
defer m.relayClientMu.RUnlock()
|
||||
|
||||
if m.relayClient == nil {
|
||||
return netip.Addr{}
|
||||
}
|
||||
return m.relayClient.ConnectedIP()
|
||||
}
|
||||
|
||||
// ServerURLs returns the addresses of the relay servers.
|
||||
func (m *Manager) ServerURLs() []string {
|
||||
return m.serverPicker.ServerURLs.Load().([]string)
|
||||
@@ -223,7 +240,7 @@ func (m *Manager) UpdateToken(token *relayAuth.Token) error {
|
||||
return m.tokenStore.UpdateToken(token)
|
||||
}
|
||||
|
||||
func (m *Manager) openConnVia(ctx context.Context, serverAddress, peerKey string) (net.Conn, error) {
|
||||
func (m *Manager) openConnVia(ctx context.Context, serverAddress, peerKey string, fallbackIP netip.Addr) (net.Conn, error) {
|
||||
// check if already has a connection to the desired relay server
|
||||
m.relayClientsMutex.RLock()
|
||||
rt, ok := m.relayClients[serverAddress]
|
||||
@@ -258,7 +275,7 @@ func (m *Manager) openConnVia(ctx context.Context, serverAddress, peerKey string
|
||||
m.relayClients[serverAddress] = rt
|
||||
m.relayClientsMutex.Unlock()
|
||||
|
||||
relayClient := NewClient(serverAddress, m.tokenStore, m.peerID, m.mtu)
|
||||
relayClient := NewClient(serverAddress, fallbackIP, m.tokenStore, m.peerID, m.mtu)
|
||||
err := relayClient.Connect(m.ctx)
|
||||
if err != nil {
|
||||
rt.err = err
|
||||
|
||||
146
shared/relay/client/manager_fallback_test.go
Normal file
146
shared/relay/client/manager_fallback_test.go
Normal file
@@ -0,0 +1,146 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/relay/server"
|
||||
)
|
||||
|
||||
// TestManager_ForeignRelayFallbackIP exercises the foreign-relay path
|
||||
// end-to-end through Manager.OpenConn. Alice and Bob register on different
|
||||
// relay servers; Alice dials Bob's foreign relay using an unresolvable
|
||||
// FQDN. Without a fallback IP the dial fails; with Bob's advertised IP it
|
||||
// recovers and a payload round-trips between the peers.
|
||||
func TestManager_ForeignRelayFallbackIP(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Alice's home relay
|
||||
homeCfg := server.ListenerConfig{Address: "127.0.0.1:52401"}
|
||||
homeSrv, err := server.NewServer(newManagerTestServerConfig(homeCfg.Address))
|
||||
if err != nil {
|
||||
t.Fatalf("create home server: %s", err)
|
||||
}
|
||||
homeErr := make(chan error, 1)
|
||||
go func() {
|
||||
if err := homeSrv.Listen(homeCfg); err != nil {
|
||||
homeErr <- err
|
||||
}
|
||||
}()
|
||||
t.Cleanup(func() { _ = homeSrv.Shutdown(context.Background()) })
|
||||
if err := waitForServerToStart(homeErr); err != nil {
|
||||
t.Fatalf("home server: %s", err)
|
||||
}
|
||||
|
||||
// Bob's foreign relay
|
||||
foreignCfg := server.ListenerConfig{Address: "127.0.0.1:52402"}
|
||||
foreignSrv, err := server.NewServer(newManagerTestServerConfig(foreignCfg.Address))
|
||||
if err != nil {
|
||||
t.Fatalf("create foreign server: %s", err)
|
||||
}
|
||||
foreignErr := make(chan error, 1)
|
||||
go func() {
|
||||
if err := foreignSrv.Listen(foreignCfg); err != nil {
|
||||
foreignErr <- err
|
||||
}
|
||||
}()
|
||||
t.Cleanup(func() { _ = foreignSrv.Shutdown(context.Background()) })
|
||||
if err := waitForServerToStart(foreignErr); err != nil {
|
||||
t.Fatalf("foreign server: %s", err)
|
||||
}
|
||||
|
||||
mCtx, mCancel := context.WithCancel(ctx)
|
||||
t.Cleanup(mCancel)
|
||||
|
||||
mgrAlice := NewManager(mCtx, toURL(homeCfg), "alice", iface.DefaultMTU)
|
||||
if err := mgrAlice.Serve(); err != nil {
|
||||
t.Fatalf("alice manager serve: %s", err)
|
||||
}
|
||||
|
||||
mgrBob := NewManager(mCtx, toURL(foreignCfg), "bob", iface.DefaultMTU)
|
||||
if err := mgrBob.Serve(); err != nil {
|
||||
t.Fatalf("bob manager serve: %s", err)
|
||||
}
|
||||
|
||||
// Bob's real relay URL (what mgrBob.RelayInstanceAddress returns).
|
||||
bobRealAddr, err := mgrBob.RelayInstanceAddress()
|
||||
if err != nil {
|
||||
t.Fatalf("bob relay address: %s", err)
|
||||
}
|
||||
// What Bob's RelayInstanceIP() reports — this is the field that
|
||||
// would ride along in signal as relayServerIP.
|
||||
bobAdvertisedIP := mgrBob.RelayInstanceIP()
|
||||
if !bobAdvertisedIP.IsValid() {
|
||||
t.Fatalf("expected valid RelayInstanceIP for bob, got zero")
|
||||
}
|
||||
|
||||
// .invalid is reserved (RFC 2606), so DNS resolution always fails.
|
||||
const brokenFQDN = "rel://relay-bob-instance.invalid:52402"
|
||||
if brokenFQDN == bobRealAddr {
|
||||
t.Fatalf("broken FQDN must differ from bob's real address (%s)", bobRealAddr)
|
||||
}
|
||||
|
||||
t.Run("no fallback IP, dial fails", func(t *testing.T) {
|
||||
dialCtx, dialCancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer dialCancel()
|
||||
_, err := mgrAlice.OpenConn(dialCtx, brokenFQDN, "bob", netip.Addr{})
|
||||
if err == nil {
|
||||
t.Fatalf("expected OpenConn to fail without fallback IP, got success")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("fallback IP recovers", func(t *testing.T) {
|
||||
// Bob waits for Alice's incoming peer connection on his side.
|
||||
bobSideCh := make(chan error, 1)
|
||||
go func() {
|
||||
conn, err := mgrBob.OpenConn(ctx, bobRealAddr, "alice", netip.Addr{})
|
||||
if err != nil {
|
||||
bobSideCh <- err
|
||||
return
|
||||
}
|
||||
buf := make([]byte, 1024)
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
bobSideCh <- err
|
||||
return
|
||||
}
|
||||
if _, err := conn.Write(buf[:n]); err != nil {
|
||||
bobSideCh <- err
|
||||
return
|
||||
}
|
||||
bobSideCh <- nil
|
||||
}()
|
||||
|
||||
aliceConn, err := mgrAlice.OpenConn(ctx, brokenFQDN, "bob", bobAdvertisedIP)
|
||||
if err != nil {
|
||||
t.Fatalf("alice OpenConn with fallback IP: %s", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = aliceConn.Close() })
|
||||
|
||||
payload := []byte("alice-to-bob")
|
||||
if _, err := aliceConn.Write(payload); err != nil {
|
||||
t.Fatalf("alice write: %s", err)
|
||||
}
|
||||
|
||||
buf := make([]byte, len(payload))
|
||||
if _, err := aliceConn.Read(buf); err != nil {
|
||||
t.Fatalf("alice read echo: %s", err)
|
||||
}
|
||||
if string(buf) != string(payload) {
|
||||
t.Fatalf("echo mismatch: got %q want %q", buf, payload)
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-bobSideCh:
|
||||
if err != nil {
|
||||
t.Fatalf("bob side: %s", err)
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatalf("timed out waiting for bob side")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -104,11 +105,11 @@ func TestForeignConn(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get relay address: %s", err)
|
||||
}
|
||||
connAliceToBob, err := clientAlice.OpenConn(ctx, bobsSrvAddr, "bob")
|
||||
connAliceToBob, err := clientAlice.OpenConn(ctx, bobsSrvAddr, "bob", netip.Addr{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
connBobToAlice, err := clientBob.OpenConn(ctx, bobsSrvAddr, "alice")
|
||||
connBobToAlice, err := clientBob.OpenConn(ctx, bobsSrvAddr, "alice", netip.Addr{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
@@ -208,7 +209,7 @@ func TestForeginConnClose(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("failed to serve manager: %s", err)
|
||||
}
|
||||
conn, err := mgr.OpenConn(ctx, toURL(srvCfg2)[0], "bob")
|
||||
conn, err := mgr.OpenConn(ctx, toURL(srvCfg2)[0], "bob", netip.Addr{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
@@ -300,7 +301,7 @@ func TestForeignAutoClose(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Log("open connection to another peer")
|
||||
if _, err = mgr.OpenConn(ctx, foreignServerURL, "anotherpeer"); err == nil {
|
||||
if _, err = mgr.OpenConn(ctx, foreignServerURL, "anotherpeer", netip.Addr{}); err == nil {
|
||||
t.Fatalf("should have failed to open connection to another peer")
|
||||
}
|
||||
|
||||
@@ -369,7 +370,7 @@ func TestAutoReconnect(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Errorf("failed to get relay address: %s", err)
|
||||
}
|
||||
conn, err := clientAlice.OpenConn(ctx, ra, "bob")
|
||||
conn, err := clientAlice.OpenConn(ctx, ra, "bob", netip.Addr{})
|
||||
if err != nil {
|
||||
t.Errorf("failed to bind channel: %s", err)
|
||||
}
|
||||
@@ -387,7 +388,7 @@ func TestAutoReconnect(t *testing.T) {
|
||||
time.Sleep(reconnectingTimeout + 1*time.Second)
|
||||
|
||||
log.Infof("reopent the connection")
|
||||
_, err = clientAlice.OpenConn(ctx, ra, "bob")
|
||||
_, err = clientAlice.OpenConn(ctx, ra, "bob", netip.Addr{})
|
||||
if err != nil {
|
||||
t.Errorf("failed to open channel: %s", err)
|
||||
}
|
||||
@@ -434,7 +435,7 @@ func TestNotifierDoubleAdd(t *testing.T) {
|
||||
t.Fatalf("failed to serve manager: %s", err)
|
||||
}
|
||||
|
||||
conn1, err := clientAlice.OpenConn(ctx, clientAlice.ServerURLs()[0], "bob")
|
||||
conn1, err := clientAlice.OpenConn(ctx, clientAlice.ServerURLs()[0], "bob", netip.Addr{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
@@ -69,7 +70,7 @@ func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) {
|
||||
|
||||
func (sp *ServerPicker) startConnection(ctx context.Context, resultChan chan connResult, url string) {
|
||||
log.Infof("try to connecting to relay server: %s", url)
|
||||
relayClient := NewClient(url, sp.TokenStore, sp.PeerID, sp.MTU)
|
||||
relayClient := NewClient(url, netip.Addr{}, sp.TokenStore, sp.PeerID, sp.MTU)
|
||||
err := relayClient.Connect(ctx)
|
||||
resultChan <- connResult{
|
||||
RelayClient: relayClient,
|
||||
|
||||
Reference in New Issue
Block a user