Files
netbird/client/internal/portforward/pcp/client.go
2026-04-15 11:43:16 +02:00

409 lines
12 KiB
Go

package pcp
import (
"context"
"crypto/rand"
"errors"
"fmt"
"net"
"net/netip"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
const (
defaultTimeout = 3 * time.Second
responseBufferSize = 128
// RFC 6887 Section 8.1.1 retry timing
initialRetryDelay = 3 * time.Second
maxRetryDelay = 1024 * time.Second
maxRetries = 4 // 3s + 6s + 12s + 24s = 45s total worst case
)
// Client is a PCP protocol client.
// All methods are safe for concurrent use.
type Client struct {
gateway netip.Addr
timeout time.Duration
mu sync.Mutex
// localIP caches the resolved local IP address.
localIP netip.Addr
// lastEpoch is the last observed server epoch value.
lastEpoch uint32
// epochTime tracks when lastEpoch was received for state loss detection.
epochTime time.Time
// externalIP caches the external IP from the last successful MAP response.
externalIP netip.Addr
// epochStateLost is set when epoch indicates server restart.
epochStateLost bool
}
// NewClient creates a new PCP client for the gateway at the given IP.
func NewClient(gateway net.IP) *Client {
addr, ok := netip.AddrFromSlice(gateway)
if !ok {
log.Debugf("invalid gateway IP: %v", gateway)
}
return &Client{
gateway: addr.Unmap(),
timeout: defaultTimeout,
}
}
// NewClientWithTimeout creates a new PCP client with a custom timeout.
func NewClientWithTimeout(gateway net.IP, timeout time.Duration) *Client {
addr, ok := netip.AddrFromSlice(gateway)
if !ok {
log.Debugf("invalid gateway IP: %v", gateway)
}
return &Client{
gateway: addr.Unmap(),
timeout: timeout,
}
}
// SetLocalIP sets the local IP address to use in PCP requests.
func (c *Client) SetLocalIP(ip net.IP) {
addr, ok := netip.AddrFromSlice(ip)
if !ok {
log.Debugf("invalid local IP: %v", ip)
}
c.mu.Lock()
c.localIP = addr.Unmap()
c.mu.Unlock()
}
// Gateway returns the gateway IP address.
func (c *Client) Gateway() net.IP {
return c.gateway.AsSlice()
}
// Announce sends a PCP ANNOUNCE request to discover PCP support.
// Returns the server's epoch time on success.
func (c *Client) Announce(ctx context.Context) (epoch uint32, err error) {
localIP, err := c.getLocalIP()
if err != nil {
return 0, fmt.Errorf("get local IP: %w", err)
}
req := buildAnnounceRequest(localIP)
resp, err := c.sendRequest(ctx, req)
if err != nil {
return 0, fmt.Errorf("send announce: %w", err)
}
parsed, err := parseResponse(resp)
if err != nil {
return 0, fmt.Errorf("parse announce response: %w", err)
}
if parsed.ResultCode != ResultSuccess {
return 0, fmt.Errorf("PCP ANNOUNCE failed: %s", ResultCodeString(parsed.ResultCode))
}
c.mu.Lock()
if c.updateEpochLocked(parsed.Epoch) {
log.Warnf("PCP server epoch indicates state loss - mappings may need refresh")
}
c.mu.Unlock()
return parsed.Epoch, nil
}
// AddPortMapping requests a port mapping from the PCP server.
func (c *Client) AddPortMapping(ctx context.Context, protocol string, internalPort int, lifetime time.Duration) (*MapResponse, error) {
return c.addPortMappingWithHint(ctx, protocol, internalPort, internalPort, netip.Addr{}, lifetime)
}
// AddPortMappingWithHint requests a port mapping with suggested external port and IP.
// Use lifetime <= 0 to delete a mapping.
func (c *Client) AddPortMappingWithHint(ctx context.Context, protocol string, internalPort, suggestedExtPort int, suggestedExtIP net.IP, lifetime time.Duration) (*MapResponse, error) {
var extIP netip.Addr
if suggestedExtIP != nil {
var ok bool
extIP, ok = netip.AddrFromSlice(suggestedExtIP)
if !ok {
log.Debugf("invalid suggested external IP: %v", suggestedExtIP)
}
extIP = extIP.Unmap()
}
return c.addPortMappingWithHint(ctx, protocol, internalPort, suggestedExtPort, extIP, lifetime)
}
func (c *Client) addPortMappingWithHint(ctx context.Context, protocol string, internalPort, suggestedExtPort int, suggestedExtIP netip.Addr, lifetime time.Duration) (*MapResponse, error) {
localIP, err := c.getLocalIP()
if err != nil {
return nil, fmt.Errorf("get local IP: %w", err)
}
proto, err := protocolNumber(protocol)
if err != nil {
return nil, fmt.Errorf("parse protocol: %w", err)
}
var nonce [12]byte
if _, err := rand.Read(nonce[:]); err != nil {
return nil, fmt.Errorf("generate nonce: %w", err)
}
// Convert lifetime to seconds. Lifetime 0 means delete, so only apply
// default for positive durations that round to 0 seconds.
var lifetimeSec uint32
if lifetime > 0 {
lifetimeSec = uint32(lifetime.Seconds())
if lifetimeSec == 0 {
lifetimeSec = DefaultLifetime
}
}
req := buildMapRequest(localIP, nonce, proto, uint16(internalPort), uint16(suggestedExtPort), suggestedExtIP, lifetimeSec)
resp, err := c.sendRequest(ctx, req)
if err != nil {
return nil, fmt.Errorf("send map request: %w", err)
}
mapResp, err := parseMapResponse(resp)
if err != nil {
return nil, fmt.Errorf("parse map response: %w", err)
}
if mapResp.Nonce != nonce {
return nil, fmt.Errorf("nonce mismatch in response")
}
if mapResp.Protocol != proto {
return nil, fmt.Errorf("protocol mismatch: requested %d, got %d", proto, mapResp.Protocol)
}
if mapResp.InternalPort != uint16(internalPort) {
return nil, fmt.Errorf("internal port mismatch: requested %d, got %d", internalPort, mapResp.InternalPort)
}
if mapResp.ResultCode != ResultSuccess {
return nil, &Error{
Code: mapResp.ResultCode,
Message: ResultCodeString(mapResp.ResultCode),
}
}
c.mu.Lock()
if c.updateEpochLocked(mapResp.Epoch) {
log.Warnf("PCP server epoch indicates state loss - mappings may need refresh")
}
c.cacheExternalIPLocked(mapResp.ExternalIP)
c.mu.Unlock()
return mapResp, nil
}
// DeletePortMapping removes a port mapping by requesting zero lifetime.
func (c *Client) DeletePortMapping(ctx context.Context, protocol string, internalPort int) error {
if _, err := c.addPortMappingWithHint(ctx, protocol, internalPort, 0, netip.Addr{}, 0); err != nil {
var pcpErr *Error
if errors.As(err, &pcpErr) && pcpErr.Code == ResultNotAuthorized {
return nil
}
return fmt.Errorf("delete mapping: %w", err)
}
return nil
}
// GetExternalAddress returns the external IP address.
// First checks for a cached value from previous MAP responses.
// If not cached, creates a short-lived mapping to discover the external IP.
func (c *Client) GetExternalAddress(ctx context.Context) (net.IP, error) {
c.mu.Lock()
if c.externalIP.IsValid() {
ip := c.externalIP.AsSlice()
c.mu.Unlock()
return ip, nil
}
c.mu.Unlock()
// Use an ephemeral port in the dynamic range (49152-65535).
// Port 0 is not valid with UDP/TCP protocols per RFC 6887.
ephemeralPort := 49152 + int(uint16(time.Now().UnixNano()))%(65535-49152)
// Use minimal lifetime (1 second) for discovery.
resp, err := c.AddPortMapping(ctx, "udp", ephemeralPort, time.Second)
if err != nil {
return nil, fmt.Errorf("create temporary mapping: %w", err)
}
if err := c.DeletePortMapping(ctx, "udp", ephemeralPort); err != nil {
log.Debugf("cleanup temporary PCP mapping: %v", err)
}
return resp.ExternalIP.AsSlice(), nil
}
// LastEpoch returns the last observed server epoch value.
// A decrease in epoch indicates the server may have restarted and mappings may be lost.
func (c *Client) LastEpoch() uint32 {
c.mu.Lock()
defer c.mu.Unlock()
return c.lastEpoch
}
// EpochStateLost returns true if epoch state loss was detected and clears the flag.
func (c *Client) EpochStateLost() bool {
c.mu.Lock()
defer c.mu.Unlock()
lost := c.epochStateLost
c.epochStateLost = false
return lost
}
// updateEpoch updates the epoch tracking and detects potential state loss.
// Returns true if state loss was detected (server likely restarted).
// Caller must hold c.mu.
func (c *Client) updateEpochLocked(newEpoch uint32) bool {
now := time.Now()
stateLost := false
// RFC 6887 Section 8.5: Detect invalid epoch indicating server state loss.
// client_delta = time since last response
// server_delta = epoch change since last response
// Invalid if: client_delta+2 < server_delta - server_delta/16
// OR: server_delta+2 < client_delta - client_delta/16
// The +2 handles quantization, /16 (6.25%) handles clock drift.
if !c.epochTime.IsZero() && c.lastEpoch > 0 {
clientDelta := uint32(now.Sub(c.epochTime).Seconds())
serverDelta := newEpoch - c.lastEpoch
// Check for epoch going backwards or jumping unexpectedly.
// Subtraction is safe: serverDelta/16 is always <= serverDelta.
if clientDelta+2 < serverDelta-(serverDelta/16) ||
serverDelta+2 < clientDelta-(clientDelta/16) {
stateLost = true
c.epochStateLost = true
}
}
c.lastEpoch = newEpoch
c.epochTime = now
return stateLost
}
// cacheExternalIP stores the external IP from a successful MAP response.
// Caller must hold c.mu.
func (c *Client) cacheExternalIPLocked(ip netip.Addr) {
if ip.IsValid() && !ip.IsUnspecified() {
c.externalIP = ip
}
}
// sendRequest sends a PCP request with retries per RFC 6887 Section 8.1.1.
func (c *Client) sendRequest(ctx context.Context, req []byte) ([]byte, error) {
addr := &net.UDPAddr{IP: c.gateway.AsSlice(), Port: Port}
var lastErr error
delay := initialRetryDelay
for range maxRetries {
resp, err := c.sendOnce(ctx, addr, req)
if err == nil {
return resp, nil
}
lastErr = err
if ctx.Err() != nil {
return nil, ctx.Err()
}
// RFC 6887 Section 8.1.1: RT = (1 + RAND) * MIN(2 * RTprev, MRT)
// RAND is random between -0.1 and +0.1
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(retryDelayWithJitter(delay)):
}
delay = min(delay*2, maxRetryDelay)
}
return nil, fmt.Errorf("PCP request failed after %d retries: %w", maxRetries, lastErr)
}
// retryDelayWithJitter applies RFC 6887 jitter: multiply by (1 + RAND) where RAND is [-0.1, +0.1].
func retryDelayWithJitter(d time.Duration) time.Duration {
var b [1]byte
_, _ = rand.Read(b[:])
// Convert byte to range [-0.1, +0.1]: (b/255 * 0.2) - 0.1
jitter := (float64(b[0])/255.0)*0.2 - 0.1
return time.Duration(float64(d) * (1 + jitter))
}
func (c *Client) sendOnce(ctx context.Context, addr *net.UDPAddr, req []byte) ([]byte, error) {
// Use ListenUDP instead of DialUDP to validate response source address per RFC 6887 §8.3.
conn, err := net.ListenUDP("udp", nil)
if err != nil {
return nil, fmt.Errorf("listen: %w", err)
}
defer func() {
if err := conn.Close(); err != nil {
log.Debugf("close UDP connection: %v", err)
}
}()
timeout := c.timeout
if deadline, ok := ctx.Deadline(); ok {
if remaining := time.Until(deadline); remaining < timeout {
timeout = remaining
}
}
if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil {
return nil, fmt.Errorf("set deadline: %w", err)
}
if _, err := conn.WriteToUDP(req, addr); err != nil {
return nil, fmt.Errorf("write: %w", err)
}
resp := make([]byte, responseBufferSize)
n, from, err := conn.ReadFromUDP(resp)
if err != nil {
return nil, fmt.Errorf("read: %w", err)
}
// RFC 6887 §8.3: Validate response came from expected PCP server.
if !from.IP.Equal(addr.IP) {
return nil, fmt.Errorf("response from unexpected source %s (expected %s)", from.IP, addr.IP)
}
return resp[:n], nil
}
func (c *Client) getLocalIP() (netip.Addr, error) {
c.mu.Lock()
defer c.mu.Unlock()
if !c.localIP.IsValid() {
return netip.Addr{}, fmt.Errorf("local IP not set for gateway %s", c.gateway)
}
return c.localIP, nil
}
func protocolNumber(protocol string) (uint8, error) {
switch protocol {
case "udp", "UDP":
return ProtoUDP, nil
case "tcp", "TCP":
return ProtoTCP, nil
default:
return 0, fmt.Errorf("unsupported protocol: %s", protocol)
}
}
// Error represents a PCP error response.
type Error struct {
Code uint8
Message string
}
func (e *Error) Error() string {
return fmt.Sprintf("PCP error: %s (%d)", e.Message, e.Code)
}