diff --git a/.gitignore b/.gitignore index d14efa9..1a56bfa 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ nohup.out *.iml certs/ newt_arm64 +key \ No newline at end of file diff --git a/README.md b/README.md index 0370a76..2d06abf 100644 --- a/README.md +++ b/README.md @@ -47,13 +47,11 @@ When Newt receives WireGuard control messages, it will use the information encod - `docker-socket` (optional): Set the Docker socket to use the container discovery integration - `docker-enforce-network-validation` (optional): Validate the container target is on the same network as the newt process. Default: false -### Accpet Client Connection +### Client Connections -- `accept-clients` (optional): Enable WireGuard server mode to accept incoming newt client connections. Default: false - - `generateAndSaveKeyTo` (optional): Path to save generated private key - - `native` (optional): Use native WireGuard interface when accepting clients (requires WireGuard kernel module and Linux, must run as root). Default: false (uses userspace netstack) - - `interface` (optional): Name of the WireGuard interface. Default: newt - - `keep-interface` (optional): Keep the WireGuard interface. Default: false +- `disable-clients` (optional): Disable clients on the WireGuard interface. Default: false (clients enabled) +- `native` (optional): Use native WireGuard interface (requires WireGuard kernel module and Linux, must run as root). Default: false (uses userspace netstack) +- `interface` (optional): Name of the WireGuard interface. Default: newt ### Metrics & Observability @@ -73,9 +71,11 @@ When Newt receives WireGuard control messages, it will use the information encod ### Security & TLS - `enforce-hc-cert` (optional): Enforce certificate validation for health checks. Default: false (accepts any cert) -- `tls-client-cert` (optional): Client certificate (p12 or pfx) for mTLS or path to client certificate (PEM format). See [mTLS](#mtls) -- `tls-client-key` (optional): Path to private key for mTLS (PEM format, optional if using PKCS12) -- `tls-ca-cert` (optional): Path to CA certificate to verify server (PEM format, optional if using PKCS12) +- `tls-client-cert-file` (optional): Path to client certificate file (PEM/DER format) for mTLS. See [mTLS](#mtls) +- `tls-client-key` (optional): Path to client private key file (PEM/DER format) for mTLS +- `tls-client-ca` (optional): Path to CA certificate file for validating remote certificates (can be specified multiple times) +- `tls-client-cert` (optional): Path to client certificate (PKCS12 format) - DEPRECATED: use `--tls-client-cert-file` and `--tls-client-key` instead +- `prefer-endpoint` (optional): Prefer this endpoint for the connection (if set, will override the endpoint from the server) ### Monitoring & Health @@ -101,13 +101,11 @@ All CLI arguments can be set using environment variables as an alternative to co - `DOCKER_SOCKET`: Path to Docker socket for container discovery (equivalent to `--docker-socket`) - `DOCKER_ENFORCE_NETWORK_VALIDATION`: Validate container targets are on same network. Default: false (equivalent to `--docker-enforce-network-validation`) -### Accept Client Connections +### Client Connections -- `ACCEPT_CLIENTS`: Enable WireGuard server mode. Default: false (equivalent to `--accept-clients`) -- `GENERATE_AND_SAVE_KEY_TO`: Path to save generated private key (equivalent to `--generateAndSaveKeyTo`) +- `DISABLE_CLIENTS`: Disable clients on the WireGuard interface. Default: false (equivalent to `--disable-clients`) - `USE_NATIVE_INTERFACE`: Use native WireGuard interface (Linux only). Default: false (equivalent to `--native`) - `INTERFACE`: Name of the WireGuard interface. Default: newt (equivalent to `--interface`) -- `KEEP_INTERFACE`: Keep the WireGuard interface after shutdown. Default: false (equivalent to `--keep-interface`) ### Monitoring & Health @@ -132,10 +130,10 @@ All CLI arguments can be set using environment variables as an alternative to co ### Security & TLS - `ENFORCE_HC_CERT`: Enforce certificate validation for health checks. Default: false (equivalent to `--enforce-hc-cert`) -- `TLS_CLIENT_CERT`: Path to client certificate for mTLS (equivalent to `--tls-client-cert`) -- `TLS_CLIENT_KEY`: Path to private key for mTLS (equivalent to `--tls-client-key`) -- `TLS_CA_CERT`: Path to CA certificate to verify server (equivalent to `--tls-ca-cert`) -- `SKIP_TLS_VERIFY`: Skip TLS verification for server connections. Default: false +- `TLS_CLIENT_CERT`: Path to client certificate file (PEM/DER format) for mTLS (equivalent to `--tls-client-cert-file`) +- `TLS_CLIENT_KEY`: Path to client private key file (PEM/DER format) for mTLS (equivalent to `--tls-client-key`) +- `TLS_CLIENT_CAS`: Comma-separated list of CA certificate file paths for validating remote certificates (equivalent to multiple `--tls-client-ca` flags) +- `TLS_CLIENT_CERT_PKCS12`: Path to client certificate (PKCS12 format) - DEPRECATED: use `TLS_CLIENT_CERT` and `TLS_CLIENT_KEY` instead ## Loading secrets from files @@ -202,9 +200,9 @@ services: - --health-file /tmp/healthy ``` -## Accept Client Connections +## Client Connections -When the `--accept-clients` flag is enabled (or `ACCEPT_CLIENTS=true` environment variable is set), Newt operates as a WireGuard server that can accept incoming client connections from other devices. This enables peer-to-peer connectivity through the Newt instance. +By default, Newt can accept incoming client connections from other devices, enabling peer-to-peer connectivity through the Newt instance. This behavior can be disabled with the `--disable-clients` flag (or `DISABLE_CLIENTS=true` environment variable). ### How It Works @@ -260,7 +258,7 @@ To use native mode: 3. Run Newt as root (`sudo`) 4. Ensure the system allows creation of network interfaces -Docker Compose example: +Docker Compose example (with clients enabled by default): ```yaml services: @@ -272,7 +270,6 @@ services: - PANGOLIN_ENDPOINT=https://example.com - NEWT_ID=2ix2t8xk22ubpfy - NEWT_SECRET=nnisrfsdfc7prqsp9ewo1dvtvci50j5uiqotez00dgap0ii2 - - ACCEPT_CLIENTS=true ``` ### Technical Details @@ -394,9 +391,9 @@ newt \ You can now provide separate files for: -* `--tls-client-cert`: client certificate (`.crt` or `.pem`) +* `--tls-client-cert-file`: client certificate (`.crt` or `.pem`) * `--tls-client-key`: client private key (`.key` or `.pem`) -* `--tls-ca-cert`: CA cert to verify the server +* `--tls-client-ca`: CA cert to verify the server (can be specified multiple times) Example: @@ -405,9 +402,9 @@ newt \ --id 31frd0uzbjvp721 \ --secret h51mmlknrvrwv8s4r1i210azhumt6isgbpyavxodibx1k2d6 \ --endpoint https://example.com \ ---tls-client-cert ./client.crt \ +--tls-client-cert-file ./client.crt \ --tls-client-key ./client.key \ ---tls-ca-cert ./ca.crt +--tls-client-ca ./ca.crt ``` diff --git a/bind/shared_bind.go b/bind/shared_bind.go new file mode 100644 index 0000000..f266cb0 --- /dev/null +++ b/bind/shared_bind.go @@ -0,0 +1,675 @@ +//go:build !js + +package bind + +import ( + "bytes" + "fmt" + "net" + "net/netip" + "runtime" + "sync" + "sync/atomic" + + "github.com/fosrl/newt/logger" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" + wgConn "golang.zx2c4.com/wireguard/conn" +) + +// Magic packet constants for connection testing +// These packets are intercepted by SharedBind and responded to directly, +// without being passed to the WireGuard device. +var ( + // MagicTestRequest is the prefix for a test request packet + // Format: PANGOLIN_TEST_REQ + 8 bytes of random data (for echo) + MagicTestRequest = []byte("PANGOLIN_TEST_REQ") + + // MagicTestResponse is the prefix for a test response packet + // Format: PANGOLIN_TEST_RSP + 8 bytes echoed from request + MagicTestResponse = []byte("PANGOLIN_TEST_RSP") +) + +const ( + // MagicPacketDataLen is the length of random data included in test packets + MagicPacketDataLen = 8 + + // MagicTestRequestLen is the total length of a test request packet + MagicTestRequestLen = 17 + MagicPacketDataLen // len("PANGOLIN_TEST_REQ") + 8 + + // MagicTestResponseLen is the total length of a test response packet + MagicTestResponseLen = 17 + MagicPacketDataLen // len("PANGOLIN_TEST_RSP") + 8 +) + +// PacketSource identifies where a packet came from +type PacketSource uint8 + +const ( + SourceSocket PacketSource = iota // From physical UDP socket (hole-punched clients) + SourceNetstack // From netstack (relay through main tunnel) +) + +// SourceAwareEndpoint wraps an endpoint with source information +type SourceAwareEndpoint struct { + wgConn.Endpoint + source PacketSource +} + +// GetSource returns the source of this endpoint +func (e *SourceAwareEndpoint) GetSource() PacketSource { + return e.source +} + +// injectedPacket represents a packet injected into the SharedBind from an internal source +type injectedPacket struct { + data []byte + endpoint wgConn.Endpoint +} + +// Endpoint represents a network endpoint for the SharedBind +type Endpoint struct { + AddrPort netip.AddrPort +} + +// ClearSrc implements the wgConn.Endpoint interface +func (e *Endpoint) ClearSrc() {} + +// DstIP implements the wgConn.Endpoint interface +func (e *Endpoint) DstIP() netip.Addr { + return e.AddrPort.Addr() +} + +// SrcIP implements the wgConn.Endpoint interface +func (e *Endpoint) SrcIP() netip.Addr { + return netip.Addr{} +} + +// DstToBytes implements the wgConn.Endpoint interface +func (e *Endpoint) DstToBytes() []byte { + b, _ := e.AddrPort.MarshalBinary() + return b +} + +// DstToString implements the wgConn.Endpoint interface +func (e *Endpoint) DstToString() string { + return e.AddrPort.String() +} + +// SrcToString implements the wgConn.Endpoint interface +func (e *Endpoint) SrcToString() string { + return "" +} + +// SharedBind is a thread-safe UDP bind that can be shared between WireGuard +// and hole punch senders. It wraps a single UDP connection and implements +// reference counting to prevent premature closure. +// It also supports receiving packets from a netstack and routing responses +// back through the appropriate source. +type SharedBind struct { + mu sync.RWMutex + + // The underlying UDP connection (for hole-punched clients) + udpConn *net.UDPConn + + // IPv4 and IPv6 packet connections for advanced features + ipv4PC *ipv4.PacketConn + ipv6PC *ipv6.PacketConn + + // Reference counting to prevent closing while in use + refCount atomic.Int32 + closed atomic.Bool + + // Channels for receiving data + recvFuncs []wgConn.ReceiveFunc + + // Port binding information + port uint16 + + // Channel for packets from netstack (from direct relay) - larger buffer for throughput + netstackPackets chan injectedPacket + + // Netstack connection for sending responses back through the tunnel + // Using atomic.Pointer for lock-free access in hot path + netstackConn atomic.Pointer[net.PacketConn] + + // Track which endpoints came from netstack (key: netip.AddrPort, value: struct{}) + // Using netip.AddrPort directly as key is more efficient than string + netstackEndpoints sync.Map + + // Pre-allocated message buffers for batch operations (Linux only) + ipv4Msgs []ipv4.Message + + // Shutdown signal for receive goroutines + closeChan chan struct{} + + // Callback for magic test responses (used for holepunch testing) + magicResponseCallback atomic.Pointer[func(addr netip.AddrPort, echoData []byte)] +} + +// MagicResponseCallback is the function signature for magic packet response callbacks +type MagicResponseCallback func(addr netip.AddrPort, echoData []byte) + +// New creates a new SharedBind from an existing UDP connection. +// The SharedBind takes ownership of the connection and will close it +// when all references are released. +func New(udpConn *net.UDPConn) (*SharedBind, error) { + if udpConn == nil { + return nil, fmt.Errorf("udpConn cannot be nil") + } + + bind := &SharedBind{ + udpConn: udpConn, + netstackPackets: make(chan injectedPacket, 1024), // Larger buffer for better throughput + closeChan: make(chan struct{}), + } + + // Initialize reference count to 1 (the creator holds the first reference) + bind.refCount.Store(1) + + // Get the local port + if addr, ok := udpConn.LocalAddr().(*net.UDPAddr); ok { + bind.port = uint16(addr.Port) + } + + return bind, nil +} + +// SetNetstackConn sets the netstack connection for receiving/sending packets through the tunnel. +// This connection is used for relay traffic that should go back through the main tunnel. +func (b *SharedBind) SetNetstackConn(conn net.PacketConn) { + b.netstackConn.Store(&conn) +} + +// GetNetstackConn returns the netstack connection if set +func (b *SharedBind) GetNetstackConn() net.PacketConn { + ptr := b.netstackConn.Load() + if ptr == nil { + return nil + } + return *ptr +} + +// InjectPacket allows injecting a packet directly into the SharedBind's receive path. +// This is used for direct relay from netstack without going through the host network. +// The fromAddr should be the address the packet appears to come from. +func (b *SharedBind) InjectPacket(data []byte, fromAddr netip.AddrPort) error { + if b.closed.Load() { + return net.ErrClosed + } + + // Unmap IPv4-in-IPv6 addresses to ensure consistency with parsed endpoints + if fromAddr.Addr().Is4In6() { + fromAddr = netip.AddrPortFrom(fromAddr.Addr().Unmap(), fromAddr.Port()) + } + + // Track this endpoint as coming from netstack so responses go back the same way + // Use AddrPort directly as key (more efficient than string) + b.netstackEndpoints.Store(fromAddr, struct{}{}) + + // Make a copy of the data to avoid issues with buffer reuse + dataCopy := make([]byte, len(data)) + copy(dataCopy, data) + + select { + case b.netstackPackets <- injectedPacket{ + data: dataCopy, + endpoint: &wgConn.StdNetEndpoint{AddrPort: fromAddr}, + }: + return nil + case <-b.closeChan: + return net.ErrClosed + default: + // Channel full, drop the packet + return fmt.Errorf("netstack packet buffer full") + } +} + +// AddRef increments the reference count. Call this when sharing +// the bind with another component. +func (b *SharedBind) AddRef() { + newCount := b.refCount.Add(1) + // Optional: Add logging for debugging + _ = newCount // Placeholder for potential logging +} + +// Release decrements the reference count. When it reaches zero, +// the underlying UDP connection is closed. +func (b *SharedBind) Release() error { + newCount := b.refCount.Add(-1) + // Optional: Add logging for debugging + _ = newCount // Placeholder for potential logging + + if newCount < 0 { + // This should never happen with proper usage + b.refCount.Store(0) + return fmt.Errorf("SharedBind reference count went negative") + } + + if newCount == 0 { + return b.closeConnection() + } + + return nil +} + +// closeConnection actually closes the UDP connection +func (b *SharedBind) closeConnection() error { + if !b.closed.CompareAndSwap(false, true) { + // Already closed + return nil + } + + // Signal all goroutines to stop + close(b.closeChan) + + b.mu.Lock() + defer b.mu.Unlock() + + var err error + if b.udpConn != nil { + err = b.udpConn.Close() + b.udpConn = nil + } + + b.ipv4PC = nil + b.ipv6PC = nil + + // Clear netstack connection (but don't close it - it's managed externally) + b.netstackConn.Store(nil) + + // Clear tracked netstack endpoints + b.netstackEndpoints = sync.Map{} + + return err +} + +// ClearNetstackConn clears the netstack connection and tracked endpoints. +// Call this when stopping the relay. +func (b *SharedBind) ClearNetstackConn() { + b.netstackConn.Store(nil) + + // Clear tracked netstack endpoints + b.netstackEndpoints = sync.Map{} +} + +// GetUDPConn returns the underlying UDP connection. +// The caller must not close this connection directly. +func (b *SharedBind) GetUDPConn() *net.UDPConn { + b.mu.RLock() + defer b.mu.RUnlock() + return b.udpConn +} + +// GetRefCount returns the current reference count (for debugging) +func (b *SharedBind) GetRefCount() int32 { + return b.refCount.Load() +} + +// IsClosed returns whether the bind is closed +func (b *SharedBind) IsClosed() bool { + return b.closed.Load() +} + +// SetMagicResponseCallback sets a callback function that will be called when +// a magic test response packet is received. This is used for holepunch testing. +// Pass nil to clear the callback. +func (b *SharedBind) SetMagicResponseCallback(callback MagicResponseCallback) { + if callback == nil { + b.magicResponseCallback.Store(nil) + } else { + // Convert to the function type the atomic.Pointer expects + fn := func(addr netip.AddrPort, echoData []byte) { + callback(addr, echoData) + } + b.magicResponseCallback.Store(&fn) + } +} + +// WriteToUDP writes data to a specific UDP address. +// This is thread-safe and can be used by hole punch senders. +func (b *SharedBind) WriteToUDP(data []byte, addr *net.UDPAddr) (int, error) { + if b.closed.Load() { + return 0, net.ErrClosed + } + + b.mu.RLock() + conn := b.udpConn + b.mu.RUnlock() + + if conn == nil { + return 0, net.ErrClosed + } + + return conn.WriteToUDP(data, addr) +} + +// Close implements the WireGuard Bind interface. +// It decrements the reference count and closes the connection if no references remain. +func (b *SharedBind) Close() error { + return b.Release() +} + +// Open implements the WireGuard Bind interface. +// Since the connection is already open, this just sets up the receive functions. +func (b *SharedBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) { + if b.closed.Load() { + return nil, 0, net.ErrClosed + } + + b.mu.Lock() + defer b.mu.Unlock() + + if b.udpConn == nil { + return nil, 0, net.ErrClosed + } + + // Set up IPv4 and IPv6 packet connections for advanced features + if runtime.GOOS == "linux" || runtime.GOOS == "android" { + b.ipv4PC = ipv4.NewPacketConn(b.udpConn) + b.ipv6PC = ipv6.NewPacketConn(b.udpConn) + + // Pre-allocate message buffers for batch operations + batchSize := wgConn.IdealBatchSize + b.ipv4Msgs = make([]ipv4.Message, batchSize) + for i := range b.ipv4Msgs { + b.ipv4Msgs[i].OOB = make([]byte, 0) + } + } + + // Create receive functions - one for socket, one for netstack + recvFuncs := make([]wgConn.ReceiveFunc, 0, 2) + + // Add socket receive function (reads from physical UDP socket) + recvFuncs = append(recvFuncs, b.makeReceiveSocket()) + + // Add netstack receive function (reads from injected packets channel) + recvFuncs = append(recvFuncs, b.makeReceiveNetstack()) + + b.recvFuncs = recvFuncs + return recvFuncs, b.port, nil +} + +// makeReceiveSocket creates a receive function for physical UDP socket packets +func (b *SharedBind) makeReceiveSocket() wgConn.ReceiveFunc { + return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) { + if b.closed.Load() { + return 0, net.ErrClosed + } + + b.mu.RLock() + conn := b.udpConn + pc := b.ipv4PC + b.mu.RUnlock() + + if conn == nil { + return 0, net.ErrClosed + } + + // Use batch reading on Linux for performance + if pc != nil && (runtime.GOOS == "linux" || runtime.GOOS == "android") { + return b.receiveIPv4Batch(pc, bufs, sizes, eps) + } + return b.receiveIPv4Simple(conn, bufs, sizes, eps) + } +} + +// makeReceiveNetstack creates a receive function for netstack-injected packets +func (b *SharedBind) makeReceiveNetstack() wgConn.ReceiveFunc { + return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) { + select { + case <-b.closeChan: + return 0, net.ErrClosed + case pkt := <-b.netstackPackets: + if len(pkt.data) <= len(bufs[0]) { + copy(bufs[0], pkt.data) + sizes[0] = len(pkt.data) + eps[0] = pkt.endpoint + return 1, nil + } + // Packet too large for buffer, skip it + return 0, nil + } + } +} + +// receiveIPv4Batch uses batch reading for better performance on Linux +func (b *SharedBind) receiveIPv4Batch(pc *ipv4.PacketConn, bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) { + // Use pre-allocated messages, just update buffer pointers + numBufs := len(bufs) + if numBufs > len(b.ipv4Msgs) { + numBufs = len(b.ipv4Msgs) + } + + for i := 0; i < numBufs; i++ { + b.ipv4Msgs[i].Buffers = [][]byte{bufs[i]} + } + + numMsgs, err := pc.ReadBatch(b.ipv4Msgs[:numBufs], 0) + if err != nil { + return 0, err + } + + // Process messages and filter out magic packets + writeIdx := 0 + for i := 0; i < numMsgs; i++ { + if b.ipv4Msgs[i].N == 0 { + continue + } + + // Check for magic packet + if b.ipv4Msgs[i].Addr != nil { + if udpAddr, ok := b.ipv4Msgs[i].Addr.(*net.UDPAddr); ok { + data := bufs[i][:b.ipv4Msgs[i].N] + if b.handleMagicPacket(data, udpAddr) { + // Magic packet handled, skip this message + continue + } + } + } + + // Not a magic packet, include in output + if writeIdx != i { + // Need to copy data to the correct position + copy(bufs[writeIdx], bufs[i][:b.ipv4Msgs[i].N]) + } + sizes[writeIdx] = b.ipv4Msgs[i].N + + if b.ipv4Msgs[i].Addr != nil { + if udpAddr, ok := b.ipv4Msgs[i].Addr.(*net.UDPAddr); ok { + addrPort := udpAddr.AddrPort() + // Unmap IPv4-in-IPv6 addresses to ensure consistency with parsed endpoints + if addrPort.Addr().Is4In6() { + addrPort = netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port()) + } + eps[writeIdx] = &wgConn.StdNetEndpoint{AddrPort: addrPort} + } + } + writeIdx++ + } + + return writeIdx, nil +} + +// receiveIPv4Simple uses simple ReadFromUDP for non-Linux platforms +func (b *SharedBind) receiveIPv4Simple(conn *net.UDPConn, bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) { + for { + n, addr, err := conn.ReadFromUDP(bufs[0]) + if err != nil { + return 0, err + } + + // Check for magic test packet and handle it directly + if b.handleMagicPacket(bufs[0][:n], addr) { + // Magic packet was handled, read another packet + continue + } + + sizes[0] = n + if addr != nil { + addrPort := addr.AddrPort() + // Unmap IPv4-in-IPv6 addresses to ensure consistency with parsed endpoints + if addrPort.Addr().Is4In6() { + addrPort = netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port()) + } + eps[0] = &wgConn.StdNetEndpoint{AddrPort: addrPort} + } + + return 1, nil + } +} + +// handleMagicPacket checks if the packet is a magic test packet and responds if so. +// Returns true if the packet was a magic packet and was handled (should not be passed to WireGuard). +func (b *SharedBind) handleMagicPacket(data []byte, addr *net.UDPAddr) bool { + // Check if this is a test request packet + if len(data) >= MagicTestRequestLen && bytes.HasPrefix(data, MagicTestRequest) { + logger.Debug("Received magic test REQUEST from %s, sending response", addr.String()) + // Extract the random data portion to echo back + echoData := data[len(MagicTestRequest) : len(MagicTestRequest)+MagicPacketDataLen] + + // Build response packet + response := make([]byte, MagicTestResponseLen) + copy(response, MagicTestResponse) + copy(response[len(MagicTestResponse):], echoData) + + // Send response back to sender + b.mu.RLock() + conn := b.udpConn + b.mu.RUnlock() + + if conn != nil { + _, _ = conn.WriteToUDP(response, addr) + } + + return true + } + + // Check if this is a test response packet + if len(data) >= MagicTestResponseLen && bytes.HasPrefix(data, MagicTestResponse) { + logger.Debug("Received magic test RESPONSE from %s", addr.String()) + // Extract the echoed data + echoData := data[len(MagicTestResponse) : len(MagicTestResponse)+MagicPacketDataLen] + + // Call the callback if set + callbackPtr := b.magicResponseCallback.Load() + if callbackPtr != nil { + callback := *callbackPtr + addrPort := addr.AddrPort() + // Unmap IPv4-in-IPv6 addresses to ensure consistency + if addrPort.Addr().Is4In6() { + addrPort = netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port()) + } + callback(addrPort, echoData) + } else { + logger.Debug("Magic response received but no callback registered") + } + + return true + } + + return false +} + +// Send implements the WireGuard Bind interface. +// It sends packets to the specified endpoint, routing through the appropriate +// source (netstack or physical socket) based on where the endpoint's packets came from. +func (b *SharedBind) Send(bufs [][]byte, ep wgConn.Endpoint) error { + if b.closed.Load() { + return net.ErrClosed + } + + // Extract the destination address from the endpoint + var destAddrPort netip.AddrPort + + // Try to cast to StdNetEndpoint first (most common case, avoid allocations) + if stdEp, ok := ep.(*wgConn.StdNetEndpoint); ok { + destAddrPort = stdEp.AddrPort + } else { + // Fallback: construct from DstIP and DstToBytes + dstBytes := ep.DstToBytes() + if len(dstBytes) >= 6 { // Minimum for IPv4 (4 bytes) + port (2 bytes) + var addr netip.Addr + var port uint16 + + if len(dstBytes) >= 18 { // IPv6 (16 bytes) + port (2 bytes) + addr, _ = netip.AddrFromSlice(dstBytes[:16]) + port = uint16(dstBytes[16]) | uint16(dstBytes[17])<<8 + } else { // IPv4 + addr, _ = netip.AddrFromSlice(dstBytes[:4]) + port = uint16(dstBytes[4]) | uint16(dstBytes[5])<<8 + } + + if addr.IsValid() { + destAddrPort = netip.AddrPortFrom(addr, port) + } + } + } + + if !destAddrPort.IsValid() { + return fmt.Errorf("could not extract destination address from endpoint") + } + + // Check if this endpoint came from netstack - if so, send through netstack + // Use AddrPort directly as key (more efficient than string conversion) + if _, isNetstackEndpoint := b.netstackEndpoints.Load(destAddrPort); isNetstackEndpoint { + connPtr := b.netstackConn.Load() + if connPtr != nil && *connPtr != nil { + netstackConn := *connPtr + destAddr := net.UDPAddrFromAddrPort(destAddrPort) + // Send all buffers through netstack + for _, buf := range bufs { + _, err := netstackConn.WriteTo(buf, destAddr) + if err != nil { + return err + } + } + return nil + } + // Fall through to socket if netstack conn not available + } + + // Send through the physical UDP socket (for hole-punched clients) + b.mu.RLock() + conn := b.udpConn + b.mu.RUnlock() + + if conn == nil { + return net.ErrClosed + } + + destAddr := net.UDPAddrFromAddrPort(destAddrPort) + + // Send all buffers to the destination + for _, buf := range bufs { + _, err := conn.WriteToUDP(buf, destAddr) + if err != nil { + return err + } + } + + return nil +} + +// SetMark implements the WireGuard Bind interface. +// It's a no-op for this implementation. +func (b *SharedBind) SetMark(mark uint32) error { + // Not implemented for this use case + return nil +} + +// BatchSize returns the preferred batch size for sending packets. +func (b *SharedBind) BatchSize() int { + if runtime.GOOS == "linux" || runtime.GOOS == "android" { + return wgConn.IdealBatchSize + } + return 1 +} + +// ParseEndpoint creates a new endpoint from a string address. +func (b *SharedBind) ParseEndpoint(s string) (wgConn.Endpoint, error) { + addrPort, err := netip.ParseAddrPort(s) + if err != nil { + return nil, err + } + return &wgConn.StdNetEndpoint{AddrPort: addrPort}, nil +} diff --git a/bind/shared_bind_test.go b/bind/shared_bind_test.go new file mode 100644 index 0000000..0d63e7a --- /dev/null +++ b/bind/shared_bind_test.go @@ -0,0 +1,605 @@ +//go:build !js + +package bind + +import ( + "net" + "net/netip" + "sync" + "testing" + "time" + + wgConn "golang.zx2c4.com/wireguard/conn" +) + +// TestSharedBindCreation tests basic creation and initialization +func TestSharedBindCreation(t *testing.T) { + // Create a UDP connection + udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + t.Fatalf("Failed to create UDP connection: %v", err) + } + defer udpConn.Close() + + // Create SharedBind + bind, err := New(udpConn) + if err != nil { + t.Fatalf("Failed to create SharedBind: %v", err) + } + + if bind == nil { + t.Fatal("SharedBind is nil") + } + + // Verify initial reference count + if bind.refCount.Load() != 1 { + t.Errorf("Expected initial refCount to be 1, got %d", bind.refCount.Load()) + } + + // Clean up + if err := bind.Close(); err != nil { + t.Errorf("Failed to close SharedBind: %v", err) + } +} + +// TestSharedBindReferenceCount tests reference counting +func TestSharedBindReferenceCount(t *testing.T) { + udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + t.Fatalf("Failed to create UDP connection: %v", err) + } + + bind, err := New(udpConn) + if err != nil { + t.Fatalf("Failed to create SharedBind: %v", err) + } + + // Add references + bind.AddRef() + if bind.refCount.Load() != 2 { + t.Errorf("Expected refCount to be 2, got %d", bind.refCount.Load()) + } + + bind.AddRef() + if bind.refCount.Load() != 3 { + t.Errorf("Expected refCount to be 3, got %d", bind.refCount.Load()) + } + + // Release references + bind.Release() + if bind.refCount.Load() != 2 { + t.Errorf("Expected refCount to be 2 after release, got %d", bind.refCount.Load()) + } + + bind.Release() + bind.Release() // This should close the connection + + if !bind.closed.Load() { + t.Error("Expected bind to be closed after all references released") + } +} + +// TestSharedBindWriteToUDP tests the WriteToUDP functionality +func TestSharedBindWriteToUDP(t *testing.T) { + // Create sender + senderConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + t.Fatalf("Failed to create sender UDP connection: %v", err) + } + + senderBind, err := New(senderConn) + if err != nil { + t.Fatalf("Failed to create sender SharedBind: %v", err) + } + defer senderBind.Close() + + // Create receiver + receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + t.Fatalf("Failed to create receiver UDP connection: %v", err) + } + defer receiverConn.Close() + + receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr) + + // Send data + testData := []byte("Hello, SharedBind!") + n, err := senderBind.WriteToUDP(testData, receiverAddr) + if err != nil { + t.Fatalf("WriteToUDP failed: %v", err) + } + + if n != len(testData) { + t.Errorf("Expected to send %d bytes, sent %d", len(testData), n) + } + + // Receive data + buf := make([]byte, 1024) + receiverConn.SetReadDeadline(time.Now().Add(2 * time.Second)) + n, _, err = receiverConn.ReadFromUDP(buf) + if err != nil { + t.Fatalf("Failed to receive data: %v", err) + } + + if string(buf[:n]) != string(testData) { + t.Errorf("Expected to receive %q, got %q", testData, buf[:n]) + } +} + +// TestSharedBindConcurrentWrites tests thread-safety +func TestSharedBindConcurrentWrites(t *testing.T) { + // Create sender + senderConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + t.Fatalf("Failed to create sender UDP connection: %v", err) + } + + senderBind, err := New(senderConn) + if err != nil { + t.Fatalf("Failed to create sender SharedBind: %v", err) + } + defer senderBind.Close() + + // Create receiver + receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + t.Fatalf("Failed to create receiver UDP connection: %v", err) + } + defer receiverConn.Close() + + receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr) + + // Launch concurrent writes + numGoroutines := 100 + var wg sync.WaitGroup + wg.Add(numGoroutines) + + for i := 0; i < numGoroutines; i++ { + go func(id int) { + defer wg.Done() + data := []byte{byte(id)} + _, err := senderBind.WriteToUDP(data, receiverAddr) + if err != nil { + t.Errorf("WriteToUDP failed in goroutine %d: %v", id, err) + } + }(i) + } + + wg.Wait() +} + +// TestSharedBindWireGuardInterface tests WireGuard Bind interface implementation +func TestSharedBindWireGuardInterface(t *testing.T) { + udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + t.Fatalf("Failed to create UDP connection: %v", err) + } + + bind, err := New(udpConn) + if err != nil { + t.Fatalf("Failed to create SharedBind: %v", err) + } + defer bind.Close() + + // Test Open + recvFuncs, port, err := bind.Open(0) + if err != nil { + t.Fatalf("Open failed: %v", err) + } + + if len(recvFuncs) == 0 { + t.Error("Expected at least one receive function") + } + + if port == 0 { + t.Error("Expected non-zero port") + } + + // Test SetMark (should be a no-op) + if err := bind.SetMark(0); err != nil { + t.Errorf("SetMark failed: %v", err) + } + + // Test BatchSize + batchSize := bind.BatchSize() + if batchSize <= 0 { + t.Error("Expected positive batch size") + } +} + +// TestSharedBindSend tests the Send method with WireGuard endpoints +func TestSharedBindSend(t *testing.T) { + // Create sender + senderConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + t.Fatalf("Failed to create sender UDP connection: %v", err) + } + + senderBind, err := New(senderConn) + if err != nil { + t.Fatalf("Failed to create sender SharedBind: %v", err) + } + defer senderBind.Close() + + // Create receiver + receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + t.Fatalf("Failed to create receiver UDP connection: %v", err) + } + defer receiverConn.Close() + + receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr) + + // Create an endpoint + addrPort := receiverAddr.AddrPort() + endpoint := &wgConn.StdNetEndpoint{AddrPort: addrPort} + + // Send data + testData := []byte("WireGuard packet") + bufs := [][]byte{testData} + err = senderBind.Send(bufs, endpoint) + if err != nil { + t.Fatalf("Send failed: %v", err) + } + + // Receive data + buf := make([]byte, 1024) + receiverConn.SetReadDeadline(time.Now().Add(2 * time.Second)) + n, _, err := receiverConn.ReadFromUDP(buf) + if err != nil { + t.Fatalf("Failed to receive data: %v", err) + } + + if string(buf[:n]) != string(testData) { + t.Errorf("Expected to receive %q, got %q", testData, buf[:n]) + } +} + +// TestSharedBindMultipleUsers simulates WireGuard and hole punch using the same bind +func TestSharedBindMultipleUsers(t *testing.T) { + // Create shared bind + udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + t.Fatalf("Failed to create UDP connection: %v", err) + } + + sharedBind, err := New(udpConn) + if err != nil { + t.Fatalf("Failed to create SharedBind: %v", err) + } + + // Add reference for hole punch sender + sharedBind.AddRef() + + // Create receiver + receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + t.Fatalf("Failed to create receiver UDP connection: %v", err) + } + defer receiverConn.Close() + + receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr) + + var wg sync.WaitGroup + + // Simulate WireGuard using the bind + wg.Add(1) + go func() { + defer wg.Done() + addrPort := receiverAddr.AddrPort() + endpoint := &wgConn.StdNetEndpoint{AddrPort: addrPort} + + for i := 0; i < 10; i++ { + data := []byte("WireGuard packet") + bufs := [][]byte{data} + if err := sharedBind.Send(bufs, endpoint); err != nil { + t.Errorf("WireGuard Send failed: %v", err) + } + time.Sleep(10 * time.Millisecond) + } + }() + + // Simulate hole punch sender using the bind + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < 10; i++ { + data := []byte("Hole punch packet") + if _, err := sharedBind.WriteToUDP(data, receiverAddr); err != nil { + t.Errorf("Hole punch WriteToUDP failed: %v", err) + } + time.Sleep(10 * time.Millisecond) + } + }() + + wg.Wait() + + // Release the hole punch reference + sharedBind.Release() + + // Close WireGuard's reference (should close the connection) + sharedBind.Close() + + if !sharedBind.closed.Load() { + t.Error("Expected bind to be closed after all users released it") + } +} + +// TestEndpoint tests the Endpoint implementation +func TestEndpoint(t *testing.T) { + addr := netip.MustParseAddr("192.168.1.1") + addrPort := netip.AddrPortFrom(addr, 51820) + + ep := &Endpoint{AddrPort: addrPort} + + // Test DstIP + if ep.DstIP() != addr { + t.Errorf("Expected DstIP to be %v, got %v", addr, ep.DstIP()) + } + + // Test DstToString + expected := "192.168.1.1:51820" + if ep.DstToString() != expected { + t.Errorf("Expected DstToString to be %q, got %q", expected, ep.DstToString()) + } + + // Test DstToBytes + bytes := ep.DstToBytes() + if len(bytes) == 0 { + t.Error("Expected DstToBytes to return non-empty slice") + } + + // Test SrcIP (should be zero) + if ep.SrcIP().IsValid() { + t.Error("Expected SrcIP to be invalid") + } + + // Test ClearSrc (should not panic) + ep.ClearSrc() +} + +// TestParseEndpoint tests the ParseEndpoint method +func TestParseEndpoint(t *testing.T) { + udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + t.Fatalf("Failed to create UDP connection: %v", err) + } + + bind, err := New(udpConn) + if err != nil { + t.Fatalf("Failed to create SharedBind: %v", err) + } + defer bind.Close() + + tests := []struct { + name string + input string + wantErr bool + checkAddr func(*testing.T, wgConn.Endpoint) + }{ + { + name: "valid IPv4", + input: "192.168.1.1:51820", + wantErr: false, + checkAddr: func(t *testing.T, ep wgConn.Endpoint) { + if ep.DstToString() != "192.168.1.1:51820" { + t.Errorf("Expected 192.168.1.1:51820, got %s", ep.DstToString()) + } + }, + }, + { + name: "valid IPv6", + input: "[::1]:51820", + wantErr: false, + checkAddr: func(t *testing.T, ep wgConn.Endpoint) { + if ep.DstToString() != "[::1]:51820" { + t.Errorf("Expected [::1]:51820, got %s", ep.DstToString()) + } + }, + }, + { + name: "invalid - missing port", + input: "192.168.1.1", + wantErr: true, + }, + { + name: "invalid - bad format", + input: "not-an-address", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ep, err := bind.ParseEndpoint(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("ParseEndpoint() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && tt.checkAddr != nil { + tt.checkAddr(t, ep) + } + }) + } +} + +// TestNetstackRouting tests that packets from netstack endpoints are routed back through netstack +func TestNetstackRouting(t *testing.T) { + // Create the SharedBind with a physical UDP socket + physicalConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + t.Fatalf("Failed to create physical UDP connection: %v", err) + } + + sharedBind, err := New(physicalConn) + if err != nil { + t.Fatalf("Failed to create SharedBind: %v", err) + } + defer sharedBind.Close() + + // Create a mock "netstack" connection (just another UDP socket for testing) + netstackConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + t.Fatalf("Failed to create netstack UDP connection: %v", err) + } + defer netstackConn.Close() + + // Set the netstack connection + sharedBind.SetNetstackConn(netstackConn) + + // Create a "client" that would receive packets + clientConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + t.Fatalf("Failed to create client UDP connection: %v", err) + } + defer clientConn.Close() + + clientAddr := clientConn.LocalAddr().(*net.UDPAddr) + clientAddrPort := clientAddr.AddrPort() + + // Inject a packet from the "netstack" source - this should track the endpoint + testData := []byte("test packet from netstack") + err = sharedBind.InjectPacket(testData, clientAddrPort) + if err != nil { + t.Fatalf("InjectPacket failed: %v", err) + } + + // Now when we send a response to this endpoint, it should go through netstack + endpoint := &wgConn.StdNetEndpoint{AddrPort: clientAddrPort} + responseData := []byte("response packet") + err = sharedBind.Send([][]byte{responseData}, endpoint) + if err != nil { + t.Fatalf("Send failed: %v", err) + } + + // The packet should be received by the client from the netstack connection + buf := make([]byte, 1024) + clientConn.SetReadDeadline(time.Now().Add(2 * time.Second)) + n, fromAddr, err := clientConn.ReadFromUDP(buf) + if err != nil { + t.Fatalf("Failed to receive response: %v", err) + } + + if string(buf[:n]) != string(responseData) { + t.Errorf("Expected to receive %q, got %q", responseData, buf[:n]) + } + + // Verify the response came from the netstack connection, not the physical one + netstackAddr := netstackConn.LocalAddr().(*net.UDPAddr) + if fromAddr.Port != netstackAddr.Port { + t.Errorf("Expected response from netstack port %d, got %d", netstackAddr.Port, fromAddr.Port) + } +} + +// TestSocketRouting tests that packets from socket endpoints are routed through socket +func TestSocketRouting(t *testing.T) { + // Create the SharedBind with a physical UDP socket + physicalConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + t.Fatalf("Failed to create physical UDP connection: %v", err) + } + + sharedBind, err := New(physicalConn) + if err != nil { + t.Fatalf("Failed to create SharedBind: %v", err) + } + defer sharedBind.Close() + + // Create a mock "netstack" connection + netstackConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + t.Fatalf("Failed to create netstack UDP connection: %v", err) + } + defer netstackConn.Close() + + // Set the netstack connection + sharedBind.SetNetstackConn(netstackConn) + + // Create a "client" that would receive packets (this simulates a hole-punched client) + clientConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + t.Fatalf("Failed to create client UDP connection: %v", err) + } + defer clientConn.Close() + + clientAddr := clientConn.LocalAddr().(*net.UDPAddr) + clientAddrPort := clientAddr.AddrPort() + + // Don't inject from netstack - this endpoint is NOT tracked as netstack-sourced + // So Send should use the physical socket + + endpoint := &wgConn.StdNetEndpoint{AddrPort: clientAddrPort} + responseData := []byte("response packet via socket") + err = sharedBind.Send([][]byte{responseData}, endpoint) + if err != nil { + t.Fatalf("Send failed: %v", err) + } + + // The packet should be received by the client from the physical connection + buf := make([]byte, 1024) + clientConn.SetReadDeadline(time.Now().Add(2 * time.Second)) + n, fromAddr, err := clientConn.ReadFromUDP(buf) + if err != nil { + t.Fatalf("Failed to receive response: %v", err) + } + + if string(buf[:n]) != string(responseData) { + t.Errorf("Expected to receive %q, got %q", responseData, buf[:n]) + } + + // Verify the response came from the physical connection, not the netstack one + physicalAddr := physicalConn.LocalAddr().(*net.UDPAddr) + if fromAddr.Port != physicalAddr.Port { + t.Errorf("Expected response from physical port %d, got %d", physicalAddr.Port, fromAddr.Port) + } +} + +// TestClearNetstackConn tests that clearing the netstack connection works correctly +func TestClearNetstackConn(t *testing.T) { + physicalConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + t.Fatalf("Failed to create physical UDP connection: %v", err) + } + + sharedBind, err := New(physicalConn) + if err != nil { + t.Fatalf("Failed to create SharedBind: %v", err) + } + defer sharedBind.Close() + + // Set a netstack connection + netstackConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + t.Fatalf("Failed to create netstack UDP connection: %v", err) + } + defer netstackConn.Close() + + sharedBind.SetNetstackConn(netstackConn) + + // Inject a packet to track an endpoint + testAddrPort := netip.MustParseAddrPort("192.168.1.100:51820") + err = sharedBind.InjectPacket([]byte("test"), testAddrPort) + if err != nil { + t.Fatalf("InjectPacket failed: %v", err) + } + + // Verify the endpoint is tracked + _, tracked := sharedBind.netstackEndpoints.Load(testAddrPort.String()) + if !tracked { + t.Error("Expected endpoint to be tracked as netstack-sourced") + } + + // Clear the netstack connection + sharedBind.ClearNetstackConn() + + // Verify the netstack connection is cleared + if sharedBind.GetNetstackConn() != nil { + t.Error("Expected netstack connection to be nil after clear") + } + + // Verify the tracked endpoints are cleared + _, stillTracked := sharedBind.netstackEndpoints.Load(testAddrPort.String()) + if stillTracked { + t.Error("Expected endpoint tracking to be cleared") + } +} diff --git a/clients.go b/clients.go index 4b282a7..3f28f4c 100644 --- a/clients.go +++ b/clients.go @@ -1,20 +1,17 @@ package main import ( - "fmt" "strings" + "github.com/fosrl/newt/clients" + wgnetstack "github.com/fosrl/newt/clients" + "github.com/fosrl/newt/clients/permissions" "github.com/fosrl/newt/logger" - "github.com/fosrl/newt/proxy" "github.com/fosrl/newt/websocket" "golang.zx2c4.com/wireguard/tun/netstack" - - "github.com/fosrl/newt/wgnetstack" - "github.com/fosrl/newt/wgtester" ) -var wgService *wgnetstack.WireGuardService -var wgTesterServer *wgtester.Server +var wgService *clients.WireGuardService var ready bool func setupClients(client *websocket.Client) { @@ -27,43 +24,29 @@ func setupClients(client *websocket.Client) { host = strings.TrimSuffix(host, "/") + logger.Info("Setting up clients with netstack2...") + + // if useNativeInterface is true make sure we have permission to use native interface if useNativeInterface { - setupClientsNative(client, host) - } else { - setupClientsNetstack(client, host) + logger.Debug("Checking permissions for native interface") + err := permissions.CheckNativeInterfacePermissions() + if err != nil { + logger.Fatal("Insufficient permissions to create native TUN interface: %v", err) + return + } } - ready = true -} - -func setupClientsNetstack(client *websocket.Client, host string) { - logger.Info("Setting up clients with netstack...") // Create WireGuard service - wgService, err = wgnetstack.NewWireGuardService(interfaceName, mtuInt, generateAndSaveKeyTo, host, id, client, "9.9.9.9") + wgService, err = wgnetstack.NewWireGuardService(interfaceName, mtuInt, host, id, client, dns, useNativeInterface) if err != nil { logger.Fatal("Failed to create WireGuard service: %v", err) } - // // Set up callback to restart wgtester with netstack when WireGuard is ready - wgService.SetOnNetstackReady(func(tnet *netstack.Net) { - - wgTesterServer = wgtester.NewServerWithNetstack("0.0.0.0", wgService.Port, id, tnet) // TODO: maybe make this the same ip of the wg server? - err := wgTesterServer.Start() - if err != nil { - logger.Error("Failed to start WireGuard tester server: %v", err) - } - }) - - wgService.SetOnNetstackClose(func() { - if wgTesterServer != nil { - wgTesterServer.Stop() - wgTesterServer = nil - } - }) - client.OnTokenUpdate(func(token string) { wgService.SetToken(token) }) + + ready = true } func setDownstreamTNetstack(tnet *netstack.Net) { @@ -75,16 +58,9 @@ func setDownstreamTNetstack(tnet *netstack.Net) { func closeClients() { logger.Info("Closing clients...") if wgService != nil { - wgService.Close(!keepInterface) + wgService.Close() wgService = nil } - - closeWgServiceNative() - - if wgTesterServer != nil { - wgTesterServer.Stop() - wgTesterServer = nil - } } func clientsHandleNewtConnection(publicKey string, endpoint string) { @@ -103,8 +79,6 @@ func clientsHandleNewtConnection(publicKey string, endpoint string) { if wgService != nil { wgService.StartHolepunch(publicKey, endpoint) } - - clientsHandleNewtConnectionNative(publicKey, endpoint) } func clientsOnConnect() { @@ -114,19 +88,17 @@ func clientsOnConnect() { if wgService != nil { wgService.LoadRemoteConfig() } - - clientsOnConnectNative() } -func clientsAddProxyTarget(pm *proxy.ProxyManager, tunnelIp string) { +// clientsStartDirectRelay starts a direct UDP relay from the main tunnel netstack +// to the clients' WireGuard, bypassing the proxy for better performance. +func clientsStartDirectRelay(tunnelIP string) { if !ready { return } - // add a udp proxy for localost and the wgService port - // TODO: make sure this port is not used in a target if wgService != nil { - pm.AddTarget("udp", tunnelIp, int(wgService.Port), fmt.Sprintf("127.0.0.1:%d", wgService.Port)) + if err := wgService.StartDirectUDPRelay(tunnelIP); err != nil { + logger.Error("Failed to start direct UDP relay: %v", err) + } } - - clientsAddProxyTargetNative(pm, tunnelIp) } diff --git a/clients/clients.go b/clients/clients.go new file mode 100644 index 0000000..d5fb5f3 --- /dev/null +++ b/clients/clients.go @@ -0,0 +1,1253 @@ +package clients + +import ( + "context" + "encoding/json" + "fmt" + "net" + "net/netip" + "os" + "runtime" + "strconv" + "strings" + "sync" + "time" + + "github.com/fosrl/newt/bind" + "github.com/fosrl/newt/holepunch" + "github.com/fosrl/newt/logger" + "github.com/fosrl/newt/netstack2" + "github.com/fosrl/newt/network" + "github.com/fosrl/newt/util" + "github.com/fosrl/newt/websocket" + "github.com/fosrl/newt/wgtester" + "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/ipc" + "golang.zx2c4.com/wireguard/tun" + "golang.zx2c4.com/wireguard/tun/netstack" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/fosrl/newt/internal/telemetry" +) + +type WgConfig struct { + IpAddress string `json:"ipAddress"` + Peers []Peer `json:"peers"` + Targets []Target `json:"targets"` +} + +type Target struct { + SourcePrefix string `json:"sourcePrefix"` + DestPrefix string `json:"destPrefix"` + RewriteTo string `json:"rewriteTo,omitempty"` + PortRange []PortRange `json:"portRange,omitempty"` +} + +type PortRange struct { + Min uint16 `json:"min"` + Max uint16 `json:"max"` +} + +type Peer struct { + PublicKey string `json:"publicKey"` + AllowedIPs []string `json:"allowedIps"` + Endpoint string `json:"endpoint"` +} + +type PeerBandwidth struct { + PublicKey string `json:"publicKey"` + BytesIn float64 `json:"bytesIn"` + BytesOut float64 `json:"bytesOut"` +} + +type PeerReading struct { + BytesReceived int64 + BytesTransmitted int64 + LastChecked time.Time +} + +type WireGuardService struct { + interfaceName string + mtu int + client *websocket.Client + config WgConfig + key wgtypes.Key + newtId string + lastReadings map[string]PeerReading + mu sync.Mutex + Port uint16 + host string + serverPubKey string + token string + stopGetConfig func() + // Netstack fields + tun tun.Device + tnet *netstack2.Net + device *device.Device + dns []netip.Addr + // Callback for when netstack is ready + onNetstackReady func(*netstack2.Net) + // Callback for when netstack is closed + onNetstackClose func() + othertnet *netstack.Net + // Proxy manager for tunnel + TunnelIP string + // Shared bind and holepunch manager + sharedBind *bind.SharedBind + holePunchManager *holepunch.Manager + useNativeInterface bool + // Direct UDP relay from main tunnel to clients' WireGuard + directRelayStop chan struct{} + directRelayWg sync.WaitGroup + netstackListener net.PacketConn + netstackListenerMu sync.Mutex + wgTesterServer *wgtester.Server +} + +func NewWireGuardService(interfaceName string, mtu int, host string, newtId string, wsClient *websocket.Client, dns string, useNativeInterface bool) (*WireGuardService, error) { + key, err := wgtypes.GeneratePrivateKey() + if err != nil { + return nil, fmt.Errorf("failed to generate private key: %v", err) + } + + // Find an available port + port, err := util.FindAvailableUDPPort(49152, 65535) + + if err != nil { + return nil, fmt.Errorf("error finding available port: %v", err) + } + + // Create shared UDP socket for both holepunch and WireGuard + localAddr := &net.UDPAddr{ + Port: int(port), + IP: net.IPv4zero, + } + + udpConn, err := net.ListenUDP("udp", localAddr) + if err != nil { + return nil, fmt.Errorf("failed to create UDP socket: %v", err) + } + + sharedBind, err := bind.New(udpConn) + if err != nil { + udpConn.Close() + return nil, fmt.Errorf("failed to create shared bind: %v", err) + } + + // Add a reference for the hole punch manager (creator already has one reference for WireGuard) + sharedBind.AddRef() + + logger.Info("Created shared UDP socket on port %d (refcount: %d)", port, sharedBind.GetRefCount()) + + // Parse DNS addresses + dnsAddrs := []netip.Addr{netip.MustParseAddr(dns)} + + service := &WireGuardService{ + interfaceName: interfaceName, + mtu: mtu, + client: wsClient, + key: key, + newtId: newtId, + host: host, + lastReadings: make(map[string]PeerReading), + Port: port, + dns: dnsAddrs, + sharedBind: sharedBind, + useNativeInterface: useNativeInterface, + } + + // Create the holepunch manager with ResolveDomain function + // We'll need to pass a domain resolver function + service.holePunchManager = holepunch.NewManager(sharedBind, newtId, "newt", key.PublicKey().String()) + + // Register websocket handlers + wsClient.RegisterHandler("newt/wg/receive-config", service.handleConfig) + wsClient.RegisterHandler("newt/wg/peer/add", service.handleAddPeer) + wsClient.RegisterHandler("newt/wg/peer/remove", service.handleRemovePeer) + wsClient.RegisterHandler("newt/wg/peer/update", service.handleUpdatePeer) + wsClient.RegisterHandler("newt/wg/targets/add", service.handleAddTarget) + wsClient.RegisterHandler("newt/wg/targets/remove", service.handleRemoveTarget) + wsClient.RegisterHandler("newt/wg/targets/update", service.handleUpdateTarget) + + return service, nil +} + +// ReportRTT allows reporting native RTTs to telemetry, rate-limited externally. +func (s *WireGuardService) ReportRTT(seconds float64) { + if s.serverPubKey == "" { + return + } + telemetry.ObserveTunnelLatency(context.Background(), s.serverPubKey, "wireguard", seconds) +} + +func (s *WireGuardService) SetOthertnet(tnet *netstack.Net) { + s.othertnet = tnet +} + +func (s *WireGuardService) Close() { + if s.stopGetConfig != nil { + s.stopGetConfig() + s.stopGetConfig = nil + } + + // Stop the direct UDP relay first + s.StopDirectUDPRelay() + + // Stop hole punch manager + if s.holePunchManager != nil { + s.holePunchManager.Stop() + } + + s.mu.Lock() + defer s.mu.Unlock() + + // Close WireGuard device first - this will call sharedBind.Close() which releases WireGuard's reference + if s.device != nil { + s.device.Close() + s.device = nil + } + + // Clear references but don't manually close since device.Close() already did it + if s.tnet != nil { + s.tnet = nil + } + if s.tun != nil { + s.tun = nil // Don't call tun.Close() here since device.Close() already closed it + } + + // Release the hole punch reference to the shared bind + if s.sharedBind != nil { + // Release hole punch reference (WireGuard already released its reference via device.Close()) + logger.Debug("Releasing shared bind (refcount before release: %d)", s.sharedBind.GetRefCount()) + s.sharedBind.Release() + s.sharedBind = nil + logger.Info("Released shared UDP bind") + } + + if s.wgTesterServer != nil { + s.wgTesterServer.Stop() + s.wgTesterServer = nil + } +} + +func (s *WireGuardService) SetToken(token string) { + s.token = token + if s.holePunchManager != nil { + s.holePunchManager.SetToken(token) + } +} + +// GetNetstackNet returns the netstack network interface for use by other components +func (s *WireGuardService) GetNetstackNet() *netstack2.Net { + s.mu.Lock() + defer s.mu.Unlock() + return s.tnet +} + +// IsReady returns true if the WireGuard service is ready to use +func (s *WireGuardService) IsReady() bool { + s.mu.Lock() + defer s.mu.Unlock() + return s.device != nil && s.tnet != nil +} + +// GetPublicKey returns the public key of this WireGuard service +func (s *WireGuardService) GetPublicKey() wgtypes.Key { + return s.key.PublicKey() +} + +// SetOnNetstackReady sets a callback function to be called when the netstack interface is ready +func (s *WireGuardService) SetOnNetstackReady(callback func(*netstack2.Net)) { + s.onNetstackReady = callback +} + +func (s *WireGuardService) SetOnNetstackClose(callback func()) { + s.onNetstackClose = callback +} + +// StartHolepunch starts hole punching to a specific endpoint +func (s *WireGuardService) StartHolepunch(publicKey string, endpoint string) { + if s.holePunchManager == nil { + logger.Warn("Hole punch manager not initialized") + return + } + + // Convert websocket.ExitNode to holepunch.ExitNode + hpExitNodes := []holepunch.ExitNode{ + { + Endpoint: endpoint, + PublicKey: publicKey, + }, + } + + // Start hole punching using the manager + if err := s.holePunchManager.StartMultipleExitNodes(hpExitNodes); err != nil { + logger.Warn("Failed to start hole punch: %v", err) + } + + logger.Info("Starting hole punch to %s with public key: %s", endpoint, publicKey) +} + +// StartDirectUDPRelay starts a direct UDP relay from the main tunnel netstack to the clients' WireGuard. +// This bypasses the proxy by listening on the main tunnel's netstack and forwarding packets +// directly to the SharedBind that feeds the clients' WireGuard device. +// Responses are automatically routed back through the netstack by the SharedBind. +// tunnelIP is the IP address to listen on within the main tunnel's netstack. +func (s *WireGuardService) StartDirectUDPRelay(tunnelIP string) error { + if s.othertnet == nil { + return fmt.Errorf("main tunnel netstack (othertnet) not set") + } + if s.sharedBind == nil { + return fmt.Errorf("shared bind not initialized") + } + + // Stop any existing relay + s.StopDirectUDPRelay() + + s.directRelayStop = make(chan struct{}) + + // Parse the tunnel IP + ip := net.ParseIP(tunnelIP) + if ip == nil { + return fmt.Errorf("invalid tunnel IP: %s", tunnelIP) + } + + // Listen on the main tunnel netstack for UDP packets destined for the clients' WireGuard port + listenAddr := &net.UDPAddr{ + IP: ip, + Port: int(s.Port), + } + + // Use othertnet (main tunnel's netstack) to listen + listener, err := s.othertnet.ListenUDP(listenAddr) + if err != nil { + return fmt.Errorf("failed to listen on main tunnel netstack: %v", err) + } + + // Store the listener reference so we can close it later + s.netstackListenerMu.Lock() + s.netstackListener = listener + s.netstackListenerMu.Unlock() + + // Set the netstack connection on the SharedBind so responses go back through the tunnel + s.sharedBind.SetNetstackConn(listener) + + logger.Info("Started direct UDP relay on %s:%d (bidirectional via SharedBind)", tunnelIP, s.Port) + + // Start the relay goroutine to read from netstack and inject into SharedBind + s.directRelayWg.Add(1) + go s.runDirectUDPRelay(listener) + + return nil +} + +// runDirectUDPRelay handles receiving UDP packets from the main tunnel netstack +// and injecting them into the SharedBind for processing by WireGuard. +// Responses are handled automatically by SharedBind.Send() which routes them +// back through the netstack connection. +func (s *WireGuardService) runDirectUDPRelay(listener net.PacketConn) { + defer s.directRelayWg.Done() + // Note: Don't close listener here - it's also used by SharedBind for sending responses + // It will be closed when the relay is stopped + + logger.Info("Direct UDP relay started (bidirectional through SharedBind)") + + buf := make([]byte, 65535) // Max UDP packet size + + for { + select { + case <-s.directRelayStop: + logger.Info("Stopping direct UDP relay") + return + default: + } + + // Set a read deadline so we can check for stop signal periodically + listener.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) + + n, remoteAddr, err := listener.ReadFrom(buf) + if err != nil { + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + continue // Just a timeout, check for stop and try again + } + if s.directRelayStop != nil { + select { + case <-s.directRelayStop: + return // Stopped + default: + } + } + logger.Debug("Direct UDP relay read error: %v", err) + continue + } + + // Get the source address + var srcAddrPort netip.AddrPort + if udpAddr, ok := remoteAddr.(*net.UDPAddr); ok { + srcAddrPort = udpAddr.AddrPort() + // Unmap IPv4-in-IPv6 addresses to ensure consistency with parsed endpoints + if srcAddrPort.Addr().Is4In6() { + srcAddrPort = netip.AddrPortFrom(srcAddrPort.Addr().Unmap(), srcAddrPort.Port()) + } + } else { + logger.Debug("Unexpected address type in relay: %T", remoteAddr) + continue + } + + // Inject the packet directly into the SharedBind (also tracks this endpoint as netstack-sourced) + if err := s.sharedBind.InjectPacket(buf[:n], srcAddrPort); err != nil { + logger.Debug("Failed to inject packet into SharedBind: %v", err) + continue + } + + // logger.Debug("Relayed %d bytes from %s into WireGuard", n, srcAddrPort.String()) + } +} + +// StopDirectUDPRelay stops the direct UDP relay and closes the netstack listener +func (s *WireGuardService) StopDirectUDPRelay() { + if s.directRelayStop != nil { + close(s.directRelayStop) + s.directRelayWg.Wait() + s.directRelayStop = nil + } + + // Clear the netstack connection from SharedBind so responses don't try to use it + if s.sharedBind != nil { + s.sharedBind.ClearNetstackConn() + } + + // Close the netstack listener + s.netstackListenerMu.Lock() + if s.netstackListener != nil { + s.netstackListener.Close() + s.netstackListener = nil + } + s.netstackListenerMu.Unlock() +} + +func (s *WireGuardService) LoadRemoteConfig() error { + if s.stopGetConfig != nil { + s.stopGetConfig() + s.stopGetConfig = nil + } + s.stopGetConfig = s.client.SendMessageInterval("newt/wg/get-config", map[string]interface{}{ + "publicKey": s.key.PublicKey().String(), + "port": s.Port, + }, 2*time.Second) + + logger.Info("Requesting WireGuard configuration from remote server") + go s.periodicBandwidthCheck() + + return nil +} + +func (s *WireGuardService) handleConfig(msg websocket.WSMessage) { + var config WgConfig + + logger.Debug("Received message: %v", msg) + logger.Info("Received WireGuard clients configuration from remote server") + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Info("Error marshaling data: %v", err) + return + } + + if err := json.Unmarshal(jsonData, &config); err != nil { + logger.Info("Error unmarshaling target data: %v", err) + return + } + s.config = config + + if s.stopGetConfig != nil { + s.stopGetConfig() + s.stopGetConfig = nil + } + + // Ensure the WireGuard interface and peers are configured + if err := s.ensureWireguardInterface(config); err != nil { + logger.Error("Failed to ensure WireGuard interface: %v", err) + } + + if err := s.ensureWireguardPeers(config.Peers); err != nil { + logger.Error("Failed to ensure WireGuard peers: %v", err) + } + + if err := s.ensureTargets(config.Targets); err != nil { + logger.Error("Failed to ensure WireGuard targets: %v", err) + } +} + +func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error { + s.mu.Lock() + + // split off the cidr from the IP address + parts := strings.Split(wgconfig.IpAddress, "/") + if len(parts) != 2 { + s.mu.Unlock() + return fmt.Errorf("invalid IP address format: %s", wgconfig.IpAddress) + } + // Parse the IP address and CIDR mask + tunnelIP := netip.MustParseAddr(parts[0]) + + var err error + + if s.useNativeInterface { + // Create native TUN device + var interfaceName = s.interfaceName + if runtime.GOOS == "darwin" { + interfaceName, err = network.FindUnusedUTUN() + if err != nil { + s.mu.Unlock() + return fmt.Errorf("failed to find unused utun: %v", err) + } + } + + s.tun, err = tun.CreateTUN(interfaceName, s.mtu) + if err != nil { + s.mu.Unlock() + return fmt.Errorf("failed to create native TUN device: %v", err) + } + + // Get the real interface name (may differ on some platforms) + if realName, err := s.tun.Name(); err == nil { + interfaceName = realName + } + + s.TunnelIP = tunnelIP.String() + // s.tnet is nil for native interface - proxy features not available + s.tnet = nil + + // Create WireGuard device using the shared bind + s.device = device.NewDevice(s.tun, s.sharedBind, device.NewLogger( + device.LogLevelSilent, + "wireguard: ", + )) + + fileUAPI, err := func() (*os.File, error) { + return ipc.UAPIOpen(interfaceName) + }() + if err != nil { + logger.Error("UAPI listen error: %v", err) + } + + uapiListener, err := ipc.UAPIListen(interfaceName, fileUAPI) + if err != nil { + logger.Error("Failed to listen on uapi socket: %v", err) + os.Exit(1) + } + + go func() { + for { + conn, err := uapiListener.Accept() + if err != nil { + + return + } + go s.device.IpcHandle(conn) + } + }() + logger.Info("UAPI listener started") + + // Configure WireGuard with private key + config := fmt.Sprintf("private_key=%s", util.FixKey(s.key.String())) + + err = s.device.IpcSet(config) + if err != nil { + s.mu.Unlock() + return fmt.Errorf("failed to configure WireGuard device: %v", err) + } + + // Bring up the device + err = s.device.Up() + if err != nil { + s.mu.Unlock() + return fmt.Errorf("failed to bring up WireGuard device: %v", err) + } + + // Configure the network interface with IP address + if err := network.ConfigureInterface(interfaceName, wgconfig.IpAddress, s.mtu); err != nil { + s.mu.Unlock() + return fmt.Errorf("failed to configure interface: %v", err) + } + + s.wgTesterServer = wgtester.NewServer("0.0.0.0", s.Port, s.newtId) // TODO: maybe make this the same ip of the wg server? + err = s.wgTesterServer.Start() + if err != nil { + logger.Error("Failed to start WireGuard tester server: %v", err) + } + + logger.Info("WireGuard native device created and configured on %s", interfaceName) + + s.mu.Unlock() + return nil + } + + // Create TUN device and network stack using netstack + s.tun, s.tnet, err = netstack2.CreateNetTUNWithOptions( + []netip.Addr{tunnelIP}, + s.dns, + s.mtu, + netstack2.NetTunOptions{ + EnableTCPProxy: true, + EnableUDPProxy: true, + }, + ) + if err != nil { + s.mu.Unlock() + return fmt.Errorf("failed to create TUN device: %v", err) + } + + s.TunnelIP = tunnelIP.String() + + // Create WireGuard device using the shared bind + s.device = device.NewDevice(s.tun, s.sharedBind, device.NewLogger( + device.LogLevelSilent, // Use silent logging by default - could be made configurable + "wireguard: ", + )) + + // Configure WireGuard with private key + config := fmt.Sprintf("private_key=%s", util.FixKey(s.key.String())) + + err = s.device.IpcSet(config) + if err != nil { + s.mu.Unlock() + return fmt.Errorf("failed to configure WireGuard device: %v", err) + } + + // Bring up the device + err = s.device.Up() + if err != nil { + s.mu.Unlock() + return fmt.Errorf("failed to bring up WireGuard device: %v", err) + } + + logger.Info("WireGuard netstack device created and configured") + + // Release the mutex before calling the callback + s.mu.Unlock() + + s.wgTesterServer = wgtester.NewServerWithNetstack("0.0.0.0", s.Port, s.newtId, s.tnet) // TODO: maybe make this the same ip of the wg server? + err = s.wgTesterServer.Start() + if err != nil { + logger.Error("Failed to start WireGuard tester server: %v", err) + } + + // Note: we already unlocked above, so don't use defer unlock + return nil +} + +func (s *WireGuardService) ensureWireguardPeers(peers []Peer) error { + // For netstack, we need to manage peers differently + // We'll configure peers directly on the device using IPC + + // First, clear all existing peers by getting current config and removing them + currentConfig, err := s.device.IpcGet() + if err != nil { + return fmt.Errorf("failed to get current device config: %v", err) + } + + // Parse current peers and remove them + lines := strings.Split(currentConfig, "\n") + var currentPeerKeys []string + for _, line := range lines { + if strings.HasPrefix(line, "public_key=") { + pubKey := strings.TrimPrefix(line, "public_key=") + currentPeerKeys = append(currentPeerKeys, pubKey) + } + } + + // Remove existing peers + for _, pubKey := range currentPeerKeys { + removeConfig := fmt.Sprintf("public_key=%s\nremove=true", pubKey) + if err := s.device.IpcSet(removeConfig); err != nil { + logger.Warn("Failed to remove peer %s: %v", pubKey, err) + } + } + + // Add new peers + for _, peer := range peers { + if err := s.addPeerToDevice(peer); err != nil { + return fmt.Errorf("failed to add peer: %v", err) + } + } + + return nil +} + +func (s *WireGuardService) ensureTargets(targets []Target) error { + if s.tnet == nil { + // Native interface mode - proxy features not available, skip silently + logger.Debug("Skipping target configuration - using native interface (no proxy support)") + return nil + } + + for _, target := range targets { + sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix) + if err != nil { + return fmt.Errorf("invalid CIDR %s: %v", target.SourcePrefix, err) + } + + destPrefix, err := netip.ParsePrefix(target.DestPrefix) + if err != nil { + return fmt.Errorf("invalid CIDR %s: %v", target.DestPrefix, err) + } + + var portRanges []netstack2.PortRange + for _, pr := range target.PortRange { + portRanges = append(portRanges, netstack2.PortRange{ + Min: pr.Min, + Max: pr.Max, + }) + } + + s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges) + + logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange) + } + + return nil +} + +func (s *WireGuardService) addPeerToDevice(peer Peer) error { + // parse the key first + pubKey, err := wgtypes.ParseKey(peer.PublicKey) + if err != nil { + return fmt.Errorf("failed to parse public key: %v", err) + } + + // Build IPC configuration string for the peer + config := fmt.Sprintf("public_key=%s", util.FixKey(pubKey.String())) + + // Add allowed IPs + for _, allowedIP := range peer.AllowedIPs { + config += fmt.Sprintf("\nallowed_ip=%s", allowedIP) + } + + // Add endpoint if specified + if peer.Endpoint != "" { + config += fmt.Sprintf("\nendpoint=%s", peer.Endpoint) + } + + // Add persistent keepalive + config += "\npersistent_keepalive_interval=25" + + // Apply the configuration + if err := s.device.IpcSet(config); err != nil { + return fmt.Errorf("failed to configure peer: %v", err) + } + + logger.Info("Peer %s added successfully", peer.PublicKey) + return nil +} + +func (s *WireGuardService) handleAddPeer(msg websocket.WSMessage) { + logger.Debug("Received message: %v", msg.Data) + var peer Peer + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Info("Error marshaling data: %v", err) + return + } + + if err := json.Unmarshal(jsonData, &peer); err != nil { + logger.Info("Error unmarshaling target data: %v", err) + return + } + + if s.device == nil { + logger.Info("WireGuard device is not initialized") + return + } + + s.holePunchManager.TriggerHolePunch() + + err = s.addPeerToDevice(peer) + if err != nil { + logger.Info("Error adding peer: %v", err) + return + } +} + +func (s *WireGuardService) handleRemovePeer(msg websocket.WSMessage) { + logger.Debug("Received message: %v", msg.Data) + // parse the publicKey from the message which is json { "publicKey": "asdfasdfl;akjsdf" } + type RemoveRequest struct { + PublicKey string `json:"publicKey"` + } + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Info("Error marshaling data: %v", err) + return + } + + var request RemoveRequest + if err := json.Unmarshal(jsonData, &request); err != nil { + logger.Info("Error unmarshaling data: %v", err) + return + } + + if s.device == nil { + logger.Info("WireGuard device is not initialized") + return + } + + if err := s.removePeer(request.PublicKey); err != nil { + logger.Info("Error removing peer: %v", err) + return + } +} + +func (s *WireGuardService) removePeer(publicKey string) error { + + // Parse the public key + pubKey, err := wgtypes.ParseKey(publicKey) + if err != nil { + return fmt.Errorf("failed to parse public key: %v", err) + } + + // Build IPC configuration string to remove the peer + config := fmt.Sprintf("public_key=%s\nremove=true", util.FixKey(pubKey.String())) + + if err := s.device.IpcSet(config); err != nil { + return fmt.Errorf("failed to remove peer: %v", err) + } + + logger.Info("Peer %s removed successfully", publicKey) + return nil +} + +func (s *WireGuardService) handleUpdatePeer(msg websocket.WSMessage) { + logger.Debug("Received message: %v", msg.Data) + // Define a struct to match the incoming message structure with optional fields + type UpdatePeerRequest struct { + PublicKey string `json:"publicKey"` + AllowedIPs []string `json:"allowedIps,omitempty"` + Endpoint string `json:"endpoint,omitempty"` + } + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Info("Error marshaling data: %v", err) + return + } + + var request UpdatePeerRequest + if err := json.Unmarshal(jsonData, &request); err != nil { + logger.Info("Error unmarshaling peer data: %v", err) + return + } + + s.holePunchManager.TriggerHolePunch() + + // Parse the public key + pubKey, err := wgtypes.ParseKey(request.PublicKey) + if err != nil { + logger.Info("Failed to parse public key: %v", err) + return + } + + if s.device == nil { + logger.Info("WireGuard device is not initialized") + return + } + + // Build IPC configuration string to update the peer + config := fmt.Sprintf("public_key=%s\nupdate_only=true", util.FixKey(pubKey.String())) + + // Handle AllowedIPs update + if len(request.AllowedIPs) > 0 { + config += "\nreplace_allowed_ips=true" + for _, allowedIP := range request.AllowedIPs { + config += fmt.Sprintf("\nallowed_ip=%s", allowedIP) + } + logger.Info("Updating AllowedIPs for peer %s", request.PublicKey) + } + + // Handle Endpoint field special case + endpointSpecified := false + for key := range msg.Data.(map[string]interface{}) { + if key == "endpoint" { + endpointSpecified = true + break + } + } + + if endpointSpecified { + if request.Endpoint != "" { + config += fmt.Sprintf("\nendpoint=%s", request.Endpoint) + logger.Info("Updating Endpoint for peer %s to %s", request.PublicKey, request.Endpoint) + } else { + config += "\nendpoint=0.0.0.0:0" // Remove endpoint + logger.Info("Removing Endpoint for peer %s", request.PublicKey) + } + } + + // Always set persistent keepalive + config += "\npersistent_keepalive_interval=25" + + // Apply the configuration update + if err := s.device.IpcSet(config); err != nil { + logger.Info("Error updating peer configuration: %v", err) + return + } + + logger.Info("Peer %s updated successfully", request.PublicKey) +} + +func (s *WireGuardService) periodicBandwidthCheck() { + ticker := time.NewTicker(10 * time.Second) + defer ticker.Stop() + + for range ticker.C { + if err := s.reportPeerBandwidth(); err != nil { + logger.Info("Failed to report peer bandwidth: %v", err) + } + } +} + +func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) { + if s.device == nil { + return []PeerBandwidth{}, nil + } + + // Get device statistics using IPC + stats, err := s.device.IpcGet() + if err != nil { + return nil, fmt.Errorf("failed to get device statistics: %v", err) + } + + peerBandwidths := []PeerBandwidth{} + now := time.Now() + + s.mu.Lock() + defer s.mu.Unlock() + + // Parse the IPC response to extract peer statistics + lines := strings.Split(stats, "\n") + var currentPubKey string + var rxBytes, txBytes int64 + + for _, line := range lines { + if strings.HasPrefix(line, "public_key=") { + // Process previous peer if we have one + if currentPubKey != "" { + bandwidth := s.processPeerBandwidth(currentPubKey, rxBytes, txBytes, now) + if bandwidth != nil { + peerBandwidths = append(peerBandwidths, *bandwidth) + } + } + // Start new peer + currentPubKey = strings.TrimPrefix(line, "public_key=") + rxBytes = 0 + txBytes = 0 + } else if strings.HasPrefix(line, "rx_bytes=") { + rxBytes, _ = strconv.ParseInt(strings.TrimPrefix(line, "rx_bytes="), 10, 64) + } else if strings.HasPrefix(line, "tx_bytes=") { + txBytes, _ = strconv.ParseInt(strings.TrimPrefix(line, "tx_bytes="), 10, 64) + } + } + + // Process the last peer + if currentPubKey != "" { + bandwidth := s.processPeerBandwidth(currentPubKey, rxBytes, txBytes, now) + if bandwidth != nil { + peerBandwidths = append(peerBandwidths, *bandwidth) + } + } + + // Clean up old peers + devicePeers := make(map[string]bool) + lines = strings.Split(stats, "\n") + for _, line := range lines { + if strings.HasPrefix(line, "public_key=") { + pubKey := strings.TrimPrefix(line, "public_key=") + devicePeers[pubKey] = true + } + } + + for publicKey := range s.lastReadings { + if !devicePeers[publicKey] { + delete(s.lastReadings, publicKey) + } + } + + // parse the public keys and have them as base64 in the opposite order to fixKey + for i := range peerBandwidths { + peerBandwidths[i].PublicKey = util.UnfixKey(peerBandwidths[i].PublicKey) // its in the long form but we need base64 + } + + return peerBandwidths, nil +} + +func (s *WireGuardService) processPeerBandwidth(publicKey string, rxBytes, txBytes int64, now time.Time) *PeerBandwidth { + currentReading := PeerReading{ + BytesReceived: rxBytes, + BytesTransmitted: txBytes, + LastChecked: now, + } + + var bytesInDiff, bytesOutDiff float64 + lastReading, exists := s.lastReadings[publicKey] + + if exists { + timeDiff := currentReading.LastChecked.Sub(lastReading.LastChecked).Seconds() + if timeDiff > 0 { + // Calculate bytes transferred since last reading + bytesInDiff = float64(currentReading.BytesReceived - lastReading.BytesReceived) + bytesOutDiff = float64(currentReading.BytesTransmitted - lastReading.BytesTransmitted) + + // Handle counter wraparound (if the counter resets or overflows) + if bytesInDiff < 0 { + bytesInDiff = float64(currentReading.BytesReceived) + } + if bytesOutDiff < 0 { + bytesOutDiff = float64(currentReading.BytesTransmitted) + } + + // Convert to MB + bytesInMB := bytesInDiff / (1024 * 1024) + bytesOutMB := bytesOutDiff / (1024 * 1024) + + // Update the last reading + s.lastReadings[publicKey] = currentReading + + return &PeerBandwidth{ + PublicKey: publicKey, + BytesIn: bytesInMB, + BytesOut: bytesOutMB, + } + } + } + + // For first reading or if readings are too close together, report 0 + s.lastReadings[publicKey] = currentReading + return &PeerBandwidth{ + PublicKey: publicKey, + BytesIn: 0, + BytesOut: 0, + } +} + +func (s *WireGuardService) reportPeerBandwidth() error { + bandwidths, err := s.calculatePeerBandwidth() + if err != nil { + return fmt.Errorf("failed to calculate peer bandwidth: %v", err) + } + + err = s.client.SendMessageNoLog("newt/receive-bandwidth", map[string]interface{}{ + "bandwidthData": bandwidths, + }) + if err != nil { + return fmt.Errorf("failed to send bandwidth data: %v", err) + } + + return nil +} + +// filterReadOnlyFields removes read-only fields from WireGuard IPC configuration +func (s *WireGuardService) handleAddTarget(msg websocket.WSMessage) { + logger.Debug("Received message: %v", msg.Data) + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Info("Error marshaling data: %v", err) + return + } + + if s.tnet == nil { + // Native interface mode - proxy features not available, skip silently + logger.Debug("Skipping add target - using native interface (no proxy support)") + return + } + + // Try to unmarshal as array first + var targets []Target + if err := json.Unmarshal(jsonData, &targets); err != nil { + logger.Warn("Error unmarshaling target data: %v", err) + return + } + + // Process all targets + for _, target := range targets { + sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix) + if err != nil { + logger.Info("Invalid CIDR %s: %v", target.SourcePrefix, err) + continue + } + + destPrefix, err := netip.ParsePrefix(target.DestPrefix) + if err != nil { + logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err) + continue + } + + var portRanges []netstack2.PortRange + for _, pr := range target.PortRange { + portRanges = append(portRanges, netstack2.PortRange{ + Min: pr.Min, + Max: pr.Max, + }) + } + + s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges) + + logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange) + } +} + +// filterReadOnlyFields removes read-only fields from WireGuard IPC configuration +func (s *WireGuardService) handleRemoveTarget(msg websocket.WSMessage) { + logger.Debug("Received message: %v", msg.Data) + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Info("Error marshaling data: %v", err) + return + } + + if s.tnet == nil { + // Native interface mode - proxy features not available, skip silently + logger.Debug("Skipping remove target - using native interface (no proxy support)") + return + } + + // Try to unmarshal as array first + var targets []Target + if err := json.Unmarshal(jsonData, &targets); err != nil { + logger.Warn("Error unmarshaling target data: %v", err) + return + } + + // Process all targets + for _, target := range targets { + sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix) + if err != nil { + logger.Info("Invalid CIDR %s: %v", target.SourcePrefix, err) + continue + } + + destPrefix, err := netip.ParsePrefix(target.DestPrefix) + if err != nil { + logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err) + continue + } + + s.tnet.RemoveProxySubnetRule(sourcePrefix, destPrefix) + + logger.Info("Removed target subnet %s with destination %s", target.SourcePrefix, target.DestPrefix) + } +} + +func (s *WireGuardService) handleUpdateTarget(msg websocket.WSMessage) { + logger.Debug("Received message: %v", msg.Data) + + // you are going to get a oldTarget and a newTarget in the message + type UpdateTargetRequest struct { + OldTargets []Target `json:"oldTargets"` + NewTargets []Target `json:"newTargets"` + } + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Info("Error marshaling data: %v", err) + return + } + + if s.tnet == nil { + // Native interface mode - proxy features not available, skip silently + logger.Debug("Skipping update target - using native interface (no proxy support)") + return + } + + // Try to unmarshal as array first + var requests UpdateTargetRequest + if err := json.Unmarshal(jsonData, &requests); err != nil { + logger.Warn("Error unmarshaling target data: %v", err) + return + } + + // Process all update requests + for _, target := range requests.OldTargets { + sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix) + if err != nil { + logger.Info("Invalid CIDR %s: %v", target.SourcePrefix, err) + continue + } + + destPrefix, err := netip.ParsePrefix(target.DestPrefix) + if err != nil { + logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err) + continue + } + + s.tnet.RemoveProxySubnetRule(sourcePrefix, destPrefix) + logger.Info("Removed target subnet %s with destination %s", target.SourcePrefix, target.DestPrefix) + } + + for _, target := range requests.NewTargets { + // Now add the new target + sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix) + if err != nil { + logger.Info("Invalid CIDR %s: %v", target.SourcePrefix, err) + continue + } + + destPrefix, err := netip.ParsePrefix(target.DestPrefix) + if err != nil { + logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err) + continue + } + + var portRanges []netstack2.PortRange + for _, pr := range target.PortRange { + portRanges = append(portRanges, netstack2.PortRange{ + Min: pr.Min, + Max: pr.Max, + }) + } + + s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges) + logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange) + } +} + +// filterReadOnlyFields removes read-only fields from WireGuard IPC configuration +func (s *WireGuardService) filterReadOnlyFields(config string) string { + lines := strings.Split(config, "\n") + var filteredLines []string + + // List of read-only fields that should not be included in IpcSet + readOnlyFields := map[string]bool{ + "last_handshake_time_sec": true, + "last_handshake_time_nsec": true, + "rx_bytes": true, + "tx_bytes": true, + "protocol_version": true, + } + + for _, line := range lines { + if line == "" { + continue + } + + // Check if this line contains a read-only field + isReadOnly := false + for field := range readOnlyFields { + if strings.HasPrefix(line, field+"=") { + isReadOnly = true + break + } + } + + // Only include non-read-only lines + if !isReadOnly { + filteredLines = append(filteredLines, line) + } + } + + return strings.Join(filteredLines, "\n") +} diff --git a/clients/permissions/permissions_darwin.go b/clients/permissions/permissions_darwin.go new file mode 100644 index 0000000..d14bef4 --- /dev/null +++ b/clients/permissions/permissions_darwin.go @@ -0,0 +1,18 @@ +//go:build darwin + +package permissions + +import ( + "fmt" + "os" +) + +// CheckNativeInterfacePermissions checks if the process has sufficient +// permissions to create a native TUN interface on macOS. +// This typically requires root privileges. +func CheckNativeInterfacePermissions() error { + if os.Geteuid() == 0 { + return nil + } + return fmt.Errorf("insufficient permissions: need root to create TUN interface on macOS") +} diff --git a/clients/permissions/permissions_linux.go b/clients/permissions/permissions_linux.go new file mode 100644 index 0000000..e97ee6a --- /dev/null +++ b/clients/permissions/permissions_linux.go @@ -0,0 +1,96 @@ +//go:build linux + +package permissions + +import ( + "fmt" + "os" + "unsafe" + + "github.com/fosrl/newt/logger" + "golang.org/x/sys/unix" +) + +const ( + // TUN device constants + tunDevice = "/dev/net/tun" + ifnamsiz = 16 + iffTun = 0x0001 + iffNoPi = 0x1000 + tunSetIff = 0x400454ca +) + +// ifReq is the structure for TUNSETIFF ioctl +type ifReq struct { + Name [ifnamsiz]byte + Flags uint16 + _ [22]byte // padding to match kernel structure +} + +// CheckNativeInterfacePermissions checks if the process has sufficient +// permissions to create a native TUN interface on Linux. +// This requires either root privileges (UID 0) or CAP_NET_ADMIN capability. +func CheckNativeInterfacePermissions() error { + logger.Debug("Checking native interface permissions on Linux") + + // Check if running as root + if os.Geteuid() == 0 { + logger.Debug("Running as root, sufficient permissions for native TUN interface") + return nil + } + + // Check for CAP_NET_ADMIN capability + caps := unix.CapUserHeader{ + Version: unix.LINUX_CAPABILITY_VERSION_3, + Pid: 0, // 0 means current process + } + + var data [2]unix.CapUserData + if err := unix.Capget(&caps, &data[0]); err != nil { + logger.Debug("Failed to get capabilities: %v, will try creating test TUN", err) + } else { + // CAP_NET_ADMIN is capability bit 12 + const CAP_NET_ADMIN = 12 + if data[0].Effective&(1< 0 { avgLatency := totalLatency / time.Duration(successCount) - logger.Debug("Reliable ping succeeded after %d attempts, avg latency: %v", attempt, avgLatency) + // logger.Debug("Reliable ping succeeded after %d attempts, avg latency: %v", attempt, avgLatency) return avgLatency, nil } } @@ -366,89 +348,6 @@ func startPingCheck(tnet *netstack.Net, serverIP string, client *websocket.Clien return pingStopChan } -func parseLogLevel(level string) logger.LogLevel { - switch strings.ToUpper(level) { - case "DEBUG": - return logger.DEBUG - case "INFO": - return logger.INFO - case "WARN": - return logger.WARN - case "ERROR": - return logger.ERROR - case "FATAL": - return logger.FATAL - default: - return logger.INFO // default to INFO if invalid level provided - } -} - -func mapToWireGuardLogLevel(level logger.LogLevel) int { - switch level { - case logger.DEBUG: - return device.LogLevelVerbose - // case logger.INFO: - // return device.LogLevel - case logger.WARN: - return device.LogLevelError - case logger.ERROR, logger.FATAL: - return device.LogLevelSilent - default: - return device.LogLevelSilent - } -} - -func resolveDomain(domain string) (string, error) { - // Check if there's a port in the domain - host, port, err := net.SplitHostPort(domain) - if err != nil { - // No port found, use the domain as is - host = domain - port = "" - } - - // Remove any protocol prefix if present - if strings.HasPrefix(host, "http://") { - host = strings.TrimPrefix(host, "http://") - } else if strings.HasPrefix(host, "https://") { - host = strings.TrimPrefix(host, "https://") - } - - // if there are any trailing slashes, remove them - host = strings.TrimSuffix(host, "/") - - // Lookup IP addresses - ips, err := net.LookupIP(host) - if err != nil { - return "", fmt.Errorf("DNS lookup failed: %v", err) - } - - if len(ips) == 0 { - return "", fmt.Errorf("no IP addresses found for domain %s", host) - } - - // Get the first IPv4 address if available - var ipAddr string - for _, ip := range ips { - if ipv4 := ip.To4(); ipv4 != nil { - ipAddr = ipv4.String() - break - } - } - - // If no IPv4 found, use the first IP (might be IPv6) - if ipAddr == "" { - ipAddr = ips[0].String() - } - - // Add port back if it existed - if port != "" { - ipAddr = net.JoinHostPort(ipAddr, port) - } - - return ipAddr, nil -} - func parseTargetData(data interface{}) (TargetData, error) { var targetData TargetData jsonData, err := json.Marshal(data) diff --git a/docker/client.go b/docker/docker.go similarity index 100% rename from docker/client.go rename to docker/docker.go diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 0000000..b7bd9a1 --- /dev/null +++ b/examples/README.md @@ -0,0 +1,167 @@ +# Extensible Logger + +This logger package provides a flexible logging system that can be extended with custom log writers. + +## Basic Usage (Current Behavior) + +The logger works exactly as before with no changes required: + +```go +package main + +import "your-project/logger" + +func main() { + // Use default logger + logger.Info("This works as before") + logger.Debug("Debug message") + logger.Error("Error message") + + // Or create a custom instance + log := logger.NewLogger() + log.SetLevel(logger.INFO) + log.Info("Custom logger instance") +} +``` + +## Custom Log Writers + +To use a custom log backend, implement the `LogWriter` interface: + +```go +type LogWriter interface { + Write(level LogLevel, timestamp time.Time, message string) +} +``` + +### Example: OS Log Writer (macOS/iOS) + +```go +package main + +import "your-project/logger" + +func main() { + // Create an OS log writer + osWriter := logger.NewOSLogWriter( + "net.pangolin.Pangolin.PacketTunnel", + "PangolinGo", + "MyApp", + ) + + // Create a logger with the OS log writer + log := logger.NewLoggerWithWriter(osWriter) + log.SetLevel(logger.DEBUG) + + // Use it just like the standard logger + log.Info("This message goes to os_log") + log.Error("Error logged to os_log") +} +``` + +### Example: Custom Writer + +```go +package main + +import ( + "fmt" + "time" + "your-project/logger" +) + +// CustomWriter writes logs to a custom destination +type CustomWriter struct { + // your custom fields +} + +func (w *CustomWriter) Write(level logger.LogLevel, timestamp time.Time, message string) { + // Your custom logging logic + fmt.Printf("[CUSTOM] %s [%s] %s\n", timestamp.Format(time.RFC3339), level.String(), message) +} + +func main() { + customWriter := &CustomWriter{} + log := logger.NewLoggerWithWriter(customWriter) + log.Info("Custom logging!") +} +``` + +### Example: Multi-Writer (Log to Multiple Destinations) + +```go +package main + +import ( + "time" + "your-project/logger" +) + +// MultiWriter writes to multiple log writers +type MultiWriter struct { + writers []logger.LogWriter +} + +func NewMultiWriter(writers ...logger.LogWriter) *MultiWriter { + return &MultiWriter{writers: writers} +} + +func (w *MultiWriter) Write(level logger.LogLevel, timestamp time.Time, message string) { + for _, writer := range w.writers { + writer.Write(level, timestamp, message) + } +} + +func main() { + // Log to both standard output and OS log + standardWriter := logger.NewStandardWriter() + osWriter := logger.NewOSLogWriter("com.example.app", "Main", "App") + + multiWriter := NewMultiWriter(standardWriter, osWriter) + log := logger.NewLoggerWithWriter(multiWriter) + + log.Info("This goes to both stdout and os_log!") +} +``` + +## API Reference + +### Creating Loggers + +- `NewLogger()` - Creates a logger with the default StandardWriter +- `NewLoggerWithWriter(writer LogWriter)` - Creates a logger with a custom writer + +### Built-in Writers + +- `NewStandardWriter()` - Standard writer that outputs to stdout (default) +- `NewOSLogWriter(subsystem, category, prefix string)` - OS log writer for macOS/iOS (example) + +### Logger Methods + +- `SetLevel(level LogLevel)` - Set minimum log level +- `SetOutput(output *os.File)` - Set output file (StandardWriter only) +- `Debug(format string, args ...interface{})` - Log debug message +- `Info(format string, args ...interface{})` - Log info message +- `Warn(format string, args ...interface{})` - Log warning message +- `Error(format string, args ...interface{})` - Log error message +- `Fatal(format string, args ...interface{})` - Log fatal message and exit + +### Global Functions + +For convenience, you can use global functions that use the default logger: + +- `logger.Debug(format, args...)` +- `logger.Info(format, args...)` +- `logger.Warn(format, args...)` +- `logger.Error(format, args...)` +- `logger.Fatal(format, args...)` +- `logger.SetOutput(output *os.File)` + +## Migration Guide + +No changes needed! The logger maintains 100% backward compatibility. Your existing code will continue to work without modifications. + +If you want to switch to a custom writer: +1. Create your writer implementing `LogWriter` +2. Use `NewLoggerWithWriter()` instead of `NewLogger()` +3. That's it! diff --git a/examples/logger_examples.go b/examples/logger_examples.go new file mode 100644 index 0000000..81e95e4 --- /dev/null +++ b/examples/logger_examples.go @@ -0,0 +1,161 @@ +// Example usage patterns for the extensible logger +package main + +import ( + "fmt" + "os" + "time" + + "github.com/fosrl/newt/logger" +) + +// Example 1: Using the default logger (works exactly as before) +func exampleDefaultLogger() { + logger.Info("Starting application") + logger.Debug("Debug information") + logger.Warn("Warning message") + logger.Error("Error occurred") +} + +// Example 2: Using a custom logger instance with standard writer +func exampleCustomInstance() { + log := logger.NewLogger() + log.SetLevel(logger.INFO) + log.Info("This is from a custom instance") +} + +// Example 3: Custom writer that adds JSON formatting +type JSONWriter struct{} + +func (w *JSONWriter) Write(level logger.LogLevel, timestamp time.Time, message string) { + fmt.Printf("{\"time\":\"%s\",\"level\":\"%s\",\"message\":\"%s\"}\n", + timestamp.Format(time.RFC3339), + level.String(), + message) +} + +func exampleJSONLogger() { + jsonWriter := &JSONWriter{} + log := logger.NewLoggerWithWriter(jsonWriter) + log.Info("This will be logged as JSON") +} + +// Example 4: File writer +type FileWriter struct { + file *os.File +} + +func NewFileWriter(filename string) (*FileWriter, error) { + file, err := os.OpenFile(filename, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) + if err != nil { + return nil, err + } + return &FileWriter{file: file}, nil +} + +func (w *FileWriter) Write(level logger.LogLevel, timestamp time.Time, message string) { + fmt.Fprintf(w.file, "[%s] %s: %s\n", + timestamp.Format("2006-01-02 15:04:05"), + level.String(), + message) +} + +func (w *FileWriter) Close() error { + return w.file.Close() +} + +func exampleFileLogger() { + fileWriter, err := NewFileWriter("/tmp/app.log") + if err != nil { + panic(err) + } + defer fileWriter.Close() + + log := logger.NewLoggerWithWriter(fileWriter) + log.Info("This goes to a file") +} + +// Example 5: Multi-writer to log to multiple destinations +type MultiWriter struct { + writers []logger.LogWriter +} + +func NewMultiWriter(writers ...logger.LogWriter) *MultiWriter { + return &MultiWriter{writers: writers} +} + +func (w *MultiWriter) Write(level logger.LogLevel, timestamp time.Time, message string) { + for _, writer := range w.writers { + writer.Write(level, timestamp, message) + } +} + +func exampleMultiWriter() { + // Log to both stdout and a file + standardWriter := logger.NewStandardWriter() + fileWriter, _ := NewFileWriter("/tmp/app.log") + + multiWriter := NewMultiWriter(standardWriter, fileWriter) + log := logger.NewLoggerWithWriter(multiWriter) + + log.Info("This goes to both stdout and file!") +} + +// Example 6: Conditional writer (only log errors to a specific destination) +type ErrorOnlyWriter struct { + writer logger.LogWriter +} + +func NewErrorOnlyWriter(writer logger.LogWriter) *ErrorOnlyWriter { + return &ErrorOnlyWriter{writer: writer} +} + +func (w *ErrorOnlyWriter) Write(level logger.LogLevel, timestamp time.Time, message string) { + if level >= logger.ERROR { + w.writer.Write(level, timestamp, message) + } +} + +func exampleConditionalWriter() { + errorWriter, _ := NewFileWriter("/tmp/errors.log") + errorOnlyWriter := NewErrorOnlyWriter(errorWriter) + + log := logger.NewLoggerWithWriter(errorOnlyWriter) + log.Info("This won't be logged") + log.Error("This will be logged to errors.log") +} + +/* Example 7: OS Log Writer (macOS/iOS only) +// Uncomment on Darwin platforms + +func exampleOSLogWriter() { + osWriter := logger.NewOSLogWriter( + "net.pangolin.Pangolin.PacketTunnel", + "PangolinGo", + "MyApp", + ) + + log := logger.NewLoggerWithWriter(osWriter) + log.Info("This goes to os_log and can be viewed with Console.app") +} +*/ + +func main() { + fmt.Println("=== Example 1: Default Logger ===") + exampleDefaultLogger() + + fmt.Println("\n=== Example 2: Custom Instance ===") + exampleCustomInstance() + + fmt.Println("\n=== Example 3: JSON Logger ===") + exampleJSONLogger() + + fmt.Println("\n=== Example 4: File Logger ===") + exampleFileLogger() + + fmt.Println("\n=== Example 5: Multi-Writer ===") + exampleMultiWriter() + + fmt.Println("\n=== Example 6: Conditional Writer ===") + exampleConditionalWriter() +} diff --git a/examples/oslog_writer_example.go b/examples/oslog_writer_example.go new file mode 100644 index 0000000..2c5d3f7 --- /dev/null +++ b/examples/oslog_writer_example.go @@ -0,0 +1,86 @@ +//go:build darwin +// +build darwin + +package main + +/* +#cgo CFLAGS: -I../PacketTunnel +#include "../PacketTunnel/OSLogBridge.h" +#include +*/ +import "C" +import ( + "fmt" + "runtime" + "time" + "unsafe" +) + +// OSLogWriter is a LogWriter implementation that writes to Apple's os_log +type OSLogWriter struct { + subsystem string + category string + prefix string +} + +// NewOSLogWriter creates a new OSLogWriter +func NewOSLogWriter(subsystem, category, prefix string) *OSLogWriter { + writer := &OSLogWriter{ + subsystem: subsystem, + category: category, + prefix: prefix, + } + + // Initialize the OS log bridge + cSubsystem := C.CString(subsystem) + cCategory := C.CString(category) + defer C.free(unsafe.Pointer(cSubsystem)) + defer C.free(unsafe.Pointer(cCategory)) + + C.initOSLogBridge(cSubsystem, cCategory) + + return writer +} + +// Write implements the LogWriter interface +func (w *OSLogWriter) Write(level LogLevel, timestamp time.Time, message string) { + // Get caller information (skip 3 frames to get to the actual caller) + _, file, line, ok := runtime.Caller(3) + if !ok { + file = "unknown" + line = 0 + } else { + // Get just the filename, not the full path + for i := len(file) - 1; i > 0; i-- { + if file[i] == '/' { + file = file[i+1:] + break + } + } + } + + formattedTime := timestamp.Format("2006-01-02 15:04:05.000") + fullMessage := fmt.Sprintf("[%s] [%s] [%s] %s:%d - %s", + formattedTime, level.String(), w.prefix, file, line, message) + + cMessage := C.CString(fullMessage) + defer C.free(unsafe.Pointer(cMessage)) + + // Map Go log levels to os_log levels: + // 0=DEBUG, 1=INFO, 2=DEFAULT (WARN), 3=ERROR + var osLogLevel C.int + switch level { + case DEBUG: + osLogLevel = 0 // DEBUG + case INFO: + osLogLevel = 1 // INFO + case WARN: + osLogLevel = 2 // DEFAULT + case ERROR, FATAL: + osLogLevel = 3 // ERROR + default: + osLogLevel = 2 // DEFAULT + } + + C.logToOSLog(osLogLevel, cMessage) +} diff --git a/go.mod b/go.mod index 109e2e5..ac634ed 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,6 @@ go 1.25 require ( github.com/docker/docker v28.5.2+incompatible - github.com/google/gopacket v1.1.19 github.com/gorilla/websocket v1.5.3 github.com/prometheus/client_golang v1.23.2 github.com/vishvananda/netlink v1.3.1 @@ -18,9 +17,12 @@ require ( go.opentelemetry.io/otel/sdk v1.38.0 go.opentelemetry.io/otel/sdk/metric v1.38.0 golang.org/x/crypto v0.45.0 + golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 golang.org/x/net v0.47.0 + golang.org/x/sys v0.38.0 golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 + golang.zx2c4.com/wireguard/windows v0.5.3 google.golang.org/grpc v1.76.0 gopkg.in/yaml.v3 v3.0.1 gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c @@ -41,14 +43,9 @@ require ( github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/google/btree v1.1.3 // indirect - github.com/google/go-cmp v0.7.0 // indirect github.com/google/uuid v1.6.0 // indirect github.com/grafana/regexp v0.0.0-20240518133315-a468a5bfb3bc // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 // indirect - github.com/josharian/native v1.1.0 // indirect - github.com/mdlayher/genetlink v1.3.2 // indirect - github.com/mdlayher/netlink v1.7.2 // indirect - github.com/mdlayher/socket v0.5.1 // indirect github.com/moby/docker-image-spec v1.3.1 // indirect github.com/moby/sys/atomicwriter v0.1.0 // indirect github.com/moby/term v0.5.2 // indirect @@ -68,12 +65,11 @@ require ( go.opentelemetry.io/otel/trace v1.38.0 // indirect go.opentelemetry.io/proto/otlp v1.7.1 // indirect go.yaml.in/yaml/v2 v2.4.2 // indirect - golang.org/x/mod v0.29.0 // indirect + golang.org/x/mod v0.30.0 // indirect golang.org/x/sync v0.18.0 // indirect - golang.org/x/sys v0.38.0 // indirect golang.org/x/text v0.31.0 // indirect golang.org/x/time v0.12.0 // indirect - golang.org/x/tools v0.38.0 // indirect + golang.org/x/tools v0.39.0 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 // indirect diff --git a/go.sum b/go.sum index 4fe11a3..835b0e7 100644 --- a/go.sum +++ b/go.sum @@ -37,8 +37,6 @@ github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg= github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= -github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= -github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= @@ -47,8 +45,6 @@ github.com/grafana/regexp v0.0.0-20240518133315-a468a5bfb3bc h1:GN2Lv3MGO7AS6PrR github.com/grafana/regexp v0.0.0-20240518133315-a468a5bfb3bc/go.mod h1:+JKpmjMGhpgPL+rXZ5nsZieVzvarn86asRlBg4uNGnk= github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 h1:8Tjv8EJ+pM1xP8mK6egEbD1OgnVTyacbefKhmbLhIhU= github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2/go.mod h1:pkJQ2tZHJ0aFOVEEot6oZmaVEZcRme73eIFmhiVuRWs= -github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA= -github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= @@ -57,14 +53,6 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= -github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw= -github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o= -github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g= -github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw= -github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos= -github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ= -github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws= -github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc= github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0= github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo= github.com/moby/sys/atomicwriter v0.1.0 h1:kw5D/EqkBwsBFi0ss9v1VG3wIkVhzGvLklJ+w3A14Sw= @@ -137,42 +125,34 @@ go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI= go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= -golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= -golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= -golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA= -golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 h1:zfMcR1Cs4KNuomFFgGefv5N0czO2XZpUbxGUy8i8ug0= +golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6/go.mod h1:46edojNIoXTNOhySWIWdix628clX9ODXwPsQuG6hsK0= +golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk= +golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc= golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM= golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM= golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= -golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ= -golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs= -golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ= +golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb h1:whnFRlWMcXI9d+ZbWg+4sHnLp52d5yiIPUxMBSt4X9A= golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw= golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 h1:3GDAcqdIg1ozBNLgPy4SLT84nfcBjr6rhGtXYtrkWLU= golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T97yPqesLiNrOYxkwmhMI0ZIlJDm+p0PMR8eRVeR5tQ= +golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE= +golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI= gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5 h1:BIRfGDEjiHRrk0QKZe3Xv2ieMhtgRGeLcZQ0mIVn4EY= diff --git a/holepunch/holepunch.go b/holepunch/holepunch.go new file mode 100644 index 0000000..b6e0a6b --- /dev/null +++ b/holepunch/holepunch.go @@ -0,0 +1,517 @@ +package holepunch + +import ( + "encoding/json" + "fmt" + "net" + "sync" + "time" + + "github.com/fosrl/newt/bind" + "github.com/fosrl/newt/logger" + "github.com/fosrl/newt/util" + "golang.org/x/crypto/chacha20poly1305" + "golang.org/x/crypto/curve25519" + mrand "golang.org/x/exp/rand" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +// ExitNode represents a WireGuard exit node for hole punching +type ExitNode struct { + Endpoint string `json:"endpoint"` + PublicKey string `json:"publicKey"` +} + +// Manager handles UDP hole punching operations +type Manager struct { + mu sync.Mutex + running bool + stopChan chan struct{} + sharedBind *bind.SharedBind + ID string + token string + publicKey string + clientType string + exitNodes map[string]ExitNode // key is endpoint + updateChan chan struct{} // signals the goroutine to refresh exit nodes + + sendHolepunchInterval time.Duration +} + +const sendHolepunchIntervalMax = 60 * time.Second +const sendHolepunchIntervalMin = 1 * time.Second + +// NewManager creates a new hole punch manager +func NewManager(sharedBind *bind.SharedBind, ID string, clientType string, publicKey string) *Manager { + return &Manager{ + sharedBind: sharedBind, + ID: ID, + clientType: clientType, + publicKey: publicKey, + exitNodes: make(map[string]ExitNode), + sendHolepunchInterval: sendHolepunchIntervalMin, + } +} + +// SetToken updates the authentication token used for hole punching +func (m *Manager) SetToken(token string) { + m.mu.Lock() + defer m.mu.Unlock() + m.token = token +} + +// IsRunning returns whether hole punching is currently active +func (m *Manager) IsRunning() bool { + m.mu.Lock() + defer m.mu.Unlock() + return m.running +} + +// Stop stops any ongoing hole punch operations +func (m *Manager) Stop() { + m.mu.Lock() + defer m.mu.Unlock() + + if !m.running { + return + } + + if m.stopChan != nil { + close(m.stopChan) + m.stopChan = nil + } + + if m.updateChan != nil { + close(m.updateChan) + m.updateChan = nil + } + + m.running = false + logger.Info("Hole punch manager stopped") +} + +// AddExitNode adds a new exit node to the rotation if it doesn't already exist +func (m *Manager) AddExitNode(exitNode ExitNode) bool { + m.mu.Lock() + defer m.mu.Unlock() + + if _, exists := m.exitNodes[exitNode.Endpoint]; exists { + logger.Debug("Exit node %s already exists in rotation", exitNode.Endpoint) + return false + } + + m.exitNodes[exitNode.Endpoint] = exitNode + logger.Info("Added exit node %s to hole punch rotation", exitNode.Endpoint) + + // Signal the goroutine to refresh if running + if m.running && m.updateChan != nil { + select { + case m.updateChan <- struct{}{}: + default: + // Channel full or closed, skip + } + } + + return true +} + +// RemoveExitNode removes an exit node from the rotation +func (m *Manager) RemoveExitNode(endpoint string) bool { + m.mu.Lock() + defer m.mu.Unlock() + + if _, exists := m.exitNodes[endpoint]; !exists { + logger.Debug("Exit node %s not found in rotation", endpoint) + return false + } + + delete(m.exitNodes, endpoint) + logger.Info("Removed exit node %s from hole punch rotation", endpoint) + + // Signal the goroutine to refresh if running + if m.running && m.updateChan != nil { + select { + case m.updateChan <- struct{}{}: + default: + // Channel full or closed, skip + } + } + + return true +} + +// GetExitNodes returns a copy of the current exit nodes +func (m *Manager) GetExitNodes() []ExitNode { + m.mu.Lock() + defer m.mu.Unlock() + + nodes := make([]ExitNode, 0, len(m.exitNodes)) + for _, node := range m.exitNodes { + nodes = append(nodes, node) + } + return nodes +} + +// ResetInterval resets the hole punch interval back to the minimum value, +// allowing it to climb back up through exponential backoff. +// This is useful when network conditions change or connectivity is restored. +func (m *Manager) ResetInterval() { + m.mu.Lock() + defer m.mu.Unlock() + + if m.sendHolepunchInterval != sendHolepunchIntervalMin { + m.sendHolepunchInterval = sendHolepunchIntervalMin + logger.Info("Reset hole punch interval to minimum (%v)", sendHolepunchIntervalMin) + } + + // Signal the goroutine to apply the new interval if running + if m.running && m.updateChan != nil { + select { + case m.updateChan <- struct{}{}: + default: + // Channel full or closed, skip + } + } +} + +// TriggerHolePunch sends an immediate hole punch packet to all configured exit nodes +// This is useful for triggering hole punching on demand without waiting for the interval +func (m *Manager) TriggerHolePunch() error { + m.mu.Lock() + + if len(m.exitNodes) == 0 { + m.mu.Unlock() + return fmt.Errorf("no exit nodes configured") + } + + // Get a copy of exit nodes to work with + currentExitNodes := make([]ExitNode, 0, len(m.exitNodes)) + for _, node := range m.exitNodes { + currentExitNodes = append(currentExitNodes, node) + } + m.mu.Unlock() + + logger.Info("Triggering on-demand hole punch to %d exit nodes", len(currentExitNodes)) + + // Send hole punch to all exit nodes + successCount := 0 + for _, exitNode := range currentExitNodes { + host, err := util.ResolveDomain(exitNode.Endpoint) + if err != nil { + logger.Warn("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err) + continue + } + + serverAddr := net.JoinHostPort(host, "21820") + remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr) + if err != nil { + logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err) + continue + } + + if err := m.sendHolePunch(remoteAddr, exitNode.PublicKey); err != nil { + logger.Warn("Failed to send on-demand hole punch to %s: %v", exitNode.Endpoint, err) + continue + } + + logger.Debug("Sent on-demand hole punch to %s", exitNode.Endpoint) + successCount++ + } + + if successCount == 0 { + return fmt.Errorf("failed to send hole punch to any exit node") + } + + logger.Info("Successfully sent on-demand hole punch to %d/%d exit nodes", successCount, len(currentExitNodes)) + return nil +} + +// StartMultipleExitNodes starts hole punching to multiple exit nodes +func (m *Manager) StartMultipleExitNodes(exitNodes []ExitNode) error { + m.mu.Lock() + + if m.running { + m.mu.Unlock() + logger.Debug("UDP hole punch already running, skipping new request") + return fmt.Errorf("hole punch already running") + } + + // Populate exit nodes map + m.exitNodes = make(map[string]ExitNode) + for _, node := range exitNodes { + m.exitNodes[node.Endpoint] = node + } + + m.running = true + m.stopChan = make(chan struct{}) + m.updateChan = make(chan struct{}, 1) + m.mu.Unlock() + + logger.Info("Starting UDP hole punch to %d exit nodes with shared bind", len(exitNodes)) + + go m.runMultipleExitNodes() + + return nil +} + +// Start starts hole punching with the current set of exit nodes +func (m *Manager) Start() error { + m.mu.Lock() + + if m.running { + m.mu.Unlock() + logger.Debug("UDP hole punch already running") + return fmt.Errorf("hole punch already running") + } + + m.running = true + m.stopChan = make(chan struct{}) + m.updateChan = make(chan struct{}, 1) + nodeCount := len(m.exitNodes) + m.mu.Unlock() + + if nodeCount == 0 { + logger.Info("Starting UDP hole punch manager (waiting for exit nodes to be added)") + } else { + logger.Info("Starting UDP hole punch with %d exit nodes", nodeCount) + } + + go m.runMultipleExitNodes() + + return nil +} + +// runMultipleExitNodes performs hole punching to multiple exit nodes +func (m *Manager) runMultipleExitNodes() { + defer func() { + m.mu.Lock() + m.running = false + m.mu.Unlock() + logger.Info("UDP hole punch goroutine ended for all exit nodes") + }() + + // Resolve all endpoints upfront + type resolvedExitNode struct { + remoteAddr *net.UDPAddr + publicKey string + endpointName string + } + + resolveNodes := func() []resolvedExitNode { + m.mu.Lock() + currentExitNodes := make([]ExitNode, 0, len(m.exitNodes)) + for _, node := range m.exitNodes { + currentExitNodes = append(currentExitNodes, node) + } + m.mu.Unlock() + + var resolvedNodes []resolvedExitNode + for _, exitNode := range currentExitNodes { + host, err := util.ResolveDomain(exitNode.Endpoint) + if err != nil { + logger.Warn("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err) + continue + } + + serverAddr := net.JoinHostPort(host, "21820") + remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr) + if err != nil { + logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err) + continue + } + + resolvedNodes = append(resolvedNodes, resolvedExitNode{ + remoteAddr: remoteAddr, + publicKey: exitNode.PublicKey, + endpointName: exitNode.Endpoint, + }) + logger.Info("Resolved exit node: %s -> %s", exitNode.Endpoint, remoteAddr.String()) + } + return resolvedNodes + } + + resolvedNodes := resolveNodes() + + if len(resolvedNodes) == 0 { + logger.Info("No exit nodes available yet, waiting for nodes to be added") + } else { + // Send initial hole punch to all exit nodes + for _, node := range resolvedNodes { + if err := m.sendHolePunch(node.remoteAddr, node.publicKey); err != nil { + logger.Warn("Failed to send initial hole punch to %s: %v", node.endpointName, err) + } + } + } + + // Start with minimum interval + m.mu.Lock() + m.sendHolepunchInterval = sendHolepunchIntervalMin + m.mu.Unlock() + + ticker := time.NewTicker(m.sendHolepunchInterval) + defer ticker.Stop() + + for { + select { + case <-m.stopChan: + logger.Debug("Hole punch stopped by signal") + return + case <-m.updateChan: + // Re-resolve exit nodes when update is signaled + logger.Info("Refreshing exit nodes for hole punching") + resolvedNodes = resolveNodes() + if len(resolvedNodes) == 0 { + logger.Warn("No exit nodes available after refresh") + } else { + logger.Info("Updated resolved nodes count: %d", len(resolvedNodes)) + } + // Reset interval to minimum on update + m.mu.Lock() + m.sendHolepunchInterval = sendHolepunchIntervalMin + m.mu.Unlock() + ticker.Reset(m.sendHolepunchInterval) + // Send immediate hole punch to newly resolved nodes + for _, node := range resolvedNodes { + if err := m.sendHolePunch(node.remoteAddr, node.publicKey); err != nil { + logger.Debug("Failed to send hole punch to %s: %v", node.endpointName, err) + } + } + case <-ticker.C: + // Send hole punch to all exit nodes (if any are available) + if len(resolvedNodes) > 0 { + for _, node := range resolvedNodes { + if err := m.sendHolePunch(node.remoteAddr, node.publicKey); err != nil { + logger.Debug("Failed to send hole punch to %s: %v", node.endpointName, err) + } + } + // Exponential backoff: double the interval up to max + m.mu.Lock() + newInterval := m.sendHolepunchInterval * 2 + if newInterval > sendHolepunchIntervalMax { + newInterval = sendHolepunchIntervalMax + } + if newInterval != m.sendHolepunchInterval { + m.sendHolepunchInterval = newInterval + ticker.Reset(m.sendHolepunchInterval) + logger.Debug("Increased hole punch interval to %v", m.sendHolepunchInterval) + } + m.mu.Unlock() + } + } + } +} + +// sendHolePunch sends an encrypted hole punch packet using the shared bind +func (m *Manager) sendHolePunch(remoteAddr *net.UDPAddr, serverPubKey string) error { + m.mu.Lock() + token := m.token + ID := m.ID + m.mu.Unlock() + + if serverPubKey == "" || token == "" { + return fmt.Errorf("server public key or OLM token is empty") + } + + var payload interface{} + if m.clientType == "newt" { + payload = struct { + ID string `json:"newtId"` + Token string `json:"token"` + PublicKey string `json:"publicKey"` + }{ + ID: ID, + Token: token, + PublicKey: m.publicKey, + } + } else { + payload = struct { + ID string `json:"olmId"` + Token string `json:"token"` + PublicKey string `json:"publicKey"` + }{ + ID: ID, + Token: token, + PublicKey: m.publicKey, + } + } + + // Convert payload to JSON + payloadBytes, err := json.Marshal(payload) + if err != nil { + return fmt.Errorf("failed to marshal payload: %w", err) + } + + // Encrypt the payload using the server's WireGuard public key + encryptedPayload, err := encryptPayload(payloadBytes, serverPubKey) + if err != nil { + return fmt.Errorf("failed to encrypt payload: %w", err) + } + + jsonData, err := json.Marshal(encryptedPayload) + if err != nil { + return fmt.Errorf("failed to marshal encrypted payload: %w", err) + } + + _, err = m.sharedBind.WriteToUDP(jsonData, remoteAddr) + if err != nil { + return fmt.Errorf("failed to write to UDP: %w", err) + } + + logger.Debug("Sent UDP hole punch to %s: %s", remoteAddr.String(), string(jsonData)) + + return nil +} + +// encryptPayload encrypts the payload using ChaCha20-Poly1305 AEAD with X25519 key exchange +func encryptPayload(payload []byte, serverPublicKey string) (interface{}, error) { + // Generate an ephemeral keypair for this message + ephemeralPrivateKey, err := wgtypes.GeneratePrivateKey() + if err != nil { + return nil, fmt.Errorf("failed to generate ephemeral private key: %v", err) + } + ephemeralPublicKey := ephemeralPrivateKey.PublicKey() + + // Parse the server's public key + serverPubKey, err := wgtypes.ParseKey(serverPublicKey) + if err != nil { + return nil, fmt.Errorf("failed to parse server public key: %v", err) + } + + // Use X25519 for key exchange + var ephPrivKeyFixed [32]byte + copy(ephPrivKeyFixed[:], ephemeralPrivateKey[:]) + + // Perform X25519 key exchange + sharedSecret, err := curve25519.X25519(ephPrivKeyFixed[:], serverPubKey[:]) + if err != nil { + return nil, fmt.Errorf("failed to perform X25519 key exchange: %v", err) + } + + // Create an AEAD cipher using the shared secret + aead, err := chacha20poly1305.New(sharedSecret) + if err != nil { + return nil, fmt.Errorf("failed to create AEAD cipher: %v", err) + } + + // Generate a random nonce + nonce := make([]byte, aead.NonceSize()) + if _, err := mrand.Read(nonce); err != nil { + return nil, fmt.Errorf("failed to generate nonce: %v", err) + } + + // Encrypt the payload + ciphertext := aead.Seal(nil, nonce, payload, nil) + + // Prepare the final encrypted message + encryptedMsg := struct { + EphemeralPublicKey string `json:"ephemeralPublicKey"` + Nonce []byte `json:"nonce"` + Ciphertext []byte `json:"ciphertext"` + }{ + EphemeralPublicKey: ephemeralPublicKey.String(), + Nonce: nonce, + Ciphertext: ciphertext, + } + + return encryptedMsg, nil +} diff --git a/holepunch/tester.go b/holepunch/tester.go new file mode 100644 index 0000000..3bebc4d --- /dev/null +++ b/holepunch/tester.go @@ -0,0 +1,343 @@ +package holepunch + +import ( + "crypto/rand" + "fmt" + "net" + "net/netip" + "sync" + "time" + + "github.com/fosrl/newt/bind" + "github.com/fosrl/newt/logger" + "github.com/fosrl/newt/util" +) + +// TestResult represents the result of a connection test +type TestResult struct { + // Success indicates whether the test was successful + Success bool + // RTT is the round-trip time of the test packet + RTT time.Duration + // Endpoint is the endpoint that was tested + Endpoint string + // Error contains any error that occurred during the test + Error error +} + +// TestConnectionOptions configures the connection test +type TestConnectionOptions struct { + // Timeout is how long to wait for a response (default: 5 seconds) + Timeout time.Duration + // Retries is the number of times to retry on failure (default: 0) + Retries int +} + +// DefaultTestOptions returns the default test options +func DefaultTestOptions() TestConnectionOptions { + return TestConnectionOptions{ + Timeout: 5 * time.Second, + Retries: 0, + } +} + +// HolepunchTester monitors holepunch connectivity using magic packets +type HolepunchTester struct { + sharedBind *bind.SharedBind + mu sync.RWMutex + running bool + stopChan chan struct{} + + // Pending requests waiting for responses (key: echo data as string) + pendingRequests sync.Map // map[string]*pendingRequest + + // Callback when connection status changes + callback HolepunchStatusCallback +} + +// HolepunchStatus represents the status of a holepunch connection +type HolepunchStatus struct { + Endpoint string + Connected bool + RTT time.Duration +} + +// HolepunchStatusCallback is called when holepunch status changes +type HolepunchStatusCallback func(status HolepunchStatus) + +// pendingRequest tracks a pending test request +type pendingRequest struct { + endpoint string + sentAt time.Time + replyChan chan time.Duration +} + +// NewHolepunchTester creates a new holepunch tester using the given SharedBind +func NewHolepunchTester(sharedBind *bind.SharedBind) *HolepunchTester { + return &HolepunchTester{ + sharedBind: sharedBind, + } +} + +// SetCallback sets the callback for connection status changes +func (t *HolepunchTester) SetCallback(callback HolepunchStatusCallback) { + t.mu.Lock() + defer t.mu.Unlock() + t.callback = callback +} + +// Start begins listening for magic packet responses +func (t *HolepunchTester) Start() error { + t.mu.Lock() + defer t.mu.Unlock() + + if t.running { + return fmt.Errorf("tester already running") + } + + if t.sharedBind == nil { + return fmt.Errorf("sharedBind is nil") + } + + t.running = true + t.stopChan = make(chan struct{}) + + // Register our callback with the SharedBind to receive magic responses + t.sharedBind.SetMagicResponseCallback(t.handleResponse) + + logger.Debug("HolepunchTester started") + return nil +} + +// Stop stops the tester +func (t *HolepunchTester) Stop() { + t.mu.Lock() + defer t.mu.Unlock() + + if !t.running { + return + } + + t.running = false + close(t.stopChan) + + // Clear the callback + if t.sharedBind != nil { + t.sharedBind.SetMagicResponseCallback(nil) + } + + // Cancel all pending requests + t.pendingRequests.Range(func(key, value interface{}) bool { + if req, ok := value.(*pendingRequest); ok { + close(req.replyChan) + } + t.pendingRequests.Delete(key) + return true + }) + + logger.Debug("HolepunchTester stopped") +} + +// handleResponse is called by SharedBind when a magic response is received +func (t *HolepunchTester) handleResponse(addr netip.AddrPort, echoData []byte) { + logger.Debug("Received magic response from %s", addr.String()) + key := string(echoData) + + value, ok := t.pendingRequests.LoadAndDelete(key) + if !ok { + // No matching request found + logger.Debug("No pending request found for magic response from %s", addr.String()) + return + } + + req := value.(*pendingRequest) + rtt := time.Since(req.sentAt) + logger.Debug("Magic response matched pending request for %s (RTT: %v)", req.endpoint, rtt) + + // Send RTT to the waiting goroutine (non-blocking) + select { + case req.replyChan <- rtt: + default: + } +} + +// TestEndpoint sends a magic test packet to the endpoint and waits for a response. +// This uses the SharedBind so packets come from the same source port as WireGuard. +func (t *HolepunchTester) TestEndpoint(endpoint string, timeout time.Duration) TestResult { + result := TestResult{ + Endpoint: endpoint, + } + + t.mu.RLock() + running := t.running + sharedBind := t.sharedBind + t.mu.RUnlock() + + if !running { + result.Error = fmt.Errorf("tester not running") + return result + } + + if sharedBind == nil || sharedBind.IsClosed() { + result.Error = fmt.Errorf("sharedBind is nil or closed") + return result + } + + // Resolve the endpoint + host, err := util.ResolveDomain(endpoint) + if err != nil { + host = endpoint + } + + _, _, err = net.SplitHostPort(host) + if err != nil { + host = net.JoinHostPort(host, "21820") + } + + remoteAddr, err := net.ResolveUDPAddr("udp", host) + if err != nil { + result.Error = fmt.Errorf("failed to resolve UDP address %s: %w", host, err) + return result + } + + // Generate random data for the test packet + randomData := make([]byte, bind.MagicPacketDataLen) + if _, err := rand.Read(randomData); err != nil { + result.Error = fmt.Errorf("failed to generate random data: %w", err) + return result + } + + // Create a pending request + req := &pendingRequest{ + endpoint: endpoint, + sentAt: time.Now(), + replyChan: make(chan time.Duration, 1), + } + + key := string(randomData) + t.pendingRequests.Store(key, req) + + // Build the test request packet + request := make([]byte, bind.MagicTestRequestLen) + copy(request, bind.MagicTestRequest) + copy(request[len(bind.MagicTestRequest):], randomData) + + // Send the test packet + _, err = sharedBind.WriteToUDP(request, remoteAddr) + if err != nil { + t.pendingRequests.Delete(key) + result.Error = fmt.Errorf("failed to send test packet: %w", err) + return result + } + + // Wait for response with timeout + select { + case rtt, ok := <-req.replyChan: + if ok { + result.Success = true + result.RTT = rtt + } else { + result.Error = fmt.Errorf("request cancelled") + } + case <-time.After(timeout): + t.pendingRequests.Delete(key) + result.Error = fmt.Errorf("timeout waiting for response") + } + + return result +} + +// TestConnectionWithBind sends a magic test packet using an existing SharedBind. +// This is useful when you want to test the connection through the same socket +// that WireGuard is using, which tests the actual hole-punched path. +func TestConnectionWithBind(sharedBind *bind.SharedBind, endpoint string, opts *TestConnectionOptions) TestResult { + if opts == nil { + defaultOpts := DefaultTestOptions() + opts = &defaultOpts + } + + result := TestResult{ + Endpoint: endpoint, + } + + if sharedBind == nil { + result.Error = fmt.Errorf("sharedBind is nil") + return result + } + + if sharedBind.IsClosed() { + result.Error = fmt.Errorf("sharedBind is closed") + return result + } + + // Resolve the endpoint + host, err := util.ResolveDomain(endpoint) + if err != nil { + host = endpoint + } + + _, _, err = net.SplitHostPort(host) + if err != nil { + host = net.JoinHostPort(host, "21820") + } + + remoteAddr, err := net.ResolveUDPAddr("udp", host) + if err != nil { + result.Error = fmt.Errorf("failed to resolve UDP address %s: %w", host, err) + return result + } + + // Generate random data for the test packet + randomData := make([]byte, bind.MagicPacketDataLen) + if _, err := rand.Read(randomData); err != nil { + result.Error = fmt.Errorf("failed to generate random data: %w", err) + return result + } + + // Build the test request packet + request := make([]byte, bind.MagicTestRequestLen) + copy(request, bind.MagicTestRequest) + copy(request[len(bind.MagicTestRequest):], randomData) + + // Get the underlying UDP connection to set read deadline and read response + udpConn := sharedBind.GetUDPConn() + if udpConn == nil { + result.Error = fmt.Errorf("could not get UDP connection from SharedBind") + return result + } + + attempts := opts.Retries + 1 + for attempt := 0; attempt < attempts; attempt++ { + if attempt > 0 { + logger.Debug("Retrying connection test to %s (attempt %d/%d)", endpoint, attempt+1, attempts) + } + + // Note: We can't easily set a read deadline on the shared connection + // without affecting WireGuard, so we use a goroutine with timeout instead + startTime := time.Now() + + // Send the test packet through the shared bind + _, err = sharedBind.WriteToUDP(request, remoteAddr) + if err != nil { + result.Error = fmt.Errorf("failed to send test packet: %w", err) + if attempt < attempts-1 { + continue + } + return result + } + + // For shared bind test, we send the packet but can't easily wait for + // response without interfering with WireGuard's receive loop. + // The response will be handled by SharedBind automatically. + // We consider the test successful if the send succeeded. + // For a full round-trip test, use TestConnection() with a separate socket. + + result.RTT = time.Since(startTime) + result.Success = true + result.Error = nil + logger.Debug("Test packet sent to %s via SharedBind", endpoint) + return result + } + + return result +} diff --git a/key b/key deleted file mode 100644 index 62c22b9..0000000 --- a/key +++ /dev/null @@ -1 +0,0 @@ -oBvcoMJZXGzTZ4X+aNSCCQIjroREFBeRCs+a328xWGA= \ No newline at end of file diff --git a/linux.go b/linux.go deleted file mode 100644 index 70918d3..0000000 --- a/linux.go +++ /dev/null @@ -1,74 +0,0 @@ -//go:build linux - -package main - -import ( - "fmt" - "os" - "runtime" - - "github.com/fosrl/newt/logger" - "github.com/fosrl/newt/proxy" - "github.com/fosrl/newt/websocket" - "github.com/fosrl/newt/wg" - "github.com/fosrl/newt/wgtester" -) - -var wgServiceNative *wg.WireGuardService - -func setupClientsNative(client *websocket.Client, host string) { - - if runtime.GOOS != "linux" { - logger.Fatal("Tunnel management is only supported on Linux right now!") - os.Exit(1) - } - - // make sure we are sudo - if os.Geteuid() != 0 { - logger.Fatal("You must run this program as root to manage tunnels on Linux.") - os.Exit(1) - } - - // Create WireGuard service - wgServiceNative, err = wg.NewWireGuardService(interfaceName, mtuInt, generateAndSaveKeyTo, host, id, client) - if err != nil { - logger.Fatal("Failed to create WireGuard service: %v", err) - } - - wgTesterServer = wgtester.NewServer("0.0.0.0", wgServiceNative.Port, id) // TODO: maybe make this the same ip of the wg server? - err := wgTesterServer.Start() - if err != nil { - logger.Error("Failed to start WireGuard tester server: %v", err) - } - - client.OnTokenUpdate(func(token string) { - wgServiceNative.SetToken(token) - }) -} - -func closeWgServiceNative() { - if wgServiceNative != nil { - wgServiceNative.Close(!keepInterface) - wgServiceNative = nil - } -} - -func clientsOnConnectNative() { - if wgServiceNative != nil { - wgServiceNative.LoadRemoteConfig() - } -} - -func clientsHandleNewtConnectionNative(publicKey, endpoint string) { - if wgServiceNative != nil { - wgServiceNative.StartHolepunch(publicKey, endpoint) - } -} - -func clientsAddProxyTargetNative(pm *proxy.ProxyManager, tunnelIp string) { - // add a udp proxy for localost and the wgService port - // TODO: make sure this port is not used in a target - if wgServiceNative != nil { - pm.AddTarget("udp", tunnelIp, int(wgServiceNative.Port), fmt.Sprintf("127.0.0.1:%d", wgServiceNative.Port)) - } -} diff --git a/logger/logger.go b/logger/logger.go index 28cac91..e00ed3a 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -2,16 +2,15 @@ package logger import ( "fmt" - "io" - "log" "os" + "strings" "sync" "time" ) // Logger struct holds the logger instance type Logger struct { - logger *log.Logger + writer LogWriter level LogLevel } @@ -20,17 +19,29 @@ var ( once sync.Once ) -// NewLogger creates a new logger instance +// NewLogger creates a new logger instance with the default StandardWriter func NewLogger() *Logger { return &Logger{ - logger: log.New(os.Stdout, "", 0), + writer: NewStandardWriter(), + level: DEBUG, + } +} + +// NewLoggerWithWriter creates a new logger instance with a custom LogWriter +func NewLoggerWithWriter(writer LogWriter) *Logger { + return &Logger{ + writer: writer, level: DEBUG, } } // Init initializes the default logger -func Init() *Logger { +func Init(logger *Logger) *Logger { once.Do(func() { + if logger != nil { + defaultLogger = logger + return + } defaultLogger = NewLogger() }) return defaultLogger @@ -39,7 +50,7 @@ func Init() *Logger { // GetLogger returns the default logger instance func GetLogger() *Logger { if defaultLogger == nil { - Init() + Init(nil) } return defaultLogger } @@ -49,9 +60,11 @@ func (l *Logger) SetLevel(level LogLevel) { l.level = level } -// SetOutput sets the output destination for the logger -func (l *Logger) SetOutput(w io.Writer) { - l.logger.SetOutput(w) +// SetOutput sets the output destination for the logger (only works with StandardWriter) +func (l *Logger) SetOutput(output *os.File) { + if sw, ok := l.writer.(*StandardWriter); ok { + sw.SetOutput(output) + } } // log handles the actual logging @@ -60,24 +73,8 @@ func (l *Logger) log(level LogLevel, format string, args ...interface{}) { return } - // Get timezone from environment variable or use local timezone - timezone := os.Getenv("LOGGER_TIMEZONE") - var location *time.Location - var err error - - if timezone != "" { - location, err = time.LoadLocation(timezone) - if err != nil { - // If invalid timezone, fall back to local - location = time.Local - } - } else { - location = time.Local - } - - timestamp := time.Now().In(location).Format("2006/01/02 15:04:05") message := fmt.Sprintf(format, args...) - l.logger.Printf("%s: %s %s", level.String(), timestamp, message) + l.writer.Write(level, time.Now(), message) } // Debug logs debug level messages @@ -128,6 +125,29 @@ func Fatal(format string, args ...interface{}) { } // SetOutput sets the output destination for the default logger -func SetOutput(w io.Writer) { - GetLogger().SetOutput(w) +func SetOutput(output *os.File) { + GetLogger().SetOutput(output) +} + +// WireGuardLogger is a wrapper type that matches WireGuard's Logger interface +type WireGuardLogger struct { + Verbosef func(format string, args ...any) + Errorf func(format string, args ...any) +} + +// GetWireGuardLogger returns a WireGuard-compatible logger that writes to the newt logger +// The prepend string is added as a prefix to all log messages +func (l *Logger) GetWireGuardLogger(prepend string) *WireGuardLogger { + return &WireGuardLogger{ + Verbosef: func(format string, args ...any) { + // if the format string contains "Sending keepalive packet", skip debug logging to reduce noise + if strings.Contains(format, "Sending keepalive packet") { + return + } + l.Debug(prepend+format, args...) + }, + Errorf: func(format string, args ...any) { + l.Error(prepend+format, args...) + }, + } } diff --git a/logger/writer.go b/logger/writer.go new file mode 100644 index 0000000..860894d --- /dev/null +++ b/logger/writer.go @@ -0,0 +1,54 @@ +package logger + +import ( + "fmt" + "os" + "time" +) + +// LogWriter is an interface for writing log messages +// Implement this interface to create custom log backends (OS log, syslog, etc.) +type LogWriter interface { + // Write writes a log message with the given level, timestamp, and formatted message + Write(level LogLevel, timestamp time.Time, message string) +} + +// StandardWriter is the default log writer that writes to an io.Writer +type StandardWriter struct { + output *os.File + timezone *time.Location +} + +// NewStandardWriter creates a new standard writer with the default configuration +func NewStandardWriter() *StandardWriter { + // Get timezone from environment variable or use local timezone + timezone := os.Getenv("LOGGER_TIMEZONE") + var location *time.Location + var err error + + if timezone != "" { + location, err = time.LoadLocation(timezone) + if err != nil { + // If invalid timezone, fall back to local + location = time.Local + } + } else { + location = time.Local + } + + return &StandardWriter{ + output: os.Stdout, + timezone: location, + } +} + +// SetOutput sets the output destination +func (w *StandardWriter) SetOutput(output *os.File) { + w.output = output +} + +// Write implements the LogWriter interface +func (w *StandardWriter) Write(level LogLevel, timestamp time.Time, message string) { + formattedTime := timestamp.In(w.timezone).Format("2006/01/02 15:04:05") + fmt.Fprintf(w.output, "%s: %s %s\n", level.String(), formattedTime, message) +} diff --git a/main.go b/main.go index 57ac17c..0879a96 100644 --- a/main.go +++ b/main.go @@ -22,6 +22,7 @@ import ( "github.com/fosrl/newt/logger" "github.com/fosrl/newt/proxy" "github.com/fosrl/newt/updates" + "github.com/fosrl/newt/util" "github.com/fosrl/newt/websocket" "github.com/fosrl/newt/internal/state" @@ -115,9 +116,7 @@ var ( err error logLevel string interfaceName string - generateAndSaveKeyTo string - keepInterface bool - acceptClients bool + disableClients bool updownScript string dockerSocket string dockerEnforceNetworkValidation string @@ -168,7 +167,6 @@ func main() { logLevel = os.Getenv("LOG_LEVEL") updownScript = os.Getenv("UPDOWN_SCRIPT") interfaceName = os.Getenv("INTERFACE") - generateAndSaveKeyTo = os.Getenv("GENERATE_AND_SAVE_KEY_TO") // Metrics/observability env mirrors metricsEnabledEnv := os.Getenv("NEWT_METRICS_PROMETHEUS_ENABLED") @@ -177,10 +175,8 @@ func main() { regionEnv := os.Getenv("NEWT_REGION") asyncBytesEnv := os.Getenv("NEWT_METRICS_ASYNC_BYTES") - keepInterfaceEnv := os.Getenv("KEEP_INTERFACE") - keepInterface = keepInterfaceEnv == "true" - acceptClientsEnv := os.Getenv("ACCEPT_CLIENTS") - acceptClients = acceptClientsEnv == "true" + disableClientsEnv := os.Getenv("DISABLE_CLIENTS") + disableClients = disableClientsEnv == "true" useNativeInterfaceEnv := os.Getenv("USE_NATIVE_INTERFACE") useNativeInterface = useNativeInterfaceEnv == "true" enforceHealthcheckCertEnv := os.Getenv("ENFORCE_HC_CERT") @@ -239,17 +235,11 @@ func main() { if interfaceName == "" { flag.StringVar(&interfaceName, "interface", "newt", "Name of the WireGuard interface") } - if generateAndSaveKeyTo == "" { - flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "", "Path to save generated private key") - } - if keepInterfaceEnv == "" { - flag.BoolVar(&keepInterface, "keep-interface", false, "Keep the WireGuard interface") - } if useNativeInterfaceEnv == "" { - flag.BoolVar(&useNativeInterface, "native", false, "Use native WireGuard interface (requires WireGuard kernel module) and linux") + flag.BoolVar(&useNativeInterface, "native", false, "Use native WireGuard interface") } - if acceptClientsEnv == "" { - flag.BoolVar(&acceptClients, "accept-clients", false, "Accept clients on the WireGuard interface") + if disableClientsEnv == "" { + flag.BoolVar(&disableClients, "disable-clients", false, "Disable clients on the WireGuard interface") } if enforceHealthcheckCertEnv == "" { flag.BoolVar(&enforceHealthcheckCert, "enforce-hc-cert", false, "Enforce certificate validation for health checks (default: false, accepts any cert)") @@ -367,9 +357,9 @@ func main() { tlsClientCAs = append(tlsClientCAs, tlsClientCAsFlag...) } - logger.Init() - loggerLevel := parseLogLevel(logLevel) - logger.GetLogger().SetLevel(parseLogLevel(logLevel)) + logger.Init(nil) + loggerLevel := util.ParseLogLevel(logLevel) + logger.GetLogger().SetLevel(loggerLevel) // Initialize telemetry after flags are parsed (so flags override env) tcfg := telemetry.FromEnv() @@ -538,7 +528,7 @@ func main() { var wgData WgData var dockerEventMonitor *docker.EventMonitor - if acceptClients { + if !disableClients { setupClients(client) } @@ -650,7 +640,7 @@ func main() { // Create WireGuard device dev = device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger( - mapToWireGuardLogLevel(loggerLevel), + util.MapToWireGuardLogLevel(loggerLevel), "wireguard: ", )) @@ -663,7 +653,7 @@ func main() { logger.Info("Connecting to endpoint: %s", host) - endpoint, err := resolveDomain(wgData.Endpoint) + endpoint, err := util.ResolveDomain(wgData.Endpoint) if err != nil { logger.Error("Failed to resolve endpoint: %v", err) regResult = "failure" @@ -677,7 +667,7 @@ func main() { public_key=%s allowed_ip=%s/32 endpoint=%s -persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.PublicKey), wgData.ServerIP, endpoint) +persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(wgData.PublicKey), wgData.ServerIP, endpoint) err = dev.IpcSet(config) if err != nil { @@ -747,7 +737,8 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub // } } - clientsAddProxyTarget(pm, wgData.TunnelIP) + // Start direct UDP relay from main tunnel to clients' WireGuard (bypasses proxy) + clientsStartDirectRelay(wgData.TunnelIP) if err := healthMonitor.AddTargets(wgData.HealthCheckTargets); err != nil { logger.Error("Failed to bulk add health check targets: %v", err) @@ -800,6 +791,7 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub // Close the WireGuard device and TUN closeWgTunnel() + closeClients() if stopFunc != nil { stopFunc() // stop the ws from sending more requests @@ -1397,7 +1389,12 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub "noCloud": noCloud, }, 3*time.Second) logger.Debug("Requesting exit nodes from server") - clientsOnConnect() + + if client.GetServerVersion() != "" { // to prevent issues with running newt > 1.7 versions with older servers + clientsOnConnect() + } else { + logger.Warn("CLIENTS WILL NOT WORK ON THIS VERSION OF NEWT WITH THIS VERSION OF PANGOLIN, PLEASE UPDATE THE SERVER TO 1.13 OR HIGHER OR DOWNGRADE NEWT") + } } // Send registration message to the server for backward compatibility diff --git a/netstack2/handlers.go b/netstack2/handlers.go new file mode 100644 index 0000000..bdc9feb --- /dev/null +++ b/netstack2/handlers.go @@ -0,0 +1,350 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. + */ + +package netstack2 + +import ( + "context" + "fmt" + "io" + "net" + "sync" + "time" + + "github.com/fosrl/newt/logger" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" +) + +const ( + // defaultWndSize if set to zero, the default + // receive window buffer size is used instead. + defaultWndSize = 0 + + // maxConnAttempts specifies the maximum number + // of in-flight tcp connection attempts. + maxConnAttempts = 2 << 10 + + // tcpKeepaliveCount is the maximum number of + // TCP keep-alive probes to send before giving up + // and killing the connection if no response is + // obtained from the other end. + tcpKeepaliveCount = 9 + + // tcpKeepaliveIdle specifies the time a connection + // must remain idle before the first TCP keepalive + // packet is sent. Once this time is reached, + // tcpKeepaliveInterval option is used instead. + tcpKeepaliveIdle = 60 * time.Second + + // tcpKeepaliveInterval specifies the interval + // time between sending TCP keepalive packets. + tcpKeepaliveInterval = 30 * time.Second + + // tcpConnectTimeout is the default timeout for TCP handshakes. + tcpConnectTimeout = 5 * time.Second + + // tcpWaitTimeout implements a TCP half-close timeout. + tcpWaitTimeout = 60 * time.Second + + // udpSessionTimeout is the default timeout for UDP sessions. + udpSessionTimeout = 60 * time.Second + + // Buffer size for copying data + bufferSize = 32 * 1024 +) + +// TCPHandler handles TCP connections from netstack +type TCPHandler struct { + stack *stack.Stack + proxyHandler *ProxyHandler +} + +// UDPHandler handles UDP connections from netstack +type UDPHandler struct { + stack *stack.Stack + proxyHandler *ProxyHandler +} + +// NewTCPHandler creates a new TCP handler +func NewTCPHandler(s *stack.Stack, ph *ProxyHandler) *TCPHandler { + return &TCPHandler{stack: s, proxyHandler: ph} +} + +// NewUDPHandler creates a new UDP handler +func NewUDPHandler(s *stack.Stack, ph *ProxyHandler) *UDPHandler { + return &UDPHandler{stack: s, proxyHandler: ph} +} + +// InstallTCPHandler installs the TCP forwarder on the stack +func (h *TCPHandler) InstallTCPHandler() error { + tcpForwarder := tcp.NewForwarder(h.stack, defaultWndSize, maxConnAttempts, func(r *tcp.ForwarderRequest) { + var ( + wq waiter.Queue + ep tcpip.Endpoint + err tcpip.Error + id = r.ID() + ) + + // Perform a TCP three-way handshake + ep, err = r.CreateEndpoint(&wq) + if err != nil { + // RST: prevent potential half-open TCP connection leak + r.Complete(true) + return + } + defer r.Complete(false) + + // Set socket options + setTCPSocketOptions(h.stack, ep) + + // Create TCP connection from netstack endpoint + netstackConn := gonet.NewTCPConn(&wq, ep) + + // Handle the connection in a goroutine + go h.handleTCPConn(netstackConn, id) + }) + + h.stack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket) + return nil +} + +// handleTCPConn handles a TCP connection by proxying it to the actual target +func (h *TCPHandler) handleTCPConn(netstackConn *gonet.TCPConn, id stack.TransportEndpointID) { + defer netstackConn.Close() + + // Extract source and target address from the connection ID + srcIP := id.RemoteAddress.String() + srcPort := id.RemotePort + dstIP := id.LocalAddress.String() + dstPort := id.LocalPort + + logger.Info("TCP Forwarder: Handling connection %s:%d -> %s:%d", srcIP, srcPort, dstIP, dstPort) + + // Check if there's a destination rewrite for this connection (e.g., localhost targets) + actualDstIP := dstIP + if h.proxyHandler != nil { + if rewrittenAddr, ok := h.proxyHandler.LookupDestinationRewrite(srcIP, dstIP, dstPort, uint8(tcp.ProtocolNumber)); ok { + actualDstIP = rewrittenAddr.String() + logger.Info("TCP Forwarder: Using rewritten destination %s (original: %s)", actualDstIP, dstIP) + } + } + + targetAddr := fmt.Sprintf("%s:%d", actualDstIP, dstPort) + + // Create context with timeout for connection establishment + ctx, cancel := context.WithTimeout(context.Background(), tcpConnectTimeout) + defer cancel() + + // Dial the actual target using standard net package + var d net.Dialer + targetConn, err := d.DialContext(ctx, "tcp", targetAddr) + if err != nil { + logger.Info("TCP Forwarder: Failed to connect to %s: %v", targetAddr, err) + // Connection failed, netstack will handle RST + return + } + defer targetConn.Close() + + logger.Info("TCP Forwarder: Successfully connected to %s, starting bidirectional copy", targetAddr) + + // Bidirectional copy between netstack and target + pipeTCP(netstackConn, targetConn) +} + +// pipeTCP copies data bidirectionally between two connections +func pipeTCP(origin, remote net.Conn) { + wg := sync.WaitGroup{} + wg.Add(2) + + go unidirectionalStreamTCP(remote, origin, "origin->remote", &wg) + go unidirectionalStreamTCP(origin, remote, "remote->origin", &wg) + + wg.Wait() +} + +// unidirectionalStreamTCP copies data in one direction +func unidirectionalStreamTCP(dst, src net.Conn, dir string, wg *sync.WaitGroup) { + defer wg.Done() + + buf := make([]byte, bufferSize) + _, _ = io.CopyBuffer(dst, src, buf) + + // Do the upload/download side TCP half-close + if cr, ok := src.(interface{ CloseRead() error }); ok { + cr.CloseRead() + } + if cw, ok := dst.(interface{ CloseWrite() error }); ok { + cw.CloseWrite() + } + + // Set TCP half-close timeout + dst.SetReadDeadline(time.Now().Add(tcpWaitTimeout)) +} + +// setTCPSocketOptions sets TCP socket options for better performance +func setTCPSocketOptions(s *stack.Stack, ep tcpip.Endpoint) { + // TCP keepalive options + ep.SocketOptions().SetKeepAlive(true) + + idle := tcpip.KeepaliveIdleOption(tcpKeepaliveIdle) + ep.SetSockOpt(&idle) + + interval := tcpip.KeepaliveIntervalOption(tcpKeepaliveInterval) + ep.SetSockOpt(&interval) + + ep.SetSockOptInt(tcpip.KeepaliveCountOption, tcpKeepaliveCount) + + // TCP send/recv buffer size + var ss tcpip.TCPSendBufferSizeRangeOption + if err := s.TransportProtocolOption(tcp.ProtocolNumber, &ss); err == nil { + ep.SocketOptions().SetSendBufferSize(int64(ss.Default), false) + } + + var rs tcpip.TCPReceiveBufferSizeRangeOption + if err := s.TransportProtocolOption(tcp.ProtocolNumber, &rs); err == nil { + ep.SocketOptions().SetReceiveBufferSize(int64(rs.Default), false) + } +} + +// InstallUDPHandler installs the UDP forwarder on the stack +func (h *UDPHandler) InstallUDPHandler() error { + udpForwarder := udp.NewForwarder(h.stack, func(r *udp.ForwarderRequest) { + var ( + wq waiter.Queue + id = r.ID() + ) + + ep, err := r.CreateEndpoint(&wq) + if err != nil { + return + } + + // Create UDP connection from netstack endpoint + netstackConn := gonet.NewUDPConn(&wq, ep) + + // Handle the connection in a goroutine + go h.handleUDPConn(netstackConn, id) + }) + + h.stack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket) + return nil +} + +// handleUDPConn handles a UDP connection by proxying it to the actual target +func (h *UDPHandler) handleUDPConn(netstackConn *gonet.UDPConn, id stack.TransportEndpointID) { + defer netstackConn.Close() + + // Extract source and target address from the connection ID + srcIP := id.RemoteAddress.String() + srcPort := id.RemotePort + dstIP := id.LocalAddress.String() + dstPort := id.LocalPort + + logger.Info("UDP Forwarder: Handling connection %s:%d -> %s:%d", srcIP, srcPort, dstIP, dstPort) + + // Check if there's a destination rewrite for this connection (e.g., localhost targets) + actualDstIP := dstIP + if h.proxyHandler != nil { + if rewrittenAddr, ok := h.proxyHandler.LookupDestinationRewrite(srcIP, dstIP, dstPort, uint8(udp.ProtocolNumber)); ok { + actualDstIP = rewrittenAddr.String() + logger.Info("UDP Forwarder: Using rewritten destination %s (original: %s)", actualDstIP, dstIP) + } + } + + targetAddr := fmt.Sprintf("%s:%d", actualDstIP, dstPort) + + // Resolve target address + remoteUDPAddr, err := net.ResolveUDPAddr("udp", targetAddr) + if err != nil { + logger.Info("UDP Forwarder: Failed to resolve %s: %v", targetAddr, err) + return + } + + // Resolve client address (for sending responses back) + clientAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", srcIP, srcPort)) + if err != nil { + logger.Info("UDP Forwarder: Failed to resolve client %s:%d: %v", srcIP, srcPort, err) + return + } + + // Create unconnected UDP socket (so we can use WriteTo) + targetConn, err := net.ListenUDP("udp", nil) + if err != nil { + logger.Info("UDP Forwarder: Failed to create UDP socket: %v", err) + return + } + defer targetConn.Close() + + logger.Info("UDP Forwarder: Successfully created UDP socket for %s, starting bidirectional copy", targetAddr) + + // Bidirectional copy between netstack and target + pipeUDP(netstackConn, targetConn, remoteUDPAddr, clientAddr, udpSessionTimeout) +} + +// pipeUDP copies UDP packets bidirectionally +func pipeUDP(origin, remote net.PacketConn, serverAddr, clientAddr net.Addr, timeout time.Duration) { + wg := sync.WaitGroup{} + wg.Add(2) + + // Read from origin (netstack), write to remote (target server) + go unidirectionalPacketStream(remote, origin, serverAddr, "origin->remote", &wg, timeout) + // Read from remote (target server), write to origin (netstack) with client address + go unidirectionalPacketStream(origin, remote, clientAddr, "remote->origin", &wg, timeout) + + wg.Wait() +} + +// unidirectionalPacketStream copies packets in one direction +func unidirectionalPacketStream(dst, src net.PacketConn, to net.Addr, dir string, wg *sync.WaitGroup, timeout time.Duration) { + defer wg.Done() + + logger.Info("UDP %s: Starting packet stream (to=%v)", dir, to) + err := copyPacketData(dst, src, to, timeout) + if err != nil { + logger.Info("UDP %s: Stream ended with error: %v", dir, err) + } else { + logger.Info("UDP %s: Stream ended (timeout)", dir) + } +} + +// copyPacketData copies UDP packet data with timeout +func copyPacketData(dst, src net.PacketConn, to net.Addr, timeout time.Duration) error { + buf := make([]byte, 65535) // Max UDP packet size + + for { + src.SetReadDeadline(time.Now().Add(timeout)) + n, srcAddr, err := src.ReadFrom(buf) + if ne, ok := err.(net.Error); ok && ne.Timeout() { + return nil // ignore I/O timeout + } else if err == io.EOF { + return nil // ignore EOF + } else if err != nil { + return err + } + + logger.Info("UDP copyPacketData: Read %d bytes from %v", n, srcAddr) + + // Determine write destination + writeAddr := to + if writeAddr == nil { + // If no destination specified, use the source address from the packet + writeAddr = srcAddr + } + + written, err := dst.WriteTo(buf[:n], writeAddr) + if err != nil { + logger.Info("UDP copyPacketData: Write error to %v: %v", writeAddr, err) + return err + } + logger.Info("UDP copyPacketData: Wrote %d bytes to %v", written, writeAddr) + + dst.SetReadDeadline(time.Now().Add(timeout)) + } +} diff --git a/netstack2/proxy.go b/netstack2/proxy.go new file mode 100644 index 0000000..77a9d23 --- /dev/null +++ b/netstack2/proxy.go @@ -0,0 +1,710 @@ +package netstack2 + +import ( + "context" + "fmt" + "net" + "net/netip" + "sync" + "time" + + "github.com/fosrl/newt/logger" + "gvisor.dev/gvisor/pkg/buffer" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/checksum" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" +) + +// PortRange represents an allowed range of ports (inclusive) +type PortRange struct { + Min uint16 + Max uint16 +} + +// SubnetRule represents a subnet with optional port restrictions and source address +// When RewriteTo is set, DNAT (Destination Network Address Translation) is performed: +// - Incoming packets: destination IP is rewritten to the resolved RewriteTo address +// - Outgoing packets: source IP is rewritten back to the original destination +// +// RewriteTo can be either: +// - An IP address with CIDR notation (e.g., "192.168.1.1/32") +// - A domain name (e.g., "example.com") which will be resolved at request time +// +// This allows transparent proxying where traffic appears to come from the rewritten address +type SubnetRule struct { + SourcePrefix netip.Prefix // Source IP prefix (who is sending) + DestPrefix netip.Prefix // Destination IP prefix (where it's going) + RewriteTo string // Optional rewrite address for DNAT - can be IP/CIDR or domain name + PortRanges []PortRange // empty slice means all ports allowed +} + +// ruleKey is used as a map key for fast O(1) lookups +type ruleKey struct { + sourcePrefix string + destPrefix string +} + +// SubnetLookup provides fast IP subnet and port matching with O(1) lookup performance +type SubnetLookup struct { + mu sync.RWMutex + rules map[ruleKey]*SubnetRule // Map for O(1) lookups by prefix combination +} + +// NewSubnetLookup creates a new subnet lookup table +func NewSubnetLookup() *SubnetLookup { + return &SubnetLookup{ + rules: make(map[ruleKey]*SubnetRule), + } +} + +// AddSubnet adds a subnet rule with source and destination prefixes and optional port restrictions +// If portRanges is nil or empty, all ports are allowed for this subnet +// rewriteTo can be either an IP/CIDR (e.g., "192.168.1.1/32") or a domain name (e.g., "example.com") +func (sl *SubnetLookup) AddSubnet(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange) { + sl.mu.Lock() + defer sl.mu.Unlock() + + key := ruleKey{ + sourcePrefix: sourcePrefix.String(), + destPrefix: destPrefix.String(), + } + + sl.rules[key] = &SubnetRule{ + SourcePrefix: sourcePrefix, + DestPrefix: destPrefix, + RewriteTo: rewriteTo, + PortRanges: portRanges, + } +} + +// RemoveSubnet removes a subnet rule from the lookup table +func (sl *SubnetLookup) RemoveSubnet(sourcePrefix, destPrefix netip.Prefix) { + sl.mu.Lock() + defer sl.mu.Unlock() + + key := ruleKey{ + sourcePrefix: sourcePrefix.String(), + destPrefix: destPrefix.String(), + } + + delete(sl.rules, key) +} + +// Match checks if a source IP, destination IP, and port match any subnet rule +// Returns the matched rule if BOTH: +// - The source IP is in the rule's source prefix +// - The destination IP is in the rule's destination prefix +// - The port is in an allowed range (or no port restrictions exist) +// +// Returns nil if no rule matches +func (sl *SubnetLookup) Match(srcIP, dstIP netip.Addr, port uint16) *SubnetRule { + sl.mu.RLock() + defer sl.mu.RUnlock() + + // Iterate through all rules to find matching source and destination prefixes + // This is O(n) but necessary since we need to check prefix containment, not exact match + for _, rule := range sl.rules { + // Check if source and destination IPs match their respective prefixes + if !rule.SourcePrefix.Contains(srcIP) { + continue + } + if !rule.DestPrefix.Contains(dstIP) { + continue + } + + // Both IPs match - now check port restrictions + // If no port ranges specified, all ports are allowed + if len(rule.PortRanges) == 0 { + return rule + } + + // Check if port is in any of the allowed ranges + for _, pr := range rule.PortRanges { + if port >= pr.Min && port <= pr.Max { + return rule + } + } + } + + return nil +} + +// connKey uniquely identifies a connection for NAT tracking +type connKey struct { + srcIP string + srcPort uint16 + dstIP string + dstPort uint16 + proto uint8 +} + +// destKey identifies a destination for handler lookups (without source port since it may change) +type destKey struct { + srcIP string + dstIP string + dstPort uint16 + proto uint8 +} + +// natState tracks NAT translation state for reverse translation +type natState struct { + originalDst netip.Addr // Original destination before DNAT + rewrittenTo netip.Addr // The address we rewrote to +} + +// ProxyHandler handles packet injection and extraction for promiscuous mode +type ProxyHandler struct { + proxyStack *stack.Stack + proxyEp *channel.Endpoint + proxyNotifyHandle *channel.NotificationHandle + tcpHandler *TCPHandler + udpHandler *UDPHandler + subnetLookup *SubnetLookup + natTable map[connKey]*natState + destRewriteTable map[destKey]netip.Addr // Maps original dest to rewritten dest for handler lookups + natMu sync.RWMutex + enabled bool +} + +// ProxyHandlerOptions configures the proxy handler +type ProxyHandlerOptions struct { + EnableTCP bool + EnableUDP bool + MTU int +} + +// NewProxyHandler creates a new proxy handler for promiscuous mode +func NewProxyHandler(options ProxyHandlerOptions) (*ProxyHandler, error) { + if !options.EnableTCP && !options.EnableUDP { + return nil, nil // No proxy needed + } + + handler := &ProxyHandler{ + enabled: true, + subnetLookup: NewSubnetLookup(), + natTable: make(map[connKey]*natState), + destRewriteTable: make(map[destKey]netip.Addr), + proxyEp: channel.New(1024, uint32(options.MTU), ""), + proxyStack: stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ + ipv4.NewProtocol, + ipv6.NewProtocol, + }, + TransportProtocols: []stack.TransportProtocolFactory{ + tcp.NewProtocol, + udp.NewProtocol, + icmp.NewProtocol4, + icmp.NewProtocol6, + }, + }), + } + + // Initialize TCP handler if enabled + if options.EnableTCP { + handler.tcpHandler = NewTCPHandler(handler.proxyStack, handler) + if err := handler.tcpHandler.InstallTCPHandler(); err != nil { + return nil, fmt.Errorf("failed to install TCP handler: %v", err) + } + } + + // Initialize UDP handler if enabled + if options.EnableUDP { + handler.udpHandler = NewUDPHandler(handler.proxyStack, handler) + if err := handler.udpHandler.InstallUDPHandler(); err != nil { + return nil, fmt.Errorf("failed to install UDP handler: %v", err) + } + } + + // // Example 1: Add a rule with no port restrictions (all ports allowed) + // // This accepts all traffic FROM 10.0.0.0/24 TO 10.20.20.0/24 + // sourceSubnet := netip.MustParsePrefix("10.0.0.0/24") + // destSubnet := netip.MustParsePrefix("10.20.20.0/24") + // handler.AddSubnetRule(sourceSubnet, destSubnet, nil) + + // // Example 2: Add a rule with specific port ranges + // // This accepts traffic FROM 10.0.0.5/32 TO 10.20.21.21/32 only on ports 80, 443, and 8000-9000 + // sourceIP := netip.MustParsePrefix("10.0.0.5/32") + // destIP := netip.MustParsePrefix("10.20.21.21/32") + // handler.AddSubnetRule(sourceIP, destIP, []PortRange{ + // {Min: 80, Max: 80}, + // {Min: 443, Max: 443}, + // {Min: 8000, Max: 9000}, + // }) + + return handler, nil +} + +// AddSubnetRule adds a subnet with optional port restrictions to the proxy handler +// sourcePrefix: The IP prefix of the peer sending the data +// destPrefix: The IP prefix of the destination +// rewriteTo: Optional address to rewrite destination to - can be IP/CIDR or domain name +// If portRanges is nil or empty, all ports are allowed for this subnet +func (p *ProxyHandler) AddSubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange) { + if p == nil || !p.enabled { + return + } + p.subnetLookup.AddSubnet(sourcePrefix, destPrefix, rewriteTo, portRanges) +} + +// RemoveSubnetRule removes a subnet from the proxy handler +func (p *ProxyHandler) RemoveSubnetRule(sourcePrefix, destPrefix netip.Prefix) { + if p == nil || !p.enabled { + return + } + p.subnetLookup.RemoveSubnet(sourcePrefix, destPrefix) +} + +// LookupDestinationRewrite looks up the rewritten destination for a connection +// This is used by TCP/UDP handlers to find the actual target address +func (p *ProxyHandler) LookupDestinationRewrite(srcIP, dstIP string, dstPort uint16, proto uint8) (netip.Addr, bool) { + if p == nil || !p.enabled { + return netip.Addr{}, false + } + + key := destKey{ + srcIP: srcIP, + dstIP: dstIP, + dstPort: dstPort, + proto: proto, + } + + p.natMu.RLock() + defer p.natMu.RUnlock() + + addr, ok := p.destRewriteTable[key] + return addr, ok +} + +// resolveRewriteAddress resolves a rewrite address which can be either: +// - An IP address with CIDR notation (e.g., "192.168.1.1/32") - returns the IP directly +// - A plain IP address (e.g., "192.168.1.1") - returns the IP directly +// - A domain name (e.g., "example.com") - performs DNS lookup +func (p *ProxyHandler) resolveRewriteAddress(rewriteTo string) (netip.Addr, error) { + logger.Debug("Resolving rewrite address: %s", rewriteTo) + + // First, try to parse as a CIDR prefix (e.g., "192.168.1.1/32") + if prefix, err := netip.ParsePrefix(rewriteTo); err == nil { + return prefix.Addr(), nil + } + + // Try to parse as a plain IP address (e.g., "192.168.1.1") + if addr, err := netip.ParseAddr(rewriteTo); err == nil { + return addr, nil + } + + // Not an IP address, treat as domain name - perform DNS lookup + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + ips, err := net.DefaultResolver.LookupIP(ctx, "ip4", rewriteTo) + if err != nil { + return netip.Addr{}, fmt.Errorf("failed to resolve domain %s: %w", rewriteTo, err) + } + + if len(ips) == 0 { + return netip.Addr{}, fmt.Errorf("no IP addresses found for domain %s", rewriteTo) + } + + // Use the first resolved IP address + ip := ips[0] + if ip4 := ip.To4(); ip4 != nil { + addr := netip.AddrFrom4([4]byte{ip4[0], ip4[1], ip4[2], ip4[3]}) + logger.Debug("Resolved %s to %s", rewriteTo, addr) + return addr, nil + } + + return netip.Addr{}, fmt.Errorf("no IPv4 address found for domain %s", rewriteTo) +} + +// Initialize sets up the promiscuous NIC with the netTun's notification system +func (p *ProxyHandler) Initialize(notifiable channel.Notification) error { + if p == nil || !p.enabled { + return nil + } + + // Add notification handler + p.proxyNotifyHandle = p.proxyEp.AddNotify(notifiable) + + // Create NIC with promiscuous mode + tcpipErr := p.proxyStack.CreateNICWithOptions(1, p.proxyEp, stack.NICOptions{ + Disabled: false, + QDisc: nil, + }) + if tcpipErr != nil { + return fmt.Errorf("CreateNIC (proxy): %v", tcpipErr) + } + + // Enable promiscuous mode - accepts packets for any destination IP + if tcpipErr := p.proxyStack.SetPromiscuousMode(1, true); tcpipErr != nil { + return fmt.Errorf("SetPromiscuousMode: %v", tcpipErr) + } + + // Enable spoofing - allows sending packets from any source IP + if tcpipErr := p.proxyStack.SetSpoofing(1, true); tcpipErr != nil { + return fmt.Errorf("SetSpoofing: %v", tcpipErr) + } + + // Add default route + p.proxyStack.AddRoute(tcpip.Route{ + Destination: header.IPv4EmptySubnet, + NIC: 1, + }) + + return nil +} + +// HandleIncomingPacket processes incoming packets and determines if they should +// be injected into the proxy stack +func (p *ProxyHandler) HandleIncomingPacket(packet []byte) bool { + if p == nil || !p.enabled { + return false + } + + // Check minimum packet size + if len(packet) < header.IPv4MinimumSize { + return false + } + + // Only handle IPv4 for now + if packet[0]>>4 != 4 { + return false + } + + // Parse IPv4 header + ipv4Header := header.IPv4(packet) + srcIP := ipv4Header.SourceAddress() + dstIP := ipv4Header.DestinationAddress() + + // Convert gvisor tcpip.Address to netip.Addr + srcBytes := srcIP.As4() + srcAddr := netip.AddrFrom4(srcBytes) + dstBytes := dstIP.As4() + dstAddr := netip.AddrFrom4(dstBytes) + + // Parse transport layer to get destination port + var dstPort uint16 + protocol := ipv4Header.TransportProtocol() + headerLen := int(ipv4Header.HeaderLength()) + + // Extract port based on protocol + switch protocol { + case header.TCPProtocolNumber: + if len(packet) < headerLen+header.TCPMinimumSize { + return false + } + tcpHeader := header.TCP(packet[headerLen:]) + dstPort = tcpHeader.DestinationPort() + case header.UDPProtocolNumber: + if len(packet) < headerLen+header.UDPMinimumSize { + return false + } + udpHeader := header.UDP(packet[headerLen:]) + dstPort = udpHeader.DestinationPort() + default: + // For other protocols (ICMP, etc.), use port 0 (must match rules with no port restrictions) + dstPort = 0 + } + + // Check if the source IP, destination IP, and port match any subnet rule + matchedRule := p.subnetLookup.Match(srcAddr, dstAddr, dstPort) + if matchedRule != nil { + // Check if we need to perform DNAT + if matchedRule.RewriteTo != "" { + // Create connection tracking key using original destination + // This allows us to check if we've already resolved for this connection + var srcPort uint16 + switch protocol { + case header.TCPProtocolNumber: + tcpHeader := header.TCP(packet[headerLen:]) + srcPort = tcpHeader.SourcePort() + case header.UDPProtocolNumber: + udpHeader := header.UDP(packet[headerLen:]) + srcPort = udpHeader.SourcePort() + } + + // Key using original destination to track the connection + key := connKey{ + srcIP: srcAddr.String(), + srcPort: srcPort, + dstIP: dstAddr.String(), + dstPort: dstPort, + proto: uint8(protocol), + } + + // Key for handler lookups (doesn't include srcPort for flexibility) + dKey := destKey{ + srcIP: srcAddr.String(), + dstIP: dstAddr.String(), + dstPort: dstPort, + proto: uint8(protocol), + } + + // Check if we already have a NAT entry for this connection + p.natMu.RLock() + existingEntry, exists := p.natTable[key] + p.natMu.RUnlock() + + var newDst netip.Addr + if exists { + // Use the previously resolved address for this connection + newDst = existingEntry.rewrittenTo + logger.Debug("Using existing NAT entry for connection: %s -> %s", dstAddr, newDst) + } else { + // New connection - resolve the rewrite address + var err error + newDst, err = p.resolveRewriteAddress(matchedRule.RewriteTo) + if err != nil { + // Failed to resolve, skip DNAT but still proxy the packet + logger.Debug("Failed to resolve rewrite address: %v", err) + pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: buffer.MakeWithData(packet), + }) + p.proxyEp.InjectInbound(header.IPv4ProtocolNumber, pkb) + return true + } + + // Store NAT state for this connection + p.natMu.Lock() + p.natTable[key] = &natState{ + originalDst: dstAddr, + rewrittenTo: newDst, + } + // Store destination rewrite for handler lookups + p.destRewriteTable[dKey] = newDst + p.natMu.Unlock() + logger.Debug("New NAT entry for connection: %s -> %s", dstAddr, newDst) + } + + // Check if target is loopback - if so, don't rewrite packet destination + // as gVisor will drop martian packets. Instead, the handlers will use + // destRewriteTable to find the actual target address. + if !newDst.IsLoopback() { + // Rewrite the packet only for non-loopback destinations + packet = p.rewritePacketDestination(packet, newDst) + if packet == nil { + return false + } + } else { + logger.Debug("Target is loopback, not rewriting packet - handlers will use rewrite table") + } + } + + // Inject into proxy stack + pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: buffer.MakeWithData(packet), + }) + p.proxyEp.InjectInbound(header.IPv4ProtocolNumber, pkb) + return true + } + + return false +} + +// rewritePacketDestination rewrites the destination IP in a packet and recalculates checksums +func (p *ProxyHandler) rewritePacketDestination(packet []byte, newDst netip.Addr) []byte { + if len(packet) < header.IPv4MinimumSize { + return nil + } + + // Make a copy to avoid modifying the original + pkt := make([]byte, len(packet)) + copy(pkt, packet) + + ipv4Header := header.IPv4(pkt) + headerLen := int(ipv4Header.HeaderLength()) + + // Rewrite destination IP + newDstBytes := newDst.As4() + newDstAddr := tcpip.AddrFrom4(newDstBytes) + ipv4Header.SetDestinationAddress(newDstAddr) + + // Recalculate IP checksum + ipv4Header.SetChecksum(0) + ipv4Header.SetChecksum(^ipv4Header.CalculateChecksum()) + + // Update transport layer checksum if needed + protocol := ipv4Header.TransportProtocol() + switch protocol { + case header.TCPProtocolNumber: + if len(pkt) >= headerLen+header.TCPMinimumSize { + tcpHeader := header.TCP(pkt[headerLen:]) + tcpHeader.SetChecksum(0) + xsum := header.PseudoHeaderChecksum( + header.TCPProtocolNumber, + ipv4Header.SourceAddress(), + ipv4Header.DestinationAddress(), + uint16(len(pkt)-headerLen), + ) + xsum = checksum.Checksum(pkt[headerLen:], xsum) + tcpHeader.SetChecksum(^xsum) + } + case header.UDPProtocolNumber: + if len(pkt) >= headerLen+header.UDPMinimumSize { + udpHeader := header.UDP(pkt[headerLen:]) + udpHeader.SetChecksum(0) + xsum := header.PseudoHeaderChecksum( + header.UDPProtocolNumber, + ipv4Header.SourceAddress(), + ipv4Header.DestinationAddress(), + uint16(len(pkt)-headerLen), + ) + xsum = checksum.Checksum(pkt[headerLen:], xsum) + udpHeader.SetChecksum(^xsum) + } + } + + return pkt +} + +// rewritePacketSource rewrites the source IP in a packet and recalculates checksums (for reverse NAT) +func (p *ProxyHandler) rewritePacketSource(packet []byte, newSrc netip.Addr) []byte { + if len(packet) < header.IPv4MinimumSize { + return nil + } + + // Make a copy to avoid modifying the original + pkt := make([]byte, len(packet)) + copy(pkt, packet) + + ipv4Header := header.IPv4(pkt) + headerLen := int(ipv4Header.HeaderLength()) + + // Rewrite source IP + newSrcBytes := newSrc.As4() + newSrcAddr := tcpip.AddrFrom4(newSrcBytes) + ipv4Header.SetSourceAddress(newSrcAddr) + + // Recalculate IP checksum + ipv4Header.SetChecksum(0) + ipv4Header.SetChecksum(^ipv4Header.CalculateChecksum()) + + // Update transport layer checksum if needed + protocol := ipv4Header.TransportProtocol() + switch protocol { + case header.TCPProtocolNumber: + if len(pkt) >= headerLen+header.TCPMinimumSize { + tcpHeader := header.TCP(pkt[headerLen:]) + tcpHeader.SetChecksum(0) + xsum := header.PseudoHeaderChecksum( + header.TCPProtocolNumber, + ipv4Header.SourceAddress(), + ipv4Header.DestinationAddress(), + uint16(len(pkt)-headerLen), + ) + xsum = checksum.Checksum(pkt[headerLen:], xsum) + tcpHeader.SetChecksum(^xsum) + } + case header.UDPProtocolNumber: + if len(pkt) >= headerLen+header.UDPMinimumSize { + udpHeader := header.UDP(pkt[headerLen:]) + udpHeader.SetChecksum(0) + xsum := header.PseudoHeaderChecksum( + header.UDPProtocolNumber, + ipv4Header.SourceAddress(), + ipv4Header.DestinationAddress(), + uint16(len(pkt)-headerLen), + ) + xsum = checksum.Checksum(pkt[headerLen:], xsum) + udpHeader.SetChecksum(^xsum) + } + } + + return pkt +} + +// ReadOutgoingPacket reads packets from the proxy stack that need to be +// sent back through the tunnel +func (p *ProxyHandler) ReadOutgoingPacket() *buffer.View { + if p == nil || !p.enabled { + return nil + } + + pkt := p.proxyEp.Read() + if pkt != nil { + view := pkt.ToView() + pkt.DecRef() + + // Check if we need to perform reverse NAT + packet := view.AsSlice() + if len(packet) >= header.IPv4MinimumSize && packet[0]>>4 == 4 { + ipv4Header := header.IPv4(packet) + srcIP := ipv4Header.SourceAddress() + dstIP := ipv4Header.DestinationAddress() + protocol := ipv4Header.TransportProtocol() + headerLen := int(ipv4Header.HeaderLength()) + + // Extract ports + var srcPort, dstPort uint16 + switch protocol { + case header.TCPProtocolNumber: + if len(packet) >= headerLen+header.TCPMinimumSize { + tcpHeader := header.TCP(packet[headerLen:]) + srcPort = tcpHeader.SourcePort() + dstPort = tcpHeader.DestinationPort() + } + case header.UDPProtocolNumber: + if len(packet) >= headerLen+header.UDPMinimumSize { + udpHeader := header.UDP(packet[headerLen:]) + srcPort = udpHeader.SourcePort() + dstPort = udpHeader.DestinationPort() + } + } + + // Look up NAT state for reverse translation + // The key uses the original dst (before rewrite), so for replies we need to + // find the entry where the rewritten address matches the current source + p.natMu.RLock() + var natEntry *natState + for k, entry := range p.natTable { + // Match: reply's dst should be original src, reply's src should be rewritten dst + if k.srcIP == dstIP.String() && k.srcPort == dstPort && + entry.rewrittenTo.String() == srcIP.String() && k.dstPort == srcPort && + k.proto == uint8(protocol) { + natEntry = entry + break + } + } + p.natMu.RUnlock() + + if natEntry != nil { + // Perform reverse NAT - rewrite source to original destination + packet = p.rewritePacketSource(packet, natEntry.originalDst) + if packet != nil { + return buffer.NewViewWithData(packet) + } + } + } + + return view + } + + return nil +} + +// Close cleans up the proxy handler resources +func (p *ProxyHandler) Close() error { + if p == nil || !p.enabled { + return nil + } + + if p.proxyStack != nil { + p.proxyStack.RemoveNIC(1) + p.proxyStack.Close() + } + + if p.proxyEp != nil { + if p.proxyNotifyHandle != nil { + p.proxyEp.RemoveNotify(p.proxyNotifyHandle) + } + p.proxyEp.Close() + } + + return nil +} diff --git a/netstack2/tun.go b/netstack2/tun.go new file mode 100644 index 0000000..4bcea65 --- /dev/null +++ b/netstack2/tun.go @@ -0,0 +1,1149 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. + */ + +package netstack2 + +import ( + "bytes" + "context" + "crypto/rand" + "encoding/binary" + "errors" + "fmt" + "io" + "net" + "net/netip" + "os" + "regexp" + "strconv" + "strings" + "syscall" + "time" + + "golang.zx2c4.com/wireguard/tun" + + "golang.org/x/net/dns/dnsmessage" + "gvisor.dev/gvisor/pkg/buffer" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" +) + +type netTun struct { + ep *channel.Endpoint + stack *stack.Stack + events chan tun.Event + notifyHandle *channel.NotificationHandle + incomingPacket chan *buffer.View + mtu int + dnsServers []netip.Addr + hasV4, hasV6 bool + // TODO: LETS NOT KEEP THIS ON THE TUN AND MOVE IT BUT WE CAN KEEP IT FOR NOW + proxyHandler *ProxyHandler // Handles promiscuous mode packet processing +} + +type Net netTun + +// NetTunOptions contains options for creating a NetTUN device +type NetTunOptions struct { + EnableTCPProxy bool + EnableUDPProxy bool +} + +// CreateNetTUN creates a new TUN device with netstack without proxying +func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, *Net, error) { + return CreateNetTUNWithOptions(localAddresses, dnsServers, mtu, NetTunOptions{ + EnableTCPProxy: true, + EnableUDPProxy: true, + }) +} + +// CreateNetTUNWithOptions creates a new TUN device with netstack and optional TCP/UDP proxying +func CreateNetTUNWithOptions(localAddresses, dnsServers []netip.Addr, mtu int, options NetTunOptions) (tun.Device, *Net, error) { + stackOpts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}, + HandleLocal: true, + } + dev := &netTun{ + ep: channel.New(1024, uint32(mtu), ""), + stack: stack.New(stackOpts), + events: make(chan tun.Event, 10), + incomingPacket: make(chan *buffer.View), + dnsServers: dnsServers, + mtu: mtu, + } + + // Initialize proxy handler if TCP or UDP proxying is enabled + if options.EnableTCPProxy || options.EnableUDPProxy { + var err error + dev.proxyHandler, err = NewProxyHandler(ProxyHandlerOptions{ + EnableTCP: options.EnableTCPProxy, + EnableUDP: options.EnableUDPProxy, + MTU: mtu, + }) + if err != nil { + return nil, nil, fmt.Errorf("failed to create proxy handler: %v", err) + } + } + + sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is enabled by default + tcpipErr := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt) + if tcpipErr != nil { + return nil, nil, fmt.Errorf("could not enable TCP SACK: %v", tcpipErr) + } + // Create NIC 1 (main interface, no promiscuous mode) + dev.notifyHandle = dev.ep.AddNotify(dev) + tcpipErr = dev.stack.CreateNIC(1, dev.ep) + if tcpipErr != nil { + return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr) + } + + // Initialize proxy handler after main stack is set up + if dev.proxyHandler != nil { + if err := dev.proxyHandler.Initialize(dev); err != nil { + return nil, nil, err + } + } + + if err := dev.stack.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { + return nil, nil, fmt.Errorf("set ipv4 forwarding: %s", err) + } + + for _, ip := range localAddresses { + var protoNumber tcpip.NetworkProtocolNumber + if ip.Is4() { + protoNumber = ipv4.ProtocolNumber + } else if ip.Is6() { + protoNumber = ipv6.ProtocolNumber + } + protoAddr := tcpip.ProtocolAddress{ + Protocol: protoNumber, + AddressWithPrefix: tcpip.AddrFromSlice(ip.AsSlice()).WithPrefix(), + } + tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}) + if tcpipErr != nil { + return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr) + } + if ip.Is4() { + dev.hasV4 = true + } else if ip.Is6() { + dev.hasV6 = true + } + } + if dev.hasV4 { + dev.stack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: 1}) + } + if dev.hasV6 { + dev.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1}) + } + + dev.events <- tun.EventUp + return dev, (*Net)(dev), nil +} + +func (tun *netTun) Name() (string, error) { + return "go", nil +} + +func (tun *netTun) File() *os.File { + return nil +} + +func (tun *netTun) Events() <-chan tun.Event { + return tun.events +} + +func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) { + view, ok := <-tun.incomingPacket + if !ok { + return 0, os.ErrClosed + } + + n, err := view.Read(buf[0][offset:]) + if err != nil { + return 0, err + } + sizes[0] = n + return 1, nil +} + +func (tun *netTun) Write(buf [][]byte, offset int) (int, error) { + for _, buf := range buf { + packet := buf[offset:] + if len(packet) == 0 { + continue + } + + // Try to handle packet via proxy handler first + if tun.proxyHandler != nil && tun.proxyHandler.HandleIncomingPacket(packet) { + // Packet was handled by proxy + continue + } + + // Default handling: inject into main stack + pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(packet)}) + + switch packet[0] >> 4 { + case 4: + tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb) + case 6: + tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb) + default: + return 0, syscall.EAFNOSUPPORT + } + } + return len(buf), nil +} + +func (tun *netTun) WriteNotify() { + // Handle notifications from main endpoint (NIC 1) + pkt := tun.ep.Read() + if pkt != nil { + view := pkt.ToView() + pkt.DecRef() + tun.incomingPacket <- view + return + } + + // Handle notifications from proxy handler if it exists + // These are response packets from the proxied connections that need to go back to WireGuard + if tun.proxyHandler != nil { + view := tun.proxyHandler.ReadOutgoingPacket() + if view != nil { + tun.incomingPacket <- view + return + } + } +} + +func (tun *netTun) Close() error { + tun.stack.RemoveNIC(1) + + tun.stack.Close() + tun.ep.RemoveNotify(tun.notifyHandle) + tun.ep.Close() + + // Clean up proxy handler if it exists + if tun.proxyHandler != nil { + tun.proxyHandler.Close() + } + + if tun.events != nil { + close(tun.events) + } + + if tun.incomingPacket != nil { + close(tun.incomingPacket) + } + + return nil +} + +func (tun *netTun) MTU() (int, error) { + return tun.mtu, nil +} + +func (tun *netTun) BatchSize() int { + return 1 +} + +func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) { + var protoNumber tcpip.NetworkProtocolNumber + if endpoint.Addr().Is4() { + protoNumber = ipv4.ProtocolNumber + } else { + protoNumber = ipv6.ProtocolNumber + } + return tcpip.FullAddress{ + NIC: 1, + Addr: tcpip.AddrFromSlice(endpoint.Addr().AsSlice()), + Port: endpoint.Port(), + }, protoNumber +} + +func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) { + fa, pn := convertToFullAddr(addr) + return gonet.DialContextTCP(ctx, net.stack, fa, pn) +} + +func (net *Net) DialContextTCP(ctx context.Context, addr *net.TCPAddr) (*gonet.TCPConn, error) { + if addr == nil { + return net.DialContextTCPAddrPort(ctx, netip.AddrPort{}) + } + ip, _ := netip.AddrFromSlice(addr.IP) + return net.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(ip, uint16(addr.Port))) +} + +func (net *Net) DialTCPAddrPort(addr netip.AddrPort) (*gonet.TCPConn, error) { + fa, pn := convertToFullAddr(addr) + return gonet.DialTCP(net.stack, fa, pn) +} + +func (net *Net) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) { + if addr == nil { + return net.DialTCPAddrPort(netip.AddrPort{}) + } + ip, _ := netip.AddrFromSlice(addr.IP) + return net.DialTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port))) +} + +func (net *Net) ListenTCPAddrPort(addr netip.AddrPort) (*gonet.TCPListener, error) { + fa, pn := convertToFullAddr(addr) + return gonet.ListenTCP(net.stack, fa, pn) +} + +func (net *Net) ListenTCP(addr *net.TCPAddr) (*gonet.TCPListener, error) { + if addr == nil { + return net.ListenTCPAddrPort(netip.AddrPort{}) + } + ip, _ := netip.AddrFromSlice(addr.IP) + return net.ListenTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port))) +} + +func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) { + var lfa, rfa *tcpip.FullAddress + var pn tcpip.NetworkProtocolNumber + if laddr.IsValid() || laddr.Port() > 0 { + var addr tcpip.FullAddress + addr, pn = convertToFullAddr(laddr) + lfa = &addr + } + if raddr.IsValid() || raddr.Port() > 0 { + var addr tcpip.FullAddress + addr, pn = convertToFullAddr(raddr) + rfa = &addr + } + return gonet.DialUDP(net.stack, lfa, rfa, pn) +} + +func (net *Net) ListenUDPAddrPort(laddr netip.AddrPort) (*gonet.UDPConn, error) { + return net.DialUDPAddrPort(laddr, netip.AddrPort{}) +} + +func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) { + var la, ra netip.AddrPort + if laddr != nil { + ip, _ := netip.AddrFromSlice(laddr.IP) + la = netip.AddrPortFrom(ip, uint16(laddr.Port)) + } + if raddr != nil { + ip, _ := netip.AddrFromSlice(raddr.IP) + ra = netip.AddrPortFrom(ip, uint16(raddr.Port)) + } + return net.DialUDPAddrPort(la, ra) +} + +func (net *Net) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) { + return net.DialUDP(laddr, nil) +} + +// AddProxySubnetRule adds a subnet rule to the proxy handler +// If portRanges is nil or empty, all ports are allowed for this subnet +// rewriteTo can be either an IP/CIDR (e.g., "192.168.1.1/32") or a domain name (e.g., "example.com") +func (net *Net) AddProxySubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange) { + tun := (*netTun)(net) + if tun.proxyHandler != nil { + tun.proxyHandler.AddSubnetRule(sourcePrefix, destPrefix, rewriteTo, portRanges) + } +} + +// RemoveProxySubnetRule removes a subnet rule from the proxy handler +func (net *Net) RemoveProxySubnetRule(sourcePrefix, destPrefix netip.Prefix) { + tun := (*netTun)(net) + if tun.proxyHandler != nil { + tun.proxyHandler.RemoveSubnetRule(sourcePrefix, destPrefix) + } +} + +// GetProxyHandler returns the proxy handler (for advanced use cases) +// Returns nil if proxy is not enabled +func (net *Net) GetProxyHandler() *ProxyHandler { + tun := (*netTun)(net) + return tun.proxyHandler +} + +type PingConn struct { + laddr PingAddr + raddr PingAddr + wq waiter.Queue + ep tcpip.Endpoint + deadline *time.Timer +} + +type PingAddr struct{ addr netip.Addr } + +func (ia PingAddr) String() string { + return ia.addr.String() +} + +func (ia PingAddr) Network() string { + if ia.addr.Is4() { + return "ping4" + } else if ia.addr.Is6() { + return "ping6" + } + return "ping" +} + +func (ia PingAddr) Addr() netip.Addr { + return ia.addr +} + +func PingAddrFromAddr(addr netip.Addr) *PingAddr { + return &PingAddr{addr} +} + +func (net *Net) DialPingAddr(laddr, raddr netip.Addr) (*PingConn, error) { + if !laddr.IsValid() && !raddr.IsValid() { + return nil, errors.New("ping dial: invalid address") + } + v6 := laddr.Is6() || raddr.Is6() + bind := laddr.IsValid() + if !bind { + if v6 { + laddr = netip.IPv6Unspecified() + } else { + laddr = netip.IPv4Unspecified() + } + } + + tn := icmp.ProtocolNumber4 + pn := ipv4.ProtocolNumber + if v6 { + tn = icmp.ProtocolNumber6 + pn = ipv6.ProtocolNumber + } + + pc := &PingConn{ + laddr: PingAddr{laddr}, + deadline: time.NewTimer(time.Hour << 10), + } + pc.deadline.Stop() + + ep, tcpipErr := net.stack.NewEndpoint(tn, pn, &pc.wq) + if tcpipErr != nil { + return nil, fmt.Errorf("ping socket: endpoint: %s", tcpipErr) + } + pc.ep = ep + + if bind { + fa, _ := convertToFullAddr(netip.AddrPortFrom(laddr, 0)) + if tcpipErr = pc.ep.Bind(fa); tcpipErr != nil { + return nil, fmt.Errorf("ping bind: %s", tcpipErr) + } + } + + if raddr.IsValid() { + pc.raddr = PingAddr{raddr} + fa, _ := convertToFullAddr(netip.AddrPortFrom(raddr, 0)) + if tcpipErr = pc.ep.Connect(fa); tcpipErr != nil { + return nil, fmt.Errorf("ping connect: %s", tcpipErr) + } + } + + return pc, nil +} + +func (net *Net) ListenPingAddr(laddr netip.Addr) (*PingConn, error) { + return net.DialPingAddr(laddr, netip.Addr{}) +} + +func (net *Net) DialPing(laddr, raddr *PingAddr) (*PingConn, error) { + var la, ra netip.Addr + if laddr != nil { + la = laddr.addr + } + if raddr != nil { + ra = raddr.addr + } + return net.DialPingAddr(la, ra) +} + +func (net *Net) ListenPing(laddr *PingAddr) (*PingConn, error) { + var la netip.Addr + if laddr != nil { + la = laddr.addr + } + return net.ListenPingAddr(la) +} + +func (pc *PingConn) LocalAddr() net.Addr { + return pc.laddr +} + +func (pc *PingConn) RemoteAddr() net.Addr { + return pc.raddr +} + +func (pc *PingConn) Close() error { + pc.deadline.Reset(0) + pc.ep.Close() + return nil +} + +func (pc *PingConn) SetWriteDeadline(t time.Time) error { + return errors.New("not implemented") +} + +func (pc *PingConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + var na netip.Addr + switch v := addr.(type) { + case *PingAddr: + na = v.addr + case *net.IPAddr: + na, _ = netip.AddrFromSlice(v.IP) + default: + return 0, fmt.Errorf("ping write: wrong net.Addr type") + } + if !((na.Is4() && pc.laddr.addr.Is4()) || (na.Is6() && pc.laddr.addr.Is6())) { + return 0, fmt.Errorf("ping write: mismatched protocols") + } + + buf := bytes.NewReader(p) + rfa, _ := convertToFullAddr(netip.AddrPortFrom(na, 0)) + // won't block, no deadlines + n64, tcpipErr := pc.ep.Write(buf, tcpip.WriteOptions{ + To: &rfa, + }) + if tcpipErr != nil { + return int(n64), fmt.Errorf("ping write: %s", tcpipErr) + } + + return int(n64), nil +} + +func (pc *PingConn) Write(p []byte) (n int, err error) { + return pc.WriteTo(p, &pc.raddr) +} + +func (pc *PingConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + e, notifyCh := waiter.NewChannelEntry(waiter.EventIn) + pc.wq.EventRegister(&e) + defer pc.wq.EventUnregister(&e) + + select { + case <-pc.deadline.C: + return 0, nil, os.ErrDeadlineExceeded + case <-notifyCh: + } + + w := tcpip.SliceWriter(p) + + res, tcpipErr := pc.ep.Read(&w, tcpip.ReadOptions{ + NeedRemoteAddr: true, + }) + if tcpipErr != nil { + return 0, nil, fmt.Errorf("ping read: %s", tcpipErr) + } + + remoteAddr, _ := netip.AddrFromSlice(res.RemoteAddr.Addr.AsSlice()) + return res.Count, &PingAddr{remoteAddr}, nil +} + +func (pc *PingConn) Read(p []byte) (n int, err error) { + n, _, err = pc.ReadFrom(p) + return +} + +func (pc *PingConn) SetDeadline(t time.Time) error { + // pc.SetWriteDeadline is unimplemented + + return pc.SetReadDeadline(t) +} + +func (pc *PingConn) SetReadDeadline(t time.Time) error { + pc.deadline.Reset(time.Until(t)) + return nil +} + +var ( + errNoSuchHost = errors.New("no such host") + errLameReferral = errors.New("lame referral") + errCannotUnmarshalDNSMessage = errors.New("cannot unmarshal DNS message") + errCannotMarshalDNSMessage = errors.New("cannot marshal DNS message") + errServerMisbehaving = errors.New("server misbehaving") + errInvalidDNSResponse = errors.New("invalid DNS response") + errNoAnswerFromDNSServer = errors.New("no answer from DNS server") + errServerTemporarilyMisbehaving = errors.New("server misbehaving") + errCanceled = errors.New("operation was canceled") + errTimeout = errors.New("i/o timeout") + errNumericPort = errors.New("port must be numeric") + errNoSuitableAddress = errors.New("no suitable address found") + errMissingAddress = errors.New("missing address") +) + +func (net *Net) LookupHost(host string) (addrs []string, err error) { + return net.LookupContextHost(context.Background(), host) +} + +func isDomainName(s string) bool { + l := len(s) + if l == 0 || l > 254 || l == 254 && s[l-1] != '.' { + return false + } + last := byte('.') + nonNumeric := false + partlen := 0 + for i := 0; i < len(s); i++ { + c := s[i] + switch { + default: + return false + case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_': + nonNumeric = true + partlen++ + case '0' <= c && c <= '9': + partlen++ + case c == '-': + if last == '.' { + return false + } + partlen++ + nonNumeric = true + case c == '.': + if last == '.' || last == '-' { + return false + } + if partlen > 63 || partlen == 0 { + return false + } + partlen = 0 + } + last = c + } + if last == '-' || partlen > 63 { + return false + } + return nonNumeric +} + +func randU16() uint16 { + var b [2]byte + _, err := rand.Read(b[:]) + if err != nil { + panic(err) + } + return binary.LittleEndian.Uint16(b[:]) +} + +func newRequest(q dnsmessage.Question) (id uint16, udpReq, tcpReq []byte, err error) { + id = randU16() + b := dnsmessage.NewBuilder(make([]byte, 2, 514), dnsmessage.Header{ID: id, RecursionDesired: true}) + b.EnableCompression() + if err := b.StartQuestions(); err != nil { + return 0, nil, nil, err + } + if err := b.Question(q); err != nil { + return 0, nil, nil, err + } + tcpReq, err = b.Finish() + udpReq = tcpReq[2:] + l := len(tcpReq) - 2 + tcpReq[0] = byte(l >> 8) + tcpReq[1] = byte(l) + return id, udpReq, tcpReq, err +} + +func equalASCIIName(x, y dnsmessage.Name) bool { + if x.Length != y.Length { + return false + } + for i := 0; i < int(x.Length); i++ { + a := x.Data[i] + b := y.Data[i] + if 'A' <= a && a <= 'Z' { + a += 0x20 + } + if 'A' <= b && b <= 'Z' { + b += 0x20 + } + if a != b { + return false + } + } + return true +} + +func checkResponse(reqID uint16, reqQues dnsmessage.Question, respHdr dnsmessage.Header, respQues dnsmessage.Question) bool { + if !respHdr.Response { + return false + } + if reqID != respHdr.ID { + return false + } + if reqQues.Type != respQues.Type || reqQues.Class != respQues.Class || !equalASCIIName(reqQues.Name, respQues.Name) { + return false + } + return true +} + +func dnsPacketRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) { + if _, err := c.Write(b); err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, err + } + b = make([]byte, 512) + for { + n, err := c.Read(b) + if err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, err + } + var p dnsmessage.Parser + h, err := p.Start(b[:n]) + if err != nil { + continue + } + q, err := p.Question() + if err != nil || !checkResponse(id, query, h, q) { + continue + } + return p, h, nil + } +} + +func dnsStreamRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) { + if _, err := c.Write(b); err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, err + } + b = make([]byte, 1280) + if _, err := io.ReadFull(c, b[:2]); err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, err + } + l := int(b[0])<<8 | int(b[1]) + if l > len(b) { + b = make([]byte, l) + } + n, err := io.ReadFull(c, b[:l]) + if err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, err + } + var p dnsmessage.Parser + h, err := p.Start(b[:n]) + if err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage + } + q, err := p.Question() + if err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage + } + if !checkResponse(id, query, h, q) { + return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse + } + return p, h, nil +} + +func (tnet *Net) exchange(ctx context.Context, server netip.Addr, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) { + q.Class = dnsmessage.ClassINET + id, udpReq, tcpReq, err := newRequest(q) + if err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotMarshalDNSMessage + } + + for _, useUDP := range []bool{true, false} { + ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout)) + defer cancel() + + var c net.Conn + var err error + if useUDP { + c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, netip.AddrPortFrom(server, 53)) + } else { + c, err = tnet.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(server, 53)) + } + + if err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, err + } + if d, ok := ctx.Deadline(); ok && !d.IsZero() { + err := c.SetDeadline(d) + if err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, err + } + } + var p dnsmessage.Parser + var h dnsmessage.Header + if useUDP { + p, h, err = dnsPacketRoundTrip(c, id, q, udpReq) + } else { + p, h, err = dnsStreamRoundTrip(c, id, q, tcpReq) + } + c.Close() + if err != nil { + if err == context.Canceled { + err = errCanceled + } else if err == context.DeadlineExceeded { + err = errTimeout + } + return dnsmessage.Parser{}, dnsmessage.Header{}, err + } + if err := p.SkipQuestion(); err != dnsmessage.ErrSectionDone { + return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse + } + if h.Truncated { + continue + } + return p, h, nil + } + return dnsmessage.Parser{}, dnsmessage.Header{}, errNoAnswerFromDNSServer +} + +func checkHeader(p *dnsmessage.Parser, h dnsmessage.Header) error { + if h.RCode == dnsmessage.RCodeNameError { + return errNoSuchHost + } + _, err := p.AnswerHeader() + if err != nil && err != dnsmessage.ErrSectionDone { + return errCannotUnmarshalDNSMessage + } + if h.RCode == dnsmessage.RCodeSuccess && !h.Authoritative && !h.RecursionAvailable && err == dnsmessage.ErrSectionDone { + return errLameReferral + } + if h.RCode != dnsmessage.RCodeSuccess && h.RCode != dnsmessage.RCodeNameError { + if h.RCode == dnsmessage.RCodeServerFailure { + return errServerTemporarilyMisbehaving + } + return errServerMisbehaving + } + return nil +} + +func skipToAnswer(p *dnsmessage.Parser, qtype dnsmessage.Type) error { + for { + h, err := p.AnswerHeader() + if err == dnsmessage.ErrSectionDone { + return errNoSuchHost + } + if err != nil { + return errCannotUnmarshalDNSMessage + } + if h.Type == qtype { + return nil + } + if err := p.SkipAnswer(); err != nil { + return errCannotUnmarshalDNSMessage + } + } +} + +func (tnet *Net) tryOneName(ctx context.Context, name string, qtype dnsmessage.Type) (dnsmessage.Parser, string, error) { + var lastErr error + + n, err := dnsmessage.NewName(name) + if err != nil { + return dnsmessage.Parser{}, "", errCannotMarshalDNSMessage + } + q := dnsmessage.Question{ + Name: n, + Type: qtype, + Class: dnsmessage.ClassINET, + } + + for i := 0; i < 2; i++ { + for _, server := range tnet.dnsServers { + p, h, err := tnet.exchange(ctx, server, q, time.Second*5) + if err != nil { + dnsErr := &net.DNSError{ + Err: err.Error(), + Name: name, + Server: server.String(), + } + if nerr, ok := err.(net.Error); ok && nerr.Timeout() { + dnsErr.IsTimeout = true + } + if _, ok := err.(*net.OpError); ok { + dnsErr.IsTemporary = true + } + lastErr = dnsErr + continue + } + + if err := checkHeader(&p, h); err != nil { + dnsErr := &net.DNSError{ + Err: err.Error(), + Name: name, + Server: server.String(), + } + if err == errServerTemporarilyMisbehaving { + dnsErr.IsTemporary = true + } + if err == errNoSuchHost { + dnsErr.IsNotFound = true + return p, server.String(), dnsErr + } + lastErr = dnsErr + continue + } + + err = skipToAnswer(&p, qtype) + if err == nil { + return p, server.String(), nil + } + lastErr = &net.DNSError{ + Err: err.Error(), + Name: name, + Server: server.String(), + } + if err == errNoSuchHost { + lastErr.(*net.DNSError).IsNotFound = true + return p, server.String(), lastErr + } + } + } + return dnsmessage.Parser{}, "", lastErr +} + +func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string, error) { + if host == "" || (!tnet.hasV6 && !tnet.hasV4) { + return nil, &net.DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true} + } + zlen := len(host) + if strings.IndexByte(host, ':') != -1 { + if zidx := strings.LastIndexByte(host, '%'); zidx != -1 { + zlen = zidx + } + } + if ip, err := netip.ParseAddr(host[:zlen]); err == nil { + return []string{ip.String()}, nil + } + + if !isDomainName(host) { + return nil, &net.DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true} + } + type result struct { + p dnsmessage.Parser + server string + error + } + var addrsV4, addrsV6 []netip.Addr + lanes := 0 + if tnet.hasV4 { + lanes++ + } + if tnet.hasV6 { + lanes++ + } + lane := make(chan result, lanes) + var lastErr error + if tnet.hasV4 { + go func() { + p, server, err := tnet.tryOneName(ctx, host+".", dnsmessage.TypeA) + lane <- result{p, server, err} + }() + } + if tnet.hasV6 { + go func() { + p, server, err := tnet.tryOneName(ctx, host+".", dnsmessage.TypeAAAA) + lane <- result{p, server, err} + }() + } + for l := 0; l < lanes; l++ { + result := <-lane + if result.error != nil { + if lastErr == nil { + lastErr = result.error + } + continue + } + + loop: + for { + h, err := result.p.AnswerHeader() + if err != nil && err != dnsmessage.ErrSectionDone { + lastErr = &net.DNSError{ + Err: errCannotMarshalDNSMessage.Error(), + Name: host, + Server: result.server, + } + } + if err != nil { + break + } + switch h.Type { + case dnsmessage.TypeA: + a, err := result.p.AResource() + if err != nil { + lastErr = &net.DNSError{ + Err: errCannotMarshalDNSMessage.Error(), + Name: host, + Server: result.server, + } + break loop + } + addrsV4 = append(addrsV4, netip.AddrFrom4(a.A)) + + case dnsmessage.TypeAAAA: + aaaa, err := result.p.AAAAResource() + if err != nil { + lastErr = &net.DNSError{ + Err: errCannotMarshalDNSMessage.Error(), + Name: host, + Server: result.server, + } + break loop + } + addrsV6 = append(addrsV6, netip.AddrFrom16(aaaa.AAAA)) + + default: + if err := result.p.SkipAnswer(); err != nil { + lastErr = &net.DNSError{ + Err: errCannotMarshalDNSMessage.Error(), + Name: host, + Server: result.server, + } + break loop + } + continue + } + } + } + // We don't do RFC6724. Instead just put V6 addresses first if an IPv6 address is enabled + var addrs []netip.Addr + if tnet.hasV6 { + addrs = append(addrsV6, addrsV4...) + } else { + addrs = append(addrsV4, addrsV6...) + } + + if len(addrs) == 0 && lastErr != nil { + return nil, lastErr + } + saddrs := make([]string, 0, len(addrs)) + for _, ip := range addrs { + saddrs = append(saddrs, ip.String()) + } + return saddrs, nil +} + +func partialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, error) { + if deadline.IsZero() { + return deadline, nil + } + timeRemaining := deadline.Sub(now) + if timeRemaining <= 0 { + return time.Time{}, errTimeout + } + timeout := timeRemaining / time.Duration(addrsRemaining) + const saneMinimum = 2 * time.Second + if timeout < saneMinimum { + if timeRemaining < saneMinimum { + timeout = timeRemaining + } else { + timeout = saneMinimum + } + } + return now.Add(timeout), nil +} + +var protoSplitter = regexp.MustCompile(`^(tcp|udp|ping)(4|6)?$`) + +func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + if ctx == nil { + panic("nil context") + } + var acceptV4, acceptV6 bool + matches := protoSplitter.FindStringSubmatch(network) + if matches == nil { + return nil, &net.OpError{Op: "dial", Err: net.UnknownNetworkError(network)} + } else if len(matches[2]) == 0 { + acceptV4 = true + acceptV6 = true + } else { + acceptV4 = matches[2][0] == '4' + acceptV6 = !acceptV4 + } + var host string + var port int + if matches[1] == "ping" { + host = address + } else { + var sport string + var err error + host, sport, err = net.SplitHostPort(address) + if err != nil { + return nil, &net.OpError{Op: "dial", Err: err} + } + port, err = strconv.Atoi(sport) + if err != nil || port < 0 || port > 65535 { + return nil, &net.OpError{Op: "dial", Err: errNumericPort} + } + } + allAddr, err := tnet.LookupContextHost(ctx, host) + if err != nil { + return nil, &net.OpError{Op: "dial", Err: err} + } + var addrs []netip.AddrPort + for _, addr := range allAddr { + ip, err := netip.ParseAddr(addr) + if err == nil && ((ip.Is4() && acceptV4) || (ip.Is6() && acceptV6)) { + addrs = append(addrs, netip.AddrPortFrom(ip, uint16(port))) + } + } + if len(addrs) == 0 && len(allAddr) != 0 { + return nil, &net.OpError{Op: "dial", Err: errNoSuitableAddress} + } + + var firstErr error + for i, addr := range addrs { + select { + case <-ctx.Done(): + err := ctx.Err() + if err == context.Canceled { + err = errCanceled + } else if err == context.DeadlineExceeded { + err = errTimeout + } + return nil, &net.OpError{Op: "dial", Err: err} + default: + } + + dialCtx := ctx + if deadline, hasDeadline := ctx.Deadline(); hasDeadline { + partialDeadline, err := partialDeadline(time.Now(), deadline, len(addrs)-i) + if err != nil { + if firstErr == nil { + firstErr = &net.OpError{Op: "dial", Err: err} + } + break + } + if partialDeadline.Before(deadline) { + var cancel context.CancelFunc + dialCtx, cancel = context.WithDeadline(ctx, partialDeadline) + defer cancel() + } + } + + var c net.Conn + switch matches[1] { + case "tcp": + c, err = tnet.DialContextTCPAddrPort(dialCtx, addr) + case "udp": + c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, addr) + case "ping": + c, err = tnet.DialPingAddr(netip.Addr{}, addr.Addr()) + } + if err == nil { + return c, nil + } + if firstErr == nil { + firstErr = err + } + } + if firstErr == nil { + firstErr = &net.OpError{Op: "dial", Err: errMissingAddress} + } + return nil, firstErr +} + +func (tnet *Net) Dial(network, address string) (net.Conn, error) { + return tnet.DialContext(context.Background(), network, address) +} diff --git a/network/interface.go b/network/interface.go new file mode 100644 index 0000000..e110ec1 --- /dev/null +++ b/network/interface.go @@ -0,0 +1,165 @@ +package network + +import ( + "fmt" + "net" + "os/exec" + "regexp" + "runtime" + "strconv" + "time" + + "github.com/fosrl/newt/logger" + "github.com/vishvananda/netlink" +) + +// ConfigureInterface configures a network interface with an IP address and brings it up +func ConfigureInterface(interfaceName string, tunnelIp string, mtu int) error { + logger.Info("The tunnel IP is: %s", tunnelIp) + + // Parse the IP address and network + ip, ipNet, err := net.ParseCIDR(tunnelIp) + if err != nil { + return fmt.Errorf("invalid IP address: %v", err) + } + + // Convert CIDR mask to dotted decimal format (e.g., 255.255.255.0) + mask := net.IP(ipNet.Mask).String() + destinationAddress := ip.String() + + logger.Debug("The destination address is: %s", destinationAddress) + + // network.SetTunnelRemoteAddress() // what does this do? + SetIPv4Settings([]string{destinationAddress}, []string{mask}) + SetMTU(mtu) + + if interfaceName == "" { + return nil + } + + switch runtime.GOOS { + case "linux": + return configureLinux(interfaceName, ip, ipNet) + case "darwin": + return configureDarwin(interfaceName, ip, ipNet) + case "windows": + return configureWindows(interfaceName, ip, ipNet) + default: + return fmt.Errorf("unsupported operating system: %s", runtime.GOOS) + } +} + +// waitForInterfaceUp polls the network interface until it's up or times out +func waitForInterfaceUp(interfaceName string, expectedIP net.IP, timeout time.Duration) error { + logger.Info("Waiting for interface %s to be up with IP %s", interfaceName, expectedIP) + deadline := time.Now().Add(timeout) + pollInterval := 500 * time.Millisecond + + for time.Now().Before(deadline) { + // Check if interface exists and is up + iface, err := net.InterfaceByName(interfaceName) + if err == nil { + // Check if interface is up + if iface.Flags&net.FlagUp != 0 { + // Check if it has the expected IP + addrs, err := iface.Addrs() + if err == nil { + for _, addr := range addrs { + ipNet, ok := addr.(*net.IPNet) + if ok && ipNet.IP.Equal(expectedIP) { + logger.Info("Interface %s is up with correct IP", interfaceName) + return nil // Interface is up with correct IP + } + } + logger.Info("Interface %s is up but doesn't have expected IP yet", interfaceName) + } + } else { + logger.Info("Interface %s exists but is not up yet", interfaceName) + } + } else { + logger.Info("Interface %s not found yet: %v", interfaceName, err) + } + + // Wait before next check + time.Sleep(pollInterval) + } + + return fmt.Errorf("timed out waiting for interface %s to be up with IP %s", interfaceName, expectedIP) +} + +func FindUnusedUTUN() (string, error) { + ifaces, err := net.Interfaces() + if err != nil { + return "", fmt.Errorf("failed to list interfaces: %v", err) + } + used := make(map[int]bool) + re := regexp.MustCompile(`^utun(\d+)$`) + for _, iface := range ifaces { + if matches := re.FindStringSubmatch(iface.Name); len(matches) == 2 { + if num, err := strconv.Atoi(matches[1]); err == nil { + used[num] = true + } + } + } + // Try utun0 up to utun255. + for i := 0; i < 256; i++ { + if !used[i] { + return fmt.Sprintf("utun%d", i), nil + } + } + return "", fmt.Errorf("no unused utun interface found") +} + +func configureDarwin(interfaceName string, ip net.IP, ipNet *net.IPNet) error { + logger.Info("Configuring darwin interface: %s", interfaceName) + + prefix, _ := ipNet.Mask.Size() + ipStr := fmt.Sprintf("%s/%d", ip.String(), prefix) + + cmd := exec.Command("ifconfig", interfaceName, "inet", ipStr, ip.String(), "alias") + logger.Info("Running command: %v", cmd) + + out, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("ifconfig command failed: %v, output: %s", err, out) + } + + // Bring up the interface + cmd = exec.Command("ifconfig", interfaceName, "up") + logger.Info("Running command: %v", cmd) + + out, err = cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("ifconfig up command failed: %v, output: %s", err, out) + } + + return nil +} + +func configureLinux(interfaceName string, ip net.IP, ipNet *net.IPNet) error { + // Get the interface + link, err := netlink.LinkByName(interfaceName) + if err != nil { + return fmt.Errorf("failed to get interface %s: %v", interfaceName, err) + } + + // Create the IP address attributes + addr := &netlink.Addr{ + IPNet: &net.IPNet{ + IP: ip, + Mask: ipNet.Mask, + }, + } + + // Add the IP address to the interface + if err := netlink.AddrAdd(link, addr); err != nil { + return fmt.Errorf("failed to add IP address: %v", err) + } + + // Bring up the interface + if err := netlink.LinkSetUp(link); err != nil { + return fmt.Errorf("failed to bring up interface: %v", err) + } + + return nil +} diff --git a/network/interface_notwindows.go b/network/interface_notwindows.go new file mode 100644 index 0000000..5d15ace --- /dev/null +++ b/network/interface_notwindows.go @@ -0,0 +1,12 @@ +//go:build !windows + +package network + +import ( + "fmt" + "net" +) + +func configureWindows(interfaceName string, ip net.IP, ipNet *net.IPNet) error { + return fmt.Errorf("configureWindows called on non-Windows platform") +} diff --git a/network/interface_windows.go b/network/interface_windows.go new file mode 100644 index 0000000..966486b --- /dev/null +++ b/network/interface_windows.go @@ -0,0 +1,63 @@ +//go:build windows + +package network + +import ( + "fmt" + "net" + "net/netip" + + "github.com/fosrl/newt/logger" + "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" +) + +func configureWindows(interfaceName string, ip net.IP, ipNet *net.IPNet) error { + logger.Info("Configuring Windows interface: %s", interfaceName) + + // Get the LUID for the interface + iface, err := net.InterfaceByName(interfaceName) + if err != nil { + return fmt.Errorf("failed to get interface %s: %v", interfaceName, err) + } + + luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index)) + if err != nil { + return fmt.Errorf("failed to get LUID for interface %s: %v", interfaceName, err) + } + + // Create the IP address prefix + maskBits, _ := ipNet.Mask.Size() + + // Ensure we convert to the correct IP version (IPv4 vs IPv6) + var addr netip.Addr + if ip4 := ip.To4(); ip4 != nil { + // IPv4 address + addr, _ = netip.AddrFromSlice(ip4) + } else { + // IPv6 address + addr, _ = netip.AddrFromSlice(ip) + } + if !addr.IsValid() { + return fmt.Errorf("failed to convert IP address") + } + prefix := netip.PrefixFrom(addr, maskBits) + + // Add the IP address to the interface + logger.Info("Adding IP address %s to interface %s", prefix.String(), interfaceName) + err = luid.AddIPAddress(prefix) + if err != nil { + return fmt.Errorf("failed to add IP address: %v", err) + } + + // This was required when we were using the subprocess "netsh" command to bring up the interface. + // With the winipcfg library, the interface should already be up after adding the IP so we dont + // need this step anymore as far as I can tell. + + // // Wait for the interface to be up and have the correct IP + // err = waitForInterfaceUp(interfaceName, ip, 30*time.Second) + // if err != nil { + // return fmt.Errorf("interface did not come up within timeout: %v", err) + // } + + return nil +} diff --git a/network/network.go b/network/network.go deleted file mode 100644 index e359219..0000000 --- a/network/network.go +++ /dev/null @@ -1,195 +0,0 @@ -package network - -import ( - "encoding/binary" - "encoding/json" - "fmt" - "log" - "net" - "time" - - "github.com/google/gopacket" - "github.com/google/gopacket/layers" - "github.com/vishvananda/netlink" - "golang.org/x/net/bpf" - "golang.org/x/net/ipv4" -) - -const ( - udpProtocol = 17 - // EmptyUDPSize is the size of an empty UDP packet - EmptyUDPSize = 28 - timeout = time.Second * 10 -) - -// Server stores data relating to the server -type Server struct { - Hostname string - Addr *net.IPAddr - Port uint16 -} - -// PeerNet stores data about a peer's endpoint -type PeerNet struct { - Resolved bool - IP net.IP - Port uint16 - NewtID string -} - -// GetClientIP gets source ip address that will be used when sending data to dstIP -func GetClientIP(dstIP net.IP) net.IP { - routes, err := netlink.RouteGet(dstIP) - if err != nil { - log.Fatalln("Error getting route:", err) - } - return routes[0].Src -} - -// HostToAddr resolves a hostname, whether DNS or IP to a valid net.IPAddr -func HostToAddr(hostStr string) *net.IPAddr { - remoteAddrs, err := net.LookupHost(hostStr) - if err != nil { - log.Fatalln("Error parsing remote address:", err) - } - - for _, addrStr := range remoteAddrs { - if remoteAddr, err := net.ResolveIPAddr("ip4", addrStr); err == nil { - return remoteAddr - } - } - return nil -} - -// SetupRawConn creates an ipv4 and udp only RawConn and applies packet filtering -func SetupRawConn(server *Server, client *PeerNet) *ipv4.RawConn { - packetConn, err := net.ListenPacket("ip4:udp", client.IP.String()) - if err != nil { - log.Fatalln("Error creating packetConn:", err) - } - - rawConn, err := ipv4.NewRawConn(packetConn) - if err != nil { - log.Fatalln("Error creating rawConn:", err) - } - - ApplyBPF(rawConn, server, client) - - return rawConn -} - -// ApplyBPF constructs a BPF program and applies it to the RawConn -func ApplyBPF(rawConn *ipv4.RawConn, server *Server, client *PeerNet) { - const ipv4HeaderLen = 20 - const srcIPOffset = 12 - const srcPortOffset = ipv4HeaderLen + 0 - const dstPortOffset = ipv4HeaderLen + 2 - - ipArr := []byte(server.Addr.IP.To4()) - ipInt := uint32(ipArr[0])<<(3*8) + uint32(ipArr[1])<<(2*8) + uint32(ipArr[2])<<8 + uint32(ipArr[3]) - - bpfRaw, err := bpf.Assemble([]bpf.Instruction{ - bpf.LoadAbsolute{Off: srcIPOffset, Size: 4}, - bpf.JumpIf{Cond: bpf.JumpEqual, Val: ipInt, SkipFalse: 5, SkipTrue: 0}, - - bpf.LoadAbsolute{Off: srcPortOffset, Size: 2}, - bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(server.Port), SkipFalse: 3, SkipTrue: 0}, - - bpf.LoadAbsolute{Off: dstPortOffset, Size: 2}, - bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(client.Port), SkipFalse: 1, SkipTrue: 0}, - - bpf.RetConstant{Val: 1<<(8*4) - 1}, - bpf.RetConstant{Val: 0}, - }) - - if err != nil { - log.Fatalln("Error assembling BPF:", err) - } - - err = rawConn.SetBPF(bpfRaw) - if err != nil { - log.Fatalln("Error setting BPF:", err) - } -} - -// MakePacket constructs a request packet to send to the server -func MakePacket(payload []byte, server *Server, client *PeerNet) []byte { - buf := gopacket.NewSerializeBuffer() - - opts := gopacket.SerializeOptions{ - FixLengths: true, - ComputeChecksums: true, - } - - ipHeader := layers.IPv4{ - SrcIP: client.IP, - DstIP: server.Addr.IP, - Version: 4, - TTL: 64, - Protocol: layers.IPProtocolUDP, - } - - udpHeader := layers.UDP{ - SrcPort: layers.UDPPort(client.Port), - DstPort: layers.UDPPort(server.Port), - } - - payloadLayer := gopacket.Payload(payload) - - udpHeader.SetNetworkLayerForChecksum(&ipHeader) - - gopacket.SerializeLayers(buf, opts, &ipHeader, &udpHeader, &payloadLayer) - - return buf.Bytes() -} - -// SendPacket sends packet to the Server -func SendPacket(packet []byte, conn *ipv4.RawConn, server *Server, client *PeerNet) error { - fullPacket := MakePacket(packet, server, client) - _, err := conn.WriteToIP(fullPacket, server.Addr) - return err -} - -// SendDataPacket sends a JSON payload to the Server -func SendDataPacket(data interface{}, conn *ipv4.RawConn, server *Server, client *PeerNet) error { - jsonData, err := json.Marshal(data) - if err != nil { - return fmt.Errorf("failed to marshal payload: %v", err) - } - - return SendPacket(jsonData, conn, server, client) -} - -// RecvPacket receives a UDP packet from server -func RecvPacket(conn *ipv4.RawConn, server *Server, client *PeerNet) ([]byte, int, error) { - err := conn.SetReadDeadline(time.Now().Add(timeout)) - if err != nil { - return nil, 0, err - } - - response := make([]byte, 4096) - n, err := conn.Read(response) - if err != nil { - return nil, n, err - } - return response, n, nil -} - -// RecvDataPacket receives and unmarshals a JSON packet from server -func RecvDataPacket(conn *ipv4.RawConn, server *Server, client *PeerNet) ([]byte, error) { - response, n, err := RecvPacket(conn, server, client) - if err != nil { - return nil, err - } - - // Extract payload from UDP packet - payload := response[EmptyUDPSize:n] - return payload, nil -} - -// ParseResponse takes a response packet and parses it into an IP and port -func ParseResponse(response []byte) (net.IP, uint16) { - ip := net.IP(response[:4]) - port := binary.BigEndian.Uint16(response[4:6]) - return ip, port -} diff --git a/network/route.go b/network/route.go new file mode 100644 index 0000000..eb850ee --- /dev/null +++ b/network/route.go @@ -0,0 +1,282 @@ +package network + +import ( + "fmt" + "net" + "os/exec" + "runtime" + "strings" + + "github.com/fosrl/newt/logger" + "github.com/vishvananda/netlink" +) + +func DarwinAddRoute(destination string, gateway string, interfaceName string) error { + if runtime.GOOS != "darwin" { + return nil + } + + var cmd *exec.Cmd + + if gateway != "" { + // Route with specific gateway + cmd = exec.Command("route", "-q", "-n", "add", "-inet", destination, "-gateway", gateway) + } else if interfaceName != "" { + // Route via interface + cmd = exec.Command("route", "-q", "-n", "add", "-inet", destination, "-interface", interfaceName) + } else { + return fmt.Errorf("either gateway or interface must be specified") + } + + logger.Info("Running command: %v", cmd) + + out, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("route command failed: %v, output: %s", err, out) + } + + return nil +} + +func DarwinRemoveRoute(destination string) error { + if runtime.GOOS != "darwin" { + return nil + } + + cmd := exec.Command("route", "-q", "-n", "delete", "-inet", destination) + logger.Info("Running command: %v", cmd) + + out, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("route delete command failed: %v, output: %s", err, out) + } + + return nil +} + +func LinuxAddRoute(destination string, gateway string, interfaceName string) error { + if runtime.GOOS != "linux" { + return nil + } + + // Parse destination CIDR + _, ipNet, err := net.ParseCIDR(destination) + if err != nil { + return fmt.Errorf("invalid destination address: %v", err) + } + + // Create route + route := &netlink.Route{ + Dst: ipNet, + } + + if gateway != "" { + // Route with specific gateway + gw := net.ParseIP(gateway) + if gw == nil { + return fmt.Errorf("invalid gateway address: %s", gateway) + } + route.Gw = gw + logger.Info("Adding route to %s via gateway %s", destination, gateway) + } else if interfaceName != "" { + // Route via interface + link, err := netlink.LinkByName(interfaceName) + if err != nil { + return fmt.Errorf("failed to get interface %s: %v", interfaceName, err) + } + route.LinkIndex = link.Attrs().Index + logger.Info("Adding route to %s via interface %s", destination, interfaceName) + } else { + return fmt.Errorf("either gateway or interface must be specified") + } + + // Add the route + if err := netlink.RouteAdd(route); err != nil { + return fmt.Errorf("failed to add route: %v", err) + } + + return nil +} + +func LinuxRemoveRoute(destination string) error { + if runtime.GOOS != "linux" { + return nil + } + + // Parse destination CIDR + _, ipNet, err := net.ParseCIDR(destination) + if err != nil { + return fmt.Errorf("invalid destination address: %v", err) + } + + // Create route to delete + route := &netlink.Route{ + Dst: ipNet, + } + + logger.Info("Removing route to %s", destination) + + // Delete the route + if err := netlink.RouteDel(route); err != nil { + return fmt.Errorf("failed to delete route: %v", err) + } + + return nil +} + +// addRouteForServerIP adds an OS-specific route for the server IP +func AddRouteForServerIP(serverIP, interfaceName string) error { + if err := AddRouteForNetworkConfig(serverIP); err != nil { + return err + } + if interfaceName == "" { + return nil + } + if runtime.GOOS == "darwin" { + return DarwinAddRoute(serverIP, "", interfaceName) + } + // else if runtime.GOOS == "windows" { + // return WindowsAddRoute(serverIP, "", interfaceName) + // } else if runtime.GOOS == "linux" { + // return LinuxAddRoute(serverIP, "", interfaceName) + // } + return nil +} + +// removeRouteForServerIP removes an OS-specific route for the server IP +func RemoveRouteForServerIP(serverIP string, interfaceName string) error { + if err := RemoveRouteForNetworkConfig(serverIP); err != nil { + return err + } + if interfaceName == "" { + return nil + } + if runtime.GOOS == "darwin" { + return DarwinRemoveRoute(serverIP) + } + // else if runtime.GOOS == "windows" { + // return WindowsRemoveRoute(serverIP) + // } else if runtime.GOOS == "linux" { + // return LinuxRemoveRoute(serverIP) + // } + return nil +} + +func AddRouteForNetworkConfig(destination string) error { + // Parse the subnet to extract IP and mask + _, ipNet, err := net.ParseCIDR(destination) + if err != nil { + return fmt.Errorf("failed to parse subnet %s: %v", destination, err) + } + + // Convert CIDR mask to dotted decimal format (e.g., 255.255.255.0) + mask := net.IP(ipNet.Mask).String() + destinationAddress := ipNet.IP.String() + + AddIPv4IncludedRoute(IPv4Route{DestinationAddress: destinationAddress, SubnetMask: mask}) + + return nil +} + +func RemoveRouteForNetworkConfig(destination string) error { + // Parse the subnet to extract IP and mask + _, ipNet, err := net.ParseCIDR(destination) + if err != nil { + return fmt.Errorf("failed to parse subnet %s: %v", destination, err) + } + + // Convert CIDR mask to dotted decimal format (e.g., 255.255.255.0) + mask := net.IP(ipNet.Mask).String() + destinationAddress := ipNet.IP.String() + + RemoveIPv4IncludedRoute(IPv4Route{DestinationAddress: destinationAddress, SubnetMask: mask}) + + return nil +} + +// addRoutes adds routes for each subnet in RemoteSubnets +func AddRoutes(remoteSubnets []string, interfaceName string) error { + if len(remoteSubnets) == 0 { + return nil + } + + // Add routes for each subnet + for _, subnet := range remoteSubnets { + subnet = strings.TrimSpace(subnet) + if subnet == "" { + continue + } + + if err := AddRouteForNetworkConfig(subnet); err != nil { + logger.Error("Failed to add network config for subnet %s: %v", subnet, err) + continue + } + + // Add route based on operating system + if interfaceName == "" { + continue + } + + if runtime.GOOS == "darwin" { + if err := DarwinAddRoute(subnet, "", interfaceName); err != nil { + logger.Error("Failed to add Darwin route for subnet %s: %v", subnet, err) + return err + } + } else if runtime.GOOS == "windows" { + if err := WindowsAddRoute(subnet, "", interfaceName); err != nil { + logger.Error("Failed to add Windows route for subnet %s: %v", subnet, err) + return err + } + } else if runtime.GOOS == "linux" { + if err := LinuxAddRoute(subnet, "", interfaceName); err != nil { + logger.Error("Failed to add Linux route for subnet %s: %v", subnet, err) + return err + } + } + + logger.Info("Added route for remote subnet: %s", subnet) + } + return nil +} + +// removeRoutesForRemoteSubnets removes routes for each subnet in RemoteSubnets +func RemoveRoutes(remoteSubnets []string) error { + if len(remoteSubnets) == 0 { + return nil + } + + // Remove routes for each subnet + for _, subnet := range remoteSubnets { + subnet = strings.TrimSpace(subnet) + if subnet == "" { + continue + } + + if err := RemoveRouteForNetworkConfig(subnet); err != nil { + logger.Error("Failed to remove network config for subnet %s: %v", subnet, err) + continue + } + + // Remove route based on operating system + if runtime.GOOS == "darwin" { + if err := DarwinRemoveRoute(subnet); err != nil { + logger.Error("Failed to remove Darwin route for subnet %s: %v", subnet, err) + return err + } + } else if runtime.GOOS == "windows" { + if err := WindowsRemoveRoute(subnet); err != nil { + logger.Error("Failed to remove Windows route for subnet %s: %v", subnet, err) + return err + } + } else if runtime.GOOS == "linux" { + if err := LinuxRemoveRoute(subnet); err != nil { + logger.Error("Failed to remove Linux route for subnet %s: %v", subnet, err) + return err + } + } + + logger.Info("Removed route for remote subnet: %s", subnet) + } + + return nil +} diff --git a/network/route_notwindows.go b/network/route_notwindows.go new file mode 100644 index 0000000..6984c71 --- /dev/null +++ b/network/route_notwindows.go @@ -0,0 +1,11 @@ +//go:build !windows + +package network + +func WindowsAddRoute(destination string, gateway string, interfaceName string) error { + return nil +} + +func WindowsRemoveRoute(destination string) error { + return nil +} diff --git a/network/route_windows.go b/network/route_windows.go new file mode 100644 index 0000000..ba613b6 --- /dev/null +++ b/network/route_windows.go @@ -0,0 +1,148 @@ +//go:build windows + +package network + +import ( + "fmt" + "net" + "net/netip" + "runtime" + + "github.com/fosrl/newt/logger" + "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" +) + +func WindowsAddRoute(destination string, gateway string, interfaceName string) error { + if runtime.GOOS != "windows" { + return nil + } + + // Parse destination CIDR + _, ipNet, err := net.ParseCIDR(destination) + if err != nil { + return fmt.Errorf("invalid destination address: %v", err) + } + + // Convert to netip.Prefix + maskBits, _ := ipNet.Mask.Size() + + // Ensure we convert to the correct IP version (IPv4 vs IPv6) + var addr netip.Addr + if ip4 := ipNet.IP.To4(); ip4 != nil { + // IPv4 address + addr, _ = netip.AddrFromSlice(ip4) + } else { + // IPv6 address + addr, _ = netip.AddrFromSlice(ipNet.IP) + } + if !addr.IsValid() { + return fmt.Errorf("failed to convert destination IP") + } + prefix := netip.PrefixFrom(addr, maskBits) + + var luid winipcfg.LUID + var nextHop netip.Addr + + if interfaceName != "" { + // Get the interface LUID - needed for both gateway and interface-only routes + iface, err := net.InterfaceByName(interfaceName) + if err != nil { + return fmt.Errorf("failed to get interface %s: %v", interfaceName, err) + } + + luid, err = winipcfg.LUIDFromIndex(uint32(iface.Index)) + if err != nil { + return fmt.Errorf("failed to get LUID for interface %s: %v", interfaceName, err) + } + } + + if gateway != "" { + // Route with specific gateway + gwIP := net.ParseIP(gateway) + if gwIP == nil { + return fmt.Errorf("invalid gateway address: %s", gateway) + } + // Convert to correct IP version + if ip4 := gwIP.To4(); ip4 != nil { + nextHop, _ = netip.AddrFromSlice(ip4) + } else { + nextHop, _ = netip.AddrFromSlice(gwIP) + } + if !nextHop.IsValid() { + return fmt.Errorf("failed to convert gateway IP") + } + logger.Info("Adding route to %s via gateway %s on interface %s", destination, gateway, interfaceName) + } else if interfaceName != "" { + // Route via interface only + if addr.Is4() { + nextHop = netip.IPv4Unspecified() + } else { + nextHop = netip.IPv6Unspecified() + } + logger.Info("Adding route to %s via interface %s", destination, interfaceName) + } else { + return fmt.Errorf("either gateway or interface must be specified") + } + + // Add the route using winipcfg + err = luid.AddRoute(prefix, nextHop, 1) + if err != nil { + return fmt.Errorf("failed to add route: %v", err) + } + + return nil +} + +func WindowsRemoveRoute(destination string) error { + // Parse destination CIDR + _, ipNet, err := net.ParseCIDR(destination) + if err != nil { + return fmt.Errorf("invalid destination address: %v", err) + } + + // Convert to netip.Prefix + maskBits, _ := ipNet.Mask.Size() + + // Ensure we convert to the correct IP version (IPv4 vs IPv6) + var addr netip.Addr + if ip4 := ipNet.IP.To4(); ip4 != nil { + // IPv4 address + addr, _ = netip.AddrFromSlice(ip4) + } else { + // IPv6 address + addr, _ = netip.AddrFromSlice(ipNet.IP) + } + if !addr.IsValid() { + return fmt.Errorf("failed to convert destination IP") + } + prefix := netip.PrefixFrom(addr, maskBits) + + // Get all routes and find the one to delete + // We need to get the LUID from the existing route + var family winipcfg.AddressFamily + if addr.Is4() { + family = 2 // AF_INET + } else { + family = 23 // AF_INET6 + } + + routes, err := winipcfg.GetIPForwardTable2(family) + if err != nil { + return fmt.Errorf("failed to get route table: %v", err) + } + + // Find and delete matching route + for _, route := range routes { + routePrefix := route.DestinationPrefix.Prefix() + if routePrefix == prefix { + logger.Info("Removing route to %s", destination) + err = route.Delete() + if err != nil { + return fmt.Errorf("failed to delete route: %v", err) + } + return nil + } + } + + return fmt.Errorf("route to %s not found", destination) +} diff --git a/network/settings.go b/network/settings.go new file mode 100644 index 0000000..e7792e0 --- /dev/null +++ b/network/settings.go @@ -0,0 +1,190 @@ +package network + +import ( + "encoding/json" + "sync" + + "github.com/fosrl/newt/logger" +) + +// NetworkSettings represents the network configuration for the tunnel +type NetworkSettings struct { + TunnelRemoteAddress string `json:"tunnel_remote_address,omitempty"` + MTU *int `json:"mtu,omitempty"` + DNSServers []string `json:"dns_servers,omitempty"` + IPv4Addresses []string `json:"ipv4_addresses,omitempty"` + IPv4SubnetMasks []string `json:"ipv4_subnet_masks,omitempty"` + IPv4IncludedRoutes []IPv4Route `json:"ipv4_included_routes,omitempty"` + IPv4ExcludedRoutes []IPv4Route `json:"ipv4_excluded_routes,omitempty"` + IPv6Addresses []string `json:"ipv6_addresses,omitempty"` + IPv6NetworkPrefixes []string `json:"ipv6_network_prefixes,omitempty"` + IPv6IncludedRoutes []IPv6Route `json:"ipv6_included_routes,omitempty"` + IPv6ExcludedRoutes []IPv6Route `json:"ipv6_excluded_routes,omitempty"` +} + +// IPv4Route represents an IPv4 route +type IPv4Route struct { + DestinationAddress string `json:"destination_address"` + SubnetMask string `json:"subnet_mask,omitempty"` + GatewayAddress string `json:"gateway_address,omitempty"` + IsDefault bool `json:"is_default,omitempty"` +} + +// IPv6Route represents an IPv6 route +type IPv6Route struct { + DestinationAddress string `json:"destination_address"` + NetworkPrefixLength int `json:"network_prefix_length,omitempty"` + GatewayAddress string `json:"gateway_address,omitempty"` + IsDefault bool `json:"is_default,omitempty"` +} + +var ( + networkSettings NetworkSettings + networkSettingsMutex sync.RWMutex + incrementor int +) + +// SetTunnelRemoteAddress sets the tunnel remote address +func SetTunnelRemoteAddress(address string) { + networkSettingsMutex.Lock() + defer networkSettingsMutex.Unlock() + networkSettings.TunnelRemoteAddress = address + incrementor++ + logger.Info("Set tunnel remote address: %s", address) +} + +// SetMTU sets the MTU value +func SetMTU(mtu int) { + networkSettingsMutex.Lock() + defer networkSettingsMutex.Unlock() + networkSettings.MTU = &mtu + incrementor++ + logger.Info("Set MTU: %d", mtu) +} + +// SetDNSServers sets the DNS servers +func SetDNSServers(servers []string) { + networkSettingsMutex.Lock() + defer networkSettingsMutex.Unlock() + networkSettings.DNSServers = servers + incrementor++ + logger.Info("Set DNS servers: %v", servers) +} + +// SetIPv4Settings sets IPv4 addresses and subnet masks +func SetIPv4Settings(addresses []string, subnetMasks []string) { + networkSettingsMutex.Lock() + defer networkSettingsMutex.Unlock() + networkSettings.IPv4Addresses = addresses + networkSettings.IPv4SubnetMasks = subnetMasks + incrementor++ + logger.Info("Set IPv4 addresses: %v, subnet masks: %v", addresses, subnetMasks) +} + +// SetIPv4IncludedRoutes sets the included IPv4 routes +func SetIPv4IncludedRoutes(routes []IPv4Route) { + networkSettingsMutex.Lock() + defer networkSettingsMutex.Unlock() + networkSettings.IPv4IncludedRoutes = routes + incrementor++ + logger.Info("Set IPv4 included routes: %d routes", len(routes)) +} + +func AddIPv4IncludedRoute(route IPv4Route) { + networkSettingsMutex.Lock() + defer networkSettingsMutex.Unlock() + + // make sure it does not already exist + for _, r := range networkSettings.IPv4IncludedRoutes { + if r == route { + logger.Info("IPv4 included route already exists: %+v", route) + return + } + } + + networkSettings.IPv4IncludedRoutes = append(networkSettings.IPv4IncludedRoutes, route) + incrementor++ + logger.Info("Added IPv4 included route: %+v", route) +} + +func RemoveIPv4IncludedRoute(route IPv4Route) { + networkSettingsMutex.Lock() + defer networkSettingsMutex.Unlock() + routes := networkSettings.IPv4IncludedRoutes + for i, r := range routes { + if r == route { + networkSettings.IPv4IncludedRoutes = append(routes[:i], routes[i+1:]...) + logger.Info("Removed IPv4 included route: %+v", route) + return + } + } + incrementor++ + logger.Info("IPv4 included route not found for removal: %+v", route) +} + +func SetIPv4ExcludedRoutes(routes []IPv4Route) { + networkSettingsMutex.Lock() + defer networkSettingsMutex.Unlock() + networkSettings.IPv4ExcludedRoutes = routes + incrementor++ + logger.Info("Set IPv4 excluded routes: %d routes", len(routes)) +} + +// SetIPv6Settings sets IPv6 addresses and network prefixes +func SetIPv6Settings(addresses []string, networkPrefixes []string) { + networkSettingsMutex.Lock() + defer networkSettingsMutex.Unlock() + networkSettings.IPv6Addresses = addresses + networkSettings.IPv6NetworkPrefixes = networkPrefixes + incrementor++ + logger.Info("Set IPv6 addresses: %v, network prefixes: %v", addresses, networkPrefixes) +} + +// SetIPv6IncludedRoutes sets the included IPv6 routes +func SetIPv6IncludedRoutes(routes []IPv6Route) { + networkSettingsMutex.Lock() + defer networkSettingsMutex.Unlock() + networkSettings.IPv6IncludedRoutes = routes + incrementor++ + logger.Info("Set IPv6 included routes: %d routes", len(routes)) +} + +// SetIPv6ExcludedRoutes sets the excluded IPv6 routes +func SetIPv6ExcludedRoutes(routes []IPv6Route) { + networkSettingsMutex.Lock() + defer networkSettingsMutex.Unlock() + networkSettings.IPv6ExcludedRoutes = routes + incrementor++ + logger.Info("Set IPv6 excluded routes: %d routes", len(routes)) +} + +// ClearNetworkSettings clears all network settings +func ClearNetworkSettings() { + networkSettingsMutex.Lock() + defer networkSettingsMutex.Unlock() + networkSettings = NetworkSettings{} + incrementor++ + logger.Info("Cleared all network settings") +} + +func GetJSON() (string, error) { + networkSettingsMutex.RLock() + defer networkSettingsMutex.RUnlock() + data, err := json.MarshalIndent(networkSettings, "", " ") + if err != nil { + return "", err + } + return string(data), nil +} + +func GetSettings() NetworkSettings { + networkSettingsMutex.RLock() + defer networkSettingsMutex.RUnlock() + return networkSettings +} + +func GetIncrementor() int { + networkSettingsMutex.Lock() + defer networkSettingsMutex.Unlock() + return incrementor +} diff --git a/stub.go b/stub.go index 3bdbe19..e711da1 100644 --- a/stub.go +++ b/stub.go @@ -32,3 +32,8 @@ func clientsAddProxyTargetNative(pm *proxy.ProxyManager, tunnelIp string) { _ = tunnelIp // No-op for non-Linux systems } + +func clientsStartDirectRelayNative(tunnelIP string) { + _ = tunnelIP + // No-op for non-Linux systems +} diff --git a/util/util.go b/util/util.go new file mode 100644 index 0000000..66f718b --- /dev/null +++ b/util/util.go @@ -0,0 +1,213 @@ +package util + +import ( + "encoding/base64" + "encoding/binary" + "encoding/hex" + "fmt" + "net" + "strings" + + mathrand "math/rand/v2" + + "github.com/fosrl/newt/logger" + "golang.zx2c4.com/wireguard/device" +) + +func ResolveDomain(domain string) (string, error) { + // trim whitespace + domain = strings.TrimSpace(domain) + + // Remove any protocol prefix if present (do this first, before splitting host/port) + domain = strings.TrimPrefix(domain, "http://") + domain = strings.TrimPrefix(domain, "https://") + + // if there are any trailing slashes, remove them + domain = strings.TrimSuffix(domain, "/") + + // Check if there's a port in the domain + host, port, err := net.SplitHostPort(domain) + if err != nil { + // No port found, use the domain as is + host = domain + port = "" + } + + // Lookup IP addresses + ips, err := net.LookupIP(host) + if err != nil { + return "", fmt.Errorf("DNS lookup failed: %v", err) + } + + if len(ips) == 0 { + return "", fmt.Errorf("no IP addresses found for domain %s", host) + } + + // Get the first IPv4 address if available + var ipAddr string + for _, ip := range ips { + if ipv4 := ip.To4(); ipv4 != nil { + ipAddr = ipv4.String() + break + } + } + + // If no IPv4 found, use the first IP (might be IPv6) + if ipAddr == "" { + ipAddr = ips[0].String() + } + + // Add port back if it existed + if port != "" { + ipAddr = net.JoinHostPort(ipAddr, port) + } + + return ipAddr, nil +} + +func ParseLogLevel(level string) logger.LogLevel { + switch strings.ToUpper(level) { + case "DEBUG": + return logger.DEBUG + case "INFO": + return logger.INFO + case "WARN": + return logger.WARN + case "ERROR": + return logger.ERROR + case "FATAL": + return logger.FATAL + default: + return logger.INFO // default to INFO if invalid level provided + } +} + +// find an available UDP port in the range [minPort, maxPort] and also the next port for the wgtester +func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) { + if maxPort < minPort { + return 0, fmt.Errorf("invalid port range: min=%d, max=%d", minPort, maxPort) + } + + // We need to check port+1 as well, so adjust the max port to avoid going out of range + adjustedMaxPort := maxPort - 1 + if adjustedMaxPort < minPort { + return 0, fmt.Errorf("insufficient port range to find consecutive ports: min=%d, max=%d", minPort, maxPort) + } + + // Create a slice of all ports in the range (excluding the last one) + portRange := make([]uint16, adjustedMaxPort-minPort+1) + for i := range portRange { + portRange[i] = minPort + uint16(i) + } + + // Fisher-Yates shuffle to randomize the port order + for i := len(portRange) - 1; i > 0; i-- { + j := mathrand.IntN(i + 1) + portRange[i], portRange[j] = portRange[j], portRange[i] + } + + // Try each port in the randomized order + for _, port := range portRange { + // Check if port is available + addr1 := &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: int(port), + } + conn1, err1 := net.ListenUDP("udp", addr1) + if err1 != nil { + continue // Port is in use or there was an error, try next port + } + + conn1.Close() + return port, nil + } + + return 0, fmt.Errorf("no available consecutive UDP ports found in range %d-%d", minPort, maxPort) +} + +func FixKey(key string) string { + // Remove any whitespace + key = strings.TrimSpace(key) + + // Decode from base64 + decoded, err := base64.StdEncoding.DecodeString(key) + if err != nil { + logger.Fatal("Error decoding base64: %v", err) + } + + // Convert to hex + return hex.EncodeToString(decoded) +} + +// this is the opposite of FixKey +func UnfixKey(hexKey string) string { + // Decode from hex + decoded, err := hex.DecodeString(hexKey) + if err != nil { + logger.Fatal("Error decoding hex: %v", err) + } + + // Convert to base64 + return base64.StdEncoding.EncodeToString(decoded) +} + +func MapToWireGuardLogLevel(level logger.LogLevel) int { + switch level { + case logger.DEBUG: + return device.LogLevelVerbose + // case logger.INFO: + // return device.LogLevel + case logger.WARN: + return device.LogLevelError + case logger.ERROR, logger.FATAL: + return device.LogLevelSilent + default: + return device.LogLevelSilent + } +} + +// GetProtocol returns protocol number from IPv4 packet (fast path) +func GetProtocol(packet []byte) (uint8, bool) { + if len(packet) < 20 { + return 0, false + } + version := packet[0] >> 4 + if version == 4 { + return packet[9], true + } else if version == 6 { + if len(packet) < 40 { + return 0, false + } + return packet[6], true + } + return 0, false +} + +// GetDestPort returns destination port from TCP/UDP packet (fast path) +func GetDestPort(packet []byte) (uint16, bool) { + if len(packet) < 20 { + return 0, false + } + + version := packet[0] >> 4 + var headerLen int + + if version == 4 { + ihl := packet[0] & 0x0F + headerLen = int(ihl) * 4 + if len(packet) < headerLen+4 { + return 0, false + } + } else if version == 6 { + headerLen = 40 + if len(packet) < headerLen+4 { + return 0, false + } + } else { + return 0, false + } + + // Destination port is at bytes 2-3 of TCP/UDP header + port := binary.BigEndian.Uint16(packet[headerLen+2 : headerLen+4]) + return port, true +} diff --git a/websocket/client.go b/websocket/client.go index a3ba757..da1fa88 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -46,6 +46,7 @@ type Client struct { metricsCtxMu sync.RWMutex metricsCtx context.Context configNeedsSave bool // Flag to track if config needs to be saved + serverVersion string } type ClientOption func(*Client) @@ -149,6 +150,10 @@ func (c *Client) GetConfig() *Config { return c.config } +func (c *Client) GetServerVersion() string { + return c.serverVersion +} + // Connect establishes the WebSocket connection func (c *Client) Connect() error { go c.connectWithRetry() @@ -206,6 +211,26 @@ func (c *Client) SendMessage(messageType string, data interface{}) error { return nil } +// SendMessage sends a message through the WebSocket connection +func (c *Client) SendMessageNoLog(messageType string, data interface{}) error { + if c.conn == nil { + return fmt.Errorf("not connected") + } + + msg := WSMessage{ + Type: messageType, + Data: data, + } + + c.writeMux.Lock() + defer c.writeMux.Unlock() + if err := c.conn.WriteJSON(msg); err != nil { + return err + } + telemetry.IncWSMessage(c.metricsContext(), "out", "text") + return nil +} + func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration) (stop func()) { stopChan := make(chan struct{}) go func() { @@ -331,9 +356,11 @@ func (c *Client) getToken() (string, error) { } defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + logger.Debug("Token response body: %s", string(body)) + if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - logger.Error("Failed to get token with status code: %d, body: %s", resp.StatusCode, string(body)) + logger.Error("Failed to get token with status code: %d", resp.StatusCode) telemetry.IncConnAttempt(ctx, "auth", "failure") etype := "io_error" if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden { @@ -348,7 +375,7 @@ func (c *Client) getToken() (string, error) { } var tokenResp TokenResponse - if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { + if err := json.Unmarshal(body, &tokenResp); err != nil { logger.Error("Failed to decode token response.") return "", fmt.Errorf("failed to decode token response: %w", err) } @@ -361,6 +388,11 @@ func (c *Client) getToken() (string, error) { return "", fmt.Errorf("received empty token from server") } + // print server version + logger.Info("Server version: %s", tokenResp.Data.ServerVersion) + + c.serverVersion = tokenResp.Data.ServerVersion + logger.Debug("Received token: %s", tokenResp.Data.Token) telemetry.IncConnAttempt(ctx, "auth", "success") diff --git a/websocket/types.go b/websocket/types.go index 229ab50..1196d64 100644 --- a/websocket/types.go +++ b/websocket/types.go @@ -9,7 +9,8 @@ type Config struct { type TokenResponse struct { Data struct { - Token string `json:"token"` + Token string `json:"token"` + ServerVersion string `json:"serverVersion"` } `json:"data"` Success bool `json:"success"` Message string `json:"message"` diff --git a/wg/wg.go b/wg/wg.go deleted file mode 100644 index 4b9e7f7..0000000 --- a/wg/wg.go +++ /dev/null @@ -1,1030 +0,0 @@ -//go:build linux - -package wg - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "net" - "os" - "strconv" - "strings" - "sync" - "time" - - "math/rand" - - "github.com/fosrl/newt/logger" - "github.com/fosrl/newt/network" - "github.com/fosrl/newt/websocket" - "github.com/vishvananda/netlink" - "golang.org/x/crypto/chacha20poly1305" - "golang.org/x/crypto/curve25519" - "golang.zx2c4.com/wireguard/conn" - "golang.zx2c4.com/wireguard/wgctrl" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - - "github.com/fosrl/newt/internal/telemetry" -) - -type WgConfig struct { - IpAddress string `json:"ipAddress"` - Peers []Peer `json:"peers"` -} - -type Peer struct { - PublicKey string `json:"publicKey"` - AllowedIPs []string `json:"allowedIps"` - Endpoint string `json:"endpoint"` -} - -type PeerBandwidth struct { - PublicKey string `json:"publicKey"` - BytesIn float64 `json:"bytesIn"` - BytesOut float64 `json:"bytesOut"` -} - -type PeerReading struct { - BytesReceived int64 - BytesTransmitted int64 - LastChecked time.Time -} - -type WireGuardService struct { - interfaceName string - mtu int - client *websocket.Client - wgClient *wgctrl.Client - config WgConfig - key wgtypes.Key - keyFilePath string - newtId string - lastReadings map[string]PeerReading - mu sync.Mutex - Port uint16 - stopHolepunch chan struct{} - host string - serverPubKey string - holePunchEndpoint string - token string - stopGetConfig func() - interfaceCreated bool -} - -// Add this type definition -type fixedPortBind struct { - port uint16 - conn.Bind -} - -func (b *fixedPortBind) Open(port uint16) ([]conn.ReceiveFunc, uint16, error) { - // Ignore the requested port and use our fixed port - return b.Bind.Open(b.port) -} - -func NewFixedPortBind(port uint16) conn.Bind { - return &fixedPortBind{ - port: port, - Bind: conn.NewDefaultBind(), - } -} - -// find an available UDP port in the range [minPort, maxPort] and also the next port for the wgtester -func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) { - if maxPort < minPort { - return 0, fmt.Errorf("invalid port range: min=%d, max=%d", minPort, maxPort) - } - - // We need to check port+1 as well, so adjust the max port to avoid going out of range - adjustedMaxPort := maxPort - 1 - if adjustedMaxPort < minPort { - return 0, fmt.Errorf("insufficient port range to find consecutive ports: min=%d, max=%d", minPort, maxPort) - } - - // Create a slice of all ports in the range (excluding the last one) - portRange := make([]uint16, adjustedMaxPort-minPort+1) - for i := range portRange { - portRange[i] = minPort + uint16(i) - } - - // Fisher-Yates shuffle to randomize the port order - rand.Seed(time.Now().UnixNano()) - for i := len(portRange) - 1; i > 0; i-- { - j := rand.Intn(i + 1) - portRange[i], portRange[j] = portRange[j], portRange[i] - } - - // Try each port in the randomized order - for _, port := range portRange { - // Check if port is available - addr1 := &net.UDPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: int(port), - } - conn1, err1 := net.ListenUDP("udp", addr1) - if err1 != nil { - continue // Port is in use or there was an error, try next port - } - - // Check if port+1 is also available - addr2 := &net.UDPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: int(port + 1), - } - conn2, err2 := net.ListenUDP("udp", addr2) - if err2 != nil { - // The next port is not available, so close the first connection and try again - conn1.Close() - continue - } - - // Both ports are available, close connections and return the first port - conn1.Close() - conn2.Close() - return port, nil - } - - return 0, fmt.Errorf("no available consecutive UDP ports found in range %d-%d", minPort, maxPort) -} - -func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo string, host string, newtId string, wsClient *websocket.Client) (*WireGuardService, error) { - wgClient, err := wgctrl.New() - if err != nil { - return nil, fmt.Errorf("failed to create WireGuard client: %v", err) - } - - var key wgtypes.Key - var port uint16 - // if generateAndSaveKeyTo is provided, generate a private key and save it to the file. if the file already exists, load the key from the file - key, err = wgtypes.GeneratePrivateKey() - if err != nil { - return nil, fmt.Errorf("failed to generate private key: %v", err) - } - - // Load or generate private key - if generateAndSaveKeyTo != "" { - if _, err := os.Stat(generateAndSaveKeyTo); os.IsNotExist(err) { - keyData, err := os.ReadFile(generateAndSaveKeyTo) - if err != nil { - return nil, fmt.Errorf("failed to read private key: %v", err) - } - key, err = wgtypes.ParseKey(strings.TrimSpace(string(keyData))) - if err != nil { - return nil, fmt.Errorf("failed to parse private key: %v", err) - } - } else { - err = os.WriteFile(generateAndSaveKeyTo, []byte(key.String()), 0600) - if err != nil { - return nil, fmt.Errorf("failed to save private key: %v", err) - } - } - } - - // Get the existing wireguard port - device, err := wgClient.Device(interfaceName) - if err == nil { - port = uint16(device.ListenPort) - // also set the private key to the existing key - key = device.PrivateKey - if port != 0 { - logger.Info("WireGuard interface %s already exists with port %d\n", interfaceName, port) - } else { - port, err = FindAvailableUDPPort(49152, 65535) - if err != nil { - fmt.Printf("Error finding available port: %v\n", err) - return nil, err - } - } - } else { - port, err = FindAvailableUDPPort(49152, 65535) - if err != nil { - fmt.Printf("Error finding available port: %v\n", err) - return nil, err - } - } - - service := &WireGuardService{ - interfaceName: interfaceName, - mtu: mtu, - client: wsClient, - wgClient: wgClient, - key: key, - Port: port, - keyFilePath: generateAndSaveKeyTo, - newtId: newtId, - host: host, - lastReadings: make(map[string]PeerReading), - stopHolepunch: make(chan struct{}), - } - - // Register websocket handlers - wsClient.RegisterHandler("newt/wg/receive-config", service.handleConfig) - wsClient.RegisterHandler("newt/wg/peer/add", service.handleAddPeer) - wsClient.RegisterHandler("newt/wg/peer/remove", service.handleRemovePeer) - wsClient.RegisterHandler("newt/wg/peer/update", service.handleUpdatePeer) - - return service, nil -} - -func (s *WireGuardService) Close(rm bool) { - if s.stopGetConfig != nil { - s.stopGetConfig() - s.stopGetConfig = nil - } - - s.wgClient.Close() - // Remove the WireGuard interface - if rm { - if err := s.removeInterface(); err != nil { - logger.Error("Failed to remove WireGuard interface: %v", err) - } - - // Remove the private key file - // if s.keyFilePath != "" { - // if err := os.Remove(s.keyFilePath); err != nil { - // logger.Error("Failed to remove private key file: %v", err) - // } - // } - } -} - -func (s *WireGuardService) StartHolepunch(serverPubKey string, endpoint string) { - // if the device is already created dont start a new holepunch - if s.interfaceCreated { - return - } - - s.serverPubKey = serverPubKey - s.holePunchEndpoint = endpoint - - logger.Debug("Starting UDP hole punch to %s", s.holePunchEndpoint) - - s.stopHolepunch = make(chan struct{}) - - // start the UDP holepunch - go s.keepSendingUDPHolePunch(s.holePunchEndpoint) -} - -func (s *WireGuardService) SetToken(token string) { - s.token = token -} - -func (s *WireGuardService) LoadRemoteConfig() error { - s.stopGetConfig = s.client.SendMessageInterval("newt/wg/get-config", map[string]interface{}{ - "publicKey": s.key.PublicKey().String(), - "port": s.Port, - }, 2*time.Second) - - logger.Info("Requesting WireGuard configuration from remote server") - go s.periodicBandwidthCheck() - - return nil -} - -func (s *WireGuardService) handleConfig(msg websocket.WSMessage) { - ctx := context.Background() - if s.client != nil { - ctx = s.client.MetricsContext() - } - result := "success" - defer func() { - telemetry.IncConfigReload(ctx, result) - }() - - var config WgConfig - - logger.Debug("Received message: %v", msg) - logger.Info("Received WireGuard clients configuration from remote server") - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Info("Error marshaling data: %v", err) - result = "failure" - return - } - - if err := json.Unmarshal(jsonData, &config); err != nil { - logger.Info("Error unmarshaling target data: %v", err) - result = "failure" - return - } - s.config = config - - if s.stopGetConfig != nil { - s.stopGetConfig() - s.stopGetConfig = nil - } - - // telemetry: config reload success - // Optional reconnect reason mapping: config change - if s.serverPubKey != "" { - telemetry.IncReconnect(ctx, s.serverPubKey, "client", telemetry.ReasonConfigChange) - } - - // Ensure the WireGuard interface and peers are configured - start := time.Now() - if err := s.ensureWireguardInterface(config); err != nil { - logger.Error("Failed to ensure WireGuard interface: %v", err) - telemetry.ObserveConfigApply(ctx, "interface", "failure", time.Since(start).Seconds()) - result = "failure" - } else { - telemetry.ObserveConfigApply(ctx, "interface", "success", time.Since(start).Seconds()) - } - - startPeers := time.Now() - if err := s.ensureWireguardPeers(config.Peers); err != nil { - logger.Error("Failed to ensure WireGuard peers: %v", err) - telemetry.ObserveConfigApply(ctx, "peer", "failure", time.Since(startPeers).Seconds()) - result = "failure" - } else { - telemetry.ObserveConfigApply(ctx, "peer", "success", time.Since(startPeers).Seconds()) - } -} - -func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error { - // Check if the WireGuard interface exists - _, err := netlink.LinkByName(s.interfaceName) - if err != nil { - if _, ok := err.(netlink.LinkNotFoundError); ok { - // Interface doesn't exist, so create it - err = s.createWireGuardInterface() - if err != nil { - logger.Fatal("Failed to create WireGuard interface: %v", err) - } - s.interfaceCreated = true - logger.Info("Created WireGuard interface %s\n", s.interfaceName) - } else { - logger.Fatal("Error checking for WireGuard interface: %v", err) - } - } else { - logger.Info("WireGuard interface %s already exists\n", s.interfaceName) - - // get the exising wireguard port - device, err := s.wgClient.Device(s.interfaceName) - if err != nil { - return fmt.Errorf("failed to get device: %v", err) - } - - // get the existing port - s.Port = uint16(device.ListenPort) - logger.Info("WireGuard interface %s already exists with port %d\n", s.interfaceName, s.Port) - - s.interfaceCreated = true - return nil - } - - // stop the holepunch its a channel - if s.stopHolepunch != nil { - close(s.stopHolepunch) - s.stopHolepunch = nil - } - - logger.Info("Assigning IP address %s to interface %s\n", wgconfig.IpAddress, s.interfaceName) - // Assign IP address to the interface - err = s.assignIPAddress(wgconfig.IpAddress) - if err != nil { - logger.Fatal("Failed to assign IP address: %v", err) - } - - // Check if the interface already exists - _, err = s.wgClient.Device(s.interfaceName) - if err != nil { - if errors.Is(err, os.ErrNotExist) { - return fmt.Errorf("interface %s does not exist", s.interfaceName) - } - return fmt.Errorf("failed to get device: %v", err) - } - - // Parse the private key - key, err := wgtypes.ParseKey(s.key.String()) - if err != nil { - return fmt.Errorf("failed to parse private key: %v", err) - } - - config := wgtypes.Config{ - PrivateKey: &key, - ListenPort: new(int), - } - - // Use the service's fixed port instead of the config port - *config.ListenPort = int(s.Port) - - // Create and configure the WireGuard interface - err = s.wgClient.ConfigureDevice(s.interfaceName, config) - if err != nil { - return fmt.Errorf("failed to configure WireGuard device: %v", err) - } - - // bring up the interface - link, err := netlink.LinkByName(s.interfaceName) - if err != nil { - return fmt.Errorf("failed to get interface: %v", err) - } - - if err := netlink.LinkSetMTU(link, s.mtu); err != nil { - return fmt.Errorf("failed to set MTU: %v", err) - } - - if err := netlink.LinkSetUp(link); err != nil { - return fmt.Errorf("failed to bring up interface: %v", err) - } - - // if err := s.ensureMSSClamping(); err != nil { - // logger.Warn("Failed to ensure MSS clamping: %v", err) - // } - - logger.Info("WireGuard interface %s created and configured", s.interfaceName) - - return nil -} - -func (s *WireGuardService) createWireGuardInterface() error { - wgLink := &netlink.GenericLink{ - LinkAttrs: netlink.LinkAttrs{Name: s.interfaceName}, - LinkType: "wireguard", - } - return netlink.LinkAdd(wgLink) -} - -func (s *WireGuardService) assignIPAddress(ipAddress string) error { - link, err := netlink.LinkByName(s.interfaceName) - if err != nil { - return fmt.Errorf("failed to get interface: %v", err) - } - - addr, err := netlink.ParseAddr(ipAddress) - if err != nil { - return fmt.Errorf("failed to parse IP address: %v", err) - } - - return netlink.AddrAdd(link, addr) -} - -func (s *WireGuardService) ensureWireguardPeers(peers []Peer) error { - // get the current peers - device, err := s.wgClient.Device(s.interfaceName) - if err != nil { - return fmt.Errorf("failed to get device: %v", err) - } - - // get the peer public keys - var currentPeers []string - for _, peer := range device.Peers { - currentPeers = append(currentPeers, peer.PublicKey.String()) - } - - // remove any peers that are not in the config - for _, peer := range currentPeers { - found := false - for _, configPeer := range peers { - if peer == configPeer.PublicKey { - found = true - break - } - } - if !found { - err := s.removePeer(peer) - if err != nil { - return fmt.Errorf("failed to remove peer: %v", err) - } - } - } - - // add any peers that are in the config but not in the current peers - for _, configPeer := range peers { - found := false - for _, peer := range currentPeers { - if configPeer.PublicKey == peer { - found = true - break - } - } - if !found { - err := s.addPeer(configPeer) - if err != nil { - return fmt.Errorf("failed to add peer: %v", err) - } - } - } - - return nil -} - -func (s *WireGuardService) handleAddPeer(msg websocket.WSMessage) { - logger.Debug("Received message: %v", msg.Data) - var peer Peer - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Info("Error marshaling data: %v", err) - } - - if err := json.Unmarshal(jsonData, &peer); err != nil { - logger.Info("Error unmarshaling target data: %v", err) - } - - err = s.addPeer(peer) - if err != nil { - logger.Info("Error adding peer: %v", err) - return - } -} - -func (s *WireGuardService) addPeer(peer Peer) error { - pubKey, err := wgtypes.ParseKey(peer.PublicKey) - if err != nil { - return fmt.Errorf("failed to parse public key: %v", err) - } - - // parse allowed IPs into array of net.IPNet - var allowedIPs []net.IPNet - for _, ipStr := range peer.AllowedIPs { - _, ipNet, err := net.ParseCIDR(ipStr) - if err != nil { - return fmt.Errorf("failed to parse allowed IP: %v", err) - } - allowedIPs = append(allowedIPs, *ipNet) - } - // add keep alive using *time.Duration of 1 second - keepalive := time.Second - - var peerConfig wgtypes.PeerConfig - if peer.Endpoint != "" { - endpoint, err := net.ResolveUDPAddr("udp", peer.Endpoint) - if err != nil { - return fmt.Errorf("failed to resolve endpoint address: %w", err) - } - - peerConfig = wgtypes.PeerConfig{ - PublicKey: pubKey, - AllowedIPs: allowedIPs, - PersistentKeepaliveInterval: &keepalive, - Endpoint: endpoint, - } - } else { - peerConfig = wgtypes.PeerConfig{ - PublicKey: pubKey, - AllowedIPs: allowedIPs, - PersistentKeepaliveInterval: &keepalive, - } - logger.Info("Added peer with no endpoint!") - } - - config := wgtypes.Config{ - Peers: []wgtypes.PeerConfig{peerConfig}, - } - - if err := s.wgClient.ConfigureDevice(s.interfaceName, config); err != nil { - return fmt.Errorf("failed to add peer: %v", err) - } - - logger.Info("Peer %s added successfully", peer.PublicKey) - - return nil -} - -func (s *WireGuardService) handleRemovePeer(msg websocket.WSMessage) { - logger.Debug("Received message: %v", msg.Data) - // parse the publicKey from the message which is json { "publicKey": "asdfasdfl;akjsdf" } - type RemoveRequest struct { - PublicKey string `json:"publicKey"` - } - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Info("Error marshaling data: %v", err) - } - - var request RemoveRequest - if err := json.Unmarshal(jsonData, &request); err != nil { - logger.Info("Error unmarshaling data: %v", err) - return - } - - if err := s.removePeer(request.PublicKey); err != nil { - logger.Info("Error removing peer: %v", err) - return - } -} - -func (s *WireGuardService) removePeer(publicKey string) error { - pubKey, err := wgtypes.ParseKey(publicKey) - if err != nil { - return fmt.Errorf("failed to parse public key: %v", err) - } - - peerConfig := wgtypes.PeerConfig{ - PublicKey: pubKey, - Remove: true, - } - - config := wgtypes.Config{ - Peers: []wgtypes.PeerConfig{peerConfig}, - } - - if err := s.wgClient.ConfigureDevice(s.interfaceName, config); err != nil { - return fmt.Errorf("failed to remove peer: %v", err) - } - - logger.Info("Peer %s removed successfully", publicKey) - - return nil -} - -func (s *WireGuardService) handleUpdatePeer(msg websocket.WSMessage) { - logger.Debug("Received message: %v", msg.Data) - // Define a struct to match the incoming message structure with optional fields - type UpdatePeerRequest struct { - PublicKey string `json:"publicKey"` - AllowedIPs []string `json:"allowedIps,omitempty"` - Endpoint string `json:"endpoint,omitempty"` - } - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Info("Error marshaling data: %v", err) - return - } - var request UpdatePeerRequest - if err := json.Unmarshal(jsonData, &request); err != nil { - logger.Info("Error unmarshaling peer data: %v", err) - return - } - // First, get the current peer configuration to preserve any unmodified fields - device, err := s.wgClient.Device(s.interfaceName) - if err != nil { - logger.Info("Error getting WireGuard device: %v", err) - return - } - pubKey, err := wgtypes.ParseKey(request.PublicKey) - if err != nil { - logger.Info("Error parsing public key: %v", err) - return - } - // Find the existing peer configuration - var currentPeer *wgtypes.Peer - for _, p := range device.Peers { - if p.PublicKey == pubKey { - currentPeer = &p - break - } - } - if currentPeer == nil { - logger.Info("Peer %s not found, cannot update", request.PublicKey) - return - } - // Create the update peer config - peerConfig := wgtypes.PeerConfig{ - PublicKey: pubKey, - UpdateOnly: true, - } - // Keep the default persistent keepalive of 1 second - keepalive := time.Second - peerConfig.PersistentKeepaliveInterval = &keepalive - - // Handle Endpoint field special case - // If Endpoint is included in the request but empty, we want to remove the endpoint - // If Endpoint is not included, we don't modify it - endpointSpecified := false - for key := range msg.Data.(map[string]interface{}) { - if key == "endpoint" { - endpointSpecified = true - break - } - } - - // Only update AllowedIPs if provided in the request - if len(request.AllowedIPs) > 0 { - var allowedIPs []net.IPNet - for _, ipStr := range request.AllowedIPs { - _, ipNet, err := net.ParseCIDR(ipStr) - if err != nil { - logger.Info("Error parsing allowed IP %s: %v", ipStr, err) - return - } - allowedIPs = append(allowedIPs, *ipNet) - } - peerConfig.AllowedIPs = allowedIPs - peerConfig.ReplaceAllowedIPs = true - logger.Info("Updating AllowedIPs for peer %s", request.PublicKey) - } else if endpointSpecified && request.Endpoint == "" { - peerConfig.ReplaceAllowedIPs = false - } - - if endpointSpecified { - if request.Endpoint != "" { - // Update to new endpoint - endpoint, err := net.ResolveUDPAddr("udp", request.Endpoint) - if err != nil { - logger.Info("Error resolving endpoint address %s: %v", request.Endpoint, err) - return - } - peerConfig.Endpoint = endpoint - logger.Info("Updating Endpoint for peer %s to %s", request.PublicKey, request.Endpoint) - } else { - // specify any address to listen for any incoming packets - peerConfig.Endpoint = &net.UDPAddr{ - IP: net.IPv4(127, 0, 0, 1), - } - logger.Info("Removing Endpoint for peer %s", request.PublicKey) - } - } - - // Apply the configuration update - config := wgtypes.Config{ - Peers: []wgtypes.PeerConfig{peerConfig}, - } - if err := s.wgClient.ConfigureDevice(s.interfaceName, config); err != nil { - logger.Info("Error updating peer configuration: %v", err) - return - } - logger.Info("Peer %s updated successfully", request.PublicKey) -} - -func (s *WireGuardService) periodicBandwidthCheck() { - ticker := time.NewTicker(10 * time.Second) - defer ticker.Stop() - - for range ticker.C { - if err := s.reportPeerBandwidth(); err != nil { - logger.Info("Failed to report peer bandwidth: %v", err) - } - } -} - -func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) { - device, err := s.wgClient.Device(s.interfaceName) - if err != nil { - return nil, fmt.Errorf("failed to get device: %v", err) - } - - peerBandwidths := []PeerBandwidth{} - now := time.Now() - - s.mu.Lock() - defer s.mu.Unlock() - - for _, peer := range device.Peers { - publicKey := peer.PublicKey.String() - currentReading := PeerReading{ - BytesReceived: peer.ReceiveBytes, - BytesTransmitted: peer.TransmitBytes, - LastChecked: now, - } - - var bytesInDiff, bytesOutDiff float64 - lastReading, exists := s.lastReadings[publicKey] - - if exists { - timeDiff := currentReading.LastChecked.Sub(lastReading.LastChecked).Seconds() - if timeDiff > 0 { - // Calculate bytes transferred since last reading - bytesInDiff = float64(currentReading.BytesReceived - lastReading.BytesReceived) - bytesOutDiff = float64(currentReading.BytesTransmitted - lastReading.BytesTransmitted) - - // Handle counter wraparound (if the counter resets or overflows) - if bytesInDiff < 0 { - bytesInDiff = float64(currentReading.BytesReceived) - } - if bytesOutDiff < 0 { - bytesOutDiff = float64(currentReading.BytesTransmitted) - } - - // Convert to MB - bytesInMB := bytesInDiff / (1024 * 1024) - bytesOutMB := bytesOutDiff / (1024 * 1024) - - peerBandwidths = append(peerBandwidths, PeerBandwidth{ - PublicKey: publicKey, - BytesIn: bytesInMB, - BytesOut: bytesOutMB, - }) - } else { - // If readings are too close together or time hasn't passed, report 0 - peerBandwidths = append(peerBandwidths, PeerBandwidth{ - PublicKey: publicKey, - BytesIn: 0, - BytesOut: 0, - }) - } - } else { - // For first reading of a peer, report 0 to establish baseline - peerBandwidths = append(peerBandwidths, PeerBandwidth{ - PublicKey: publicKey, - BytesIn: 0, - BytesOut: 0, - }) - } - - // Update the last reading - s.lastReadings[publicKey] = currentReading - } - - // Clean up old peers - for publicKey := range s.lastReadings { - found := false - for _, peer := range device.Peers { - if peer.PublicKey.String() == publicKey { - found = true - break - } - } - if !found { - delete(s.lastReadings, publicKey) - } - } - - return peerBandwidths, nil -} - -func (s *WireGuardService) reportPeerBandwidth() error { - bandwidths, err := s.calculatePeerBandwidth() - if err != nil { - return fmt.Errorf("failed to calculate peer bandwidth: %v", err) - } - - err = s.client.SendMessage("newt/receive-bandwidth", map[string]interface{}{ - "bandwidthData": bandwidths, - }) - if err != nil { - return fmt.Errorf("failed to send bandwidth data: %v", err) - } - - return nil -} - -func (s *WireGuardService) sendUDPHolePunch(serverAddr string) error { - - if s.serverPubKey == "" || s.token == "" { - logger.Debug("Server public key or token not set, skipping UDP hole punch") - return nil - } - - // Parse server address - serverSplit := strings.Split(serverAddr, ":") - if len(serverSplit) < 2 { - return fmt.Errorf("invalid server address format, expected hostname:port") - } - - serverHostname := serverSplit[0] - serverPort, err := strconv.ParseUint(serverSplit[1], 10, 16) - if err != nil { - return fmt.Errorf("failed to parse server port: %v", err) - } - - // Resolve server hostname to IP - serverIPAddr := network.HostToAddr(serverHostname) - if serverIPAddr == nil { - return fmt.Errorf("failed to resolve server hostname") - } - - // Get client IP based on route to server - clientIP := network.GetClientIP(serverIPAddr.IP) - - // Create server and client configs - server := &network.Server{ - Hostname: serverHostname, - Addr: serverIPAddr, - Port: uint16(serverPort), - } - - client := &network.PeerNet{ - IP: clientIP, - Port: s.Port, - NewtID: s.newtId, - } - - // Setup raw connection with BPF filtering - rawConn := network.SetupRawConn(server, client) - defer rawConn.Close() - - // Create JSON payload - payload := struct { - NewtID string `json:"newtId"` - Token string `json:"token"` - }{ - NewtID: s.newtId, - Token: s.token, - } - - // Convert payload to JSON - payloadBytes, err := json.Marshal(payload) - if err != nil { - return fmt.Errorf("failed to marshal payload: %v", err) - } - - // Encrypt the payload using the server's WireGuard public key - encryptedPayload, err := s.encryptPayload(payloadBytes) - if err != nil { - return fmt.Errorf("failed to encrypt payload: %v", err) - } - - // Send the encrypted packet using the raw connection - err = network.SendDataPacket(encryptedPayload, rawConn, server, client) - if err != nil { - return fmt.Errorf("failed to send UDP packet: %v", err) - } - - return nil -} - -func (s *WireGuardService) encryptPayload(payload []byte) (interface{}, error) { - // Generate an ephemeral keypair for this message - ephemeralPrivateKey, err := wgtypes.GeneratePrivateKey() - if err != nil { - return nil, fmt.Errorf("failed to generate ephemeral private key: %v", err) - } - ephemeralPublicKey := ephemeralPrivateKey.PublicKey() - - // Parse the server's public key - serverPubKey, err := wgtypes.ParseKey(s.serverPubKey) - if err != nil { - return nil, fmt.Errorf("failed to parse server public key: %v", err) - } - - // Use X25519 for key exchange (replacing deprecated ScalarMult) - var ephPrivKeyFixed [32]byte - copy(ephPrivKeyFixed[:], ephemeralPrivateKey[:]) - - // Perform X25519 key exchange - sharedSecret, err := curve25519.X25519(ephPrivKeyFixed[:], serverPubKey[:]) - if err != nil { - return nil, fmt.Errorf("failed to perform X25519 key exchange: %v", err) - } - - // Create an AEAD cipher using the shared secret - aead, err := chacha20poly1305.New(sharedSecret) - if err != nil { - return nil, fmt.Errorf("failed to create AEAD cipher: %v", err) - } - - // Generate a random nonce - nonce := make([]byte, aead.NonceSize()) - if _, err := rand.Read(nonce); err != nil { - return nil, fmt.Errorf("failed to generate nonce: %v", err) - } - - // Encrypt the payload - ciphertext := aead.Seal(nil, nonce, payload, nil) - - // Prepare the final encrypted message - encryptedMsg := struct { - EphemeralPublicKey string `json:"ephemeralPublicKey"` - Nonce []byte `json:"nonce"` - Ciphertext []byte `json:"ciphertext"` - }{ - EphemeralPublicKey: ephemeralPublicKey.String(), - Nonce: nonce, - Ciphertext: ciphertext, - } - - return encryptedMsg, nil -} - -func (s *WireGuardService) keepSendingUDPHolePunch(host string) { - logger.Info("Starting UDP hole punch routine to %s:21820", host) - - // send initial hole punch - if err := s.sendUDPHolePunch(host + ":21820"); err != nil { - logger.Debug("Failed to send initial UDP hole punch: %v", err) - } - - ticker := time.NewTicker(3 * time.Second) - defer ticker.Stop() - - timeout := time.NewTimer(15 * time.Second) - defer timeout.Stop() - - for { - select { - case <-s.stopHolepunch: - logger.Info("Stopping UDP holepunch") - return - case <-timeout.C: - logger.Info("UDP holepunch routine timed out after 15 seconds") - return - case <-ticker.C: - if err := s.sendUDPHolePunch(host + ":21820"); err != nil { - logger.Debug("Failed to send UDP hole punch: %v", err) - } - } - } -} - -func (s *WireGuardService) removeInterface() error { - // Remove the WireGuard interface - link, err := netlink.LinkByName(s.interfaceName) - if err != nil { - return fmt.Errorf("failed to get interface: %v", err) - } - - err = netlink.LinkDel(link) - if err != nil { - return fmt.Errorf("failed to delete interface: %v", err) - } - - logger.Info("WireGuard interface %s removed successfully", s.interfaceName) - - return nil -} diff --git a/wgnetstack/wgnetstack.go b/wgnetstack/wgnetstack.go deleted file mode 100644 index 664d1f0..0000000 --- a/wgnetstack/wgnetstack.go +++ /dev/null @@ -1,1305 +0,0 @@ -package wgnetstack - -import ( - "context" - "crypto/rand" - "encoding/base64" - "encoding/hex" - "encoding/json" - "fmt" - mathrand "math/rand/v2" - "net" - "net/netip" - "os" - "strconv" - "strings" - "sync" - "time" - - "github.com/fosrl/newt/logger" - "github.com/fosrl/newt/network" - "github.com/fosrl/newt/proxy" - "github.com/fosrl/newt/websocket" - "golang.org/x/crypto/chacha20poly1305" - "golang.org/x/crypto/curve25519" - "golang.zx2c4.com/wireguard/conn" - "golang.zx2c4.com/wireguard/device" - "golang.zx2c4.com/wireguard/tun" - "golang.zx2c4.com/wireguard/tun/netstack" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - - "github.com/fosrl/newt/internal/telemetry" -) - -type WgConfig struct { - IpAddress string `json:"ipAddress"` - Peers []Peer `json:"peers"` - Targets TargetsByType `json:"targets"` -} - -type TargetsByType struct { - UDP []string `json:"udp"` - TCP []string `json:"tcp"` -} - -type TargetData struct { - Targets []string `json:"targets"` -} - -type Peer struct { - PublicKey string `json:"publicKey"` - AllowedIPs []string `json:"allowedIps"` - Endpoint string `json:"endpoint"` -} - -type PeerBandwidth struct { - PublicKey string `json:"publicKey"` - BytesIn float64 `json:"bytesIn"` - BytesOut float64 `json:"bytesOut"` -} - -type PeerReading struct { - BytesReceived int64 - BytesTransmitted int64 - LastChecked time.Time -} - -type WireGuardService struct { - interfaceName string - mtu int - client *websocket.Client - config WgConfig - key wgtypes.Key - keyFilePath string - newtId string - lastReadings map[string]PeerReading - mu sync.Mutex - Port uint16 - stopHolepunch chan struct{} - host string - serverPubKey string - holePunchEndpoint string - token string - stopGetConfig func() - // Netstack fields - tun tun.Device - tnet *netstack.Net - device *device.Device - dns []netip.Addr - // Callback for when netstack is ready - onNetstackReady func(*netstack.Net) - // Callback for when netstack is closed - onNetstackClose func() - othertnet *netstack.Net - // Proxy manager for tunnel - proxyManager *proxy.ProxyManager - TunnelIP string -} - -// GetProxyManager returns the proxy manager for this WireGuardService -func (s *WireGuardService) GetProxyManager() *proxy.ProxyManager { - return s.proxyManager -} - -// AddProxyTarget adds a target to the proxy manager -func (s *WireGuardService) AddProxyTarget(proto, listenIP string, port int, targetAddr string) error { - if s.proxyManager == nil { - return fmt.Errorf("proxy manager not initialized") - } - return s.proxyManager.AddTarget(proto, listenIP, port, targetAddr) -} - -// RemoveProxyTarget removes a target from the proxy manager -func (s *WireGuardService) RemoveProxyTarget(proto, listenIP string, port int) error { - if s.proxyManager == nil { - return fmt.Errorf("proxy manager not initialized") - } - return s.proxyManager.RemoveTarget(proto, listenIP, port) -} - -// Add this type definition -type fixedPortBind struct { - port uint16 - conn.Bind -} - -func (b *fixedPortBind) Open(port uint16) ([]conn.ReceiveFunc, uint16, error) { - // Ignore the requested port and use our fixed port - return b.Bind.Open(b.port) -} - -func NewFixedPortBind(port uint16) conn.Bind { - return &fixedPortBind{ - port: port, - Bind: conn.NewDefaultBind(), - } -} - -// find an available UDP port in the range [minPort, maxPort] and also the next port for the wgtester -func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) { - if maxPort < minPort { - return 0, fmt.Errorf("invalid port range: min=%d, max=%d", minPort, maxPort) - } - - // We need to check port+1 as well, so adjust the max port to avoid going out of range - adjustedMaxPort := maxPort - 1 - if adjustedMaxPort < minPort { - return 0, fmt.Errorf("insufficient port range to find consecutive ports: min=%d, max=%d", minPort, maxPort) - } - - // Create a slice of all ports in the range (excluding the last one) - portRange := make([]uint16, adjustedMaxPort-minPort+1) - for i := range portRange { - portRange[i] = minPort + uint16(i) - } - - // Fisher-Yates shuffle to randomize the port order - for i := len(portRange) - 1; i > 0; i-- { - j := mathrand.IntN(i + 1) - portRange[i], portRange[j] = portRange[j], portRange[i] - } - - // Try each port in the randomized order - for _, port := range portRange { - // Check if port is available - addr1 := &net.UDPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: int(port), - } - conn1, err1 := net.ListenUDP("udp", addr1) - if err1 != nil { - continue // Port is in use or there was an error, try next port - } - - conn1.Close() - return port, nil - } - - return 0, fmt.Errorf("no available consecutive UDP ports found in range %d-%d", minPort, maxPort) -} - -func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo string, host string, newtId string, wsClient *websocket.Client, dns string) (*WireGuardService, error) { - var key wgtypes.Key - var err error - - key, err = wgtypes.GeneratePrivateKey() - if err != nil { - return nil, fmt.Errorf("failed to generate private key: %v", err) - } - - // Load or generate private key - if generateAndSaveKeyTo != "" { - if _, err := os.Stat(generateAndSaveKeyTo); os.IsNotExist(err) { - // File doesn't exist, save the generated key - err = os.WriteFile(generateAndSaveKeyTo, []byte(key.String()), 0600) - if err != nil { - return nil, fmt.Errorf("failed to save private key: %v", err) - } - } else { - // File exists, read the existing key - keyData, err := os.ReadFile(generateAndSaveKeyTo) - if err != nil { - return nil, fmt.Errorf("failed to read private key: %v", err) - } - key, err = wgtypes.ParseKey(strings.TrimSpace(string(keyData))) - if err != nil { - return nil, fmt.Errorf("failed to parse private key: %v", err) - } - } - } - - // Find an available port - port, err := FindAvailableUDPPort(49152, 65535) - if err != nil { - return nil, fmt.Errorf("error finding available port: %v", err) - } - - // Parse DNS addresses - dnsAddrs := []netip.Addr{netip.MustParseAddr(dns)} - - service := &WireGuardService{ - interfaceName: interfaceName, - mtu: mtu, - client: wsClient, - key: key, - keyFilePath: generateAndSaveKeyTo, - newtId: newtId, - host: host, - lastReadings: make(map[string]PeerReading), - stopHolepunch: make(chan struct{}), - Port: port, - dns: dnsAddrs, - proxyManager: proxy.NewProxyManagerWithoutTNet(), - } - - // Register websocket handlers - wsClient.RegisterHandler("newt/wg/receive-config", service.handleConfig) - wsClient.RegisterHandler("newt/wg/peer/add", service.handleAddPeer) - wsClient.RegisterHandler("newt/wg/peer/remove", service.handleRemovePeer) - wsClient.RegisterHandler("newt/wg/peer/update", service.handleUpdatePeer) - wsClient.RegisterHandler("newt/wg/tcp/add", service.addTcpTarget) - wsClient.RegisterHandler("newt/wg/udp/add", service.addUdpTarget) - wsClient.RegisterHandler("newt/wg/udp/remove", service.removeUdpTarget) - wsClient.RegisterHandler("newt/wg/tcp/remove", service.removeTcpTarget) - - return service, nil -} - -// ReportRTT allows reporting native RTTs to telemetry, rate-limited externally. -func (s *WireGuardService) ReportRTT(seconds float64) { - if s.serverPubKey == "" { return } - telemetry.ObserveTunnelLatency(context.Background(), s.serverPubKey, "wireguard", seconds) -} - -func (s *WireGuardService) addTcpTarget(msg websocket.WSMessage) { - logger.Debug("Received: %+v", msg) - - // if there is no wgData or pm, we can't add targets - if s.TunnelIP == "" || s.proxyManager == nil { - logger.Info("No tunnel IP or proxy manager available") - return -} - - targetData, err := parseTargetData(msg.Data) - if err != nil { - logger.Info("Error parsing target data: %v", err) - return - } - - if len(targetData.Targets) > 0 { - s.updateTargets(s.proxyManager, "add", s.TunnelIP, "tcp", targetData) - } -} - -func (s *WireGuardService) addUdpTarget(msg websocket.WSMessage) { - logger.Info("Received: %+v", msg) - - // if there is no wgData or pm, we can't add targets - if s.TunnelIP == "" || s.proxyManager == nil { - logger.Info("No tunnel IP or proxy manager available") - return - } - - targetData, err := parseTargetData(msg.Data) - if err != nil { - logger.Info("Error parsing target data: %v", err) - return - } - - if len(targetData.Targets) > 0 { - s.updateTargets(s.proxyManager, "add", s.TunnelIP, "udp", targetData) - } -} - -func (s *WireGuardService) removeUdpTarget(msg websocket.WSMessage) { - logger.Info("Received: %+v", msg) - - // if there is no wgData or pm, we can't add targets - if s.TunnelIP == "" || s.proxyManager == nil { - logger.Info("No tunnel IP or proxy manager available") - return - } - - targetData, err := parseTargetData(msg.Data) - if err != nil { - logger.Info("Error parsing target data: %v", err) - return - } - - if len(targetData.Targets) > 0 { - s.updateTargets(s.proxyManager, "remove", s.TunnelIP, "udp", targetData) - } -} - -func (s *WireGuardService) removeTcpTarget(msg websocket.WSMessage) { - logger.Info("Received: %+v", msg) - - // if there is no wgData or pm, we can't add targets - if s.TunnelIP == "" || s.proxyManager == nil { - logger.Info("No tunnel IP or proxy manager available") - return - } - - targetData, err := parseTargetData(msg.Data) - if err != nil { - logger.Info("Error parsing target data: %v", err) - return - } - - if len(targetData.Targets) > 0 { - s.updateTargets(s.proxyManager, "remove", s.TunnelIP, "tcp", targetData) - } -} - -func (s *WireGuardService) SetOthertnet(tnet *netstack.Net) { - s.othertnet = tnet -} - -func (s *WireGuardService) Close(rm bool) { - if s.stopGetConfig != nil { - s.stopGetConfig() - s.stopGetConfig = nil - } - - s.mu.Lock() - defer s.mu.Unlock() - - // Close WireGuard device first - this will automatically close the TUN device - if s.device != nil { - s.device.Close() - s.device = nil - } - - // Clear references but don't manually close since device.Close() already did it - if s.tnet != nil { - s.tnet = nil - } - if s.tun != nil { - s.tun = nil // Don't call tun.Close() here since device.Close() already closed it - } -} - -func (s *WireGuardService) StartHolepunch(serverPubKey string, endpoint string) { - // if the device is already created dont start a new holepunch - if s.device != nil { - return - } - - s.serverPubKey = serverPubKey - s.holePunchEndpoint = endpoint - - logger.Debug("Starting UDP hole punch to %s", s.holePunchEndpoint) - - // Create a new stop channel for this holepunch session - s.stopHolepunch = make(chan struct{}) - - // start the UDP holepunch - go s.keepSendingUDPHolePunch(s.holePunchEndpoint) -} - -func (s *WireGuardService) SetToken(token string) { - s.token = token -} - -// GetNetstackNet returns the netstack network interface for use by other components -func (s *WireGuardService) GetNetstackNet() *netstack.Net { - s.mu.Lock() - defer s.mu.Unlock() - return s.tnet -} - -// IsReady returns true if the WireGuard service is ready to use -func (s *WireGuardService) IsReady() bool { - s.mu.Lock() - defer s.mu.Unlock() - return s.device != nil && s.tnet != nil -} - -// GetPublicKey returns the public key of this WireGuard service -func (s *WireGuardService) GetPublicKey() wgtypes.Key { - return s.key.PublicKey() -} - -// SetOnNetstackReady sets a callback function to be called when the netstack interface is ready -func (s *WireGuardService) SetOnNetstackReady(callback func(*netstack.Net)) { - s.onNetstackReady = callback -} - -func (s *WireGuardService) SetOnNetstackClose(callback func()) { - s.onNetstackClose = callback -} - -func (s *WireGuardService) LoadRemoteConfig() error { - if s.stopGetConfig != nil { - s.stopGetConfig() - s.stopGetConfig = nil - } - s.stopGetConfig = s.client.SendMessageInterval("newt/wg/get-config", map[string]interface{}{ - "publicKey": s.key.PublicKey().String(), - "port": s.Port, - }, 2*time.Second) - - logger.Info("Requesting WireGuard configuration from remote server") - go s.periodicBandwidthCheck() - - return nil -} - -func (s *WireGuardService) handleConfig(msg websocket.WSMessage) { - var config WgConfig - - logger.Debug("Received message: %v", msg) - logger.Info("Received WireGuard clients configuration from remote server") - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Info("Error marshaling data: %v", err) - return - } - - if err := json.Unmarshal(jsonData, &config); err != nil { - logger.Info("Error unmarshaling target data: %v", err) - return - } - s.config = config - - if s.stopGetConfig != nil { - s.stopGetConfig() - s.stopGetConfig = nil - } - - // Ensure the WireGuard interface and peers are configured - if err := s.ensureWireguardInterface(config); err != nil { - logger.Error("Failed to ensure WireGuard interface: %v", err) - } - - if err := s.ensureWireguardPeers(config.Peers); err != nil { - logger.Error("Failed to ensure WireGuard peers: %v", err) - } - - // add the targets if there are any - if len(config.Targets.TCP) > 0 { - s.updateTargets(s.proxyManager, "add", s.TunnelIP, "tcp", TargetData{Targets: config.Targets.TCP}) - } - - if len(config.Targets.UDP) > 0 { - s.updateTargets(s.proxyManager, "add", s.TunnelIP, "udp", TargetData{Targets: config.Targets.UDP}) - } - - // Create ProxyManager for this tunnel - s.proxyManager.Start() -} - -func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error { - s.mu.Lock() - - // split off the cidr from the IP address - parts := strings.Split(wgconfig.IpAddress, "/") - if len(parts) != 2 { - s.mu.Unlock() - return fmt.Errorf("invalid IP address format: %s", wgconfig.IpAddress) - } - // Parse the IP address and CIDR mask - tunnelIP := netip.MustParseAddr(parts[0]) - - // stop the holepunch its a channel - if s.stopHolepunch != nil { - close(s.stopHolepunch) - s.stopHolepunch = nil - } - - // Parse the IP address from the config - // tunnelIP := netip.MustParseAddr(wgconfig.IpAddress) - - // Create TUN device and network stack using netstack - var err error - s.tun, s.tnet, err = netstack.CreateNetTUN( - []netip.Addr{tunnelIP}, - s.dns, - s.mtu) - if err != nil { - s.mu.Unlock() - return fmt.Errorf("failed to create TUN device: %v", err) - } - - s.proxyManager.SetTNet(s.tnet) - s.TunnelIP = tunnelIP.String() - - // Create WireGuard device - s.device = device.NewDevice(s.tun, NewFixedPortBind(s.Port), device.NewLogger( - device.LogLevelSilent, // Use silent logging by default - could be made configurable - "wireguard: ", - )) - - // logger.Info("Private key is %s", fixKey(s.key.String())) - - // Configure WireGuard with private key - config := fmt.Sprintf("private_key=%s", fixKey(s.key.String())) - - err = s.device.IpcSet(config) - if err != nil { - s.mu.Unlock() - return fmt.Errorf("failed to configure WireGuard device: %v", err) - } - - // Bring up the device - err = s.device.Up() - if err != nil { - s.mu.Unlock() - return fmt.Errorf("failed to bring up WireGuard device: %v", err) - } - - logger.Info("WireGuard netstack device created and configured") - - // Store callback and tnet reference before releasing mutex - callback := s.onNetstackReady - tnet := s.tnet - - // Release the mutex before calling the callback - s.mu.Unlock() - - // Call the callback if it's set to notify that netstack is ready - if callback != nil { - callback(tnet) - } - - // Note: we already unlocked above, so don't use defer unlock - return nil -} - -func fixKey(key string) string { - // Remove any whitespace - key = strings.TrimSpace(key) - - // Decode from base64 - decoded, err := base64.StdEncoding.DecodeString(key) - if err != nil { - logger.Fatal("Error decoding base64: %v", err) - } - - // Convert to hex - return hex.EncodeToString(decoded) -} - -func (s *WireGuardService) ensureWireguardPeers(peers []Peer) error { - // For netstack, we need to manage peers differently - // We'll configure peers directly on the device using IPC - - // First, clear all existing peers by getting current config and removing them - currentConfig, err := s.device.IpcGet() - if err != nil { - return fmt.Errorf("failed to get current device config: %v", err) - } - - // Parse current peers and remove them - lines := strings.Split(currentConfig, "\n") - var currentPeerKeys []string - for _, line := range lines { - if strings.HasPrefix(line, "public_key=") { - pubKey := strings.TrimPrefix(line, "public_key=") - currentPeerKeys = append(currentPeerKeys, pubKey) - } - } - - // Remove existing peers - for _, pubKey := range currentPeerKeys { - removeConfig := fmt.Sprintf("public_key=%s\nremove=true", pubKey) - if err := s.device.IpcSet(removeConfig); err != nil { - logger.Warn("Failed to remove peer %s: %v", pubKey, err) - } - } - - // Add new peers - for _, peer := range peers { - if err := s.addPeerToDevice(peer); err != nil { - return fmt.Errorf("failed to add peer: %v", err) - } - } - - return nil -} - -func (s *WireGuardService) addPeerToDevice(peer Peer) error { - // parse the key first - pubKey, err := wgtypes.ParseKey(peer.PublicKey) - if err != nil { - return fmt.Errorf("failed to parse public key: %v", err) - } - - // Build IPC configuration string for the peer - config := fmt.Sprintf("public_key=%s", fixKey(pubKey.String())) - - // Add allowed IPs - for _, allowedIP := range peer.AllowedIPs { - config += fmt.Sprintf("\nallowed_ip=%s", allowedIP) - } - - // Add endpoint if specified - if peer.Endpoint != "" { - config += fmt.Sprintf("\nendpoint=%s", peer.Endpoint) - } - - // Add persistent keepalive - config += "\npersistent_keepalive_interval=25" - - // Apply the configuration - if err := s.device.IpcSet(config); err != nil { - return fmt.Errorf("failed to configure peer: %v", err) - } - - logger.Info("Peer %s added successfully", peer.PublicKey) - return nil -} - -func (s *WireGuardService) handleAddPeer(msg websocket.WSMessage) { - logger.Debug("Received message: %v", msg.Data) - var peer Peer - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Info("Error marshaling data: %v", err) - return - } - - if err := json.Unmarshal(jsonData, &peer); err != nil { - logger.Info("Error unmarshaling target data: %v", err) - return - } - - if s.device == nil { - logger.Info("WireGuard device is not initialized") - return - } - - err = s.addPeerToDevice(peer) - if err != nil { - logger.Info("Error adding peer: %v", err) - return - } -} - -func (s *WireGuardService) handleRemovePeer(msg websocket.WSMessage) { - logger.Debug("Received message: %v", msg.Data) - // parse the publicKey from the message which is json { "publicKey": "asdfasdfl;akjsdf" } - type RemoveRequest struct { - PublicKey string `json:"publicKey"` - } - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Info("Error marshaling data: %v", err) - return - } - - var request RemoveRequest - if err := json.Unmarshal(jsonData, &request); err != nil { - logger.Info("Error unmarshaling data: %v", err) - return - } - - if s.device == nil { - logger.Info("WireGuard device is not initialized") - return - } - - if err := s.removePeer(request.PublicKey); err != nil { - logger.Info("Error removing peer: %v", err) - return - } -} - -func (s *WireGuardService) removePeer(publicKey string) error { - - // Parse the public key - pubKey, err := wgtypes.ParseKey(publicKey) - if err != nil { - return fmt.Errorf("failed to parse public key: %v", err) - } - - // Build IPC configuration string to remove the peer - config := fmt.Sprintf("public_key=%s\nremove=true", fixKey(pubKey.String())) - - if err := s.device.IpcSet(config); err != nil { - return fmt.Errorf("failed to remove peer: %v", err) - } - - logger.Info("Peer %s removed successfully", publicKey) - return nil -} - -func (s *WireGuardService) handleUpdatePeer(msg websocket.WSMessage) { - logger.Debug("Received message: %v", msg.Data) - // Define a struct to match the incoming message structure with optional fields - type UpdatePeerRequest struct { - PublicKey string `json:"publicKey"` - AllowedIPs []string `json:"allowedIps,omitempty"` - Endpoint string `json:"endpoint,omitempty"` - } - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Info("Error marshaling data: %v", err) - return - } - - var request UpdatePeerRequest - if err := json.Unmarshal(jsonData, &request); err != nil { - logger.Info("Error unmarshaling peer data: %v", err) - return - } - - // Parse the public key - pubKey, err := wgtypes.ParseKey(request.PublicKey) - if err != nil { - logger.Info("Failed to parse public key: %v", err) - return - } - - if s.device == nil { - logger.Info("WireGuard device is not initialized") - return - } - - // Build IPC configuration string to update the peer - config := fmt.Sprintf("public_key=%s\nupdate_only=true", fixKey(pubKey.String())) - - // Handle AllowedIPs update - if len(request.AllowedIPs) > 0 { - config += "\nreplace_allowed_ips=true" - for _, allowedIP := range request.AllowedIPs { - config += fmt.Sprintf("\nallowed_ip=%s", allowedIP) - } - logger.Info("Updating AllowedIPs for peer %s", request.PublicKey) - } - - // Handle Endpoint field special case - endpointSpecified := false - for key := range msg.Data.(map[string]interface{}) { - if key == "endpoint" { - endpointSpecified = true - break - } - } - - if endpointSpecified { - if request.Endpoint != "" { - config += fmt.Sprintf("\nendpoint=%s", request.Endpoint) - logger.Info("Updating Endpoint for peer %s to %s", request.PublicKey, request.Endpoint) - } else { - config += "\nendpoint=0.0.0.0:0" // Remove endpoint - logger.Info("Removing Endpoint for peer %s", request.PublicKey) - } - } - - // Always set persistent keepalive - config += "\npersistent_keepalive_interval=25" - - // Apply the configuration update - if err := s.device.IpcSet(config); err != nil { - logger.Info("Error updating peer configuration: %v", err) - return - } - - logger.Info("Peer %s updated successfully", request.PublicKey) -} - -func (s *WireGuardService) periodicBandwidthCheck() { - ticker := time.NewTicker(10 * time.Second) - defer ticker.Stop() - - for range ticker.C { - if err := s.reportPeerBandwidth(); err != nil { - logger.Info("Failed to report peer bandwidth: %v", err) - } - } -} - -func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) { - if s.device == nil { - return []PeerBandwidth{}, nil - } - - // Get device statistics using IPC - stats, err := s.device.IpcGet() - if err != nil { - return nil, fmt.Errorf("failed to get device statistics: %v", err) - } - - peerBandwidths := []PeerBandwidth{} - now := time.Now() - - s.mu.Lock() - defer s.mu.Unlock() - - // Parse the IPC response to extract peer statistics - lines := strings.Split(stats, "\n") - var currentPubKey string - var rxBytes, txBytes int64 - - for _, line := range lines { - if strings.HasPrefix(line, "public_key=") { - // Process previous peer if we have one - if currentPubKey != "" { - bandwidth := s.processPeerBandwidth(currentPubKey, rxBytes, txBytes, now) - if bandwidth != nil { - peerBandwidths = append(peerBandwidths, *bandwidth) - } - } - // Start new peer - currentPubKey = strings.TrimPrefix(line, "public_key=") - rxBytes = 0 - txBytes = 0 - } else if strings.HasPrefix(line, "rx_bytes=") { - rxBytes, _ = strconv.ParseInt(strings.TrimPrefix(line, "rx_bytes="), 10, 64) - } else if strings.HasPrefix(line, "tx_bytes=") { - txBytes, _ = strconv.ParseInt(strings.TrimPrefix(line, "tx_bytes="), 10, 64) - } - } - - // Process the last peer - if currentPubKey != "" { - bandwidth := s.processPeerBandwidth(currentPubKey, rxBytes, txBytes, now) - if bandwidth != nil { - peerBandwidths = append(peerBandwidths, *bandwidth) - } - } - - // Clean up old peers - devicePeers := make(map[string]bool) - lines = strings.Split(stats, "\n") - for _, line := range lines { - if strings.HasPrefix(line, "public_key=") { - pubKey := strings.TrimPrefix(line, "public_key=") - devicePeers[pubKey] = true - } - } - - for publicKey := range s.lastReadings { - if !devicePeers[publicKey] { - delete(s.lastReadings, publicKey) - } - } - - // parse the public keys and have them as base64 in the opposite order to fixKey - for i := range peerBandwidths { - pubKeyBytes, err := base64.StdEncoding.DecodeString(peerBandwidths[i].PublicKey) - if err != nil { - logger.Info("Failed to decode public key %s: %v", peerBandwidths[i].PublicKey, err) - continue - } - // Convert to hex - peerBandwidths[i].PublicKey = hex.EncodeToString(pubKeyBytes) - } - - return peerBandwidths, nil -} - -func (s *WireGuardService) processPeerBandwidth(publicKey string, rxBytes, txBytes int64, now time.Time) *PeerBandwidth { - currentReading := PeerReading{ - BytesReceived: rxBytes, - BytesTransmitted: txBytes, - LastChecked: now, - } - - var bytesInDiff, bytesOutDiff float64 - lastReading, exists := s.lastReadings[publicKey] - - if exists { - timeDiff := currentReading.LastChecked.Sub(lastReading.LastChecked).Seconds() - if timeDiff > 0 { - // Calculate bytes transferred since last reading - bytesInDiff = float64(currentReading.BytesReceived - lastReading.BytesReceived) - bytesOutDiff = float64(currentReading.BytesTransmitted - lastReading.BytesTransmitted) - - // Handle counter wraparound (if the counter resets or overflows) - if bytesInDiff < 0 { - bytesInDiff = float64(currentReading.BytesReceived) - } - if bytesOutDiff < 0 { - bytesOutDiff = float64(currentReading.BytesTransmitted) - } - - // Convert to MB - bytesInMB := bytesInDiff / (1024 * 1024) - bytesOutMB := bytesOutDiff / (1024 * 1024) - - // Update the last reading - s.lastReadings[publicKey] = currentReading - - return &PeerBandwidth{ - PublicKey: publicKey, - BytesIn: bytesInMB, - BytesOut: bytesOutMB, - } - } - } - - // For first reading or if readings are too close together, report 0 - s.lastReadings[publicKey] = currentReading - return &PeerBandwidth{ - PublicKey: publicKey, - BytesIn: 0, - BytesOut: 0, - } -} - -func (s *WireGuardService) reportPeerBandwidth() error { - bandwidths, err := s.calculatePeerBandwidth() - if err != nil { - return fmt.Errorf("failed to calculate peer bandwidth: %v", err) - } - - err = s.client.SendMessage("newt/receive-bandwidth", map[string]interface{}{ - "bandwidthData": bandwidths, - }) - if err != nil { - return fmt.Errorf("failed to send bandwidth data: %v", err) - } - - return nil -} - -func (s *WireGuardService) sendUDPHolePunch(serverAddr string) error { - - if s.serverPubKey == "" || s.token == "" { - logger.Debug("Server public key or token not set, skipping UDP hole punch") - return nil - } - - // Parse server address - serverSplit := strings.Split(serverAddr, ":") - if len(serverSplit) < 2 { - return fmt.Errorf("invalid server address format, expected hostname:port") - } - - serverHostname := serverSplit[0] - serverPort, err := strconv.ParseUint(serverSplit[1], 10, 16) - if err != nil { - return fmt.Errorf("failed to parse server port: %v", err) - } - - // Resolve server hostname to IP - serverIPAddr := network.HostToAddr(serverHostname) - if serverIPAddr == nil { - return fmt.Errorf("failed to resolve server hostname") - } - - // Create local UDP address using the same port as WireGuard - localAddr := &net.UDPAddr{ - IP: net.IPv4zero, - Port: int(s.Port), - } - - // Create remote server address - remoteAddr := &net.UDPAddr{ - IP: serverIPAddr.IP, - Port: int(serverPort), - } - - // Create UDP connection bound to the same port as WireGuard - conn, err := net.DialUDP("udp", localAddr, remoteAddr) - if err != nil { - return fmt.Errorf("failed to create netstack UDP connection: %v", err) - } - defer conn.Close() - - // Create JSON payload - payload := struct { - NewtID string `json:"newtId"` - Token string `json:"token"` - }{ - NewtID: s.newtId, - Token: s.token, - } - - // Convert payload to JSON - payloadBytes, err := json.Marshal(payload) - if err != nil { - return fmt.Errorf("failed to marshal payload: %v", err) - } - - // Encrypt the payload using the server's WireGuard public key - encryptedPayload, err := s.encryptPayload(payloadBytes) - if err != nil { - return fmt.Errorf("failed to encrypt payload: %v", err) - } - - // Convert encrypted payload to JSON - jsonData, err := json.Marshal(encryptedPayload) - if err != nil { - return fmt.Errorf("failed to marshal encrypted payload: %v", err) - } - - // Send the encrypted packet using the netstack UDP connection - _, err = conn.Write(jsonData) - if err != nil { - return fmt.Errorf("failed to send UDP packet: %v", err) - } - - logger.Debug("Sent UDP hole punch to %s via netstack", remoteAddr.String()) - - return nil -} - -func (s *WireGuardService) encryptPayload(payload []byte) (interface{}, error) { - // Generate an ephemeral keypair for this message - ephemeralPrivateKey, err := wgtypes.GeneratePrivateKey() - if err != nil { - return nil, fmt.Errorf("failed to generate ephemeral private key: %v", err) - } - ephemeralPublicKey := ephemeralPrivateKey.PublicKey() - - // Parse the server's public key - serverPubKey, err := wgtypes.ParseKey(s.serverPubKey) - if err != nil { - return nil, fmt.Errorf("failed to parse server public key: %v", err) - } - - // Use X25519 for key exchange (replacing deprecated ScalarMult) - var ephPrivKeyFixed [32]byte - copy(ephPrivKeyFixed[:], ephemeralPrivateKey[:]) - - // Perform X25519 key exchange - sharedSecret, err := curve25519.X25519(ephPrivKeyFixed[:], serverPubKey[:]) - if err != nil { - return nil, fmt.Errorf("failed to perform X25519 key exchange: %v", err) - } - - // Create an AEAD cipher using the shared secret - aead, err := chacha20poly1305.New(sharedSecret) - if err != nil { - return nil, fmt.Errorf("failed to create AEAD cipher: %v", err) - } - - // Generate a random nonce - nonce := make([]byte, aead.NonceSize()) - if _, err := rand.Read(nonce); err != nil { - return nil, fmt.Errorf("failed to generate nonce: %v", err) - } - - // Encrypt the payload - ciphertext := aead.Seal(nil, nonce, payload, nil) - - // Prepare the final encrypted message - encryptedMsg := struct { - EphemeralPublicKey string `json:"ephemeralPublicKey"` - Nonce []byte `json:"nonce"` - Ciphertext []byte `json:"ciphertext"` - }{ - EphemeralPublicKey: ephemeralPublicKey.String(), - Nonce: nonce, - Ciphertext: ciphertext, - } - - return encryptedMsg, nil -} - -func (s *WireGuardService) keepSendingUDPHolePunch(host string) { - logger.Info("Starting UDP hole punch routine to %s:21820", host) - - // send initial hole punch - if err := s.sendUDPHolePunch(host + ":21820"); err != nil { - logger.Debug("Failed to send initial UDP hole punch: %v", err) - } - - ticker := time.NewTicker(3 * time.Second) - defer ticker.Stop() - - timeout := time.NewTimer(15 * time.Second) - defer timeout.Stop() - - for { - select { - case <-s.stopHolepunch: - logger.Info("Stopping UDP holepunch") - return - case <-timeout.C: - logger.Info("UDP holepunch routine timed out after 15 seconds") - return - case <-ticker.C: - if err := s.sendUDPHolePunch(host + ":21820"); err != nil { - logger.Debug("Failed to send UDP hole punch: %v", err) - } - } - } -} - -func (s *WireGuardService) updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error { - var replace = false - for _, t := range targetData.Targets { - // Split the first number off of the target with : separator and use as the port - parts := strings.Split(t, ":") - if len(parts) != 3 { - logger.Info("Invalid target format: %s", t) - continue - } - - // Get the port as an int - port := 0 - _, err := fmt.Sscanf(parts[0], "%d", &port) - if err != nil { - logger.Info("Invalid port: %s", parts[0]) - continue - } - - if action == "add" { - target := parts[1] + ":" + parts[2] - - // Call updown script if provided - processedTarget := target - - // Only remove the specific target if it exists - err := pm.RemoveTarget(proto, tunnelIP, port) - if err != nil { - // Ignore "target not found" errors as this is expected for new targets - if !strings.Contains(err.Error(), "target not found") { - logger.Error("Failed to remove existing target: %v", err) - } - } else { - replace = true // We successfully removed an existing target - } - - // Add the new target - pm.AddTarget(proto, tunnelIP, port, processedTarget) - - } else if action == "remove" { - logger.Info("Removing target with port %d", port) - - err := pm.RemoveTarget(proto, tunnelIP, port) - if err != nil { - logger.Error("Failed to remove target: %v", err) - return err - } - } - } - - if replace { - // If we replaced any targets, we need to hot swap the netstack - if err := s.ReplaceNetstack(); err != nil { - logger.Error("Failed to replace netstack after updating targets: %v", err) - return err - } - logger.Info("Netstack replaced successfully after updating targets") - } else { - logger.Info("No targets updated, no netstack replacement needed") - } - - return nil -} - -func parseTargetData(data interface{}) (TargetData, error) { - var targetData TargetData - jsonData, err := json.Marshal(data) - if err != nil { - logger.Info("Error marshaling data: %v", err) - return targetData, err - } - - if err := json.Unmarshal(jsonData, &targetData); err != nil { - logger.Info("Error unmarshaling target data: %v", err) - return targetData, err - } - return targetData, nil -} - -// Add this method to WireGuardService -func (s *WireGuardService) ReplaceNetstack() error { - s.mu.Lock() - defer s.mu.Unlock() - - if s.device == nil || s.tun == nil { - return fmt.Errorf("WireGuard device not initialized") - } - - // Parse the current tunnel IP from the existing config - parts := strings.Split(s.config.IpAddress, "/") - if len(parts) != 2 { - return fmt.Errorf("invalid IP address format: %s", s.config.IpAddress) - } - tunnelIP := netip.MustParseAddr(parts[0]) - - // Stop the proxy manager temporarily - s.proxyManager.Stop() - - // Create new TUN device and netstack with new DNS - newTun, newTnet, err := netstack.CreateNetTUN( - []netip.Addr{tunnelIP}, - s.dns, - s.mtu) - if err != nil { - // Restart proxy manager with old tnet on failure - s.proxyManager.Start() - return fmt.Errorf("failed to create new TUN device: %v", err) - } - - // Get current device config before closing - currentConfig, err := s.device.IpcGet() - if err != nil { - newTun.Close() - s.proxyManager.Start() - return fmt.Errorf("failed to get current device config: %v", err) - } - - // Filter out read-only fields from the config - filteredConfig := s.filterReadOnlyFields(currentConfig) - - // if onNetstackClose callback is set, call it - if s.onNetstackClose != nil { - s.onNetstackClose() - } - - // Close old device (this closes the old TUN device) - s.device.Close() - - // Update references - s.tun = newTun - s.tnet = newTnet - - // Create new WireGuard device with same port - s.device = device.NewDevice(s.tun, NewFixedPortBind(s.Port), device.NewLogger( - device.LogLevelSilent, - "wireguard: ", - )) - - // Restore the configuration (without read-only fields) - err = s.device.IpcSet(filteredConfig) - if err != nil { - return fmt.Errorf("failed to restore WireGuard configuration: %v", err) - } - - // Bring up the device - err = s.device.Up() - if err != nil { - return fmt.Errorf("failed to bring up new WireGuard device: %v", err) - } - - // Update proxy manager with new tnet and restart - s.proxyManager.SetTNet(s.tnet) - s.proxyManager.Start() - - s.proxyManager.PrintTargets() - - // Call the netstack ready callback if set - if s.onNetstackReady != nil { - go s.onNetstackReady(s.tnet) - } - - return nil -} - -// filterReadOnlyFields removes read-only fields from WireGuard IPC configuration -func (s *WireGuardService) filterReadOnlyFields(config string) string { - lines := strings.Split(config, "\n") - var filteredLines []string - - // List of read-only fields that should not be included in IpcSet - readOnlyFields := map[string]bool{ - "last_handshake_time_sec": true, - "last_handshake_time_nsec": true, - "rx_bytes": true, - "tx_bytes": true, - "protocol_version": true, - } - - for _, line := range lines { - if line == "" { - continue - } - - // Check if this line contains a read-only field - isReadOnly := false - for field := range readOnlyFields { - if strings.HasPrefix(line, field+"=") { - isReadOnly = true - break - } - } - - // Only include non-read-only lines - if !isReadOnly { - filteredLines = append(filteredLines, line) - } - } - - return strings.Join(filteredLines, "\n") -} diff --git a/wgtester/wgtester.go b/wgtester/wgtester.go index 26988f6..c76db64 100644 --- a/wgtester/wgtester.go +++ b/wgtester/wgtester.go @@ -3,12 +3,13 @@ package wgtester import ( "encoding/binary" "fmt" + "io" "net" "sync" "time" "github.com/fosrl/newt/logger" - "golang.zx2c4.com/wireguard/tun/netstack" + "github.com/fosrl/newt/netstack2" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" ) @@ -39,7 +40,7 @@ type Server struct { newtID string outputPrefix string useNetstack bool - tnet interface{} // Will be *netstack.Net when using netstack + tnet interface{} // Will be *netstack2.Net when using netstack } // NewServer creates a new connection test server using UDP @@ -56,7 +57,7 @@ func NewServer(serverAddr string, serverPort uint16, newtID string) *Server { } // NewServerWithNetstack creates a new connection test server using WireGuard netstack -func NewServerWithNetstack(serverAddr string, serverPort uint16, newtID string, tnet *netstack.Net) *Server { +func NewServerWithNetstack(serverAddr string, serverPort uint16, newtID string, tnet *netstack2.Net) *Server { return &Server{ serverAddr: serverAddr, serverPort: serverPort + 1, // use the next port for the server @@ -82,7 +83,7 @@ func (s *Server) Start() error { if s.useNetstack && s.tnet != nil { // Use WireGuard netstack - tnet := s.tnet.(*netstack.Net) + tnet := s.tnet.(*netstack2.Net) udpAddr := &net.UDPAddr{Port: int(s.serverPort)} netstackConn, err := tnet.ListenUDP(udpAddr) if err != nil { @@ -130,7 +131,7 @@ func (s *Server) Stop() { } // RestartWithNetstack stops the current server and restarts it with netstack -func (s *Server) RestartWithNetstack(tnet *netstack.Net) error { +func (s *Server) RestartWithNetstack(tnet *netstack2.Net) error { s.Stop() // Update configuration to use netstack @@ -187,6 +188,10 @@ func (s *Server) handleConnections() { case <-s.shutdownCh: return // Don't log error if we're shutting down default: + // Don't log EOF errors during shutdown - these are expected when connection is closed + if err == io.EOF { + return + } logger.Error("%sError reading from UDP: %v", s.outputPrefix, err) } continue @@ -219,7 +224,7 @@ func (s *Server) handleConnections() { copy(responsePacket[5:13], buffer[5:13]) // Log response being sent for debugging - logger.Debug("%sSending response to %s", s.outputPrefix, addr.String()) + // logger.Debug("%sSending response to %s", s.outputPrefix, addr.String()) // Send the response packet - handle both regular UDP and netstack UDP if s.useNetstack { @@ -235,7 +240,7 @@ func (s *Server) handleConnections() { if err != nil { logger.Error("%sError sending response: %v", s.outputPrefix, err) } else { - logger.Debug("%sResponse sent successfully", s.outputPrefix) + // logger.Debug("%sResponse sent successfully", s.outputPrefix) } } }