From febe13a4f8afa317d2cdb7d12af11c6adcf88774 Mon Sep 17 00:00:00 2001 From: Owen Date: Sat, 15 Nov 2025 16:32:44 -0500 Subject: [PATCH] Centralize some functions --- bind/shared_bind.go | 378 ---------------------------------- bind/shared_bind_test.go | 424 --------------------------------------- go.mod | 12 +- go.sum | 18 +- holepunch/holepunch.go | 351 -------------------------------- olm/common.go | 106 +--------- olm/olm.go | 15 +- 7 files changed, 26 insertions(+), 1278 deletions(-) delete mode 100644 bind/shared_bind.go delete mode 100644 bind/shared_bind_test.go delete mode 100644 holepunch/holepunch.go diff --git a/bind/shared_bind.go b/bind/shared_bind.go deleted file mode 100644 index bff66bf..0000000 --- a/bind/shared_bind.go +++ /dev/null @@ -1,378 +0,0 @@ -//go:build !js - -package bind - -import ( - "fmt" - "net" - "net/netip" - "runtime" - "sync" - "sync/atomic" - - "golang.org/x/net/ipv4" - "golang.org/x/net/ipv6" - wgConn "golang.zx2c4.com/wireguard/conn" -) - -// Endpoint represents a network endpoint for the SharedBind -type Endpoint struct { - AddrPort netip.AddrPort -} - -// ClearSrc implements the wgConn.Endpoint interface -func (e *Endpoint) ClearSrc() {} - -// DstIP implements the wgConn.Endpoint interface -func (e *Endpoint) DstIP() netip.Addr { - return e.AddrPort.Addr() -} - -// SrcIP implements the wgConn.Endpoint interface -func (e *Endpoint) SrcIP() netip.Addr { - return netip.Addr{} -} - -// DstToBytes implements the wgConn.Endpoint interface -func (e *Endpoint) DstToBytes() []byte { - b, _ := e.AddrPort.MarshalBinary() - return b -} - -// DstToString implements the wgConn.Endpoint interface -func (e *Endpoint) DstToString() string { - return e.AddrPort.String() -} - -// SrcToString implements the wgConn.Endpoint interface -func (e *Endpoint) SrcToString() string { - return "" -} - -// SharedBind is a thread-safe UDP bind that can be shared between WireGuard -// and hole punch senders. It wraps a single UDP connection and implements -// reference counting to prevent premature closure. -type SharedBind struct { - mu sync.RWMutex - - // The underlying UDP connection - udpConn *net.UDPConn - - // IPv4 and IPv6 packet connections for advanced features - ipv4PC *ipv4.PacketConn - ipv6PC *ipv6.PacketConn - - // Reference counting to prevent closing while in use - refCount atomic.Int32 - closed atomic.Bool - - // Channels for receiving data - recvFuncs []wgConn.ReceiveFunc - - // Port binding information - port uint16 -} - -// New creates a new SharedBind from an existing UDP connection. -// The SharedBind takes ownership of the connection and will close it -// when all references are released. -func New(udpConn *net.UDPConn) (*SharedBind, error) { - if udpConn == nil { - return nil, fmt.Errorf("udpConn cannot be nil") - } - - bind := &SharedBind{ - udpConn: udpConn, - } - - // Initialize reference count to 1 (the creator holds the first reference) - bind.refCount.Store(1) - - // Get the local port - if addr, ok := udpConn.LocalAddr().(*net.UDPAddr); ok { - bind.port = uint16(addr.Port) - } - - return bind, nil -} - -// AddRef increments the reference count. Call this when sharing -// the bind with another component. -func (b *SharedBind) AddRef() { - newCount := b.refCount.Add(1) - // Optional: Add logging for debugging - _ = newCount // Placeholder for potential logging -} - -// Release decrements the reference count. When it reaches zero, -// the underlying UDP connection is closed. -func (b *SharedBind) Release() error { - newCount := b.refCount.Add(-1) - // Optional: Add logging for debugging - _ = newCount // Placeholder for potential logging - - if newCount < 0 { - // This should never happen with proper usage - b.refCount.Store(0) - return fmt.Errorf("SharedBind reference count went negative") - } - - if newCount == 0 { - return b.closeConnection() - } - - return nil -} - -// closeConnection actually closes the UDP connection -func (b *SharedBind) closeConnection() error { - if !b.closed.CompareAndSwap(false, true) { - // Already closed - return nil - } - - b.mu.Lock() - defer b.mu.Unlock() - - var err error - if b.udpConn != nil { - err = b.udpConn.Close() - b.udpConn = nil - } - - b.ipv4PC = nil - b.ipv6PC = nil - - return err -} - -// GetUDPConn returns the underlying UDP connection. -// The caller must not close this connection directly. -func (b *SharedBind) GetUDPConn() *net.UDPConn { - b.mu.RLock() - defer b.mu.RUnlock() - return b.udpConn -} - -// GetRefCount returns the current reference count (for debugging) -func (b *SharedBind) GetRefCount() int32 { - return b.refCount.Load() -} - -// IsClosed returns whether the bind is closed -func (b *SharedBind) IsClosed() bool { - return b.closed.Load() -} - -// WriteToUDP writes data to a specific UDP address. -// This is thread-safe and can be used by hole punch senders. -func (b *SharedBind) WriteToUDP(data []byte, addr *net.UDPAddr) (int, error) { - if b.closed.Load() { - return 0, net.ErrClosed - } - - b.mu.RLock() - conn := b.udpConn - b.mu.RUnlock() - - if conn == nil { - return 0, net.ErrClosed - } - - return conn.WriteToUDP(data, addr) -} - -// Close implements the WireGuard Bind interface. -// It decrements the reference count and closes the connection if no references remain. -func (b *SharedBind) Close() error { - return b.Release() -} - -// Open implements the WireGuard Bind interface. -// Since the connection is already open, this just sets up the receive functions. -func (b *SharedBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) { - if b.closed.Load() { - return nil, 0, net.ErrClosed - } - - b.mu.Lock() - defer b.mu.Unlock() - - if b.udpConn == nil { - return nil, 0, net.ErrClosed - } - - // Set up IPv4 and IPv6 packet connections for advanced features - if runtime.GOOS == "linux" || runtime.GOOS == "android" { - b.ipv4PC = ipv4.NewPacketConn(b.udpConn) - b.ipv6PC = ipv6.NewPacketConn(b.udpConn) - } - - // Create receive functions - recvFuncs := make([]wgConn.ReceiveFunc, 0, 2) - - // Add IPv4 receive function - if b.ipv4PC != nil || runtime.GOOS != "linux" { - recvFuncs = append(recvFuncs, b.makeReceiveIPv4()) - } - - // Add IPv6 receive function if needed - // For now, we focus on IPv4 for hole punching use case - - b.recvFuncs = recvFuncs - return recvFuncs, b.port, nil -} - -// makeReceiveIPv4 creates a receive function for IPv4 packets -func (b *SharedBind) makeReceiveIPv4() wgConn.ReceiveFunc { - return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) { - if b.closed.Load() { - return 0, net.ErrClosed - } - - b.mu.RLock() - conn := b.udpConn - pc := b.ipv4PC - b.mu.RUnlock() - - if conn == nil { - return 0, net.ErrClosed - } - - // Use batch reading on Linux for performance - if pc != nil && (runtime.GOOS == "linux" || runtime.GOOS == "android") { - return b.receiveIPv4Batch(pc, bufs, sizes, eps) - } - - // Fallback to simple read for other platforms - return b.receiveIPv4Simple(conn, bufs, sizes, eps) - } -} - -// receiveIPv4Batch uses batch reading for better performance on Linux -func (b *SharedBind) receiveIPv4Batch(pc *ipv4.PacketConn, bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) { - // Create messages for batch reading - msgs := make([]ipv4.Message, len(bufs)) - for i := range bufs { - msgs[i].Buffers = [][]byte{bufs[i]} - msgs[i].OOB = make([]byte, 0) // No OOB data needed for basic use - } - - numMsgs, err := pc.ReadBatch(msgs, 0) - if err != nil { - return 0, err - } - - for i := 0; i < numMsgs; i++ { - sizes[i] = msgs[i].N - if sizes[i] == 0 { - continue - } - - if msgs[i].Addr != nil { - if udpAddr, ok := msgs[i].Addr.(*net.UDPAddr); ok { - addrPort := udpAddr.AddrPort() - eps[i] = &wgConn.StdNetEndpoint{AddrPort: addrPort} - } - } - } - - return numMsgs, nil -} - -// receiveIPv4Simple uses simple ReadFromUDP for non-Linux platforms -func (b *SharedBind) receiveIPv4Simple(conn *net.UDPConn, bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) { - n, addr, err := conn.ReadFromUDP(bufs[0]) - if err != nil { - return 0, err - } - - sizes[0] = n - if addr != nil { - addrPort := addr.AddrPort() - eps[0] = &wgConn.StdNetEndpoint{AddrPort: addrPort} - } - - return 1, nil -} - -// Send implements the WireGuard Bind interface. -// It sends packets to the specified endpoint. -func (b *SharedBind) Send(bufs [][]byte, ep wgConn.Endpoint) error { - if b.closed.Load() { - return net.ErrClosed - } - - b.mu.RLock() - conn := b.udpConn - b.mu.RUnlock() - - if conn == nil { - return net.ErrClosed - } - - // Extract the destination address from the endpoint - var destAddr *net.UDPAddr - - // Try to cast to StdNetEndpoint first - if stdEp, ok := ep.(*wgConn.StdNetEndpoint); ok { - destAddr = net.UDPAddrFromAddrPort(stdEp.AddrPort) - } else { - // Fallback: construct from DstIP and DstToBytes - dstBytes := ep.DstToBytes() - if len(dstBytes) >= 6 { // Minimum for IPv4 (4 bytes) + port (2 bytes) - var addr netip.Addr - var port uint16 - - if len(dstBytes) >= 18 { // IPv6 (16 bytes) + port (2 bytes) - addr, _ = netip.AddrFromSlice(dstBytes[:16]) - port = uint16(dstBytes[16]) | uint16(dstBytes[17])<<8 - } else { // IPv4 - addr, _ = netip.AddrFromSlice(dstBytes[:4]) - port = uint16(dstBytes[4]) | uint16(dstBytes[5])<<8 - } - - if addr.IsValid() { - destAddr = net.UDPAddrFromAddrPort(netip.AddrPortFrom(addr, port)) - } - } - } - - if destAddr == nil { - return fmt.Errorf("could not extract destination address from endpoint") - } - - // Send all buffers to the destination - for _, buf := range bufs { - _, err := conn.WriteToUDP(buf, destAddr) - if err != nil { - return err - } - } - - return nil -} - -// SetMark implements the WireGuard Bind interface. -// It's a no-op for this implementation. -func (b *SharedBind) SetMark(mark uint32) error { - // Not implemented for this use case - return nil -} - -// BatchSize returns the preferred batch size for sending packets. -func (b *SharedBind) BatchSize() int { - if runtime.GOOS == "linux" || runtime.GOOS == "android" { - return wgConn.IdealBatchSize - } - return 1 -} - -// ParseEndpoint creates a new endpoint from a string address. -func (b *SharedBind) ParseEndpoint(s string) (wgConn.Endpoint, error) { - addrPort, err := netip.ParseAddrPort(s) - if err != nil { - return nil, err - } - return &wgConn.StdNetEndpoint{AddrPort: addrPort}, nil -} diff --git a/bind/shared_bind_test.go b/bind/shared_bind_test.go deleted file mode 100644 index 6e1ec66..0000000 --- a/bind/shared_bind_test.go +++ /dev/null @@ -1,424 +0,0 @@ -//go:build !js - -package bind - -import ( - "net" - "net/netip" - "sync" - "testing" - "time" - - wgConn "golang.zx2c4.com/wireguard/conn" -) - -// TestSharedBindCreation tests basic creation and initialization -func TestSharedBindCreation(t *testing.T) { - // Create a UDP connection - udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) - if err != nil { - t.Fatalf("Failed to create UDP connection: %v", err) - } - defer udpConn.Close() - - // Create SharedBind - bind, err := New(udpConn) - if err != nil { - t.Fatalf("Failed to create SharedBind: %v", err) - } - - if bind == nil { - t.Fatal("SharedBind is nil") - } - - // Verify initial reference count - if bind.refCount.Load() != 1 { - t.Errorf("Expected initial refCount to be 1, got %d", bind.refCount.Load()) - } - - // Clean up - if err := bind.Close(); err != nil { - t.Errorf("Failed to close SharedBind: %v", err) - } -} - -// TestSharedBindReferenceCount tests reference counting -func TestSharedBindReferenceCount(t *testing.T) { - udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) - if err != nil { - t.Fatalf("Failed to create UDP connection: %v", err) - } - - bind, err := New(udpConn) - if err != nil { - t.Fatalf("Failed to create SharedBind: %v", err) - } - - // Add references - bind.AddRef() - if bind.refCount.Load() != 2 { - t.Errorf("Expected refCount to be 2, got %d", bind.refCount.Load()) - } - - bind.AddRef() - if bind.refCount.Load() != 3 { - t.Errorf("Expected refCount to be 3, got %d", bind.refCount.Load()) - } - - // Release references - bind.Release() - if bind.refCount.Load() != 2 { - t.Errorf("Expected refCount to be 2 after release, got %d", bind.refCount.Load()) - } - - bind.Release() - bind.Release() // This should close the connection - - if !bind.closed.Load() { - t.Error("Expected bind to be closed after all references released") - } -} - -// TestSharedBindWriteToUDP tests the WriteToUDP functionality -func TestSharedBindWriteToUDP(t *testing.T) { - // Create sender - senderConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) - if err != nil { - t.Fatalf("Failed to create sender UDP connection: %v", err) - } - - senderBind, err := New(senderConn) - if err != nil { - t.Fatalf("Failed to create sender SharedBind: %v", err) - } - defer senderBind.Close() - - // Create receiver - receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) - if err != nil { - t.Fatalf("Failed to create receiver UDP connection: %v", err) - } - defer receiverConn.Close() - - receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr) - - // Send data - testData := []byte("Hello, SharedBind!") - n, err := senderBind.WriteToUDP(testData, receiverAddr) - if err != nil { - t.Fatalf("WriteToUDP failed: %v", err) - } - - if n != len(testData) { - t.Errorf("Expected to send %d bytes, sent %d", len(testData), n) - } - - // Receive data - buf := make([]byte, 1024) - receiverConn.SetReadDeadline(time.Now().Add(2 * time.Second)) - n, _, err = receiverConn.ReadFromUDP(buf) - if err != nil { - t.Fatalf("Failed to receive data: %v", err) - } - - if string(buf[:n]) != string(testData) { - t.Errorf("Expected to receive %q, got %q", testData, buf[:n]) - } -} - -// TestSharedBindConcurrentWrites tests thread-safety -func TestSharedBindConcurrentWrites(t *testing.T) { - // Create sender - senderConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) - if err != nil { - t.Fatalf("Failed to create sender UDP connection: %v", err) - } - - senderBind, err := New(senderConn) - if err != nil { - t.Fatalf("Failed to create sender SharedBind: %v", err) - } - defer senderBind.Close() - - // Create receiver - receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) - if err != nil { - t.Fatalf("Failed to create receiver UDP connection: %v", err) - } - defer receiverConn.Close() - - receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr) - - // Launch concurrent writes - numGoroutines := 100 - var wg sync.WaitGroup - wg.Add(numGoroutines) - - for i := 0; i < numGoroutines; i++ { - go func(id int) { - defer wg.Done() - data := []byte{byte(id)} - _, err := senderBind.WriteToUDP(data, receiverAddr) - if err != nil { - t.Errorf("WriteToUDP failed in goroutine %d: %v", id, err) - } - }(i) - } - - wg.Wait() -} - -// TestSharedBindWireGuardInterface tests WireGuard Bind interface implementation -func TestSharedBindWireGuardInterface(t *testing.T) { - udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) - if err != nil { - t.Fatalf("Failed to create UDP connection: %v", err) - } - - bind, err := New(udpConn) - if err != nil { - t.Fatalf("Failed to create SharedBind: %v", err) - } - defer bind.Close() - - // Test Open - recvFuncs, port, err := bind.Open(0) - if err != nil { - t.Fatalf("Open failed: %v", err) - } - - if len(recvFuncs) == 0 { - t.Error("Expected at least one receive function") - } - - if port == 0 { - t.Error("Expected non-zero port") - } - - // Test SetMark (should be a no-op) - if err := bind.SetMark(0); err != nil { - t.Errorf("SetMark failed: %v", err) - } - - // Test BatchSize - batchSize := bind.BatchSize() - if batchSize <= 0 { - t.Error("Expected positive batch size") - } -} - -// TestSharedBindSend tests the Send method with WireGuard endpoints -func TestSharedBindSend(t *testing.T) { - // Create sender - senderConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) - if err != nil { - t.Fatalf("Failed to create sender UDP connection: %v", err) - } - - senderBind, err := New(senderConn) - if err != nil { - t.Fatalf("Failed to create sender SharedBind: %v", err) - } - defer senderBind.Close() - - // Create receiver - receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) - if err != nil { - t.Fatalf("Failed to create receiver UDP connection: %v", err) - } - defer receiverConn.Close() - - receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr) - - // Create an endpoint - addrPort := receiverAddr.AddrPort() - endpoint := &wgConn.StdNetEndpoint{AddrPort: addrPort} - - // Send data - testData := []byte("WireGuard packet") - bufs := [][]byte{testData} - err = senderBind.Send(bufs, endpoint) - if err != nil { - t.Fatalf("Send failed: %v", err) - } - - // Receive data - buf := make([]byte, 1024) - receiverConn.SetReadDeadline(time.Now().Add(2 * time.Second)) - n, _, err := receiverConn.ReadFromUDP(buf) - if err != nil { - t.Fatalf("Failed to receive data: %v", err) - } - - if string(buf[:n]) != string(testData) { - t.Errorf("Expected to receive %q, got %q", testData, buf[:n]) - } -} - -// TestSharedBindMultipleUsers simulates WireGuard and hole punch using the same bind -func TestSharedBindMultipleUsers(t *testing.T) { - // Create shared bind - udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) - if err != nil { - t.Fatalf("Failed to create UDP connection: %v", err) - } - - sharedBind, err := New(udpConn) - if err != nil { - t.Fatalf("Failed to create SharedBind: %v", err) - } - - // Add reference for hole punch sender - sharedBind.AddRef() - - // Create receiver - receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) - if err != nil { - t.Fatalf("Failed to create receiver UDP connection: %v", err) - } - defer receiverConn.Close() - - receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr) - - var wg sync.WaitGroup - - // Simulate WireGuard using the bind - wg.Add(1) - go func() { - defer wg.Done() - addrPort := receiverAddr.AddrPort() - endpoint := &wgConn.StdNetEndpoint{AddrPort: addrPort} - - for i := 0; i < 10; i++ { - data := []byte("WireGuard packet") - bufs := [][]byte{data} - if err := sharedBind.Send(bufs, endpoint); err != nil { - t.Errorf("WireGuard Send failed: %v", err) - } - time.Sleep(10 * time.Millisecond) - } - }() - - // Simulate hole punch sender using the bind - wg.Add(1) - go func() { - defer wg.Done() - for i := 0; i < 10; i++ { - data := []byte("Hole punch packet") - if _, err := sharedBind.WriteToUDP(data, receiverAddr); err != nil { - t.Errorf("Hole punch WriteToUDP failed: %v", err) - } - time.Sleep(10 * time.Millisecond) - } - }() - - wg.Wait() - - // Release the hole punch reference - sharedBind.Release() - - // Close WireGuard's reference (should close the connection) - sharedBind.Close() - - if !sharedBind.closed.Load() { - t.Error("Expected bind to be closed after all users released it") - } -} - -// TestEndpoint tests the Endpoint implementation -func TestEndpoint(t *testing.T) { - addr := netip.MustParseAddr("192.168.1.1") - addrPort := netip.AddrPortFrom(addr, 51820) - - ep := &Endpoint{AddrPort: addrPort} - - // Test DstIP - if ep.DstIP() != addr { - t.Errorf("Expected DstIP to be %v, got %v", addr, ep.DstIP()) - } - - // Test DstToString - expected := "192.168.1.1:51820" - if ep.DstToString() != expected { - t.Errorf("Expected DstToString to be %q, got %q", expected, ep.DstToString()) - } - - // Test DstToBytes - bytes := ep.DstToBytes() - if len(bytes) == 0 { - t.Error("Expected DstToBytes to return non-empty slice") - } - - // Test SrcIP (should be zero) - if ep.SrcIP().IsValid() { - t.Error("Expected SrcIP to be invalid") - } - - // Test ClearSrc (should not panic) - ep.ClearSrc() -} - -// TestParseEndpoint tests the ParseEndpoint method -func TestParseEndpoint(t *testing.T) { - udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) - if err != nil { - t.Fatalf("Failed to create UDP connection: %v", err) - } - - bind, err := New(udpConn) - if err != nil { - t.Fatalf("Failed to create SharedBind: %v", err) - } - defer bind.Close() - - tests := []struct { - name string - input string - wantErr bool - checkAddr func(*testing.T, wgConn.Endpoint) - }{ - { - name: "valid IPv4", - input: "192.168.1.1:51820", - wantErr: false, - checkAddr: func(t *testing.T, ep wgConn.Endpoint) { - if ep.DstToString() != "192.168.1.1:51820" { - t.Errorf("Expected 192.168.1.1:51820, got %s", ep.DstToString()) - } - }, - }, - { - name: "valid IPv6", - input: "[::1]:51820", - wantErr: false, - checkAddr: func(t *testing.T, ep wgConn.Endpoint) { - if ep.DstToString() != "[::1]:51820" { - t.Errorf("Expected [::1]:51820, got %s", ep.DstToString()) - } - }, - }, - { - name: "invalid - missing port", - input: "192.168.1.1", - wantErr: true, - }, - { - name: "invalid - bad format", - input: "not-an-address", - wantErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ep, err := bind.ParseEndpoint(tt.input) - if (err != nil) != tt.wantErr { - t.Errorf("ParseEndpoint() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !tt.wantErr && tt.checkAddr != nil { - tt.checkAddr(t, ep) - } - }) - } -} diff --git a/go.mod b/go.mod index e6ae7f2..0c16b81 100644 --- a/go.mod +++ b/go.mod @@ -4,13 +4,12 @@ go 1.25 require ( github.com/Microsoft/go-winio v0.6.2 - github.com/fosrl/newt v0.0.0-20250929233849-71c5bf7e65f7 + github.com/fosrl/newt v0.0.0 github.com/gorilla/websocket v1.5.3 github.com/vishvananda/netlink v1.3.1 - golang.org/x/crypto v0.43.0 - golang.org/x/exp v0.0.0-20250718183923-645b1fa84792 - golang.org/x/net v0.45.0 - golang.org/x/sys v0.37.0 + golang.org/x/crypto v0.44.0 + golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 + golang.org/x/sys v0.38.0 golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 software.sslmate.com/src/go-pkcs12 v0.6.0 @@ -18,6 +17,9 @@ require ( require ( github.com/vishvananda/netns v0.0.5 // indirect + golang.org/x/net v0.47.0 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect gvisor.dev/gvisor v0.0.0-20250718192347-d7830d968c56 // indirect ) + +replace github.com/fosrl/newt => ../newt diff --git a/go.sum b/go.sum index 88dc4e7..d2dbb17 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,5 @@ github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= -github.com/fosrl/newt v0.0.0-20250929233849-71c5bf7e65f7 h1:6bSU8Efyhx1SR53iSw1Wjk5V8vDfizGAudq/GlE9b+o= -github.com/fosrl/newt v0.0.0-20250929233849-71c5bf7e65f7/go.mod h1:Ac0k2FmAMC+hu21rAK+p7EnnEGrqKO/QZuGTVHA/XDM= github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg= github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= @@ -12,16 +10,16 @@ github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4= github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY= github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= -golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04= -golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0= -golang.org/x/exp v0.0.0-20250718183923-645b1fa84792 h1:R9PFI6EUdfVKgwKjZef7QIwGcBKu86OEFpJ9nUEP2l4= -golang.org/x/exp v0.0.0-20250718183923-645b1fa84792/go.mod h1:A+z0yzpGtvnG90cToK5n2tu8UJVP2XUATh+r+sfOOOc= -golang.org/x/net v0.45.0 h1:RLBg5JKixCy82FtLJpeNlVM0nrSqpCRYzVU1n8kj0tM= -golang.org/x/net v0.45.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY= +golang.org/x/crypto v0.44.0 h1:A97SsFvM3AIwEEmTBiaxPPTYpDC47w720rdiiUvgoAU= +golang.org/x/crypto v0.44.0/go.mod h1:013i+Nw79BMiQiMsOPcVCB5ZIJbYkerPrGnOa00tvmc= +golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 h1:zfMcR1Cs4KNuomFFgGefv5N0czO2XZpUbxGUy8i8ug0= +golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6/go.mod h1:46edojNIoXTNOhySWIWdix628clX9ODXwPsQuG6hsK0= +golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= +golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ= -golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= +golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= diff --git a/holepunch/holepunch.go b/holepunch/holepunch.go deleted file mode 100644 index 187d3fe..0000000 --- a/holepunch/holepunch.go +++ /dev/null @@ -1,351 +0,0 @@ -package holepunch - -import ( - "encoding/json" - "fmt" - "net" - "sync" - "time" - - "github.com/fosrl/newt/logger" - "github.com/fosrl/olm/bind" - "golang.org/x/crypto/chacha20poly1305" - "golang.org/x/crypto/curve25519" - "golang.org/x/exp/rand" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" -) - -// DomainResolver is a function type for resolving domains to IP addresses -type DomainResolver func(string) (string, error) - -// ExitNode represents a WireGuard exit node for hole punching -type ExitNode struct { - Endpoint string `json:"endpoint"` - PublicKey string `json:"publicKey"` -} - -// Manager handles UDP hole punching operations -type Manager struct { - mu sync.Mutex - running bool - stopChan chan struct{} - sharedBind *bind.SharedBind - olmID string - token string - domainResolver DomainResolver -} - -// NewManager creates a new hole punch manager -func NewManager(sharedBind *bind.SharedBind, olmID string, domainResolver DomainResolver) *Manager { - return &Manager{ - sharedBind: sharedBind, - olmID: olmID, - domainResolver: domainResolver, - } -} - -// SetToken updates the authentication token used for hole punching -func (m *Manager) SetToken(token string) { - m.mu.Lock() - defer m.mu.Unlock() - m.token = token -} - -// IsRunning returns whether hole punching is currently active -func (m *Manager) IsRunning() bool { - m.mu.Lock() - defer m.mu.Unlock() - return m.running -} - -// Stop stops any ongoing hole punch operations -func (m *Manager) Stop() { - m.mu.Lock() - defer m.mu.Unlock() - - if !m.running { - return - } - - if m.stopChan != nil { - close(m.stopChan) - m.stopChan = nil - } - - m.running = false - logger.Info("Hole punch manager stopped") -} - -// StartMultipleExitNodes starts hole punching to multiple exit nodes -func (m *Manager) StartMultipleExitNodes(exitNodes []ExitNode) error { - m.mu.Lock() - - if m.running { - m.mu.Unlock() - logger.Debug("UDP hole punch already running, skipping new request") - return fmt.Errorf("hole punch already running") - } - - if len(exitNodes) == 0 { - m.mu.Unlock() - logger.Warn("No exit nodes provided for hole punching") - return fmt.Errorf("no exit nodes provided") - } - - m.running = true - m.stopChan = make(chan struct{}) - m.mu.Unlock() - - logger.Info("Starting UDP hole punch to %d exit nodes with shared bind", len(exitNodes)) - - go m.runMultipleExitNodes(exitNodes) - - return nil -} - -// StartSingleEndpoint starts hole punching to a single endpoint (legacy mode) -func (m *Manager) StartSingleEndpoint(endpoint, serverPubKey string) error { - m.mu.Lock() - - if m.running { - m.mu.Unlock() - logger.Debug("UDP hole punch already running, skipping new request") - return fmt.Errorf("hole punch already running") - } - - m.running = true - m.stopChan = make(chan struct{}) - m.mu.Unlock() - - logger.Info("Starting UDP hole punch to %s with shared bind", endpoint) - - go m.runSingleEndpoint(endpoint, serverPubKey) - - return nil -} - -// runMultipleExitNodes performs hole punching to multiple exit nodes -func (m *Manager) runMultipleExitNodes(exitNodes []ExitNode) { - defer func() { - m.mu.Lock() - m.running = false - m.mu.Unlock() - logger.Info("UDP hole punch goroutine ended for all exit nodes") - }() - - // Resolve all endpoints upfront - type resolvedExitNode struct { - remoteAddr *net.UDPAddr - publicKey string - endpointName string - } - - var resolvedNodes []resolvedExitNode - for _, exitNode := range exitNodes { - host, err := m.domainResolver(exitNode.Endpoint) - if err != nil { - logger.Warn("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err) - continue - } - - serverAddr := net.JoinHostPort(host, "21820") - remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr) - if err != nil { - logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err) - continue - } - - resolvedNodes = append(resolvedNodes, resolvedExitNode{ - remoteAddr: remoteAddr, - publicKey: exitNode.PublicKey, - endpointName: exitNode.Endpoint, - }) - logger.Info("Resolved exit node: %s -> %s", exitNode.Endpoint, remoteAddr.String()) - } - - if len(resolvedNodes) == 0 { - logger.Error("No exit nodes could be resolved") - return - } - - // Send initial hole punch to all exit nodes - for _, node := range resolvedNodes { - if err := m.sendHolePunch(node.remoteAddr, node.publicKey); err != nil { - logger.Warn("Failed to send initial hole punch to %s: %v", node.endpointName, err) - } - } - - ticker := time.NewTicker(250 * time.Millisecond) - defer ticker.Stop() - - timeout := time.NewTimer(15 * time.Second) - defer timeout.Stop() - - for { - select { - case <-m.stopChan: - logger.Debug("Hole punch stopped by signal") - return - case <-timeout.C: - logger.Debug("Hole punch timeout reached") - return - case <-ticker.C: - // Send hole punch to all exit nodes - for _, node := range resolvedNodes { - if err := m.sendHolePunch(node.remoteAddr, node.publicKey); err != nil { - logger.Debug("Failed to send hole punch to %s: %v", node.endpointName, err) - } - } - } - } -} - -// runSingleEndpoint performs hole punching to a single endpoint -func (m *Manager) runSingleEndpoint(endpoint, serverPubKey string) { - defer func() { - m.mu.Lock() - m.running = false - m.mu.Unlock() - logger.Info("UDP hole punch goroutine ended for %s", endpoint) - }() - - host, err := m.domainResolver(endpoint) - if err != nil { - logger.Error("Failed to resolve domain %s: %v", endpoint, err) - return - } - - serverAddr := net.JoinHostPort(host, "21820") - - remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr) - if err != nil { - logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err) - return - } - - // Execute once immediately before starting the loop - if err := m.sendHolePunch(remoteAddr, serverPubKey); err != nil { - logger.Warn("Failed to send initial hole punch: %v", err) - } - - ticker := time.NewTicker(250 * time.Millisecond) - defer ticker.Stop() - - timeout := time.NewTimer(15 * time.Second) - defer timeout.Stop() - - for { - select { - case <-m.stopChan: - logger.Debug("Hole punch stopped by signal") - return - case <-timeout.C: - logger.Debug("Hole punch timeout reached") - return - case <-ticker.C: - if err := m.sendHolePunch(remoteAddr, serverPubKey); err != nil { - logger.Debug("Failed to send hole punch: %v", err) - } - } - } -} - -// sendHolePunch sends an encrypted hole punch packet using the shared bind -func (m *Manager) sendHolePunch(remoteAddr *net.UDPAddr, serverPubKey string) error { - m.mu.Lock() - token := m.token - olmID := m.olmID - m.mu.Unlock() - - if serverPubKey == "" || token == "" { - return fmt.Errorf("server public key or OLM token is empty") - } - - payload := struct { - OlmID string `json:"olmId"` - Token string `json:"token"` - }{ - OlmID: olmID, - Token: token, - } - - // Convert payload to JSON - payloadBytes, err := json.Marshal(payload) - if err != nil { - return fmt.Errorf("failed to marshal payload: %w", err) - } - - // Encrypt the payload using the server's WireGuard public key - encryptedPayload, err := encryptPayload(payloadBytes, serverPubKey) - if err != nil { - return fmt.Errorf("failed to encrypt payload: %w", err) - } - - jsonData, err := json.Marshal(encryptedPayload) - if err != nil { - return fmt.Errorf("failed to marshal encrypted payload: %w", err) - } - - _, err = m.sharedBind.WriteToUDP(jsonData, remoteAddr) - if err != nil { - return fmt.Errorf("failed to write to UDP: %w", err) - } - - logger.Debug("Sent UDP hole punch to %s: %s", remoteAddr.String(), string(jsonData)) - - return nil -} - -// encryptPayload encrypts the payload using ChaCha20-Poly1305 AEAD with X25519 key exchange -func encryptPayload(payload []byte, serverPublicKey string) (interface{}, error) { - // Generate an ephemeral keypair for this message - ephemeralPrivateKey, err := wgtypes.GeneratePrivateKey() - if err != nil { - return nil, fmt.Errorf("failed to generate ephemeral private key: %v", err) - } - ephemeralPublicKey := ephemeralPrivateKey.PublicKey() - - // Parse the server's public key - serverPubKey, err := wgtypes.ParseKey(serverPublicKey) - if err != nil { - return nil, fmt.Errorf("failed to parse server public key: %v", err) - } - - // Use X25519 for key exchange - var ephPrivKeyFixed [32]byte - copy(ephPrivKeyFixed[:], ephemeralPrivateKey[:]) - - // Perform X25519 key exchange - sharedSecret, err := curve25519.X25519(ephPrivKeyFixed[:], serverPubKey[:]) - if err != nil { - return nil, fmt.Errorf("failed to perform X25519 key exchange: %v", err) - } - - // Create an AEAD cipher using the shared secret - aead, err := chacha20poly1305.New(sharedSecret) - if err != nil { - return nil, fmt.Errorf("failed to create AEAD cipher: %v", err) - } - - // Generate a random nonce - nonce := make([]byte, aead.NonceSize()) - if _, err := rand.Read(nonce); err != nil { - return nil, fmt.Errorf("failed to generate nonce: %v", err) - } - - // Encrypt the payload - ciphertext := aead.Seal(nil, nonce, payload, nil) - - // Prepare the final encrypted message - encryptedMsg := struct { - EphemeralPublicKey string `json:"ephemeralPublicKey"` - Nonce []byte `json:"nonce"` - Ciphertext []byte `json:"ciphertext"` - }{ - EphemeralPublicKey: ephemeralPublicKey.String(), - Nonce: nonce, - Ciphertext: ciphertext, - } - - return encryptedMsg, nil -} diff --git a/olm/common.go b/olm/common.go index c15b66d..1a10eda 100644 --- a/olm/common.go +++ b/olm/common.go @@ -13,10 +13,10 @@ import ( "time" "github.com/fosrl/newt/logger" + "github.com/fosrl/newt/util" "github.com/fosrl/olm/peermonitor" "github.com/fosrl/olm/websocket" "github.com/vishvananda/netlink" - "golang.org/x/exp/rand" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) @@ -156,23 +156,6 @@ func fixKey(key string) string { return hex.EncodeToString(decoded) } -func parseLogLevel(level string) logger.LogLevel { - switch strings.ToUpper(level) { - case "DEBUG": - return logger.DEBUG - case "INFO": - return logger.INFO - case "WARN": - return logger.WARN - case "ERROR": - return logger.ERROR - case "FATAL": - return logger.FATAL - default: - return logger.INFO // default to INFO if invalid level provided - } -} - func mapToWireGuardLogLevel(level logger.LogLevel) int { switch level { case logger.DEBUG: @@ -188,89 +171,6 @@ func mapToWireGuardLogLevel(level logger.LogLevel) int { } } -func ResolveDomain(domain string) (string, error) { - // First handle any protocol prefix - domain = strings.TrimPrefix(strings.TrimPrefix(domain, "https://"), "http://") - - // if there are any trailing slashes, remove them - domain = strings.TrimSuffix(domain, "/") - - // Now split host and port - host, port, err := net.SplitHostPort(domain) - if err != nil { - // No port found, use the domain as is - host = domain - port = "" - } - - // Lookup IP addresses - ips, err := net.LookupIP(host) - if err != nil { - return "", fmt.Errorf("DNS lookup failed: %v", err) - } - - if len(ips) == 0 { - return "", fmt.Errorf("no IP addresses found for domain %s", host) - } - - // Get the first IPv4 address if available - var ipAddr string - for _, ip := range ips { - if ipv4 := ip.To4(); ipv4 != nil { - ipAddr = ipv4.String() - break - } - } - - // If no IPv4 found, use the first IP (might be IPv6) - if ipAddr == "" { - ipAddr = ips[0].String() - } - - // Add port back if it existed - if port != "" { - ipAddr = net.JoinHostPort(ipAddr, port) - } - - return ipAddr, nil -} - -func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) { - if maxPort < minPort { - return 0, fmt.Errorf("invalid port range: min=%d, max=%d", minPort, maxPort) - } - - // Create a slice of all ports in the range - portRange := make([]uint16, maxPort-minPort+1) - for i := range portRange { - portRange[i] = minPort + uint16(i) - } - - // Fisher-Yates shuffle to randomize the port order - rand.Seed(uint64(time.Now().UnixNano())) - for i := len(portRange) - 1; i > 0; i-- { - j := rand.Intn(i + 1) - portRange[i], portRange[j] = portRange[j], portRange[i] - } - - // Try each port in the randomized order - for _, port := range portRange { - addr := &net.UDPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: int(port), - } - conn, err := net.ListenUDP("udp", addr) - if err != nil { - continue // Port is in use or there was an error, try next port - } - _ = conn.SetDeadline(time.Now()) - conn.Close() - return port, nil - } - - return 0, fmt.Errorf("no available UDP ports found in range %d-%d", minPort, maxPort) -} - func sendPing(olm *websocket.Client) error { err := olm.SendMessage("olm/ping", map[string]interface{}{ "timestamp": time.Now().Unix(), @@ -311,7 +211,7 @@ func keepSendingPing(olm *websocket.Client) { // ConfigurePeer sets up or updates a peer within the WireGuard device func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, endpoint string) error { - siteHost, err := ResolveDomain(siteConfig.Endpoint) + siteHost, err := util.ResolveDomain(siteConfig.Endpoint) if err != nil { return fmt.Errorf("failed to resolve endpoint for site %d: %v", siteConfig.SiteId, err) } @@ -368,7 +268,7 @@ func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes monitorPeer := net.JoinHostPort(monitorAddress, strconv.Itoa(int(siteConfig.ServerPort+1))) // +1 for the monitor port logger.Debug("Setting up peer monitor for site %d at %s", siteConfig.SiteId, monitorPeer) - primaryRelay, err := ResolveDomain(endpoint) // Using global endpoint variable + primaryRelay, err := util.ResolveDomain(endpoint) // Using global endpoint variable if err != nil { logger.Warn("Failed to resolve primary relay endpoint: %v", err) } diff --git a/olm/olm.go b/olm/olm.go index fb20e3f..5943456 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -9,11 +9,12 @@ import ( "strconv" "time" + "github.com/fosrl/newt/bind" + "github.com/fosrl/newt/holepunch" "github.com/fosrl/newt/logger" "github.com/fosrl/newt/updates" + "github.com/fosrl/newt/util" "github.com/fosrl/olm/api" - "github.com/fosrl/olm/bind" - "github.com/fosrl/olm/holepunch" "github.com/fosrl/olm/peermonitor" "github.com/fosrl/olm/websocket" "golang.zx2c4.com/wireguard/device" @@ -78,7 +79,7 @@ func Run(ctx context.Context, config Config) { ctx, cancel := context.WithCancel(ctx) defer cancel() - logger.GetLogger().SetLevel(parseLogLevel(config.LogLevel)) + logger.GetLogger().SetLevel(util.ParseLogLevel(config.LogLevel)) if err := updates.CheckForUpdate("fosrl", "olm", config.Version); err != nil { logger.Debug("Failed to check for updates: %v", err) @@ -203,7 +204,7 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string, var ( interfaceName = config.InterfaceName - loggerLevel = parseLogLevel(config.LogLevel) + loggerLevel = util.ParseLogLevel(config.LogLevel) ) // Create a new olm client using the provided credentials @@ -231,7 +232,7 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string, // Create shared UDP socket for both holepunch and WireGuard if sharedBind == nil { - sourcePort, err := FindAvailableUDPPort(49152, 65535) + sourcePort, err := util.FindAvailableUDPPort(49152, 65535) if err != nil { logger.Error("Error finding available port: %v", err) return @@ -263,7 +264,7 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string, // Create the holepunch manager if holePunchManager == nil { - holePunchManager = holepunch.NewManager(sharedBind, id, ResolveDomain) + holePunchManager = holepunch.NewManager(sharedBind, id, "olm") } olm.RegisterHandler("olm/wg/holepunch/all", func(msg websocket.WSMessage) { @@ -705,7 +706,7 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string, return } - primaryRelay, err := ResolveDomain(relayData.Endpoint) + primaryRelay, err := util.ResolveDomain(relayData.Endpoint) if err != nil { logger.Warn("Failed to resolve primary relay endpoint: %v", err) }