[management,proxy,client] Add L4 capabilities (TLS/TCP/UDP) (#5530)

This commit is contained in:
Viktor Liu
2026-03-14 01:36:44 +08:00
committed by GitHub
parent fe9b844511
commit 3e6baea405
90 changed files with 9611 additions and 1397 deletions

496
proxy/internal/udp/relay.go Normal file
View File

@@ -0,0 +1,496 @@
package udp
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"sync"
"sync/atomic"
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/time/rate"
"github.com/netbirdio/netbird/proxy/internal/accesslog"
"github.com/netbirdio/netbird/proxy/internal/netutil"
"github.com/netbirdio/netbird/proxy/internal/types"
)
const (
// DefaultSessionTTL is the default idle timeout for UDP sessions before cleanup.
DefaultSessionTTL = 30 * time.Second
// cleanupInterval is how often the cleaner goroutine runs.
cleanupInterval = time.Minute
// maxPacketSize is the maximum UDP packet size we'll handle.
maxPacketSize = 65535
// DefaultMaxSessions is the default cap on concurrent UDP sessions per relay.
DefaultMaxSessions = 1024
// sessionCreateRate limits new session creation per second.
sessionCreateRate = 50
// sessionCreateBurst is the burst allowance for session creation.
sessionCreateBurst = 100
// defaultDialTimeout is the fallback dial timeout for backend connections.
defaultDialTimeout = 30 * time.Second
)
// l4Logger sends layer-4 access log entries to the management server.
type l4Logger interface {
LogL4(entry accesslog.L4Entry)
}
// SessionObserver receives callbacks for UDP session lifecycle events.
// All methods must be safe for concurrent use.
type SessionObserver interface {
UDPSessionStarted(accountID types.AccountID)
UDPSessionEnded(accountID types.AccountID)
UDPSessionDialError(accountID types.AccountID)
UDPSessionRejected(accountID types.AccountID)
UDPPacketRelayed(direction types.RelayDirection, bytes int)
}
// clientAddr is a typed key for UDP session lookups.
type clientAddr string
// Relay listens for incoming UDP packets on a dedicated port and
// maintains per-client sessions that relay packets to a backend
// through the WireGuard tunnel.
type Relay struct {
logger *log.Entry
listener net.PacketConn
target string
domain string
accountID types.AccountID
serviceID types.ServiceID
dialFunc types.DialContextFunc
dialTimeout time.Duration
sessionTTL time.Duration
maxSessions int
mu sync.RWMutex
sessions map[clientAddr]*session
bufPool sync.Pool
sessLimiter *rate.Limiter
sessWg sync.WaitGroup
ctx context.Context
cancel context.CancelFunc
observer SessionObserver
accessLog l4Logger
}
type session struct {
backend net.Conn
addr net.Addr
createdAt time.Time
// lastSeen stores the last activity timestamp as unix nanoseconds.
lastSeen atomic.Int64
cancel context.CancelFunc
// bytesIn tracks total bytes received from the client.
bytesIn atomic.Int64
// bytesOut tracks total bytes sent back to the client.
bytesOut atomic.Int64
}
func (s *session) updateLastSeen() {
s.lastSeen.Store(time.Now().UnixNano())
}
func (s *session) idleDuration() time.Duration {
return time.Since(time.Unix(0, s.lastSeen.Load()))
}
// RelayConfig holds the configuration for a UDP relay.
type RelayConfig struct {
Logger *log.Entry
Listener net.PacketConn
Target string
Domain string
AccountID types.AccountID
ServiceID types.ServiceID
DialFunc types.DialContextFunc
DialTimeout time.Duration
SessionTTL time.Duration
MaxSessions int
AccessLog l4Logger
}
// New creates a UDP relay for the given listener and backend target.
// MaxSessions caps the number of concurrent sessions; use 0 for DefaultMaxSessions.
// DialTimeout controls how long to wait for backend connections; use 0 for default.
// SessionTTL is the idle timeout before a session is reaped; use 0 for DefaultSessionTTL.
func New(parentCtx context.Context, cfg RelayConfig) *Relay {
maxSessions := cfg.MaxSessions
dialTimeout := cfg.DialTimeout
sessionTTL := cfg.SessionTTL
if maxSessions <= 0 {
maxSessions = DefaultMaxSessions
}
if dialTimeout <= 0 {
dialTimeout = defaultDialTimeout
}
if sessionTTL <= 0 {
sessionTTL = DefaultSessionTTL
}
ctx, cancel := context.WithCancel(parentCtx)
return &Relay{
logger: cfg.Logger,
listener: cfg.Listener,
target: cfg.Target,
domain: cfg.Domain,
accountID: cfg.AccountID,
serviceID: cfg.ServiceID,
accessLog: cfg.AccessLog,
dialFunc: cfg.DialFunc,
dialTimeout: dialTimeout,
sessionTTL: sessionTTL,
maxSessions: maxSessions,
sessions: make(map[clientAddr]*session),
bufPool: sync.Pool{
New: func() any {
buf := make([]byte, maxPacketSize)
return &buf
},
},
sessLimiter: rate.NewLimiter(sessionCreateRate, sessionCreateBurst),
ctx: ctx,
cancel: cancel,
}
}
// ServiceID returns the service ID associated with this relay.
func (r *Relay) ServiceID() types.ServiceID {
return r.serviceID
}
// SetObserver sets the session lifecycle observer. Must be called before Serve.
func (r *Relay) SetObserver(obs SessionObserver) {
r.observer = obs
}
// Serve starts the relay loop. It blocks until the context is canceled
// or the listener is closed.
func (r *Relay) Serve() {
go r.cleanupLoop()
for {
bufp := r.bufPool.Get().(*[]byte)
buf := *bufp
n, addr, err := r.listener.ReadFrom(buf)
if err != nil {
r.bufPool.Put(bufp)
if r.ctx.Err() != nil || errors.Is(err, net.ErrClosed) {
return
}
r.logger.Debugf("UDP read: %v", err)
continue
}
data := buf[:n]
sess, err := r.getOrCreateSession(addr)
if err != nil {
r.bufPool.Put(bufp)
r.logger.Debugf("create UDP session for %s: %v", addr, err)
continue
}
sess.updateLastSeen()
nw, err := sess.backend.Write(data)
if err != nil {
r.bufPool.Put(bufp)
if !netutil.IsExpectedError(err) {
r.logger.Debugf("UDP write to backend for %s: %v", addr, err)
}
r.removeSession(sess)
continue
}
sess.bytesIn.Add(int64(nw))
if r.observer != nil {
r.observer.UDPPacketRelayed(types.RelayDirectionClientToBackend, nw)
}
r.bufPool.Put(bufp)
}
}
// getOrCreateSession returns an existing session or creates a new one.
func (r *Relay) getOrCreateSession(addr net.Addr) (*session, error) {
key := clientAddr(addr.String())
r.mu.RLock()
sess, ok := r.sessions[key]
r.mu.RUnlock()
if ok && sess != nil {
return sess, nil
}
// Check before taking the write lock: if the relay is shutting down,
// don't create new sessions. This prevents orphaned goroutines when
// Serve() processes a packet that was already read before Close().
if r.ctx.Err() != nil {
return nil, r.ctx.Err()
}
r.mu.Lock()
if sess, ok = r.sessions[key]; ok && sess != nil {
r.mu.Unlock()
return sess, nil
}
if ok {
// Another goroutine is dialing for this key, skip.
r.mu.Unlock()
return nil, fmt.Errorf("session dial in progress for %s", key)
}
if len(r.sessions) >= r.maxSessions {
r.mu.Unlock()
if r.observer != nil {
r.observer.UDPSessionRejected(r.accountID)
}
return nil, fmt.Errorf("session limit reached (%d)", r.maxSessions)
}
if !r.sessLimiter.Allow() {
r.mu.Unlock()
if r.observer != nil {
r.observer.UDPSessionRejected(r.accountID)
}
return nil, fmt.Errorf("session creation rate limited")
}
// Reserve the slot with a nil session so concurrent callers for the same
// key see it exists and wait. Release the lock before dialing.
r.sessions[key] = nil
r.mu.Unlock()
dialCtx, dialCancel := context.WithTimeout(r.ctx, r.dialTimeout)
backend, err := r.dialFunc(dialCtx, "udp", r.target)
dialCancel()
if err != nil {
r.mu.Lock()
delete(r.sessions, key)
r.mu.Unlock()
if r.observer != nil {
r.observer.UDPSessionDialError(r.accountID)
}
return nil, fmt.Errorf("dial backend %s: %w", r.target, err)
}
sessCtx, sessCancel := context.WithCancel(r.ctx)
sess = &session{
backend: backend,
addr: addr,
createdAt: time.Now(),
cancel: sessCancel,
}
sess.updateLastSeen()
r.mu.Lock()
r.sessions[key] = sess
r.mu.Unlock()
if r.observer != nil {
r.observer.UDPSessionStarted(r.accountID)
}
r.sessWg.Go(func() {
r.relayBackendToClient(sessCtx, sess)
})
r.logger.Debugf("UDP session created for %s", addr)
return sess, nil
}
// relayBackendToClient reads packets from the backend and writes them
// back to the client through the public-facing listener.
func (r *Relay) relayBackendToClient(ctx context.Context, sess *session) {
bufp := r.bufPool.Get().(*[]byte)
defer r.bufPool.Put(bufp)
defer r.removeSession(sess)
for ctx.Err() == nil {
data, ok := r.readBackendPacket(sess, *bufp)
if !ok {
return
}
if data == nil {
continue
}
sess.updateLastSeen()
nw, err := r.listener.WriteTo(data, sess.addr)
if err != nil {
if !netutil.IsExpectedError(err) {
r.logger.Debugf("UDP write to client %s: %v", sess.addr, err)
}
return
}
sess.bytesOut.Add(int64(nw))
if r.observer != nil {
r.observer.UDPPacketRelayed(types.RelayDirectionBackendToClient, nw)
}
}
}
// readBackendPacket reads one packet from the backend with an idle deadline.
// Returns (data, true) on success, (nil, true) on idle timeout that should
// retry, or (nil, false) when the session should be torn down.
func (r *Relay) readBackendPacket(sess *session, buf []byte) ([]byte, bool) {
if err := sess.backend.SetReadDeadline(time.Now().Add(r.sessionTTL)); err != nil {
r.logger.Debugf("set backend read deadline for %s: %v", sess.addr, err)
return nil, false
}
n, err := sess.backend.Read(buf)
if err != nil {
if netutil.IsTimeout(err) {
if sess.idleDuration() > r.sessionTTL {
return nil, false
}
return nil, true
}
if !netutil.IsExpectedError(err) {
r.logger.Debugf("UDP read from backend for %s: %v", sess.addr, err)
}
return nil, false
}
return buf[:n], true
}
// cleanupLoop periodically removes idle sessions.
func (r *Relay) cleanupLoop() {
ticker := time.NewTicker(cleanupInterval)
defer ticker.Stop()
for {
select {
case <-r.ctx.Done():
return
case <-ticker.C:
r.cleanupIdleSessions()
}
}
}
// cleanupIdleSessions closes sessions that have been idle for too long.
func (r *Relay) cleanupIdleSessions() {
var expired []*session
r.mu.Lock()
for key, sess := range r.sessions {
if sess == nil {
continue
}
idle := sess.idleDuration()
if idle > r.sessionTTL {
r.logger.Debugf("UDP session %s idle for %s, closing (client→backend: %d bytes, backend→client: %d bytes)",
sess.addr, idle, sess.bytesIn.Load(), sess.bytesOut.Load())
delete(r.sessions, key)
sess.cancel()
if err := sess.backend.Close(); err != nil {
r.logger.Debugf("close idle session %s backend: %v", sess.addr, err)
}
expired = append(expired, sess)
}
}
r.mu.Unlock()
for _, sess := range expired {
if r.observer != nil {
r.observer.UDPSessionEnded(r.accountID)
}
r.logSessionEnd(sess)
}
}
// removeSession removes a session from the map if it still matches the
// given pointer. This is safe to call concurrently with cleanupIdleSessions
// because the identity check prevents double-close when both paths race.
func (r *Relay) removeSession(sess *session) {
r.mu.Lock()
key := clientAddr(sess.addr.String())
removed := r.sessions[key] == sess
if removed {
delete(r.sessions, key)
sess.cancel()
if err := sess.backend.Close(); err != nil {
r.logger.Debugf("close session %s backend: %v", sess.addr, err)
}
}
r.mu.Unlock()
if removed {
r.logger.Debugf("UDP session %s ended (client→backend: %d bytes, backend→client: %d bytes)",
sess.addr, sess.bytesIn.Load(), sess.bytesOut.Load())
if r.observer != nil {
r.observer.UDPSessionEnded(r.accountID)
}
r.logSessionEnd(sess)
}
}
// logSessionEnd sends an access log entry for a completed UDP session.
func (r *Relay) logSessionEnd(sess *session) {
if r.accessLog == nil {
return
}
var sourceIP netip.Addr
if ap, err := netip.ParseAddrPort(sess.addr.String()); err == nil {
sourceIP = ap.Addr().Unmap()
}
r.accessLog.LogL4(accesslog.L4Entry{
AccountID: r.accountID,
ServiceID: r.serviceID,
Protocol: accesslog.ProtocolUDP,
Host: r.domain,
SourceIP: sourceIP,
DurationMs: time.Unix(0, sess.lastSeen.Load()).Sub(sess.createdAt).Milliseconds(),
BytesUpload: sess.bytesIn.Load(),
BytesDownload: sess.bytesOut.Load(),
})
}
// Close stops the relay, waits for all session goroutines to exit,
// and cleans up remaining sessions.
func (r *Relay) Close() {
r.cancel()
if err := r.listener.Close(); err != nil {
r.logger.Debugf("close UDP listener: %v", err)
}
var closedSessions []*session
r.mu.Lock()
for key, sess := range r.sessions {
if sess == nil {
delete(r.sessions, key)
continue
}
r.logger.Debugf("UDP session %s closed (client→backend: %d bytes, backend→client: %d bytes)",
sess.addr, sess.bytesIn.Load(), sess.bytesOut.Load())
sess.cancel()
if err := sess.backend.Close(); err != nil {
r.logger.Debugf("close session %s backend: %v", sess.addr, err)
}
delete(r.sessions, key)
closedSessions = append(closedSessions, sess)
}
r.mu.Unlock()
for _, sess := range closedSessions {
if r.observer != nil {
r.observer.UDPSessionEnded(r.accountID)
}
r.logSessionEnd(sess)
}
r.sessWg.Wait()
}

