mirror of
https://github.com/fosrl/olm.git
synced 2026-02-08 05:56:41 +00:00
378
bind/shared_bind.go
Normal file
378
bind/shared_bind.go
Normal file
@@ -0,0 +1,378 @@
|
|||||||
|
//go:build !js
|
||||||
|
|
||||||
|
package bind
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"runtime"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
|
"golang.org/x/net/ipv4"
|
||||||
|
"golang.org/x/net/ipv6"
|
||||||
|
wgConn "golang.zx2c4.com/wireguard/conn"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Endpoint represents a network endpoint for the SharedBind
|
||||||
|
type Endpoint struct {
|
||||||
|
AddrPort netip.AddrPort
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearSrc implements the wgConn.Endpoint interface
|
||||||
|
func (e *Endpoint) ClearSrc() {}
|
||||||
|
|
||||||
|
// DstIP implements the wgConn.Endpoint interface
|
||||||
|
func (e *Endpoint) DstIP() netip.Addr {
|
||||||
|
return e.AddrPort.Addr()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SrcIP implements the wgConn.Endpoint interface
|
||||||
|
func (e *Endpoint) SrcIP() netip.Addr {
|
||||||
|
return netip.Addr{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DstToBytes implements the wgConn.Endpoint interface
|
||||||
|
func (e *Endpoint) DstToBytes() []byte {
|
||||||
|
b, _ := e.AddrPort.MarshalBinary()
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// DstToString implements the wgConn.Endpoint interface
|
||||||
|
func (e *Endpoint) DstToString() string {
|
||||||
|
return e.AddrPort.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SrcToString implements the wgConn.Endpoint interface
|
||||||
|
func (e *Endpoint) SrcToString() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// SharedBind is a thread-safe UDP bind that can be shared between WireGuard
|
||||||
|
// and hole punch senders. It wraps a single UDP connection and implements
|
||||||
|
// reference counting to prevent premature closure.
|
||||||
|
type SharedBind struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
|
||||||
|
// The underlying UDP connection
|
||||||
|
udpConn *net.UDPConn
|
||||||
|
|
||||||
|
// IPv4 and IPv6 packet connections for advanced features
|
||||||
|
ipv4PC *ipv4.PacketConn
|
||||||
|
ipv6PC *ipv6.PacketConn
|
||||||
|
|
||||||
|
// Reference counting to prevent closing while in use
|
||||||
|
refCount atomic.Int32
|
||||||
|
closed atomic.Bool
|
||||||
|
|
||||||
|
// Channels for receiving data
|
||||||
|
recvFuncs []wgConn.ReceiveFunc
|
||||||
|
|
||||||
|
// Port binding information
|
||||||
|
port uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a new SharedBind from an existing UDP connection.
|
||||||
|
// The SharedBind takes ownership of the connection and will close it
|
||||||
|
// when all references are released.
|
||||||
|
func New(udpConn *net.UDPConn) (*SharedBind, error) {
|
||||||
|
if udpConn == nil {
|
||||||
|
return nil, fmt.Errorf("udpConn cannot be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
bind := &SharedBind{
|
||||||
|
udpConn: udpConn,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize reference count to 1 (the creator holds the first reference)
|
||||||
|
bind.refCount.Store(1)
|
||||||
|
|
||||||
|
// Get the local port
|
||||||
|
if addr, ok := udpConn.LocalAddr().(*net.UDPAddr); ok {
|
||||||
|
bind.port = uint16(addr.Port)
|
||||||
|
}
|
||||||
|
|
||||||
|
return bind, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddRef increments the reference count. Call this when sharing
|
||||||
|
// the bind with another component.
|
||||||
|
func (b *SharedBind) AddRef() {
|
||||||
|
newCount := b.refCount.Add(1)
|
||||||
|
// Optional: Add logging for debugging
|
||||||
|
_ = newCount // Placeholder for potential logging
|
||||||
|
}
|
||||||
|
|
||||||
|
// Release decrements the reference count. When it reaches zero,
|
||||||
|
// the underlying UDP connection is closed.
|
||||||
|
func (b *SharedBind) Release() error {
|
||||||
|
newCount := b.refCount.Add(-1)
|
||||||
|
// Optional: Add logging for debugging
|
||||||
|
_ = newCount // Placeholder for potential logging
|
||||||
|
|
||||||
|
if newCount < 0 {
|
||||||
|
// This should never happen with proper usage
|
||||||
|
b.refCount.Store(0)
|
||||||
|
return fmt.Errorf("SharedBind reference count went negative")
|
||||||
|
}
|
||||||
|
|
||||||
|
if newCount == 0 {
|
||||||
|
return b.closeConnection()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// closeConnection actually closes the UDP connection
|
||||||
|
func (b *SharedBind) closeConnection() error {
|
||||||
|
if !b.closed.CompareAndSwap(false, true) {
|
||||||
|
// Already closed
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
b.mu.Lock()
|
||||||
|
defer b.mu.Unlock()
|
||||||
|
|
||||||
|
var err error
|
||||||
|
if b.udpConn != nil {
|
||||||
|
err = b.udpConn.Close()
|
||||||
|
b.udpConn = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ipv4PC = nil
|
||||||
|
b.ipv6PC = nil
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUDPConn returns the underlying UDP connection.
|
||||||
|
// The caller must not close this connection directly.
|
||||||
|
func (b *SharedBind) GetUDPConn() *net.UDPConn {
|
||||||
|
b.mu.RLock()
|
||||||
|
defer b.mu.RUnlock()
|
||||||
|
return b.udpConn
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRefCount returns the current reference count (for debugging)
|
||||||
|
func (b *SharedBind) GetRefCount() int32 {
|
||||||
|
return b.refCount.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsClosed returns whether the bind is closed
|
||||||
|
func (b *SharedBind) IsClosed() bool {
|
||||||
|
return b.closed.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteToUDP writes data to a specific UDP address.
|
||||||
|
// This is thread-safe and can be used by hole punch senders.
|
||||||
|
func (b *SharedBind) WriteToUDP(data []byte, addr *net.UDPAddr) (int, error) {
|
||||||
|
if b.closed.Load() {
|
||||||
|
return 0, net.ErrClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
b.mu.RLock()
|
||||||
|
conn := b.udpConn
|
||||||
|
b.mu.RUnlock()
|
||||||
|
|
||||||
|
if conn == nil {
|
||||||
|
return 0, net.ErrClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
return conn.WriteToUDP(data, addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close implements the WireGuard Bind interface.
|
||||||
|
// It decrements the reference count and closes the connection if no references remain.
|
||||||
|
func (b *SharedBind) Close() error {
|
||||||
|
return b.Release()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Open implements the WireGuard Bind interface.
|
||||||
|
// Since the connection is already open, this just sets up the receive functions.
|
||||||
|
func (b *SharedBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) {
|
||||||
|
if b.closed.Load() {
|
||||||
|
return nil, 0, net.ErrClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
b.mu.Lock()
|
||||||
|
defer b.mu.Unlock()
|
||||||
|
|
||||||
|
if b.udpConn == nil {
|
||||||
|
return nil, 0, net.ErrClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set up IPv4 and IPv6 packet connections for advanced features
|
||||||
|
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||||
|
b.ipv4PC = ipv4.NewPacketConn(b.udpConn)
|
||||||
|
b.ipv6PC = ipv6.NewPacketConn(b.udpConn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create receive functions
|
||||||
|
recvFuncs := make([]wgConn.ReceiveFunc, 0, 2)
|
||||||
|
|
||||||
|
// Add IPv4 receive function
|
||||||
|
if b.ipv4PC != nil || runtime.GOOS != "linux" {
|
||||||
|
recvFuncs = append(recvFuncs, b.makeReceiveIPv4())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add IPv6 receive function if needed
|
||||||
|
// For now, we focus on IPv4 for hole punching use case
|
||||||
|
|
||||||
|
b.recvFuncs = recvFuncs
|
||||||
|
return recvFuncs, b.port, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// makeReceiveIPv4 creates a receive function for IPv4 packets
|
||||||
|
func (b *SharedBind) makeReceiveIPv4() wgConn.ReceiveFunc {
|
||||||
|
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
||||||
|
if b.closed.Load() {
|
||||||
|
return 0, net.ErrClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
b.mu.RLock()
|
||||||
|
conn := b.udpConn
|
||||||
|
pc := b.ipv4PC
|
||||||
|
b.mu.RUnlock()
|
||||||
|
|
||||||
|
if conn == nil {
|
||||||
|
return 0, net.ErrClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use batch reading on Linux for performance
|
||||||
|
if pc != nil && (runtime.GOOS == "linux" || runtime.GOOS == "android") {
|
||||||
|
return b.receiveIPv4Batch(pc, bufs, sizes, eps)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to simple read for other platforms
|
||||||
|
return b.receiveIPv4Simple(conn, bufs, sizes, eps)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// receiveIPv4Batch uses batch reading for better performance on Linux
|
||||||
|
func (b *SharedBind) receiveIPv4Batch(pc *ipv4.PacketConn, bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) {
|
||||||
|
// Create messages for batch reading
|
||||||
|
msgs := make([]ipv4.Message, len(bufs))
|
||||||
|
for i := range bufs {
|
||||||
|
msgs[i].Buffers = [][]byte{bufs[i]}
|
||||||
|
msgs[i].OOB = make([]byte, 0) // No OOB data needed for basic use
|
||||||
|
}
|
||||||
|
|
||||||
|
numMsgs, err := pc.ReadBatch(msgs, 0)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < numMsgs; i++ {
|
||||||
|
sizes[i] = msgs[i].N
|
||||||
|
if sizes[i] == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if msgs[i].Addr != nil {
|
||||||
|
if udpAddr, ok := msgs[i].Addr.(*net.UDPAddr); ok {
|
||||||
|
addrPort := udpAddr.AddrPort()
|
||||||
|
eps[i] = &wgConn.StdNetEndpoint{AddrPort: addrPort}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return numMsgs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// receiveIPv4Simple uses simple ReadFromUDP for non-Linux platforms
|
||||||
|
func (b *SharedBind) receiveIPv4Simple(conn *net.UDPConn, bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) {
|
||||||
|
n, addr, err := conn.ReadFromUDP(bufs[0])
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
sizes[0] = n
|
||||||
|
if addr != nil {
|
||||||
|
addrPort := addr.AddrPort()
|
||||||
|
eps[0] = &wgConn.StdNetEndpoint{AddrPort: addrPort}
|
||||||
|
}
|
||||||
|
|
||||||
|
return 1, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send implements the WireGuard Bind interface.
|
||||||
|
// It sends packets to the specified endpoint.
|
||||||
|
func (b *SharedBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
|
||||||
|
if b.closed.Load() {
|
||||||
|
return net.ErrClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
b.mu.RLock()
|
||||||
|
conn := b.udpConn
|
||||||
|
b.mu.RUnlock()
|
||||||
|
|
||||||
|
if conn == nil {
|
||||||
|
return net.ErrClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract the destination address from the endpoint
|
||||||
|
var destAddr *net.UDPAddr
|
||||||
|
|
||||||
|
// Try to cast to StdNetEndpoint first
|
||||||
|
if stdEp, ok := ep.(*wgConn.StdNetEndpoint); ok {
|
||||||
|
destAddr = net.UDPAddrFromAddrPort(stdEp.AddrPort)
|
||||||
|
} else {
|
||||||
|
// Fallback: construct from DstIP and DstToBytes
|
||||||
|
dstBytes := ep.DstToBytes()
|
||||||
|
if len(dstBytes) >= 6 { // Minimum for IPv4 (4 bytes) + port (2 bytes)
|
||||||
|
var addr netip.Addr
|
||||||
|
var port uint16
|
||||||
|
|
||||||
|
if len(dstBytes) >= 18 { // IPv6 (16 bytes) + port (2 bytes)
|
||||||
|
addr, _ = netip.AddrFromSlice(dstBytes[:16])
|
||||||
|
port = uint16(dstBytes[16]) | uint16(dstBytes[17])<<8
|
||||||
|
} else { // IPv4
|
||||||
|
addr, _ = netip.AddrFromSlice(dstBytes[:4])
|
||||||
|
port = uint16(dstBytes[4]) | uint16(dstBytes[5])<<8
|
||||||
|
}
|
||||||
|
|
||||||
|
if addr.IsValid() {
|
||||||
|
destAddr = net.UDPAddrFromAddrPort(netip.AddrPortFrom(addr, port))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if destAddr == nil {
|
||||||
|
return fmt.Errorf("could not extract destination address from endpoint")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send all buffers to the destination
|
||||||
|
for _, buf := range bufs {
|
||||||
|
_, err := conn.WriteToUDP(buf, destAddr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMark implements the WireGuard Bind interface.
|
||||||
|
// It's a no-op for this implementation.
|
||||||
|
func (b *SharedBind) SetMark(mark uint32) error {
|
||||||
|
// Not implemented for this use case
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchSize returns the preferred batch size for sending packets.
|
||||||
|
func (b *SharedBind) BatchSize() int {
|
||||||
|
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||||
|
return wgConn.IdealBatchSize
|
||||||
|
}
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseEndpoint creates a new endpoint from a string address.
|
||||||
|
func (b *SharedBind) ParseEndpoint(s string) (wgConn.Endpoint, error) {
|
||||||
|
addrPort, err := netip.ParseAddrPort(s)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &wgConn.StdNetEndpoint{AddrPort: addrPort}, nil
|
||||||
|
}
|
||||||
424
bind/shared_bind_test.go
Normal file
424
bind/shared_bind_test.go
Normal file
@@ -0,0 +1,424 @@
|
|||||||
|
//go:build !js
|
||||||
|
|
||||||
|
package bind
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
wgConn "golang.zx2c4.com/wireguard/conn"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestSharedBindCreation tests basic creation and initialization
|
||||||
|
func TestSharedBindCreation(t *testing.T) {
|
||||||
|
// Create a UDP connection
|
||||||
|
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create UDP connection: %v", err)
|
||||||
|
}
|
||||||
|
defer udpConn.Close()
|
||||||
|
|
||||||
|
// Create SharedBind
|
||||||
|
bind, err := New(udpConn)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create SharedBind: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if bind == nil {
|
||||||
|
t.Fatal("SharedBind is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify initial reference count
|
||||||
|
if bind.refCount.Load() != 1 {
|
||||||
|
t.Errorf("Expected initial refCount to be 1, got %d", bind.refCount.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean up
|
||||||
|
if err := bind.Close(); err != nil {
|
||||||
|
t.Errorf("Failed to close SharedBind: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSharedBindReferenceCount tests reference counting
|
||||||
|
func TestSharedBindReferenceCount(t *testing.T) {
|
||||||
|
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create UDP connection: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
bind, err := New(udpConn)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create SharedBind: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add references
|
||||||
|
bind.AddRef()
|
||||||
|
if bind.refCount.Load() != 2 {
|
||||||
|
t.Errorf("Expected refCount to be 2, got %d", bind.refCount.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
bind.AddRef()
|
||||||
|
if bind.refCount.Load() != 3 {
|
||||||
|
t.Errorf("Expected refCount to be 3, got %d", bind.refCount.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Release references
|
||||||
|
bind.Release()
|
||||||
|
if bind.refCount.Load() != 2 {
|
||||||
|
t.Errorf("Expected refCount to be 2 after release, got %d", bind.refCount.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
bind.Release()
|
||||||
|
bind.Release() // This should close the connection
|
||||||
|
|
||||||
|
if !bind.closed.Load() {
|
||||||
|
t.Error("Expected bind to be closed after all references released")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSharedBindWriteToUDP tests the WriteToUDP functionality
|
||||||
|
func TestSharedBindWriteToUDP(t *testing.T) {
|
||||||
|
// Create sender
|
||||||
|
senderConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create sender UDP connection: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
senderBind, err := New(senderConn)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create sender SharedBind: %v", err)
|
||||||
|
}
|
||||||
|
defer senderBind.Close()
|
||||||
|
|
||||||
|
// Create receiver
|
||||||
|
receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create receiver UDP connection: %v", err)
|
||||||
|
}
|
||||||
|
defer receiverConn.Close()
|
||||||
|
|
||||||
|
receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr)
|
||||||
|
|
||||||
|
// Send data
|
||||||
|
testData := []byte("Hello, SharedBind!")
|
||||||
|
n, err := senderBind.WriteToUDP(testData, receiverAddr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WriteToUDP failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if n != len(testData) {
|
||||||
|
t.Errorf("Expected to send %d bytes, sent %d", len(testData), n)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Receive data
|
||||||
|
buf := make([]byte, 1024)
|
||||||
|
receiverConn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||||
|
n, _, err = receiverConn.ReadFromUDP(buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to receive data: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(buf[:n]) != string(testData) {
|
||||||
|
t.Errorf("Expected to receive %q, got %q", testData, buf[:n])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSharedBindConcurrentWrites tests thread-safety
|
||||||
|
func TestSharedBindConcurrentWrites(t *testing.T) {
|
||||||
|
// Create sender
|
||||||
|
senderConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create sender UDP connection: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
senderBind, err := New(senderConn)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create sender SharedBind: %v", err)
|
||||||
|
}
|
||||||
|
defer senderBind.Close()
|
||||||
|
|
||||||
|
// Create receiver
|
||||||
|
receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create receiver UDP connection: %v", err)
|
||||||
|
}
|
||||||
|
defer receiverConn.Close()
|
||||||
|
|
||||||
|
receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr)
|
||||||
|
|
||||||
|
// Launch concurrent writes
|
||||||
|
numGoroutines := 100
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(numGoroutines)
|
||||||
|
|
||||||
|
for i := 0; i < numGoroutines; i++ {
|
||||||
|
go func(id int) {
|
||||||
|
defer wg.Done()
|
||||||
|
data := []byte{byte(id)}
|
||||||
|
_, err := senderBind.WriteToUDP(data, receiverAddr)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("WriteToUDP failed in goroutine %d: %v", id, err)
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSharedBindWireGuardInterface tests WireGuard Bind interface implementation
|
||||||
|
func TestSharedBindWireGuardInterface(t *testing.T) {
|
||||||
|
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create UDP connection: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
bind, err := New(udpConn)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create SharedBind: %v", err)
|
||||||
|
}
|
||||||
|
defer bind.Close()
|
||||||
|
|
||||||
|
// Test Open
|
||||||
|
recvFuncs, port, err := bind.Open(0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Open failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(recvFuncs) == 0 {
|
||||||
|
t.Error("Expected at least one receive function")
|
||||||
|
}
|
||||||
|
|
||||||
|
if port == 0 {
|
||||||
|
t.Error("Expected non-zero port")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test SetMark (should be a no-op)
|
||||||
|
if err := bind.SetMark(0); err != nil {
|
||||||
|
t.Errorf("SetMark failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test BatchSize
|
||||||
|
batchSize := bind.BatchSize()
|
||||||
|
if batchSize <= 0 {
|
||||||
|
t.Error("Expected positive batch size")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSharedBindSend tests the Send method with WireGuard endpoints
|
||||||
|
func TestSharedBindSend(t *testing.T) {
|
||||||
|
// Create sender
|
||||||
|
senderConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create sender UDP connection: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
senderBind, err := New(senderConn)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create sender SharedBind: %v", err)
|
||||||
|
}
|
||||||
|
defer senderBind.Close()
|
||||||
|
|
||||||
|
// Create receiver
|
||||||
|
receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create receiver UDP connection: %v", err)
|
||||||
|
}
|
||||||
|
defer receiverConn.Close()
|
||||||
|
|
||||||
|
receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr)
|
||||||
|
|
||||||
|
// Create an endpoint
|
||||||
|
addrPort := receiverAddr.AddrPort()
|
||||||
|
endpoint := &wgConn.StdNetEndpoint{AddrPort: addrPort}
|
||||||
|
|
||||||
|
// Send data
|
||||||
|
testData := []byte("WireGuard packet")
|
||||||
|
bufs := [][]byte{testData}
|
||||||
|
err = senderBind.Send(bufs, endpoint)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Send failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Receive data
|
||||||
|
buf := make([]byte, 1024)
|
||||||
|
receiverConn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||||
|
n, _, err := receiverConn.ReadFromUDP(buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to receive data: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(buf[:n]) != string(testData) {
|
||||||
|
t.Errorf("Expected to receive %q, got %q", testData, buf[:n])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSharedBindMultipleUsers simulates WireGuard and hole punch using the same bind
|
||||||
|
func TestSharedBindMultipleUsers(t *testing.T) {
|
||||||
|
// Create shared bind
|
||||||
|
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create UDP connection: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sharedBind, err := New(udpConn)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create SharedBind: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add reference for hole punch sender
|
||||||
|
sharedBind.AddRef()
|
||||||
|
|
||||||
|
// Create receiver
|
||||||
|
receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create receiver UDP connection: %v", err)
|
||||||
|
}
|
||||||
|
defer receiverConn.Close()
|
||||||
|
|
||||||
|
receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
// Simulate WireGuard using the bind
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
addrPort := receiverAddr.AddrPort()
|
||||||
|
endpoint := &wgConn.StdNetEndpoint{AddrPort: addrPort}
|
||||||
|
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
data := []byte("WireGuard packet")
|
||||||
|
bufs := [][]byte{data}
|
||||||
|
if err := sharedBind.Send(bufs, endpoint); err != nil {
|
||||||
|
t.Errorf("WireGuard Send failed: %v", err)
|
||||||
|
}
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Simulate hole punch sender using the bind
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
data := []byte("Hole punch packet")
|
||||||
|
if _, err := sharedBind.WriteToUDP(data, receiverAddr); err != nil {
|
||||||
|
t.Errorf("Hole punch WriteToUDP failed: %v", err)
|
||||||
|
}
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
// Release the hole punch reference
|
||||||
|
sharedBind.Release()
|
||||||
|
|
||||||
|
// Close WireGuard's reference (should close the connection)
|
||||||
|
sharedBind.Close()
|
||||||
|
|
||||||
|
if !sharedBind.closed.Load() {
|
||||||
|
t.Error("Expected bind to be closed after all users released it")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestEndpoint tests the Endpoint implementation
|
||||||
|
func TestEndpoint(t *testing.T) {
|
||||||
|
addr := netip.MustParseAddr("192.168.1.1")
|
||||||
|
addrPort := netip.AddrPortFrom(addr, 51820)
|
||||||
|
|
||||||
|
ep := &Endpoint{AddrPort: addrPort}
|
||||||
|
|
||||||
|
// Test DstIP
|
||||||
|
if ep.DstIP() != addr {
|
||||||
|
t.Errorf("Expected DstIP to be %v, got %v", addr, ep.DstIP())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test DstToString
|
||||||
|
expected := "192.168.1.1:51820"
|
||||||
|
if ep.DstToString() != expected {
|
||||||
|
t.Errorf("Expected DstToString to be %q, got %q", expected, ep.DstToString())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test DstToBytes
|
||||||
|
bytes := ep.DstToBytes()
|
||||||
|
if len(bytes) == 0 {
|
||||||
|
t.Error("Expected DstToBytes to return non-empty slice")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test SrcIP (should be zero)
|
||||||
|
if ep.SrcIP().IsValid() {
|
||||||
|
t.Error("Expected SrcIP to be invalid")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test ClearSrc (should not panic)
|
||||||
|
ep.ClearSrc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestParseEndpoint tests the ParseEndpoint method
|
||||||
|
func TestParseEndpoint(t *testing.T) {
|
||||||
|
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create UDP connection: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
bind, err := New(udpConn)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create SharedBind: %v", err)
|
||||||
|
}
|
||||||
|
defer bind.Close()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
wantErr bool
|
||||||
|
checkAddr func(*testing.T, wgConn.Endpoint)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid IPv4",
|
||||||
|
input: "192.168.1.1:51820",
|
||||||
|
wantErr: false,
|
||||||
|
checkAddr: func(t *testing.T, ep wgConn.Endpoint) {
|
||||||
|
if ep.DstToString() != "192.168.1.1:51820" {
|
||||||
|
t.Errorf("Expected 192.168.1.1:51820, got %s", ep.DstToString())
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid IPv6",
|
||||||
|
input: "[::1]:51820",
|
||||||
|
wantErr: false,
|
||||||
|
checkAddr: func(t *testing.T, ep wgConn.Endpoint) {
|
||||||
|
if ep.DstToString() != "[::1]:51820" {
|
||||||
|
t.Errorf("Expected [::1]:51820, got %s", ep.DstToString())
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid - missing port",
|
||||||
|
input: "192.168.1.1",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid - bad format",
|
||||||
|
input: "not-an-address",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ep, err := bind.ParseEndpoint(tt.input)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("ParseEndpoint() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !tt.wantErr && tt.checkAddr != nil {
|
||||||
|
tt.checkAddr(t, ep)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
1
olm-binary.REMOVED.git-id
Normal file
1
olm-binary.REMOVED.git-id
Normal file
@@ -0,0 +1 @@
|
|||||||
|
767662d6fa777b3bb77d47a1c44eb5fb60249e87
|
||||||
1
olm-test.REMOVED.git-id
Normal file
1
olm-test.REMOVED.git-id
Normal file
@@ -0,0 +1 @@
|
|||||||
|
ba2c118fd96937229ef54dcd0b82fe5d53d94a87
|
||||||
209
olm/common.go
209
olm/common.go
@@ -14,13 +14,13 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/fosrl/newt/logger"
|
"github.com/fosrl/newt/logger"
|
||||||
|
"github.com/fosrl/olm/bind"
|
||||||
"github.com/fosrl/olm/peermonitor"
|
"github.com/fosrl/olm/peermonitor"
|
||||||
"github.com/fosrl/olm/websocket"
|
"github.com/fosrl/olm/websocket"
|
||||||
"github.com/vishvananda/netlink"
|
"github.com/vishvananda/netlink"
|
||||||
"golang.org/x/crypto/chacha20poly1305"
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
"golang.org/x/crypto/curve25519"
|
"golang.org/x/crypto/curve25519"
|
||||||
"golang.org/x/exp/rand"
|
"golang.org/x/exp/rand"
|
||||||
"golang.zx2c4.com/wireguard/conn"
|
|
||||||
"golang.zx2c4.com/wireguard/device"
|
"golang.zx2c4.com/wireguard/device"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
)
|
)
|
||||||
@@ -82,11 +82,6 @@ const (
|
|||||||
ENV_WG_PROCESS_FOREGROUND = "WG_PROCESS_FOREGROUND"
|
ENV_WG_PROCESS_FOREGROUND = "WG_PROCESS_FOREGROUND"
|
||||||
)
|
)
|
||||||
|
|
||||||
type fixedPortBind struct {
|
|
||||||
port uint16
|
|
||||||
conn.Bind
|
|
||||||
}
|
|
||||||
|
|
||||||
// PeerAction represents a request to add, update, or remove a peer
|
// PeerAction represents a request to add, update, or remove a peer
|
||||||
type PeerAction struct {
|
type PeerAction struct {
|
||||||
Action string `json:"action"` // "add", "update", or "remove"
|
Action string `json:"action"` // "add", "update", or "remove"
|
||||||
@@ -124,11 +119,6 @@ type RelayPeerData struct {
|
|||||||
PublicKey string `json:"publicKey"`
|
PublicKey string `json:"publicKey"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *fixedPortBind) Open(port uint16) ([]conn.ReceiveFunc, uint16, error) {
|
|
||||||
// Ignore the requested port and use our fixed port
|
|
||||||
return b.Bind.Open(b.port)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper function to format endpoints correctly
|
// Helper function to format endpoints correctly
|
||||||
func formatEndpoint(endpoint string) string {
|
func formatEndpoint(endpoint string) string {
|
||||||
if endpoint == "" {
|
if endpoint == "" {
|
||||||
@@ -156,13 +146,6 @@ func formatEndpoint(endpoint string) string {
|
|||||||
return endpoint
|
return endpoint
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewFixedPortBind(port uint16) conn.Bind {
|
|
||||||
return &fixedPortBind{
|
|
||||||
port: port,
|
|
||||||
Bind: conn.NewDefaultBind(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func fixKey(key string) string {
|
func fixKey(key string) string {
|
||||||
// Remove any whitespace
|
// Remove any whitespace
|
||||||
key = strings.TrimSpace(key)
|
key = strings.TrimSpace(key)
|
||||||
@@ -523,6 +506,196 @@ func keepSendingUDPHolePunch(endpoint string, olmID string, sourcePort uint16, s
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// keepSendingUDPHolePunchToMultipleExitNodesWithSharedBind sends hole punch packets using the shared bind
|
||||||
|
func keepSendingUDPHolePunchToMultipleExitNodesWithSharedBind(exitNodes []ExitNode, olmID string, sharedBind *bind.SharedBind) {
|
||||||
|
if len(exitNodes) == 0 {
|
||||||
|
logger.Warn("No exit nodes provided for hole punching")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if hole punching is already running
|
||||||
|
if holePunchRunning {
|
||||||
|
logger.Debug("UDP hole punch already running, skipping new request")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the flag to indicate hole punching is running
|
||||||
|
holePunchRunning = true
|
||||||
|
defer func() {
|
||||||
|
holePunchRunning = false
|
||||||
|
logger.Info("UDP hole punch goroutine ended")
|
||||||
|
}()
|
||||||
|
|
||||||
|
logger.Info("Starting UDP hole punch to %d exit nodes with shared bind", len(exitNodes))
|
||||||
|
defer logger.Info("UDP hole punch goroutine ended for all exit nodes")
|
||||||
|
|
||||||
|
// Resolve all endpoints upfront
|
||||||
|
type resolvedExitNode struct {
|
||||||
|
remoteAddr *net.UDPAddr
|
||||||
|
publicKey string
|
||||||
|
endpointName string
|
||||||
|
}
|
||||||
|
|
||||||
|
var resolvedNodes []resolvedExitNode
|
||||||
|
for _, exitNode := range exitNodes {
|
||||||
|
host, err := resolveDomain(exitNode.Endpoint)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
serverAddr := net.JoinHostPort(host, "21820")
|
||||||
|
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to resolve UDP address for %s: %v", exitNode.Endpoint, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
resolvedNodes = append(resolvedNodes, resolvedExitNode{
|
||||||
|
remoteAddr: remoteAddr,
|
||||||
|
publicKey: exitNode.PublicKey,
|
||||||
|
endpointName: exitNode.Endpoint,
|
||||||
|
})
|
||||||
|
logger.Info("Resolved exit node: %s -> %s", exitNode.Endpoint, remoteAddr.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(resolvedNodes) == 0 {
|
||||||
|
logger.Error("No exit nodes could be resolved")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send initial hole punch to all exit nodes
|
||||||
|
for _, node := range resolvedNodes {
|
||||||
|
if err := sendUDPHolePunchWithBind(sharedBind, node.remoteAddr, olmID, node.publicKey); err != nil {
|
||||||
|
logger.Error("Failed to send initial UDP hole punch to %s: %v", node.endpointName, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ticker := time.NewTicker(250 * time.Millisecond)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
timeout := time.NewTimer(15 * time.Second)
|
||||||
|
defer timeout.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-stopHolepunch:
|
||||||
|
logger.Info("Stopping UDP holepunch for all exit nodes")
|
||||||
|
return
|
||||||
|
case <-timeout.C:
|
||||||
|
logger.Info("UDP holepunch routine timed out after 15 seconds for all exit nodes")
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
// Send hole punch to all exit nodes
|
||||||
|
for _, node := range resolvedNodes {
|
||||||
|
if err := sendUDPHolePunchWithBind(sharedBind, node.remoteAddr, olmID, node.publicKey); err != nil {
|
||||||
|
logger.Error("Failed to send UDP hole punch to %s: %v", node.endpointName, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// keepSendingUDPHolePunchWithSharedBind sends hole punch packets to a single endpoint using shared bind
|
||||||
|
func keepSendingUDPHolePunchWithSharedBind(endpoint string, olmID string, sharedBind *bind.SharedBind, serverPubKey string) {
|
||||||
|
// Check if hole punching is already running
|
||||||
|
if holePunchRunning {
|
||||||
|
logger.Debug("UDP hole punch already running, skipping new request")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the flag to indicate hole punching is running
|
||||||
|
holePunchRunning = true
|
||||||
|
defer func() {
|
||||||
|
holePunchRunning = false
|
||||||
|
logger.Info("UDP hole punch goroutine ended")
|
||||||
|
}()
|
||||||
|
|
||||||
|
logger.Info("Starting UDP hole punch to %s with shared bind", endpoint)
|
||||||
|
defer logger.Info("UDP hole punch goroutine ended for %s", endpoint)
|
||||||
|
|
||||||
|
host, err := resolveDomain(endpoint)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to resolve domain %s: %v", endpoint, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
serverAddr := net.JoinHostPort(host, "21820")
|
||||||
|
|
||||||
|
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute once immediately before starting the loop
|
||||||
|
if err := sendUDPHolePunchWithBind(sharedBind, remoteAddr, olmID, serverPubKey); err != nil {
|
||||||
|
logger.Error("Failed to send initial UDP hole punch: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ticker := time.NewTicker(250 * time.Millisecond)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
timeout := time.NewTimer(15 * time.Second)
|
||||||
|
defer timeout.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-stopHolepunch:
|
||||||
|
logger.Info("Stopping UDP holepunch")
|
||||||
|
return
|
||||||
|
case <-timeout.C:
|
||||||
|
logger.Info("UDP holepunch routine timed out after 15 seconds")
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
if err := sendUDPHolePunchWithBind(sharedBind, remoteAddr, olmID, serverPubKey); err != nil {
|
||||||
|
logger.Error("Failed to send UDP hole punch: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendUDPHolePunchWithBind sends an encrypted hole punch packet using the shared bind
|
||||||
|
func sendUDPHolePunchWithBind(sharedBind *bind.SharedBind, remoteAddr *net.UDPAddr, olmID string, serverPubKey string) error {
|
||||||
|
if serverPubKey == "" || olmToken == "" {
|
||||||
|
return fmt.Errorf("server public key or OLM token is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := struct {
|
||||||
|
OlmID string `json:"olmId"`
|
||||||
|
Token string `json:"token"`
|
||||||
|
}{
|
||||||
|
OlmID: olmID,
|
||||||
|
Token: olmToken,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert payload to JSON
|
||||||
|
payloadBytes, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal payload: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encrypt the payload using the server's WireGuard public key
|
||||||
|
encryptedPayload, err := encryptPayload(payloadBytes, serverPubKey)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to encrypt payload: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(encryptedPayload)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal encrypted payload: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = sharedBind.WriteToUDP(jsonData, remoteAddr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to write to UDP: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("Sent UDP hole punch to %s: %s", remoteAddr.String(), string(jsonData))
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) {
|
func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) {
|
||||||
if maxPort < minPort {
|
if maxPort < minPort {
|
||||||
return 0, fmt.Errorf("invalid port range: min=%d, max=%d", minPort, maxPort)
|
return 0, fmt.Errorf("invalid port range: min=%d, max=%d", minPort, maxPort)
|
||||||
|
|||||||
47
olm/olm.go
47
olm/olm.go
@@ -12,6 +12,7 @@ import (
|
|||||||
"github.com/fosrl/newt/logger"
|
"github.com/fosrl/newt/logger"
|
||||||
"github.com/fosrl/newt/updates"
|
"github.com/fosrl/newt/updates"
|
||||||
"github.com/fosrl/olm/api"
|
"github.com/fosrl/olm/api"
|
||||||
|
"github.com/fosrl/olm/bind"
|
||||||
"github.com/fosrl/olm/peermonitor"
|
"github.com/fosrl/olm/peermonitor"
|
||||||
"github.com/fosrl/olm/websocket"
|
"github.com/fosrl/olm/websocket"
|
||||||
"golang.zx2c4.com/wireguard/device"
|
"golang.zx2c4.com/wireguard/device"
|
||||||
@@ -67,6 +68,7 @@ var (
|
|||||||
olmClient *websocket.Client
|
olmClient *websocket.Client
|
||||||
tunnelCancel context.CancelFunc
|
tunnelCancel context.CancelFunc
|
||||||
tunnelRunning bool
|
tunnelRunning bool
|
||||||
|
sharedBind *bind.SharedBind
|
||||||
)
|
)
|
||||||
|
|
||||||
func Run(ctx context.Context, config Config) {
|
func Run(ctx context.Context, config Config) {
|
||||||
@@ -226,12 +228,38 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string,
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Create shared UDP socket for both holepunch and WireGuard
|
||||||
|
if sharedBind == nil {
|
||||||
sourcePort, err := FindAvailableUDPPort(49152, 65535)
|
sourcePort, err := FindAvailableUDPPort(49152, 65535)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Error finding available port: %v", err)
|
logger.Error("Error finding available port: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
localAddr := &net.UDPAddr{
|
||||||
|
Port: int(sourcePort),
|
||||||
|
IP: net.IPv4zero,
|
||||||
|
}
|
||||||
|
|
||||||
|
udpConn, err := net.ListenUDP("udp", localAddr)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to create shared UDP socket: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sharedBind, err = bind.New(udpConn)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to create shared bind: %v", err)
|
||||||
|
udpConn.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add a reference for the hole punch senders (creator already has one reference for WireGuard)
|
||||||
|
sharedBind.AddRef()
|
||||||
|
|
||||||
|
logger.Info("Created shared UDP socket on port %d (refcount: %d)", sourcePort, sharedBind.GetRefCount())
|
||||||
|
}
|
||||||
|
|
||||||
olm.RegisterHandler("olm/wg/holepunch/all", func(msg websocket.WSMessage) {
|
olm.RegisterHandler("olm/wg/holepunch/all", func(msg websocket.WSMessage) {
|
||||||
logger.Debug("Received message: %v", msg.Data)
|
logger.Debug("Received message: %v", msg.Data)
|
||||||
|
|
||||||
@@ -251,7 +279,7 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string,
|
|||||||
|
|
||||||
// Start a single hole punch goroutine for all exit nodes
|
// Start a single hole punch goroutine for all exit nodes
|
||||||
logger.Info("Starting hole punch for %d exit nodes", len(holePunchData.ExitNodes))
|
logger.Info("Starting hole punch for %d exit nodes", len(holePunchData.ExitNodes))
|
||||||
go keepSendingUDPHolePunchToMultipleExitNodes(holePunchData.ExitNodes, id, sourcePort)
|
go keepSendingUDPHolePunchToMultipleExitNodesWithSharedBind(holePunchData.ExitNodes, id, sharedBind)
|
||||||
})
|
})
|
||||||
|
|
||||||
olm.RegisterHandler("olm/wg/holepunch", func(msg websocket.WSMessage) {
|
olm.RegisterHandler("olm/wg/holepunch", func(msg websocket.WSMessage) {
|
||||||
@@ -289,7 +317,7 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string,
|
|||||||
|
|
||||||
// Start hole punching for each exit node
|
// Start hole punching for each exit node
|
||||||
logger.Info("Starting hole punch for exit node: %s with public key: %s", legacyHolePunchData.Endpoint, legacyHolePunchData.ServerPubKey)
|
logger.Info("Starting hole punch for exit node: %s with public key: %s", legacyHolePunchData.Endpoint, legacyHolePunchData.ServerPubKey)
|
||||||
go keepSendingUDPHolePunch(legacyHolePunchData.Endpoint, id, sourcePort, legacyHolePunchData.ServerPubKey)
|
go keepSendingUDPHolePunchWithSharedBind(legacyHolePunchData.Endpoint, id, sharedBind, legacyHolePunchData.ServerPubKey)
|
||||||
})
|
})
|
||||||
|
|
||||||
olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) {
|
olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) {
|
||||||
@@ -305,7 +333,7 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string,
|
|||||||
stopRegister = nil
|
stopRegister = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
close(stopHolepunch)
|
// close(stopHolepunch)
|
||||||
|
|
||||||
// wait 10 milliseconds to ensure the previous connection is closed
|
// wait 10 milliseconds to ensure the previous connection is closed
|
||||||
logger.Debug("Waiting 500 milliseconds to ensure previous connection is closed")
|
logger.Debug("Waiting 500 milliseconds to ensure previous connection is closed")
|
||||||
@@ -367,7 +395,7 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string,
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
dev = device.NewDevice(tdev, NewFixedPortBind(uint16(sourcePort)), device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: "))
|
dev = device.NewDevice(tdev, sharedBind, device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: "))
|
||||||
|
|
||||||
uapiListener, err = uapiListen(interfaceName, fileUAPI)
|
uapiListener, err = uapiListen(interfaceName, fileUAPI)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -804,7 +832,7 @@ func Stop() {
|
|||||||
uapiListener = nil
|
uapiListener = nil
|
||||||
}
|
}
|
||||||
if dev != nil {
|
if dev != nil {
|
||||||
dev.Close()
|
dev.Close() // This will call sharedBind.Close() which releases WireGuard's reference
|
||||||
dev = nil
|
dev = nil
|
||||||
}
|
}
|
||||||
// Close TUN device
|
// Close TUN device
|
||||||
@@ -813,6 +841,15 @@ func Stop() {
|
|||||||
tdev = nil
|
tdev = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Release the hole punch reference to the shared bind
|
||||||
|
if sharedBind != nil {
|
||||||
|
// Release hole punch reference (WireGuard already released its reference via dev.Close())
|
||||||
|
logger.Debug("Releasing shared bind (refcount before release: %d)", sharedBind.GetRefCount())
|
||||||
|
sharedBind.Release()
|
||||||
|
sharedBind = nil
|
||||||
|
logger.Info("Released shared UDP bind")
|
||||||
|
}
|
||||||
|
|
||||||
logger.Info("Olm service stopped")
|
logger.Info("Olm service stopped")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user