mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-06 00:56:39 +00:00
[client, relay] Advertise relay server IP via signal for foreign-relay fallback dial (#6004)
This commit is contained in:
@@ -2,8 +2,12 @@ package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -146,6 +150,7 @@ func (cc *connContainer) close() {
|
||||
type Client struct {
|
||||
log *log.Entry
|
||||
connectionURL string
|
||||
serverIP netip.Addr
|
||||
authTokenStore *auth.TokenStore
|
||||
hashedID messages.PeerID
|
||||
|
||||
@@ -170,13 +175,22 @@ type Client struct {
|
||||
}
|
||||
|
||||
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
|
||||
// is called.
|
||||
func NewClient(serverURL string, authTokenStore *auth.TokenStore, peerID string, mtu uint16) *Client {
|
||||
return NewClientWithServerIP(serverURL, netip.Addr{}, authTokenStore, peerID, mtu)
|
||||
}
|
||||
|
||||
// NewClientWithServerIP creates a new client for the relay server with a known server IP. serverIP, when valid, is
|
||||
// dialed directly first; the FQDN is only attempted if the IP-based dial fails. TLS verification still uses the
|
||||
// FQDN from serverURL via SNI.
|
||||
func NewClientWithServerIP(serverURL string, serverIP netip.Addr, authTokenStore *auth.TokenStore, peerID string, mtu uint16) *Client {
|
||||
hashedID := messages.HashID(peerID)
|
||||
relayLog := log.WithFields(log.Fields{"relay": serverURL})
|
||||
|
||||
c := &Client{
|
||||
log: relayLog,
|
||||
connectionURL: serverURL,
|
||||
serverIP: serverIP,
|
||||
authTokenStore: authTokenStore,
|
||||
hashedID: hashedID,
|
||||
mtu: mtu,
|
||||
@@ -304,6 +318,23 @@ func (c *Client) ServerInstanceURL() (string, error) {
|
||||
return c.instanceURL.String(), nil
|
||||
}
|
||||
|
||||
// ConnectedIP returns the IP address of the live relay-server connection,
|
||||
// extracted from the underlying socket's RemoteAddr. Zero value if not
|
||||
// connected or if the address is not an IP literal.
|
||||
func (c *Client) ConnectedIP() netip.Addr {
|
||||
c.mu.Lock()
|
||||
conn := c.relayConn
|
||||
c.mu.Unlock()
|
||||
if conn == nil {
|
||||
return netip.Addr{}
|
||||
}
|
||||
addr := conn.RemoteAddr()
|
||||
if addr == nil {
|
||||
return netip.Addr{}
|
||||
}
|
||||
return extractIPLiteral(addr.String())
|
||||
}
|
||||
|
||||
// SetOnDisconnectListener sets a function that will be called when the connection to the relay server is closed.
|
||||
func (c *Client) SetOnDisconnectListener(fn func(string)) {
|
||||
c.listenerMutex.Lock()
|
||||
@@ -332,10 +363,23 @@ func (c *Client) Close() error {
|
||||
func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
|
||||
dialers := c.getDialers()
|
||||
|
||||
rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, dialers...)
|
||||
conn, err := rd.Dial(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
var conn net.Conn
|
||||
if c.serverIP.IsValid() {
|
||||
var err error
|
||||
conn, err = c.dialRaceDirect(ctx, dialers)
|
||||
if err != nil {
|
||||
c.log.Infof("dial via server IP %s failed, falling back to FQDN: %v", c.serverIP, err)
|
||||
conn = nil
|
||||
}
|
||||
}
|
||||
|
||||
if conn == nil {
|
||||
rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, dialers...)
|
||||
var err error
|
||||
conn, err = rd.Dial(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dial via FQDN: %w", err)
|
||||
}
|
||||
}
|
||||
c.relayConn = conn
|
||||
|
||||
@@ -351,6 +395,52 @@ func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
|
||||
return instanceURL, nil
|
||||
}
|
||||
|
||||
// dialRaceDirect dials c.serverIP, preserving the original FQDN as the TLS ServerName for SNI.
|
||||
func (c *Client) dialRaceDirect(ctx context.Context, dialers []dialer.DialeFn) (net.Conn, error) {
|
||||
directURL, serverName, err := substituteHost(c.connectionURL, c.serverIP)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("substitute host: %w", err)
|
||||
}
|
||||
|
||||
c.log.Debugf("dialing via server IP %s (SNI=%s)", c.serverIP, serverName)
|
||||
|
||||
rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, directURL, dialers...).
|
||||
WithServerName(serverName)
|
||||
return rd.Dial(ctx)
|
||||
}
|
||||
|
||||
// substituteHost replaces the host portion of a rel/rels URL with ip,
|
||||
// preserving the scheme and port. Returns the rewritten URL and the
|
||||
// original host to use as the TLS ServerName, or empty if the original
|
||||
// host is itself an IP literal (SNI requires a DNS name).
|
||||
func substituteHost(serverURL string, ip netip.Addr) (string, string, error) {
|
||||
u, err := url.Parse(serverURL)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("parse %q: %w", serverURL, err)
|
||||
}
|
||||
if u.Scheme == "" || u.Host == "" {
|
||||
return "", "", fmt.Errorf("invalid relay URL %q", serverURL)
|
||||
}
|
||||
if !ip.IsValid() {
|
||||
return "", "", errors.New("invalid server IP")
|
||||
}
|
||||
origHost := u.Hostname()
|
||||
if _, err := netip.ParseAddr(origHost); err == nil {
|
||||
origHost = ""
|
||||
}
|
||||
ip = ip.Unmap()
|
||||
newHost := ip.String()
|
||||
if ip.Is6() {
|
||||
newHost = "[" + newHost + "]"
|
||||
}
|
||||
if port := u.Port(); port != "" {
|
||||
u.Host = newHost + ":" + port
|
||||
} else {
|
||||
u.Host = newHost
|
||||
}
|
||||
return u.String(), origHost, nil
|
||||
}
|
||||
|
||||
func (c *Client) handShake(ctx context.Context) (*RelayAddr, error) {
|
||||
msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary())
|
||||
if err != nil {
|
||||
@@ -716,3 +806,21 @@ func (c *Client) handlePeersWentOfflineMsg(buf []byte) {
|
||||
}
|
||||
c.stateSubscription.OnPeersWentOffline(peersID)
|
||||
}
|
||||
|
||||
// extractIPLiteral returns the IP from address forms produced by the relay
|
||||
// dialers (URL or host:port). Zero value if the host is not an IP.
|
||||
func extractIPLiteral(s string) netip.Addr {
|
||||
if u, err := url.Parse(s); err == nil && u.Host != "" {
|
||||
s = u.Host
|
||||
}
|
||||
host, _, err := net.SplitHostPort(s)
|
||||
if err != nil {
|
||||
host = s
|
||||
}
|
||||
host = strings.Trim(host, "[]")
|
||||
ip, err := netip.ParseAddr(host)
|
||||
if err != nil {
|
||||
return netip.Addr{}
|
||||
}
|
||||
return ip.Unmap()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user