View File

@@ -0,0 +1,493 @@
package udp
import (
"context"
"fmt"
"net"
"sync"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/proxy/internal/types"
)
func TestRelay_BasicPacketExchange(t *testing.T) {
// Set up a UDP backend that echoes packets.
backend, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
defer backend.Close()
go func() {
buf := make([]byte, 65535)
for {
n, addr, err := backend.ReadFrom(buf)
if err != nil {
return
}
_, _ = backend.WriteTo(buf[:n], addr)
}
}()
// Set up the relay's public-facing listener.
listener, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
defer listener.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
logger := log.NewEntry(log.StandardLogger())
backendAddr := backend.LocalAddr().String()
dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) {
return net.Dial(network, address)
}
relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: backendAddr, DialFunc: dialFunc})
go relay.Serve()
defer relay.Close()
// Create a client and send a packet to the relay.
client, err := net.Dial("udp", listener.LocalAddr().String())
require.NoError(t, err)
defer client.Close()
testData := []byte("hello UDP relay")
_, err = client.Write(testData)
require.NoError(t, err)
// Read the echoed response.
if err := client.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
t.Fatal(err)
}
buf := make([]byte, 1024)
n, err := client.Read(buf)
require.NoError(t, err)
assert.Equal(t, testData, buf[:n], "should receive echoed packet")
}
func TestRelay_MultipleClients(t *testing.T) {
backend, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
defer backend.Close()
go func() {
buf := make([]byte, 65535)
for {
n, addr, err := backend.ReadFrom(buf)
if err != nil {
return
}
_, _ = backend.WriteTo(buf[:n], addr)
}
}()
listener, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
defer listener.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
logger := log.NewEntry(log.StandardLogger())
dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) {
return net.Dial(network, address)
}
relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: backend.LocalAddr().String(), DialFunc: dialFunc})
go relay.Serve()
defer relay.Close()
// Two clients, each should get their own session.
for i, msg := range []string{"client-1", "client-2"} {
client, err := net.Dial("udp", listener.LocalAddr().String())
require.NoError(t, err, "client %d", i)
defer client.Close()
_, err = client.Write([]byte(msg))
require.NoError(t, err)
if err := client.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
t.Fatal(err)
}
buf := make([]byte, 1024)
n, err := client.Read(buf)
require.NoError(t, err, "client %d read", i)
assert.Equal(t, msg, string(buf[:n]), "client %d should get own echo", i)
}
// Verify two sessions were created.
relay.mu.RLock()
sessionCount := len(relay.sessions)
relay.mu.RUnlock()
assert.Equal(t, 2, sessionCount, "should have two sessions")
}
func TestRelay_Close(t *testing.T) {
listener, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
logger := log.NewEntry(log.StandardLogger())
dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) {
return net.Dial(network, address)
}
relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: "127.0.0.1:9999", DialFunc: dialFunc})
done := make(chan struct{})
go func() {
relay.Serve()
close(done)
}()
relay.Close()
select {
case <-done:
case <-time.After(5 * time.Second):
t.Fatal("Serve did not return after Close")
}
}
func TestRelay_SessionCleanup(t *testing.T) {
backend, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
defer backend.Close()
go func() {
buf := make([]byte, 65535)
for {
n, addr, err := backend.ReadFrom(buf)
if err != nil {
return
}
_, _ = backend.WriteTo(buf[:n], addr)
}
}()
listener, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
defer listener.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
logger := log.NewEntry(log.StandardLogger())
dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) {
return net.Dial(network, address)
}
relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: backend.LocalAddr().String(), DialFunc: dialFunc})
go relay.Serve()
defer relay.Close()
// Create a session.
client, err := net.Dial("udp", listener.LocalAddr().String())
require.NoError(t, err)
_, err = client.Write([]byte("hello"))
require.NoError(t, err)
if err := client.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
t.Fatal(err)
}
buf := make([]byte, 1024)
_, err = client.Read(buf)
require.NoError(t, err)
client.Close()
// Verify session exists.
relay.mu.RLock()
assert.Equal(t, 1, len(relay.sessions))
relay.mu.RUnlock()
// Make session appear idle by setting lastSeen to the past.
relay.mu.Lock()
for _, sess := range relay.sessions {
sess.lastSeen.Store(time.Now().Add(-2 * DefaultSessionTTL).UnixNano())
}
relay.mu.Unlock()
// Trigger cleanup manually.
relay.cleanupIdleSessions()
relay.mu.RLock()
assert.Equal(t, 0, len(relay.sessions), "idle sessions should be cleaned up")
relay.mu.RUnlock()
}
// TestRelay_CloseAndRecreate verifies that closing a relay and creating a new
// one on the same port works cleanly (simulates port mapping modify cycle).
func TestRelay_CloseAndRecreate(t *testing.T) {
backend, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
defer backend.Close()
go func() {
buf := make([]byte, 65535)
for {
n, addr, err := backend.ReadFrom(buf)
if err != nil {
return
}
_, _ = backend.WriteTo(buf[:n], addr)
}
}()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
logger := log.NewEntry(log.StandardLogger())
dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) {
return net.Dial(network, address)
}
// First relay.
ln1, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
relay1 := New(ctx, RelayConfig{Logger: logger, Listener: ln1, Target: backend.LocalAddr().String(), DialFunc: dialFunc})
go relay1.Serve()
client1, err := net.Dial("udp", ln1.LocalAddr().String())
require.NoError(t, err)
_, err = client1.Write([]byte("relay1"))
require.NoError(t, err)
require.NoError(t, client1.SetReadDeadline(time.Now().Add(2*time.Second)))
buf := make([]byte, 1024)
n, err := client1.Read(buf)
require.NoError(t, err)
assert.Equal(t, "relay1", string(buf[:n]))
client1.Close()
// Close first relay.
relay1.Close()
// Second relay on same port.
port := ln1.LocalAddr().(*net.UDPAddr).Port
ln2, err := net.ListenPacket("udp", fmt.Sprintf("127.0.0.1:%d", port))
require.NoError(t, err)
relay2 := New(ctx, RelayConfig{Logger: logger, Listener: ln2, Target: backend.LocalAddr().String(), DialFunc: dialFunc})
go relay2.Serve()
defer relay2.Close()
client2, err := net.Dial("udp", ln2.LocalAddr().String())
require.NoError(t, err)
defer client2.Close()
_, err = client2.Write([]byte("relay2"))
require.NoError(t, err)
require.NoError(t, client2.SetReadDeadline(time.Now().Add(2*time.Second)))
n, err = client2.Read(buf)
require.NoError(t, err)
assert.Equal(t, "relay2", string(buf[:n]), "second relay should work on same port")
}
func TestRelay_SessionLimit(t *testing.T) {
backend, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
defer backend.Close()
go func() {
buf := make([]byte, 65535)
for {
n, addr, err := backend.ReadFrom(buf)
if err != nil {
return
}
_, _ = backend.WriteTo(buf[:n], addr)
}
}()
listener, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
defer listener.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
logger := log.NewEntry(log.StandardLogger())
dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) {
return net.Dial(network, address)
}
// Create a relay with a max of 2 sessions.
relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: backend.LocalAddr().String(), DialFunc: dialFunc, MaxSessions: 2})
go relay.Serve()
defer relay.Close()
// Create 2 clients to fill up the session limit.
for i := range 2 {
client, err := net.Dial("udp", listener.LocalAddr().String())
require.NoError(t, err, "client %d", i)
defer client.Close()
_, err = client.Write([]byte("hello"))
require.NoError(t, err)
require.NoError(t, client.SetReadDeadline(time.Now().Add(2*time.Second)))
buf := make([]byte, 1024)
_, err = client.Read(buf)
require.NoError(t, err, "client %d should get response", i)
}
relay.mu.RLock()
assert.Equal(t, 2, len(relay.sessions), "should have exactly 2 sessions")
relay.mu.RUnlock()
// Third client should get its packet dropped (session creation fails).
client3, err := net.Dial("udp", listener.LocalAddr().String())
require.NoError(t, err)
defer client3.Close()
_, err = client3.Write([]byte("should be dropped"))
require.NoError(t, err)
require.NoError(t, client3.SetReadDeadline(time.Now().Add(500*time.Millisecond)))
buf := make([]byte, 1024)
_, err = client3.Read(buf)
assert.Error(t, err, "third client should time out because session was rejected")
relay.mu.RLock()
assert.Equal(t, 2, len(relay.sessions), "session count should not exceed limit")
relay.mu.RUnlock()
}
// testObserver records UDP session lifecycle events for test assertions.
type testObserver struct {
mu sync.Mutex
started int
ended int
rejected int
dialErr int
packets int
bytes int
}
func (o *testObserver) UDPSessionStarted(types.AccountID) { o.mu.Lock(); o.started++; o.mu.Unlock() }
func (o *testObserver) UDPSessionEnded(types.AccountID) { o.mu.Lock(); o.ended++; o.mu.Unlock() }
func (o *testObserver) UDPSessionDialError(types.AccountID) { o.mu.Lock(); o.dialErr++; o.mu.Unlock() }
func (o *testObserver) UDPSessionRejected(types.AccountID) { o.mu.Lock(); o.rejected++; o.mu.Unlock() }
func (o *testObserver) UDPPacketRelayed(_ types.RelayDirection, b int) {
o.mu.Lock()
o.packets++
o.bytes += b
o.mu.Unlock()
}
func TestRelay_CloseFiresObserverEnded(t *testing.T) {
backend, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
defer backend.Close()
go func() {
buf := make([]byte, 65535)
for {
n, addr, err := backend.ReadFrom(buf)
if err != nil {
return
}
_, _ = backend.WriteTo(buf[:n], addr)
}
}()
listener, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
defer listener.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
logger := log.NewEntry(log.StandardLogger())
dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) {
return net.Dial(network, address)
}
obs := &testObserver{}
relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: backend.LocalAddr().String(), AccountID: "test-acct", DialFunc: dialFunc})
relay.SetObserver(obs)
go relay.Serve()
// Create two sessions.
for i := range 2 {
client, err := net.Dial("udp", listener.LocalAddr().String())
require.NoError(t, err, "client %d", i)
_, err = client.Write([]byte("hello"))
require.NoError(t, err)
require.NoError(t, client.SetReadDeadline(time.Now().Add(2*time.Second)))
buf := make([]byte, 1024)
_, err = client.Read(buf)
require.NoError(t, err)
client.Close()
}
obs.mu.Lock()
assert.Equal(t, 2, obs.started, "should have 2 started events")
obs.mu.Unlock()
// Close should fire UDPSessionEnded for all remaining sessions.
relay.Close()
obs.mu.Lock()
assert.Equal(t, 2, obs.ended, "Close should fire UDPSessionEnded for each session")
obs.mu.Unlock()
}
func TestRelay_SessionRateLimit(t *testing.T) {
backend, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
defer backend.Close()
go func() {
buf := make([]byte, 65535)
for {
n, addr, err := backend.ReadFrom(buf)
if err != nil {
return
}
_, _ = backend.WriteTo(buf[:n], addr)
}
}()
listener, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
defer listener.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
logger := log.NewEntry(log.StandardLogger())
dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) {
return net.Dial(network, address)
}
obs := &testObserver{}
// High max sessions (1000) but the relay uses a rate limiter internally
// (default: 50/s burst 100). We exhaust the burst by creating sessions
// rapidly, then verify that subsequent creates are rejected.
relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: backend.LocalAddr().String(), AccountID: "test-acct", DialFunc: dialFunc, MaxSessions: 1000})
relay.SetObserver(obs)
go relay.Serve()
defer relay.Close()
// Exhaust the burst by calling getOrCreateSession directly with
// synthetic addresses. This is faster than real UDP round-trips.
for i := range sessionCreateBurst + 20 {
addr := &net.UDPAddr{IP: net.IPv4(10, 0, byte(i/256), byte(i%256)), Port: 10000 + i}
_, _ = relay.getOrCreateSession(addr)
}
obs.mu.Lock()
rejected := obs.rejected
obs.mu.Unlock()
assert.Greater(t, rejected, 0, "some sessions should be rate-limited")
}