Files
netbird/shared/relay/client/client_fallback_test.go

281 lines
7.9 KiB
Go

package client
import (
"context"
"fmt"
"net"
"net/netip"
"testing"
"time"
"go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/relay/server"
"github.com/netbirdio/netbird/shared/relay/auth/allow"
)
// TestClient_FallbackIPRecoversFromUnresolvableFQDN verifies that when the
// primary FQDN-based dial fails (unresolvable .invalid host), Connect
// recovers via the fallback IP and SNI still uses the FQDN.
func TestClient_FallbackIPRecoversFromUnresolvableFQDN(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
listenAddr, port := freeAddr(t)
srvCfg := server.Config{
Meter: otel.Meter(""),
ExposedAddress: fmt.Sprintf("rel://test-unresolvable-host.invalid:%d", port),
TLSSupport: false,
AuthValidator: &allow.Auth{},
}
srv, err := server.NewServer(srvCfg)
if err != nil {
t.Fatalf("create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
if err := srv.Listen(server.ListenerConfig{Address: listenAddr}); err != nil {
errChan <- err
}
}()
t.Cleanup(func() {
if err := srv.Shutdown(context.Background()); err != nil {
t.Errorf("shutdown server: %s", err)
}
})
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("server failed to start: %s", err)
}
t.Run("no fallback IP, primary fails", func(t *testing.T) {
c := NewClient(srvCfg.ExposedAddress, netip.Addr{}, hmacTokenStore, "alice-nofallback", iface.DefaultMTU)
err := c.Connect(ctx)
if err == nil {
_ = c.Close()
t.Fatalf("expected connect to fail without fallback IP, got nil")
}
})
t.Run("fallback IP recovers", func(t *testing.T) {
c := NewClient(srvCfg.ExposedAddress, netip.MustParseAddr("127.0.0.1"), hmacTokenStore, "alice-fallback", iface.DefaultMTU)
if err := c.Connect(ctx); err != nil {
t.Fatalf("connect with fallback IP: %s", err)
}
t.Cleanup(func() { _ = c.Close() })
if !c.Ready() {
t.Fatalf("client not ready after connect")
}
if got := c.ConnectedIP(); got.String() != "127.0.0.1" {
t.Fatalf("ConnectedIP = %q, want 127.0.0.1", got)
}
})
}
// TestClient_ConnectedIPAfterFQDNDial verifies ConnectedIP returns the
// resolved IP after a successful FQDN-based dial. The underlying socket's
// RemoteAddr must be exposed through the dialer wrappers; if it returns
// the dial-time URL instead, ConnectedIP returns empty and the fallback
// IP we advertise to peers is empty too.
func TestClient_ConnectedIPAfterFQDNDial(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
listenAddr, port := freeAddr(t)
srvCfg := server.Config{
Meter: otel.Meter(""),
ExposedAddress: fmt.Sprintf("rel://localhost:%d", port),
TLSSupport: false,
AuthValidator: &allow.Auth{},
}
srv, err := server.NewServer(srvCfg)
if err != nil {
t.Fatalf("create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
if err := srv.Listen(server.ListenerConfig{Address: listenAddr}); err != nil {
errChan <- err
}
}()
t.Cleanup(func() { _ = srv.Shutdown(context.Background()) })
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("server failed to start: %s", err)
}
c := NewClient(srvCfg.ExposedAddress, netip.Addr{}, hmacTokenStore, "alice-fqdn", iface.DefaultMTU)
if err := c.Connect(ctx); err != nil {
t.Fatalf("connect: %s", err)
}
t.Cleanup(func() { _ = c.Close() })
got := c.ConnectedIP().String()
if got != "127.0.0.1" && got != "::1" {
t.Fatalf("ConnectedIP after FQDN dial = %q, want 127.0.0.1 or ::1", got)
}
}
func TestSubstituteHost(t *testing.T) {
tests := []struct {
name string
serverURL string
ip string
wantURL string
wantServerName string
wantErr bool
}{
{
name: "rels with port",
serverURL: "rels://relay.netbird.io:443",
ip: "10.0.0.5",
wantURL: "rels://10.0.0.5:443",
wantServerName: "relay.netbird.io",
},
{
name: "rel with port",
serverURL: "rel://relay.example.com:80",
ip: "192.0.2.1",
wantURL: "rel://192.0.2.1:80",
wantServerName: "relay.example.com",
},
{
name: "ipv6 fallback bracketed",
serverURL: "rels://relay.example.com:443",
ip: "2001:db8::1",
wantURL: "rels://[2001:db8::1]:443",
wantServerName: "relay.example.com",
},
{
name: "no port",
serverURL: "rels://relay.example.com",
ip: "10.0.0.5",
wantURL: "rels://10.0.0.5",
wantServerName: "relay.example.com",
},
{
name: "ipv6 server with port returns empty SNI",
serverURL: "rels://[2001:db8::5]:443",
ip: "10.0.0.5",
wantURL: "rels://10.0.0.5:443",
wantServerName: "",
},
{
name: "ipv4 server with port returns empty SNI",
serverURL: "rels://10.0.0.5:443",
ip: "10.0.0.6",
wantURL: "rels://10.0.0.6:443",
wantServerName: "",
},
{
name: "ipv6 fallback no port",
serverURL: "rels://relay.example.com",
ip: "2001:db8::1",
wantURL: "rels://[2001:db8::1]",
wantServerName: "relay.example.com",
},
{
name: "missing scheme",
serverURL: "relay.example.com:443",
ip: "10.0.0.5",
wantErr: true,
},
{
name: "empty",
serverURL: "",
ip: "10.0.0.5",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var ip netip.Addr
if tt.ip != "" {
ip = netip.MustParseAddr(tt.ip)
}
gotURL, gotName, err := substituteHost(tt.serverURL, ip)
if tt.wantErr {
if err == nil {
t.Fatalf("expected error, got nil")
}
return
}
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if gotURL != tt.wantURL {
t.Errorf("URL = %q, want %q", gotURL, tt.wantURL)
}
if gotName != tt.wantServerName {
t.Errorf("ServerName = %q, want %q", gotName, tt.wantServerName)
}
})
}
}
func TestClient_ConnectedIPEmptyWhenNotConnected(t *testing.T) {
c := NewClient("rel://example.invalid:80", netip.Addr{}, hmacTokenStore, "x", iface.DefaultMTU)
if got := c.ConnectedIP(); got.IsValid() {
t.Fatalf("ConnectedIP on disconnected client = %q, want zero", got)
}
}
// staticAddr is a net.Addr that returns a fixed string. Used to verify
// ConnectedIP parses RemoteAddr correctly.
type staticAddr struct{ s string }
func (a staticAddr) Network() string { return "tcp" }
func (a staticAddr) String() string { return a.s }
type stubConn struct {
net.Conn
remote net.Addr
}
func (s stubConn) RemoteAddr() net.Addr { return s.remote }
func TestClient_ConnectedIPParsesRemoteAddr(t *testing.T) {
tests := []struct {
name string
s string
want string
}{
{"hostport ipv4", "127.0.0.1:50301", "127.0.0.1"},
{"hostport ipv6 bracketed", "[::1]:50301", "::1"},
{"url with ipv4", "rel://127.0.0.1:50301", "127.0.0.1"},
{"url with ipv6", "rels://[2001:db8::1]:443", "2001:db8::1"},
{"fqdn url returns empty", "rel://relay.example.com:50301", ""},
{"fqdn hostport returns empty", "relay.example.com:50301", ""},
{"plain ipv4 no port", "10.0.0.1", "10.0.0.1"},
{"empty", "", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Client{relayConn: stubConn{remote: staticAddr{s: tt.s}}}
got := c.ConnectedIP()
var gotStr string
if got.IsValid() {
gotStr = got.String()
}
if gotStr != tt.want {
t.Errorf("ConnectedIP(%q) = %q, want %q", tt.s, gotStr, tt.want)
}
})
}
}
// freeAddr returns a 127.0.0.1 address with an OS-assigned port. The
// listener is closed before returning, so the port is briefly free for
// the caller to bind. Avoids hardcoded ports that can collide.
func freeAddr(t *testing.T) (string, int) {
t.Helper()
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("get free port: %s", err)
}
addr := l.Addr().(*net.TCPAddr)
_ = l.Close()
return addr.String(), addr.Port
}