From 0d86de47df5a8ced0f9b1a9e3c288a7abc7fa9aa Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Wed, 15 Apr 2026 18:43:16 +0900 Subject: [PATCH] [client] Add PCP support (#5219) --- client/internal/portforward/env.go | 15 +- client/internal/portforward/manager.go | 82 +++- client/internal/portforward/pcp/client.go | 408 ++++++++++++++++++ .../internal/portforward/pcp/client_test.go | 187 ++++++++ client/internal/portforward/pcp/nat.go | 209 +++++++++ client/internal/portforward/pcp/protocol.go | 225 ++++++++++ client/internal/portforward/state.go | 63 +++ 7 files changed, 1176 insertions(+), 13 deletions(-) create mode 100644 client/internal/portforward/pcp/client.go create mode 100644 client/internal/portforward/pcp/client_test.go create mode 100644 client/internal/portforward/pcp/nat.go create mode 100644 client/internal/portforward/pcp/protocol.go create mode 100644 client/internal/portforward/state.go diff --git a/client/internal/portforward/env.go b/client/internal/portforward/env.go index 444a6b478..ba83c79bf 100644 --- a/client/internal/portforward/env.go +++ b/client/internal/portforward/env.go @@ -8,18 +8,27 @@ import ( ) const ( - envDisableNATMapper = "NB_DISABLE_NAT_MAPPER" + envDisableNATMapper = "NB_DISABLE_NAT_MAPPER" + envDisablePCPHealthCheck = "NB_DISABLE_PCP_HEALTH_CHECK" ) func isDisabledByEnv() bool { - val := os.Getenv(envDisableNATMapper) + return parseBoolEnv(envDisableNATMapper) +} + +func isHealthCheckDisabled() bool { + return parseBoolEnv(envDisablePCPHealthCheck) +} + +func parseBoolEnv(key string) bool { + val := os.Getenv(key) if val == "" { return false } disabled, err := strconv.ParseBool(val) if err != nil { - log.Warnf("failed to parse %s: %v", envDisableNATMapper, err) + log.Warnf("failed to parse %s: %v", key, err) return false } return disabled diff --git a/client/internal/portforward/manager.go b/client/internal/portforward/manager.go index bf7533af9..b0680160c 100644 --- a/client/internal/portforward/manager.go +++ b/client/internal/portforward/manager.go @@ -12,12 +12,15 @@ import ( "github.com/libp2p/go-nat" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/portforward/pcp" ) const ( - defaultMappingTTL = 2 * time.Hour - discoveryTimeout = 10 * time.Second - mappingDescription = "NetBird" + defaultMappingTTL = 2 * time.Hour + healthCheckInterval = 1 * time.Minute + discoveryTimeout = 10 * time.Second + mappingDescription = "NetBird" ) // upnpErrPermanentLeaseOnly matches UPnP error 725 in SOAP fault XML, @@ -154,7 +157,7 @@ func (m *Manager) setup(ctx context.Context) (nat.NAT, *Mapping, error) { discoverCtx, discoverCancel := context.WithTimeout(ctx, discoveryTimeout) defer discoverCancel() - gateway, err := nat.DiscoverGateway(discoverCtx) + gateway, err := discoverGateway(discoverCtx) if err != nil { return nil, nil, fmt.Errorf("discover gateway: %w", err) } @@ -189,7 +192,6 @@ func (m *Manager) createMapping(ctx context.Context, gateway nat.NAT) (*Mapping, externalIP, err := gateway.GetExternalAddress() if err != nil { log.Debugf("failed to get external address: %v", err) - // todo return with err? } mapping := &Mapping{ @@ -208,27 +210,87 @@ func (m *Manager) createMapping(ctx context.Context, gateway nat.NAT) (*Mapping, func (m *Manager) renewLoop(ctx context.Context, gateway nat.NAT, ttl time.Duration) { if ttl == 0 { - // Permanent mappings don't expire, just wait for cancellation. - <-ctx.Done() + // Permanent mappings don't expire, just wait for cancellation + // but still run health checks for PCP gateways. + m.permanentLeaseLoop(ctx, gateway) return } - ticker := time.NewTicker(ttl / 2) - defer ticker.Stop() + renewTicker := time.NewTicker(ttl / 2) + healthTicker := time.NewTicker(healthCheckInterval) + defer renewTicker.Stop() + defer healthTicker.Stop() for { select { case <-ctx.Done(): return - case <-ticker.C: + case <-renewTicker.C: if err := m.renewMapping(ctx, gateway); err != nil { log.Warnf("failed to renew port mapping: %v", err) continue } + case <-healthTicker.C: + if m.checkHealthAndRecreate(ctx, gateway) { + renewTicker.Reset(ttl / 2) + } } } } +func (m *Manager) permanentLeaseLoop(ctx context.Context, gateway nat.NAT) { + healthTicker := time.NewTicker(healthCheckInterval) + defer healthTicker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-healthTicker.C: + m.checkHealthAndRecreate(ctx, gateway) + } + } +} + +func (m *Manager) checkHealthAndRecreate(ctx context.Context, gateway nat.NAT) bool { + if isHealthCheckDisabled() { + return false + } + + m.mappingLock.Lock() + hasMapping := m.mapping != nil + m.mappingLock.Unlock() + + if !hasMapping { + return false + } + + pcpNAT, ok := gateway.(*pcp.NAT) + if !ok { + return false + } + + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + epoch, serverRestarted, err := pcpNAT.CheckServerHealth(ctx) + if err != nil { + log.Debugf("PCP health check failed: %v", err) + return false + } + + if serverRestarted { + log.Warnf("PCP server restart detected (epoch=%d), recreating port mapping", epoch) + if err := m.renewMapping(ctx, gateway); err != nil { + log.Errorf("failed to recreate port mapping after server restart: %v", err) + return false + } + return true + } + + return false +} + func (m *Manager) renewMapping(ctx context.Context, gateway nat.NAT) error { ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() diff --git a/client/internal/portforward/pcp/client.go b/client/internal/portforward/pcp/client.go new file mode 100644 index 000000000..f6d243ef9 --- /dev/null +++ b/client/internal/portforward/pcp/client.go @@ -0,0 +1,408 @@ +package pcp + +import ( + "context" + "crypto/rand" + "errors" + "fmt" + "net" + "net/netip" + "sync" + "time" + + log "github.com/sirupsen/logrus" +) + +const ( + defaultTimeout = 3 * time.Second + responseBufferSize = 128 + + // RFC 6887 Section 8.1.1 retry timing + initialRetryDelay = 3 * time.Second + maxRetryDelay = 1024 * time.Second + maxRetries = 4 // 3s + 6s + 12s + 24s = 45s total worst case +) + +// Client is a PCP protocol client. +// All methods are safe for concurrent use. +type Client struct { + gateway netip.Addr + timeout time.Duration + + mu sync.Mutex + // localIP caches the resolved local IP address. + localIP netip.Addr + // lastEpoch is the last observed server epoch value. + lastEpoch uint32 + // epochTime tracks when lastEpoch was received for state loss detection. + epochTime time.Time + // externalIP caches the external IP from the last successful MAP response. + externalIP netip.Addr + // epochStateLost is set when epoch indicates server restart. + epochStateLost bool +} + +// NewClient creates a new PCP client for the gateway at the given IP. +func NewClient(gateway net.IP) *Client { + addr, ok := netip.AddrFromSlice(gateway) + if !ok { + log.Debugf("invalid gateway IP: %v", gateway) + } + return &Client{ + gateway: addr.Unmap(), + timeout: defaultTimeout, + } +} + +// NewClientWithTimeout creates a new PCP client with a custom timeout. +func NewClientWithTimeout(gateway net.IP, timeout time.Duration) *Client { + addr, ok := netip.AddrFromSlice(gateway) + if !ok { + log.Debugf("invalid gateway IP: %v", gateway) + } + return &Client{ + gateway: addr.Unmap(), + timeout: timeout, + } +} + +// SetLocalIP sets the local IP address to use in PCP requests. +func (c *Client) SetLocalIP(ip net.IP) { + addr, ok := netip.AddrFromSlice(ip) + if !ok { + log.Debugf("invalid local IP: %v", ip) + } + c.mu.Lock() + c.localIP = addr.Unmap() + c.mu.Unlock() +} + +// Gateway returns the gateway IP address. +func (c *Client) Gateway() net.IP { + return c.gateway.AsSlice() +} + +// Announce sends a PCP ANNOUNCE request to discover PCP support. +// Returns the server's epoch time on success. +func (c *Client) Announce(ctx context.Context) (epoch uint32, err error) { + localIP, err := c.getLocalIP() + if err != nil { + return 0, fmt.Errorf("get local IP: %w", err) + } + + req := buildAnnounceRequest(localIP) + resp, err := c.sendRequest(ctx, req) + if err != nil { + return 0, fmt.Errorf("send announce: %w", err) + } + + parsed, err := parseResponse(resp) + if err != nil { + return 0, fmt.Errorf("parse announce response: %w", err) + } + + if parsed.ResultCode != ResultSuccess { + return 0, fmt.Errorf("PCP ANNOUNCE failed: %s", ResultCodeString(parsed.ResultCode)) + } + + c.mu.Lock() + if c.updateEpochLocked(parsed.Epoch) { + log.Warnf("PCP server epoch indicates state loss - mappings may need refresh") + } + c.mu.Unlock() + return parsed.Epoch, nil +} + +// AddPortMapping requests a port mapping from the PCP server. +func (c *Client) AddPortMapping(ctx context.Context, protocol string, internalPort int, lifetime time.Duration) (*MapResponse, error) { + return c.addPortMappingWithHint(ctx, protocol, internalPort, internalPort, netip.Addr{}, lifetime) +} + +// AddPortMappingWithHint requests a port mapping with suggested external port and IP. +// Use lifetime <= 0 to delete a mapping. +func (c *Client) AddPortMappingWithHint(ctx context.Context, protocol string, internalPort, suggestedExtPort int, suggestedExtIP net.IP, lifetime time.Duration) (*MapResponse, error) { + var extIP netip.Addr + if suggestedExtIP != nil { + var ok bool + extIP, ok = netip.AddrFromSlice(suggestedExtIP) + if !ok { + log.Debugf("invalid suggested external IP: %v", suggestedExtIP) + } + extIP = extIP.Unmap() + } + return c.addPortMappingWithHint(ctx, protocol, internalPort, suggestedExtPort, extIP, lifetime) +} + +func (c *Client) addPortMappingWithHint(ctx context.Context, protocol string, internalPort, suggestedExtPort int, suggestedExtIP netip.Addr, lifetime time.Duration) (*MapResponse, error) { + localIP, err := c.getLocalIP() + if err != nil { + return nil, fmt.Errorf("get local IP: %w", err) + } + + proto, err := protocolNumber(protocol) + if err != nil { + return nil, fmt.Errorf("parse protocol: %w", err) + } + + var nonce [12]byte + if _, err := rand.Read(nonce[:]); err != nil { + return nil, fmt.Errorf("generate nonce: %w", err) + } + + // Convert lifetime to seconds. Lifetime 0 means delete, so only apply + // default for positive durations that round to 0 seconds. + var lifetimeSec uint32 + if lifetime > 0 { + lifetimeSec = uint32(lifetime.Seconds()) + if lifetimeSec == 0 { + lifetimeSec = DefaultLifetime + } + } + + req := buildMapRequest(localIP, nonce, proto, uint16(internalPort), uint16(suggestedExtPort), suggestedExtIP, lifetimeSec) + + resp, err := c.sendRequest(ctx, req) + if err != nil { + return nil, fmt.Errorf("send map request: %w", err) + } + + mapResp, err := parseMapResponse(resp) + if err != nil { + return nil, fmt.Errorf("parse map response: %w", err) + } + + if mapResp.Nonce != nonce { + return nil, fmt.Errorf("nonce mismatch in response") + } + + if mapResp.Protocol != proto { + return nil, fmt.Errorf("protocol mismatch: requested %d, got %d", proto, mapResp.Protocol) + } + if mapResp.InternalPort != uint16(internalPort) { + return nil, fmt.Errorf("internal port mismatch: requested %d, got %d", internalPort, mapResp.InternalPort) + } + + if mapResp.ResultCode != ResultSuccess { + return nil, &Error{ + Code: mapResp.ResultCode, + Message: ResultCodeString(mapResp.ResultCode), + } + } + + c.mu.Lock() + if c.updateEpochLocked(mapResp.Epoch) { + log.Warnf("PCP server epoch indicates state loss - mappings may need refresh") + } + c.cacheExternalIPLocked(mapResp.ExternalIP) + c.mu.Unlock() + return mapResp, nil +} + +// DeletePortMapping removes a port mapping by requesting zero lifetime. +func (c *Client) DeletePortMapping(ctx context.Context, protocol string, internalPort int) error { + if _, err := c.addPortMappingWithHint(ctx, protocol, internalPort, 0, netip.Addr{}, 0); err != nil { + var pcpErr *Error + if errors.As(err, &pcpErr) && pcpErr.Code == ResultNotAuthorized { + return nil + } + return fmt.Errorf("delete mapping: %w", err) + } + return nil +} + +// GetExternalAddress returns the external IP address. +// First checks for a cached value from previous MAP responses. +// If not cached, creates a short-lived mapping to discover the external IP. +func (c *Client) GetExternalAddress(ctx context.Context) (net.IP, error) { + c.mu.Lock() + if c.externalIP.IsValid() { + ip := c.externalIP.AsSlice() + c.mu.Unlock() + return ip, nil + } + c.mu.Unlock() + + // Use an ephemeral port in the dynamic range (49152-65535). + // Port 0 is not valid with UDP/TCP protocols per RFC 6887. + ephemeralPort := 49152 + int(uint16(time.Now().UnixNano()))%(65535-49152) + + // Use minimal lifetime (1 second) for discovery. + resp, err := c.AddPortMapping(ctx, "udp", ephemeralPort, time.Second) + if err != nil { + return nil, fmt.Errorf("create temporary mapping: %w", err) + } + + if err := c.DeletePortMapping(ctx, "udp", ephemeralPort); err != nil { + log.Debugf("cleanup temporary PCP mapping: %v", err) + } + + return resp.ExternalIP.AsSlice(), nil +} + +// LastEpoch returns the last observed server epoch value. +// A decrease in epoch indicates the server may have restarted and mappings may be lost. +func (c *Client) LastEpoch() uint32 { + c.mu.Lock() + defer c.mu.Unlock() + return c.lastEpoch +} + +// EpochStateLost returns true if epoch state loss was detected and clears the flag. +func (c *Client) EpochStateLost() bool { + c.mu.Lock() + defer c.mu.Unlock() + lost := c.epochStateLost + c.epochStateLost = false + return lost +} + +// updateEpoch updates the epoch tracking and detects potential state loss. +// Returns true if state loss was detected (server likely restarted). +// Caller must hold c.mu. +func (c *Client) updateEpochLocked(newEpoch uint32) bool { + now := time.Now() + stateLost := false + + // RFC 6887 Section 8.5: Detect invalid epoch indicating server state loss. + // client_delta = time since last response + // server_delta = epoch change since last response + // Invalid if: client_delta+2 < server_delta - server_delta/16 + // OR: server_delta+2 < client_delta - client_delta/16 + // The +2 handles quantization, /16 (6.25%) handles clock drift. + if !c.epochTime.IsZero() && c.lastEpoch > 0 { + clientDelta := uint32(now.Sub(c.epochTime).Seconds()) + serverDelta := newEpoch - c.lastEpoch + + // Check for epoch going backwards or jumping unexpectedly. + // Subtraction is safe: serverDelta/16 is always <= serverDelta. + if clientDelta+2 < serverDelta-(serverDelta/16) || + serverDelta+2 < clientDelta-(clientDelta/16) { + stateLost = true + c.epochStateLost = true + } + } + + c.lastEpoch = newEpoch + c.epochTime = now + return stateLost +} + +// cacheExternalIP stores the external IP from a successful MAP response. +// Caller must hold c.mu. +func (c *Client) cacheExternalIPLocked(ip netip.Addr) { + if ip.IsValid() && !ip.IsUnspecified() { + c.externalIP = ip + } +} + +// sendRequest sends a PCP request with retries per RFC 6887 Section 8.1.1. +func (c *Client) sendRequest(ctx context.Context, req []byte) ([]byte, error) { + addr := &net.UDPAddr{IP: c.gateway.AsSlice(), Port: Port} + + var lastErr error + delay := initialRetryDelay + + for range maxRetries { + resp, err := c.sendOnce(ctx, addr, req) + if err == nil { + return resp, nil + } + lastErr = err + + if ctx.Err() != nil { + return nil, ctx.Err() + } + + // RFC 6887 Section 8.1.1: RT = (1 + RAND) * MIN(2 * RTprev, MRT) + // RAND is random between -0.1 and +0.1 + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(retryDelayWithJitter(delay)): + } + delay = min(delay*2, maxRetryDelay) + } + + return nil, fmt.Errorf("PCP request failed after %d retries: %w", maxRetries, lastErr) +} + +// retryDelayWithJitter applies RFC 6887 jitter: multiply by (1 + RAND) where RAND is [-0.1, +0.1]. +func retryDelayWithJitter(d time.Duration) time.Duration { + var b [1]byte + _, _ = rand.Read(b[:]) + // Convert byte to range [-0.1, +0.1]: (b/255 * 0.2) - 0.1 + jitter := (float64(b[0])/255.0)*0.2 - 0.1 + return time.Duration(float64(d) * (1 + jitter)) +} + +func (c *Client) sendOnce(ctx context.Context, addr *net.UDPAddr, req []byte) ([]byte, error) { + // Use ListenUDP instead of DialUDP to validate response source address per RFC 6887 §8.3. + conn, err := net.ListenUDP("udp", nil) + if err != nil { + return nil, fmt.Errorf("listen: %w", err) + } + defer func() { + if err := conn.Close(); err != nil { + log.Debugf("close UDP connection: %v", err) + } + }() + + timeout := c.timeout + if deadline, ok := ctx.Deadline(); ok { + if remaining := time.Until(deadline); remaining < timeout { + timeout = remaining + } + } + + if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil { + return nil, fmt.Errorf("set deadline: %w", err) + } + + if _, err := conn.WriteToUDP(req, addr); err != nil { + return nil, fmt.Errorf("write: %w", err) + } + + resp := make([]byte, responseBufferSize) + n, from, err := conn.ReadFromUDP(resp) + if err != nil { + return nil, fmt.Errorf("read: %w", err) + } + + // RFC 6887 §8.3: Validate response came from expected PCP server. + if !from.IP.Equal(addr.IP) { + return nil, fmt.Errorf("response from unexpected source %s (expected %s)", from.IP, addr.IP) + } + + return resp[:n], nil +} + +func (c *Client) getLocalIP() (netip.Addr, error) { + c.mu.Lock() + defer c.mu.Unlock() + + if !c.localIP.IsValid() { + return netip.Addr{}, fmt.Errorf("local IP not set for gateway %s", c.gateway) + } + return c.localIP, nil +} + +func protocolNumber(protocol string) (uint8, error) { + switch protocol { + case "udp", "UDP": + return ProtoUDP, nil + case "tcp", "TCP": + return ProtoTCP, nil + default: + return 0, fmt.Errorf("unsupported protocol: %s", protocol) + } +} + +// Error represents a PCP error response. +type Error struct { + Code uint8 + Message string +} + +func (e *Error) Error() string { + return fmt.Sprintf("PCP error: %s (%d)", e.Message, e.Code) +} diff --git a/client/internal/portforward/pcp/client_test.go b/client/internal/portforward/pcp/client_test.go new file mode 100644 index 000000000..79f44a426 --- /dev/null +++ b/client/internal/portforward/pcp/client_test.go @@ -0,0 +1,187 @@ +package pcp + +import ( + "context" + "net" + "net/netip" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAddrConversion(t *testing.T) { + tests := []struct { + name string + addr netip.Addr + }{ + {"IPv4", netip.MustParseAddr("192.168.1.100")}, + {"IPv4 loopback", netip.MustParseAddr("127.0.0.1")}, + {"IPv6", netip.MustParseAddr("2001:db8::1")}, + {"IPv6 loopback", netip.MustParseAddr("::1")}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + b16 := addrTo16(tt.addr) + + recovered := addrFrom16(b16) + assert.Equal(t, tt.addr, recovered, "address should round-trip") + }) + } +} + +func TestBuildAnnounceRequest(t *testing.T) { + clientIP := netip.MustParseAddr("192.168.1.100") + req := buildAnnounceRequest(clientIP) + + require.Len(t, req, headerSize) + assert.Equal(t, byte(Version), req[0], "version") + assert.Equal(t, byte(OpAnnounce), req[1], "opcode") + + // Check client IP is properly encoded as IPv4-mapped IPv6 + assert.Equal(t, byte(0xff), req[18], "IPv4-mapped prefix byte 10") + assert.Equal(t, byte(0xff), req[19], "IPv4-mapped prefix byte 11") + assert.Equal(t, byte(192), req[20], "IP octet 1") + assert.Equal(t, byte(168), req[21], "IP octet 2") + assert.Equal(t, byte(1), req[22], "IP octet 3") + assert.Equal(t, byte(100), req[23], "IP octet 4") +} + +func TestBuildMapRequest(t *testing.T) { + clientIP := netip.MustParseAddr("192.168.1.100") + nonce := [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12} + req := buildMapRequest(clientIP, nonce, ProtoUDP, 51820, 51820, netip.Addr{}, 3600) + + require.Len(t, req, mapRequestSize) + assert.Equal(t, byte(Version), req[0], "version") + assert.Equal(t, byte(OpMap), req[1], "opcode") + + // Lifetime at bytes 4-7 + assert.Equal(t, uint32(3600), (uint32(req[4])<<24)|(uint32(req[5])<<16)|(uint32(req[6])<<8)|uint32(req[7]), "lifetime") + + // Nonce at bytes 24-35 + assert.Equal(t, nonce[:], req[24:36], "nonce") + + // Protocol at byte 36 + assert.Equal(t, byte(ProtoUDP), req[36], "protocol") + + // Internal port at bytes 40-41 + assert.Equal(t, uint16(51820), (uint16(req[40])<<8)|uint16(req[41]), "internal port") + + // External port at bytes 42-43 + assert.Equal(t, uint16(51820), (uint16(req[42])<<8)|uint16(req[43]), "external port") +} + +func TestParseResponse(t *testing.T) { + // Construct a valid ANNOUNCE response + resp := make([]byte, headerSize) + resp[0] = Version + resp[1] = OpAnnounce | OpReply + // Result code = 0 (success) + // Lifetime = 0 + // Epoch = 12345 + resp[8] = 0 + resp[9] = 0 + resp[10] = 0x30 + resp[11] = 0x39 + + parsed, err := parseResponse(resp) + require.NoError(t, err) + assert.Equal(t, uint8(Version), parsed.Version) + assert.Equal(t, uint8(OpAnnounce|OpReply), parsed.Opcode) + assert.Equal(t, uint8(ResultSuccess), parsed.ResultCode) + assert.Equal(t, uint32(12345), parsed.Epoch) +} + +func TestParseResponseErrors(t *testing.T) { + t.Run("too short", func(t *testing.T) { + _, err := parseResponse([]byte{1, 2, 3}) + assert.Error(t, err) + }) + + t.Run("wrong version", func(t *testing.T) { + resp := make([]byte, headerSize) + resp[0] = 1 // Wrong version + resp[1] = OpReply + _, err := parseResponse(resp) + assert.Error(t, err) + }) + + t.Run("missing reply bit", func(t *testing.T) { + resp := make([]byte, headerSize) + resp[0] = Version + resp[1] = OpAnnounce // Missing OpReply bit + _, err := parseResponse(resp) + assert.Error(t, err) + }) +} + +func TestResultCodeString(t *testing.T) { + assert.Equal(t, "SUCCESS", ResultCodeString(ResultSuccess)) + assert.Equal(t, "NOT_AUTHORIZED", ResultCodeString(ResultNotAuthorized)) + assert.Equal(t, "ADDRESS_MISMATCH", ResultCodeString(ResultAddressMismatch)) + assert.Contains(t, ResultCodeString(255), "UNKNOWN") +} + +func TestProtocolNumber(t *testing.T) { + proto, err := protocolNumber("udp") + require.NoError(t, err) + assert.Equal(t, uint8(ProtoUDP), proto) + + proto, err = protocolNumber("tcp") + require.NoError(t, err) + assert.Equal(t, uint8(ProtoTCP), proto) + + proto, err = protocolNumber("UDP") + require.NoError(t, err) + assert.Equal(t, uint8(ProtoUDP), proto) + + _, err = protocolNumber("icmp") + assert.Error(t, err) +} + +func TestClientCreation(t *testing.T) { + gateway := netip.MustParseAddr("192.168.1.1").AsSlice() + + client := NewClient(gateway) + assert.Equal(t, net.IP(gateway), client.Gateway()) + assert.Equal(t, defaultTimeout, client.timeout) + + clientWithTimeout := NewClientWithTimeout(gateway, 5*time.Second) + assert.Equal(t, 5*time.Second, clientWithTimeout.timeout) +} + +func TestNATType(t *testing.T) { + n := NewNAT(netip.MustParseAddr("192.168.1.1").AsSlice(), netip.MustParseAddr("192.168.1.100").AsSlice()) + assert.Equal(t, "PCP", n.Type()) +} + +// Integration test - skipped unless PCP_TEST_GATEWAY env is set +func TestClientIntegration(t *testing.T) { + t.Skip("Integration test - run manually with PCP_TEST_GATEWAY=") + + gateway := netip.MustParseAddr("10.0.1.1").AsSlice() // Change to your test gateway + localIP := netip.MustParseAddr("10.0.1.100").AsSlice() // Change to your local IP + + client := NewClient(gateway) + client.SetLocalIP(localIP) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // Test ANNOUNCE + epoch, err := client.Announce(ctx) + require.NoError(t, err) + t.Logf("Server epoch: %d", epoch) + + // Test MAP + resp, err := client.AddPortMapping(ctx, "udp", 51820, 1*time.Hour) + require.NoError(t, err) + t.Logf("Mapping: internal=%d external=%d externalIP=%s", + resp.InternalPort, resp.ExternalPort, resp.ExternalIP) + + // Cleanup + err = client.DeletePortMapping(ctx, "udp", 51820) + require.NoError(t, err) +} diff --git a/client/internal/portforward/pcp/nat.go b/client/internal/portforward/pcp/nat.go new file mode 100644 index 000000000..1dc24274b --- /dev/null +++ b/client/internal/portforward/pcp/nat.go @@ -0,0 +1,209 @@ +package pcp + +import ( + "context" + "fmt" + "net" + "net/netip" + "sync" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/libp2p/go-nat" + "github.com/libp2p/go-netroute" +) + +var _ nat.NAT = (*NAT)(nil) + +// NAT implements the go-nat NAT interface using PCP. +// Supports dual-stack (IPv4 and IPv6) when available. +// All methods are safe for concurrent use. +// +// TODO: IPv6 pinholes use the local IPv6 address. If the address changes +// (e.g., due to SLAAC rotation or network change), the pinhole becomes stale +// and needs to be recreated with the new address. +type NAT struct { + client *Client + + mu sync.RWMutex + // client6 is the IPv6 PCP client, nil if IPv6 is unavailable. + client6 *Client + // localIP6 caches the local IPv6 address used for PCP requests. + localIP6 netip.Addr +} + +// NewNAT creates a new NAT instance backed by PCP. +func NewNAT(gateway, localIP net.IP) *NAT { + client := NewClient(gateway) + client.SetLocalIP(localIP) + return &NAT{ + client: client, + } +} + +// Type returns "PCP" as the NAT type. +func (n *NAT) Type() string { + return "PCP" +} + +// GetDeviceAddress returns the gateway IP address. +func (n *NAT) GetDeviceAddress() (net.IP, error) { + return n.client.Gateway(), nil +} + +// GetExternalAddress returns the external IP address. +func (n *NAT) GetExternalAddress() (net.IP, error) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + return n.client.GetExternalAddress(ctx) +} + +// GetInternalAddress returns the local IP address used to communicate with the gateway. +func (n *NAT) GetInternalAddress() (net.IP, error) { + addr, err := n.client.getLocalIP() + if err != nil { + return nil, err + } + return addr.AsSlice(), nil +} + +// AddPortMapping creates a port mapping on both IPv4 and IPv6 (if available). +func (n *NAT) AddPortMapping(ctx context.Context, protocol string, internalPort int, _ string, timeout time.Duration) (int, error) { + resp, err := n.client.AddPortMapping(ctx, protocol, internalPort, timeout) + if err != nil { + return 0, fmt.Errorf("add mapping: %w", err) + } + + n.mu.RLock() + client6 := n.client6 + localIP6 := n.localIP6 + n.mu.RUnlock() + + if client6 == nil { + return int(resp.ExternalPort), nil + } + + if _, err := client6.AddPortMapping(ctx, protocol, internalPort, timeout); err != nil { + log.Warnf("IPv6 PCP mapping failed (continuing with IPv4): %v", err) + return int(resp.ExternalPort), nil + } + + log.Infof("created IPv6 PCP pinhole: %s:%d", localIP6, internalPort) + return int(resp.ExternalPort), nil +} + +// DeletePortMapping removes a port mapping from both IPv4 and IPv6. +func (n *NAT) DeletePortMapping(ctx context.Context, protocol string, internalPort int) error { + err := n.client.DeletePortMapping(ctx, protocol, internalPort) + + n.mu.RLock() + client6 := n.client6 + n.mu.RUnlock() + + if client6 != nil { + if err6 := client6.DeletePortMapping(ctx, protocol, internalPort); err6 != nil { + log.Warnf("IPv6 PCP delete mapping failed: %v", err6) + } + } + + if err != nil { + return fmt.Errorf("delete mapping: %w", err) + } + return nil +} + +// CheckServerHealth sends an ANNOUNCE to verify the server is still responsive. +// Returns the current epoch and whether the server may have restarted (epoch state loss detected). +func (n *NAT) CheckServerHealth(ctx context.Context) (epoch uint32, serverRestarted bool, err error) { + epoch, err = n.client.Announce(ctx) + if err != nil { + return 0, false, fmt.Errorf("announce: %w", err) + } + return epoch, n.client.EpochStateLost(), nil +} + +// DiscoverPCP attempts to discover a PCP-capable gateway. +// Returns a NAT interface if PCP is supported, or an error otherwise. +// Discovers both IPv4 and IPv6 gateways when available. +func DiscoverPCP(ctx context.Context) (nat.NAT, error) { + gateway, localIP, err := getDefaultGateway() + if err != nil { + return nil, fmt.Errorf("get default gateway: %w", err) + } + + client := NewClient(gateway) + client.SetLocalIP(localIP) + if _, err := client.Announce(ctx); err != nil { + return nil, fmt.Errorf("PCP announce: %w", err) + } + + result := &NAT{client: client} + discoverIPv6(ctx, result) + + return result, nil +} + +func discoverIPv6(ctx context.Context, result *NAT) { + gateway6, localIP6, err := getDefaultGateway6() + if err != nil { + log.Debugf("IPv6 gateway discovery failed: %v", err) + return + } + + client6 := NewClient(gateway6) + client6.SetLocalIP(localIP6) + if _, err := client6.Announce(ctx); err != nil { + log.Debugf("PCP IPv6 announce failed: %v", err) + return + } + + addr, ok := netip.AddrFromSlice(localIP6) + if !ok { + log.Debugf("invalid IPv6 local IP: %v", localIP6) + return + } + result.mu.Lock() + result.client6 = client6 + result.localIP6 = addr + result.mu.Unlock() + log.Debugf("PCP IPv6 gateway discovered: %s (local: %s)", gateway6, localIP6) +} + +// getDefaultGateway returns the default IPv4 gateway and local IP using the system routing table. +func getDefaultGateway() (gateway net.IP, localIP net.IP, err error) { + router, err := netroute.New() + if err != nil { + return nil, nil, err + } + + _, gateway, localIP, err = router.Route(net.IPv4zero) + if err != nil { + return nil, nil, err + } + + if gateway == nil { + return nil, nil, nat.ErrNoNATFound + } + + return gateway, localIP, nil +} + +// getDefaultGateway6 returns the default IPv6 gateway IP address using the system routing table. +func getDefaultGateway6() (gateway net.IP, localIP net.IP, err error) { + router, err := netroute.New() + if err != nil { + return nil, nil, err + } + + _, gateway, localIP, err = router.Route(net.IPv6zero) + if err != nil { + return nil, nil, err + } + + if gateway == nil { + return nil, nil, nat.ErrNoNATFound + } + + return gateway, localIP, nil +} diff --git a/client/internal/portforward/pcp/protocol.go b/client/internal/portforward/pcp/protocol.go new file mode 100644 index 000000000..d81c50c8c --- /dev/null +++ b/client/internal/portforward/pcp/protocol.go @@ -0,0 +1,225 @@ +// Package pcp implements the Port Control Protocol (RFC 6887). +// +// # Implemented Features +// +// - ANNOUNCE opcode: Discovers PCP server support +// - MAP opcode: Creates/deletes port mappings (IPv4 NAT) and firewall pinholes (IPv6) +// - Dual-stack: Simultaneous IPv4 and IPv6 support via separate clients +// - Nonce validation: Prevents response spoofing +// - Epoch tracking: Detects server restarts per Section 8.5 +// - RFC-compliant retry timing: 3s initial, exponential backoff to 1024s max (Section 8.1.1) +// +// # Not Implemented +// +// - PEER opcode: For outbound peer connections (not needed for inbound NAT traversal) +// - THIRD_PARTY option: For managing mappings on behalf of other devices +// - PREFER_FAILURE option: Requires exact external port or fail (IPv4 NAT only, not needed for IPv6 pinholing) +// - FILTER option: To restrict remote peer addresses +// +// These optional features are omitted because the primary use case is simple +// port forwarding for WireGuard, which only requires MAP with default behavior. +package pcp + +import ( + "encoding/binary" + "fmt" + "net/netip" +) + +const ( + // Version is the PCP protocol version (RFC 6887). + Version = 2 + + // Port is the standard PCP server port. + Port = 5351 + + // DefaultLifetime is the default requested mapping lifetime in seconds. + DefaultLifetime = 7200 // 2 hours + + // Header sizes + headerSize = 24 + mapPayloadSize = 36 + mapRequestSize = headerSize + mapPayloadSize // 60 bytes +) + +// Opcodes +const ( + OpAnnounce = 0 + OpMap = 1 + OpPeer = 2 + OpReply = 0x80 // OR'd with opcode in responses +) + +// Protocol numbers for MAP requests +const ( + ProtoUDP = 17 + ProtoTCP = 6 +) + +// Result codes (RFC 6887 Section 7.4) +const ( + ResultSuccess = 0 + ResultUnsuppVersion = 1 + ResultNotAuthorized = 2 + ResultMalformedRequest = 3 + ResultUnsuppOpcode = 4 + ResultUnsuppOption = 5 + ResultMalformedOption = 6 + ResultNetworkFailure = 7 + ResultNoResources = 8 + ResultUnsuppProtocol = 9 + ResultUserExQuota = 10 + ResultCannotProvideExt = 11 + ResultAddressMismatch = 12 + ResultExcessiveRemotePeers = 13 +) + +// ResultCodeString returns a human-readable string for a result code. +func ResultCodeString(code uint8) string { + switch code { + case ResultSuccess: + return "SUCCESS" + case ResultUnsuppVersion: + return "UNSUPP_VERSION" + case ResultNotAuthorized: + return "NOT_AUTHORIZED" + case ResultMalformedRequest: + return "MALFORMED_REQUEST" + case ResultUnsuppOpcode: + return "UNSUPP_OPCODE" + case ResultUnsuppOption: + return "UNSUPP_OPTION" + case ResultMalformedOption: + return "MALFORMED_OPTION" + case ResultNetworkFailure: + return "NETWORK_FAILURE" + case ResultNoResources: + return "NO_RESOURCES" + case ResultUnsuppProtocol: + return "UNSUPP_PROTOCOL" + case ResultUserExQuota: + return "USER_EX_QUOTA" + case ResultCannotProvideExt: + return "CANNOT_PROVIDE_EXTERNAL" + case ResultAddressMismatch: + return "ADDRESS_MISMATCH" + case ResultExcessiveRemotePeers: + return "EXCESSIVE_REMOTE_PEERS" + default: + return fmt.Sprintf("UNKNOWN(%d)", code) + } +} + +// Response represents a parsed PCP response header. +type Response struct { + Version uint8 + Opcode uint8 + ResultCode uint8 + Lifetime uint32 + Epoch uint32 +} + +// MapResponse contains the full response to a MAP request. +type MapResponse struct { + Response + Nonce [12]byte + Protocol uint8 + InternalPort uint16 + ExternalPort uint16 + ExternalIP netip.Addr +} + +// addrTo16 converts an address to its 16-byte IPv4-mapped IPv6 representation. +func addrTo16(addr netip.Addr) [16]byte { + if addr.Is4() { + return netip.AddrFrom4(addr.As4()).As16() + } + return addr.As16() +} + +// addrFrom16 extracts an address from a 16-byte representation, unmapping IPv4. +func addrFrom16(b [16]byte) netip.Addr { + return netip.AddrFrom16(b).Unmap() +} + +// buildAnnounceRequest creates a PCP ANNOUNCE request packet. +func buildAnnounceRequest(clientIP netip.Addr) []byte { + req := make([]byte, headerSize) + req[0] = Version + req[1] = OpAnnounce + mapped := addrTo16(clientIP) + copy(req[8:24], mapped[:]) + return req +} + +// buildMapRequest creates a PCP MAP request packet. +func buildMapRequest(clientIP netip.Addr, nonce [12]byte, protocol uint8, internalPort, suggestedExtPort uint16, suggestedExtIP netip.Addr, lifetime uint32) []byte { + req := make([]byte, mapRequestSize) + + // Header + req[0] = Version + req[1] = OpMap + binary.BigEndian.PutUint32(req[4:8], lifetime) + mapped := addrTo16(clientIP) + copy(req[8:24], mapped[:]) + + // MAP payload + copy(req[24:36], nonce[:]) + req[36] = protocol + binary.BigEndian.PutUint16(req[40:42], internalPort) + binary.BigEndian.PutUint16(req[42:44], suggestedExtPort) + if suggestedExtIP.IsValid() { + extMapped := addrTo16(suggestedExtIP) + copy(req[44:60], extMapped[:]) + } + + return req +} + +// parseResponse parses the common PCP response header. +func parseResponse(data []byte) (*Response, error) { + if len(data) < headerSize { + return nil, fmt.Errorf("response too short: %d bytes", len(data)) + } + + resp := &Response{ + Version: data[0], + Opcode: data[1], + ResultCode: data[3], // Byte 2 is reserved, byte 3 is result code (RFC 6887 §7.2) + Lifetime: binary.BigEndian.Uint32(data[4:8]), + Epoch: binary.BigEndian.Uint32(data[8:12]), + } + + if resp.Version != Version { + return nil, fmt.Errorf("unsupported PCP version: %d", resp.Version) + } + + if resp.Opcode&OpReply == 0 { + return nil, fmt.Errorf("response missing reply bit: opcode=0x%02x", resp.Opcode) + } + + return resp, nil +} + +// parseMapResponse parses a complete MAP response. +func parseMapResponse(data []byte) (*MapResponse, error) { + if len(data) < mapRequestSize { + return nil, fmt.Errorf("MAP response too short: %d bytes", len(data)) + } + + resp, err := parseResponse(data) + if err != nil { + return nil, fmt.Errorf("parse header: %w", err) + } + + mapResp := &MapResponse{ + Response: *resp, + Protocol: data[36], + InternalPort: binary.BigEndian.Uint16(data[40:42]), + ExternalPort: binary.BigEndian.Uint16(data[42:44]), + ExternalIP: addrFrom16([16]byte(data[44:60])), + } + copy(mapResp.Nonce[:], data[24:36]) + + return mapResp, nil +} diff --git a/client/internal/portforward/state.go b/client/internal/portforward/state.go new file mode 100644 index 000000000..b1315cdc0 --- /dev/null +++ b/client/internal/portforward/state.go @@ -0,0 +1,63 @@ +//go:build !js + +package portforward + +import ( + "context" + "fmt" + + "github.com/libp2p/go-nat" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/portforward/pcp" +) + +// discoverGateway is the function used for NAT gateway discovery. +// It can be replaced in tests to avoid real network operations. +// Tries PCP first, then falls back to NAT-PMP/UPnP. +var discoverGateway = defaultDiscoverGateway + +func defaultDiscoverGateway(ctx context.Context) (nat.NAT, error) { + pcpGateway, err := pcp.DiscoverPCP(ctx) + if err == nil { + return pcpGateway, nil + } + log.Debugf("PCP discovery failed: %v, trying NAT-PMP/UPnP", err) + + return nat.DiscoverGateway(ctx) +} + +// State is persisted only for crash recovery cleanup +type State struct { + InternalPort uint16 `json:"internal_port,omitempty"` + Protocol string `json:"protocol,omitempty"` +} + +func (s *State) Name() string { + return "port_forward_state" +} + +// Cleanup implements statemanager.CleanableState for crash recovery +func (s *State) Cleanup() error { + if s.InternalPort == 0 { + return nil + } + + log.Infof("cleaning up stale port mapping for port %d", s.InternalPort) + + ctx, cancel := context.WithTimeout(context.Background(), discoveryTimeout) + defer cancel() + + gateway, err := discoverGateway(ctx) + if err != nil { + // Discovery failure is not an error - gateway may not exist + log.Debugf("cleanup: no gateway found: %v", err) + return nil + } + + if err := gateway.DeletePortMapping(ctx, s.Protocol, int(s.InternalPort)); err != nil { + return fmt.Errorf("delete port mapping: %w", err) + } + + return nil +}