mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
409 lines
12 KiB
Go
409 lines
12 KiB
Go
package pcp
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"net/netip"
|
|
"sync"
|
|
"time"
|
|
|
|
log "github.com/sirupsen/logrus"
|
|
)
|
|
|
|
const (
|
|
defaultTimeout = 3 * time.Second
|
|
responseBufferSize = 128
|
|
|
|
// RFC 6887 Section 8.1.1 retry timing
|
|
initialRetryDelay = 3 * time.Second
|
|
maxRetryDelay = 1024 * time.Second
|
|
maxRetries = 4 // 3s + 6s + 12s + 24s = 45s total worst case
|
|
)
|
|
|
|
// Client is a PCP protocol client.
|
|
// All methods are safe for concurrent use.
|
|
type Client struct {
|
|
gateway netip.Addr
|
|
timeout time.Duration
|
|
|
|
mu sync.Mutex
|
|
// localIP caches the resolved local IP address.
|
|
localIP netip.Addr
|
|
// lastEpoch is the last observed server epoch value.
|
|
lastEpoch uint32
|
|
// epochTime tracks when lastEpoch was received for state loss detection.
|
|
epochTime time.Time
|
|
// externalIP caches the external IP from the last successful MAP response.
|
|
externalIP netip.Addr
|
|
// epochStateLost is set when epoch indicates server restart.
|
|
epochStateLost bool
|
|
}
|
|
|
|
// NewClient creates a new PCP client for the gateway at the given IP.
|
|
func NewClient(gateway net.IP) *Client {
|
|
addr, ok := netip.AddrFromSlice(gateway)
|
|
if !ok {
|
|
log.Debugf("invalid gateway IP: %v", gateway)
|
|
}
|
|
return &Client{
|
|
gateway: addr.Unmap(),
|
|
timeout: defaultTimeout,
|
|
}
|
|
}
|
|
|
|
// NewClientWithTimeout creates a new PCP client with a custom timeout.
|
|
func NewClientWithTimeout(gateway net.IP, timeout time.Duration) *Client {
|
|
addr, ok := netip.AddrFromSlice(gateway)
|
|
if !ok {
|
|
log.Debugf("invalid gateway IP: %v", gateway)
|
|
}
|
|
return &Client{
|
|
gateway: addr.Unmap(),
|
|
timeout: timeout,
|
|
}
|
|
}
|
|
|
|
// SetLocalIP sets the local IP address to use in PCP requests.
|
|
func (c *Client) SetLocalIP(ip net.IP) {
|
|
addr, ok := netip.AddrFromSlice(ip)
|
|
if !ok {
|
|
log.Debugf("invalid local IP: %v", ip)
|
|
}
|
|
c.mu.Lock()
|
|
c.localIP = addr.Unmap()
|
|
c.mu.Unlock()
|
|
}
|
|
|
|
// Gateway returns the gateway IP address.
|
|
func (c *Client) Gateway() net.IP {
|
|
return c.gateway.AsSlice()
|
|
}
|
|
|
|
// Announce sends a PCP ANNOUNCE request to discover PCP support.
|
|
// Returns the server's epoch time on success.
|
|
func (c *Client) Announce(ctx context.Context) (epoch uint32, err error) {
|
|
localIP, err := c.getLocalIP()
|
|
if err != nil {
|
|
return 0, fmt.Errorf("get local IP: %w", err)
|
|
}
|
|
|
|
req := buildAnnounceRequest(localIP)
|
|
resp, err := c.sendRequest(ctx, req)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("send announce: %w", err)
|
|
}
|
|
|
|
parsed, err := parseResponse(resp)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("parse announce response: %w", err)
|
|
}
|
|
|
|
if parsed.ResultCode != ResultSuccess {
|
|
return 0, fmt.Errorf("PCP ANNOUNCE failed: %s", ResultCodeString(parsed.ResultCode))
|
|
}
|
|
|
|
c.mu.Lock()
|
|
if c.updateEpochLocked(parsed.Epoch) {
|
|
log.Warnf("PCP server epoch indicates state loss - mappings may need refresh")
|
|
}
|
|
c.mu.Unlock()
|
|
return parsed.Epoch, nil
|
|
}
|
|
|
|
// AddPortMapping requests a port mapping from the PCP server.
|
|
func (c *Client) AddPortMapping(ctx context.Context, protocol string, internalPort int, lifetime time.Duration) (*MapResponse, error) {
|
|
return c.addPortMappingWithHint(ctx, protocol, internalPort, internalPort, netip.Addr{}, lifetime)
|
|
}
|
|
|
|
// AddPortMappingWithHint requests a port mapping with suggested external port and IP.
|
|
// Use lifetime <= 0 to delete a mapping.
|
|
func (c *Client) AddPortMappingWithHint(ctx context.Context, protocol string, internalPort, suggestedExtPort int, suggestedExtIP net.IP, lifetime time.Duration) (*MapResponse, error) {
|
|
var extIP netip.Addr
|
|
if suggestedExtIP != nil {
|
|
var ok bool
|
|
extIP, ok = netip.AddrFromSlice(suggestedExtIP)
|
|
if !ok {
|
|
log.Debugf("invalid suggested external IP: %v", suggestedExtIP)
|
|
}
|
|
extIP = extIP.Unmap()
|
|
}
|
|
return c.addPortMappingWithHint(ctx, protocol, internalPort, suggestedExtPort, extIP, lifetime)
|
|
}
|
|
|
|
func (c *Client) addPortMappingWithHint(ctx context.Context, protocol string, internalPort, suggestedExtPort int, suggestedExtIP netip.Addr, lifetime time.Duration) (*MapResponse, error) {
|
|
localIP, err := c.getLocalIP()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("get local IP: %w", err)
|
|
}
|
|
|
|
proto, err := protocolNumber(protocol)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parse protocol: %w", err)
|
|
}
|
|
|
|
var nonce [12]byte
|
|
if _, err := rand.Read(nonce[:]); err != nil {
|
|
return nil, fmt.Errorf("generate nonce: %w", err)
|
|
}
|
|
|
|
// Convert lifetime to seconds. Lifetime 0 means delete, so only apply
|
|
// default for positive durations that round to 0 seconds.
|
|
var lifetimeSec uint32
|
|
if lifetime > 0 {
|
|
lifetimeSec = uint32(lifetime.Seconds())
|
|
if lifetimeSec == 0 {
|
|
lifetimeSec = DefaultLifetime
|
|
}
|
|
}
|
|
|
|
req := buildMapRequest(localIP, nonce, proto, uint16(internalPort), uint16(suggestedExtPort), suggestedExtIP, lifetimeSec)
|
|
|
|
resp, err := c.sendRequest(ctx, req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("send map request: %w", err)
|
|
}
|
|
|
|
mapResp, err := parseMapResponse(resp)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parse map response: %w", err)
|
|
}
|
|
|
|
if mapResp.Nonce != nonce {
|
|
return nil, fmt.Errorf("nonce mismatch in response")
|
|
}
|
|
|
|
if mapResp.Protocol != proto {
|
|
return nil, fmt.Errorf("protocol mismatch: requested %d, got %d", proto, mapResp.Protocol)
|
|
}
|
|
if mapResp.InternalPort != uint16(internalPort) {
|
|
return nil, fmt.Errorf("internal port mismatch: requested %d, got %d", internalPort, mapResp.InternalPort)
|
|
}
|
|
|
|
if mapResp.ResultCode != ResultSuccess {
|
|
return nil, &Error{
|
|
Code: mapResp.ResultCode,
|
|
Message: ResultCodeString(mapResp.ResultCode),
|
|
}
|
|
}
|
|
|
|
c.mu.Lock()
|
|
if c.updateEpochLocked(mapResp.Epoch) {
|
|
log.Warnf("PCP server epoch indicates state loss - mappings may need refresh")
|
|
}
|
|
c.cacheExternalIPLocked(mapResp.ExternalIP)
|
|
c.mu.Unlock()
|
|
return mapResp, nil
|
|
}
|
|
|
|
// DeletePortMapping removes a port mapping by requesting zero lifetime.
|
|
func (c *Client) DeletePortMapping(ctx context.Context, protocol string, internalPort int) error {
|
|
if _, err := c.addPortMappingWithHint(ctx, protocol, internalPort, 0, netip.Addr{}, 0); err != nil {
|
|
var pcpErr *Error
|
|
if errors.As(err, &pcpErr) && pcpErr.Code == ResultNotAuthorized {
|
|
return nil
|
|
}
|
|
return fmt.Errorf("delete mapping: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetExternalAddress returns the external IP address.
|
|
// First checks for a cached value from previous MAP responses.
|
|
// If not cached, creates a short-lived mapping to discover the external IP.
|
|
func (c *Client) GetExternalAddress(ctx context.Context) (net.IP, error) {
|
|
c.mu.Lock()
|
|
if c.externalIP.IsValid() {
|
|
ip := c.externalIP.AsSlice()
|
|
c.mu.Unlock()
|
|
return ip, nil
|
|
}
|
|
c.mu.Unlock()
|
|
|
|
// Use an ephemeral port in the dynamic range (49152-65535).
|
|
// Port 0 is not valid with UDP/TCP protocols per RFC 6887.
|
|
ephemeralPort := 49152 + int(uint16(time.Now().UnixNano()))%(65535-49152)
|
|
|
|
// Use minimal lifetime (1 second) for discovery.
|
|
resp, err := c.AddPortMapping(ctx, "udp", ephemeralPort, time.Second)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create temporary mapping: %w", err)
|
|
}
|
|
|
|
if err := c.DeletePortMapping(ctx, "udp", ephemeralPort); err != nil {
|
|
log.Debugf("cleanup temporary PCP mapping: %v", err)
|
|
}
|
|
|
|
return resp.ExternalIP.AsSlice(), nil
|
|
}
|
|
|
|
// LastEpoch returns the last observed server epoch value.
|
|
// A decrease in epoch indicates the server may have restarted and mappings may be lost.
|
|
func (c *Client) LastEpoch() uint32 {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
return c.lastEpoch
|
|
}
|
|
|
|
// EpochStateLost returns true if epoch state loss was detected and clears the flag.
|
|
func (c *Client) EpochStateLost() bool {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
lost := c.epochStateLost
|
|
c.epochStateLost = false
|
|
return lost
|
|
}
|
|
|
|
// updateEpoch updates the epoch tracking and detects potential state loss.
|
|
// Returns true if state loss was detected (server likely restarted).
|
|
// Caller must hold c.mu.
|
|
func (c *Client) updateEpochLocked(newEpoch uint32) bool {
|
|
now := time.Now()
|
|
stateLost := false
|
|
|
|
// RFC 6887 Section 8.5: Detect invalid epoch indicating server state loss.
|
|
// client_delta = time since last response
|
|
// server_delta = epoch change since last response
|
|
// Invalid if: client_delta+2 < server_delta - server_delta/16
|
|
// OR: server_delta+2 < client_delta - client_delta/16
|
|
// The +2 handles quantization, /16 (6.25%) handles clock drift.
|
|
if !c.epochTime.IsZero() && c.lastEpoch > 0 {
|
|
clientDelta := uint32(now.Sub(c.epochTime).Seconds())
|
|
serverDelta := newEpoch - c.lastEpoch
|
|
|
|
// Check for epoch going backwards or jumping unexpectedly.
|
|
// Subtraction is safe: serverDelta/16 is always <= serverDelta.
|
|
if clientDelta+2 < serverDelta-(serverDelta/16) ||
|
|
serverDelta+2 < clientDelta-(clientDelta/16) {
|
|
stateLost = true
|
|
c.epochStateLost = true
|
|
}
|
|
}
|
|
|
|
c.lastEpoch = newEpoch
|
|
c.epochTime = now
|
|
return stateLost
|
|
}
|
|
|
|
// cacheExternalIP stores the external IP from a successful MAP response.
|
|
// Caller must hold c.mu.
|
|
func (c *Client) cacheExternalIPLocked(ip netip.Addr) {
|
|
if ip.IsValid() && !ip.IsUnspecified() {
|
|
c.externalIP = ip
|
|
}
|
|
}
|
|
|
|
// sendRequest sends a PCP request with retries per RFC 6887 Section 8.1.1.
|
|
func (c *Client) sendRequest(ctx context.Context, req []byte) ([]byte, error) {
|
|
addr := &net.UDPAddr{IP: c.gateway.AsSlice(), Port: Port}
|
|
|
|
var lastErr error
|
|
delay := initialRetryDelay
|
|
|
|
for range maxRetries {
|
|
resp, err := c.sendOnce(ctx, addr, req)
|
|
if err == nil {
|
|
return resp, nil
|
|
}
|
|
lastErr = err
|
|
|
|
if ctx.Err() != nil {
|
|
return nil, ctx.Err()
|
|
}
|
|
|
|
// RFC 6887 Section 8.1.1: RT = (1 + RAND) * MIN(2 * RTprev, MRT)
|
|
// RAND is random between -0.1 and +0.1
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
case <-time.After(retryDelayWithJitter(delay)):
|
|
}
|
|
delay = min(delay*2, maxRetryDelay)
|
|
}
|
|
|
|
return nil, fmt.Errorf("PCP request failed after %d retries: %w", maxRetries, lastErr)
|
|
}
|
|
|
|
// retryDelayWithJitter applies RFC 6887 jitter: multiply by (1 + RAND) where RAND is [-0.1, +0.1].
|
|
func retryDelayWithJitter(d time.Duration) time.Duration {
|
|
var b [1]byte
|
|
_, _ = rand.Read(b[:])
|
|
// Convert byte to range [-0.1, +0.1]: (b/255 * 0.2) - 0.1
|
|
jitter := (float64(b[0])/255.0)*0.2 - 0.1
|
|
return time.Duration(float64(d) * (1 + jitter))
|
|
}
|
|
|
|
func (c *Client) sendOnce(ctx context.Context, addr *net.UDPAddr, req []byte) ([]byte, error) {
|
|
// Use ListenUDP instead of DialUDP to validate response source address per RFC 6887 §8.3.
|
|
conn, err := net.ListenUDP("udp", nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("listen: %w", err)
|
|
}
|
|
defer func() {
|
|
if err := conn.Close(); err != nil {
|
|
log.Debugf("close UDP connection: %v", err)
|
|
}
|
|
}()
|
|
|
|
timeout := c.timeout
|
|
if deadline, ok := ctx.Deadline(); ok {
|
|
if remaining := time.Until(deadline); remaining < timeout {
|
|
timeout = remaining
|
|
}
|
|
}
|
|
|
|
if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil {
|
|
return nil, fmt.Errorf("set deadline: %w", err)
|
|
}
|
|
|
|
if _, err := conn.WriteToUDP(req, addr); err != nil {
|
|
return nil, fmt.Errorf("write: %w", err)
|
|
}
|
|
|
|
resp := make([]byte, responseBufferSize)
|
|
n, from, err := conn.ReadFromUDP(resp)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("read: %w", err)
|
|
}
|
|
|
|
// RFC 6887 §8.3: Validate response came from expected PCP server.
|
|
if !from.IP.Equal(addr.IP) {
|
|
return nil, fmt.Errorf("response from unexpected source %s (expected %s)", from.IP, addr.IP)
|
|
}
|
|
|
|
return resp[:n], nil
|
|
}
|
|
|
|
func (c *Client) getLocalIP() (netip.Addr, error) {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
if !c.localIP.IsValid() {
|
|
return netip.Addr{}, fmt.Errorf("local IP not set for gateway %s", c.gateway)
|
|
}
|
|
return c.localIP, nil
|
|
}
|
|
|
|
func protocolNumber(protocol string) (uint8, error) {
|
|
switch protocol {
|
|
case "udp", "UDP":
|
|
return ProtoUDP, nil
|
|
case "tcp", "TCP":
|
|
return ProtoTCP, nil
|
|
default:
|
|
return 0, fmt.Errorf("unsupported protocol: %s", protocol)
|
|
}
|
|
}
|
|
|
|
// Error represents a PCP error response.
|
|
type Error struct {
|
|
Code uint8
|
|
Message string
|
|
}
|
|
|
|
func (e *Error) Error() string {
|
|
return fmt.Sprintf("PCP error: %s (%d)", e.Message, e.Code)
|
|
}
|