mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
[management,proxy,client] Add L4 capabilities (TLS/TCP/UDP) (#5530)
This commit is contained in:
496
proxy/internal/udp/relay.go
Normal file
496
proxy/internal/udp/relay.go
Normal file
@@ -0,0 +1,496 @@
|
||||
package udp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/time/rate"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/accesslog"
|
||||
"github.com/netbirdio/netbird/proxy/internal/netutil"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultSessionTTL is the default idle timeout for UDP sessions before cleanup.
|
||||
DefaultSessionTTL = 30 * time.Second
|
||||
// cleanupInterval is how often the cleaner goroutine runs.
|
||||
cleanupInterval = time.Minute
|
||||
// maxPacketSize is the maximum UDP packet size we'll handle.
|
||||
maxPacketSize = 65535
|
||||
// DefaultMaxSessions is the default cap on concurrent UDP sessions per relay.
|
||||
DefaultMaxSessions = 1024
|
||||
// sessionCreateRate limits new session creation per second.
|
||||
sessionCreateRate = 50
|
||||
// sessionCreateBurst is the burst allowance for session creation.
|
||||
sessionCreateBurst = 100
|
||||
// defaultDialTimeout is the fallback dial timeout for backend connections.
|
||||
defaultDialTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
// l4Logger sends layer-4 access log entries to the management server.
|
||||
type l4Logger interface {
|
||||
LogL4(entry accesslog.L4Entry)
|
||||
}
|
||||
|
||||
// SessionObserver receives callbacks for UDP session lifecycle events.
|
||||
// All methods must be safe for concurrent use.
|
||||
type SessionObserver interface {
|
||||
UDPSessionStarted(accountID types.AccountID)
|
||||
UDPSessionEnded(accountID types.AccountID)
|
||||
UDPSessionDialError(accountID types.AccountID)
|
||||
UDPSessionRejected(accountID types.AccountID)
|
||||
UDPPacketRelayed(direction types.RelayDirection, bytes int)
|
||||
}
|
||||
|
||||
// clientAddr is a typed key for UDP session lookups.
|
||||
type clientAddr string
|
||||
|
||||
// Relay listens for incoming UDP packets on a dedicated port and
|
||||
// maintains per-client sessions that relay packets to a backend
|
||||
// through the WireGuard tunnel.
|
||||
type Relay struct {
|
||||
logger *log.Entry
|
||||
listener net.PacketConn
|
||||
target string
|
||||
domain string
|
||||
accountID types.AccountID
|
||||
serviceID types.ServiceID
|
||||
dialFunc types.DialContextFunc
|
||||
dialTimeout time.Duration
|
||||
sessionTTL time.Duration
|
||||
maxSessions int
|
||||
|
||||
mu sync.RWMutex
|
||||
sessions map[clientAddr]*session
|
||||
|
||||
bufPool sync.Pool
|
||||
sessLimiter *rate.Limiter
|
||||
sessWg sync.WaitGroup
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
observer SessionObserver
|
||||
accessLog l4Logger
|
||||
}
|
||||
|
||||
type session struct {
|
||||
backend net.Conn
|
||||
addr net.Addr
|
||||
createdAt time.Time
|
||||
// lastSeen stores the last activity timestamp as unix nanoseconds.
|
||||
lastSeen atomic.Int64
|
||||
cancel context.CancelFunc
|
||||
// bytesIn tracks total bytes received from the client.
|
||||
bytesIn atomic.Int64
|
||||
// bytesOut tracks total bytes sent back to the client.
|
||||
bytesOut atomic.Int64
|
||||
}
|
||||
|
||||
func (s *session) updateLastSeen() {
|
||||
s.lastSeen.Store(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
func (s *session) idleDuration() time.Duration {
|
||||
return time.Since(time.Unix(0, s.lastSeen.Load()))
|
||||
}
|
||||
|
||||
// RelayConfig holds the configuration for a UDP relay.
|
||||
type RelayConfig struct {
|
||||
Logger *log.Entry
|
||||
Listener net.PacketConn
|
||||
Target string
|
||||
Domain string
|
||||
AccountID types.AccountID
|
||||
ServiceID types.ServiceID
|
||||
DialFunc types.DialContextFunc
|
||||
DialTimeout time.Duration
|
||||
SessionTTL time.Duration
|
||||
MaxSessions int
|
||||
AccessLog l4Logger
|
||||
}
|
||||
|
||||
// New creates a UDP relay for the given listener and backend target.
|
||||
// MaxSessions caps the number of concurrent sessions; use 0 for DefaultMaxSessions.
|
||||
// DialTimeout controls how long to wait for backend connections; use 0 for default.
|
||||
// SessionTTL is the idle timeout before a session is reaped; use 0 for DefaultSessionTTL.
|
||||
func New(parentCtx context.Context, cfg RelayConfig) *Relay {
|
||||
maxSessions := cfg.MaxSessions
|
||||
dialTimeout := cfg.DialTimeout
|
||||
sessionTTL := cfg.SessionTTL
|
||||
if maxSessions <= 0 {
|
||||
maxSessions = DefaultMaxSessions
|
||||
}
|
||||
if dialTimeout <= 0 {
|
||||
dialTimeout = defaultDialTimeout
|
||||
}
|
||||
if sessionTTL <= 0 {
|
||||
sessionTTL = DefaultSessionTTL
|
||||
}
|
||||
ctx, cancel := context.WithCancel(parentCtx)
|
||||
return &Relay{
|
||||
logger: cfg.Logger,
|
||||
listener: cfg.Listener,
|
||||
target: cfg.Target,
|
||||
domain: cfg.Domain,
|
||||
accountID: cfg.AccountID,
|
||||
serviceID: cfg.ServiceID,
|
||||
accessLog: cfg.AccessLog,
|
||||
dialFunc: cfg.DialFunc,
|
||||
dialTimeout: dialTimeout,
|
||||
sessionTTL: sessionTTL,
|
||||
maxSessions: maxSessions,
|
||||
sessions: make(map[clientAddr]*session),
|
||||
bufPool: sync.Pool{
|
||||
New: func() any {
|
||||
buf := make([]byte, maxPacketSize)
|
||||
return &buf
|
||||
},
|
||||
},
|
||||
sessLimiter: rate.NewLimiter(sessionCreateRate, sessionCreateBurst),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
// ServiceID returns the service ID associated with this relay.
|
||||
func (r *Relay) ServiceID() types.ServiceID {
|
||||
return r.serviceID
|
||||
}
|
||||
|
||||
// SetObserver sets the session lifecycle observer. Must be called before Serve.
|
||||
func (r *Relay) SetObserver(obs SessionObserver) {
|
||||
r.observer = obs
|
||||
}
|
||||
|
||||
// Serve starts the relay loop. It blocks until the context is canceled
|
||||
// or the listener is closed.
|
||||
func (r *Relay) Serve() {
|
||||
go r.cleanupLoop()
|
||||
|
||||
for {
|
||||
bufp := r.bufPool.Get().(*[]byte)
|
||||
buf := *bufp
|
||||
|
||||
n, addr, err := r.listener.ReadFrom(buf)
|
||||
if err != nil {
|
||||
r.bufPool.Put(bufp)
|
||||
if r.ctx.Err() != nil || errors.Is(err, net.ErrClosed) {
|
||||
return
|
||||
}
|
||||
r.logger.Debugf("UDP read: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
data := buf[:n]
|
||||
sess, err := r.getOrCreateSession(addr)
|
||||
if err != nil {
|
||||
r.bufPool.Put(bufp)
|
||||
r.logger.Debugf("create UDP session for %s: %v", addr, err)
|
||||
continue
|
||||
}
|
||||
|
||||
sess.updateLastSeen()
|
||||
|
||||
nw, err := sess.backend.Write(data)
|
||||
if err != nil {
|
||||
r.bufPool.Put(bufp)
|
||||
if !netutil.IsExpectedError(err) {
|
||||
r.logger.Debugf("UDP write to backend for %s: %v", addr, err)
|
||||
}
|
||||
r.removeSession(sess)
|
||||
continue
|
||||
}
|
||||
sess.bytesIn.Add(int64(nw))
|
||||
|
||||
if r.observer != nil {
|
||||
r.observer.UDPPacketRelayed(types.RelayDirectionClientToBackend, nw)
|
||||
}
|
||||
r.bufPool.Put(bufp)
|
||||
}
|
||||
}
|
||||
|
||||
// getOrCreateSession returns an existing session or creates a new one.
|
||||
func (r *Relay) getOrCreateSession(addr net.Addr) (*session, error) {
|
||||
key := clientAddr(addr.String())
|
||||
|
||||
r.mu.RLock()
|
||||
sess, ok := r.sessions[key]
|
||||
r.mu.RUnlock()
|
||||
if ok && sess != nil {
|
||||
return sess, nil
|
||||
}
|
||||
|
||||
// Check before taking the write lock: if the relay is shutting down,
|
||||
// don't create new sessions. This prevents orphaned goroutines when
|
||||
// Serve() processes a packet that was already read before Close().
|
||||
if r.ctx.Err() != nil {
|
||||
return nil, r.ctx.Err()
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
|
||||
if sess, ok = r.sessions[key]; ok && sess != nil {
|
||||
r.mu.Unlock()
|
||||
return sess, nil
|
||||
}
|
||||
if ok {
|
||||
// Another goroutine is dialing for this key, skip.
|
||||
r.mu.Unlock()
|
||||
return nil, fmt.Errorf("session dial in progress for %s", key)
|
||||
}
|
||||
|
||||
if len(r.sessions) >= r.maxSessions {
|
||||
r.mu.Unlock()
|
||||
if r.observer != nil {
|
||||
r.observer.UDPSessionRejected(r.accountID)
|
||||
}
|
||||
return nil, fmt.Errorf("session limit reached (%d)", r.maxSessions)
|
||||
}
|
||||
|
||||
if !r.sessLimiter.Allow() {
|
||||
r.mu.Unlock()
|
||||
if r.observer != nil {
|
||||
r.observer.UDPSessionRejected(r.accountID)
|
||||
}
|
||||
return nil, fmt.Errorf("session creation rate limited")
|
||||
}
|
||||
|
||||
// Reserve the slot with a nil session so concurrent callers for the same
|
||||
// key see it exists and wait. Release the lock before dialing.
|
||||
r.sessions[key] = nil
|
||||
r.mu.Unlock()
|
||||
|
||||
dialCtx, dialCancel := context.WithTimeout(r.ctx, r.dialTimeout)
|
||||
backend, err := r.dialFunc(dialCtx, "udp", r.target)
|
||||
dialCancel()
|
||||
if err != nil {
|
||||
r.mu.Lock()
|
||||
delete(r.sessions, key)
|
||||
r.mu.Unlock()
|
||||
if r.observer != nil {
|
||||
r.observer.UDPSessionDialError(r.accountID)
|
||||
}
|
||||
return nil, fmt.Errorf("dial backend %s: %w", r.target, err)
|
||||
}
|
||||
|
||||
sessCtx, sessCancel := context.WithCancel(r.ctx)
|
||||
sess = &session{
|
||||
backend: backend,
|
||||
addr: addr,
|
||||
createdAt: time.Now(),
|
||||
cancel: sessCancel,
|
||||
}
|
||||
sess.updateLastSeen()
|
||||
|
||||
r.mu.Lock()
|
||||
r.sessions[key] = sess
|
||||
r.mu.Unlock()
|
||||
|
||||
if r.observer != nil {
|
||||
r.observer.UDPSessionStarted(r.accountID)
|
||||
}
|
||||
|
||||
r.sessWg.Go(func() {
|
||||
r.relayBackendToClient(sessCtx, sess)
|
||||
})
|
||||
|
||||
r.logger.Debugf("UDP session created for %s", addr)
|
||||
return sess, nil
|
||||
}
|
||||
|
||||
// relayBackendToClient reads packets from the backend and writes them
|
||||
// back to the client through the public-facing listener.
|
||||
func (r *Relay) relayBackendToClient(ctx context.Context, sess *session) {
|
||||
bufp := r.bufPool.Get().(*[]byte)
|
||||
defer r.bufPool.Put(bufp)
|
||||
defer r.removeSession(sess)
|
||||
|
||||
for ctx.Err() == nil {
|
||||
data, ok := r.readBackendPacket(sess, *bufp)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if data == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
sess.updateLastSeen()
|
||||
|
||||
nw, err := r.listener.WriteTo(data, sess.addr)
|
||||
if err != nil {
|
||||
if !netutil.IsExpectedError(err) {
|
||||
r.logger.Debugf("UDP write to client %s: %v", sess.addr, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
sess.bytesOut.Add(int64(nw))
|
||||
|
||||
if r.observer != nil {
|
||||
r.observer.UDPPacketRelayed(types.RelayDirectionBackendToClient, nw)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// readBackendPacket reads one packet from the backend with an idle deadline.
|
||||
// Returns (data, true) on success, (nil, true) on idle timeout that should
|
||||
// retry, or (nil, false) when the session should be torn down.
|
||||
func (r *Relay) readBackendPacket(sess *session, buf []byte) ([]byte, bool) {
|
||||
if err := sess.backend.SetReadDeadline(time.Now().Add(r.sessionTTL)); err != nil {
|
||||
r.logger.Debugf("set backend read deadline for %s: %v", sess.addr, err)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
n, err := sess.backend.Read(buf)
|
||||
if err != nil {
|
||||
if netutil.IsTimeout(err) {
|
||||
if sess.idleDuration() > r.sessionTTL {
|
||||
return nil, false
|
||||
}
|
||||
return nil, true
|
||||
}
|
||||
if !netutil.IsExpectedError(err) {
|
||||
r.logger.Debugf("UDP read from backend for %s: %v", sess.addr, err)
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return buf[:n], true
|
||||
}
|
||||
|
||||
// cleanupLoop periodically removes idle sessions.
|
||||
func (r *Relay) cleanupLoop() {
|
||||
ticker := time.NewTicker(cleanupInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-r.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
r.cleanupIdleSessions()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupIdleSessions closes sessions that have been idle for too long.
|
||||
func (r *Relay) cleanupIdleSessions() {
|
||||
var expired []*session
|
||||
|
||||
r.mu.Lock()
|
||||
for key, sess := range r.sessions {
|
||||
if sess == nil {
|
||||
continue
|
||||
}
|
||||
idle := sess.idleDuration()
|
||||
if idle > r.sessionTTL {
|
||||
r.logger.Debugf("UDP session %s idle for %s, closing (client→backend: %d bytes, backend→client: %d bytes)",
|
||||
sess.addr, idle, sess.bytesIn.Load(), sess.bytesOut.Load())
|
||||
delete(r.sessions, key)
|
||||
sess.cancel()
|
||||
if err := sess.backend.Close(); err != nil {
|
||||
r.logger.Debugf("close idle session %s backend: %v", sess.addr, err)
|
||||
}
|
||||
expired = append(expired, sess)
|
||||
}
|
||||
}
|
||||
r.mu.Unlock()
|
||||
|
||||
for _, sess := range expired {
|
||||
if r.observer != nil {
|
||||
r.observer.UDPSessionEnded(r.accountID)
|
||||
}
|
||||
r.logSessionEnd(sess)
|
||||
}
|
||||
}
|
||||
|
||||
// removeSession removes a session from the map if it still matches the
|
||||
// given pointer. This is safe to call concurrently with cleanupIdleSessions
|
||||
// because the identity check prevents double-close when both paths race.
|
||||
func (r *Relay) removeSession(sess *session) {
|
||||
r.mu.Lock()
|
||||
key := clientAddr(sess.addr.String())
|
||||
removed := r.sessions[key] == sess
|
||||
if removed {
|
||||
delete(r.sessions, key)
|
||||
sess.cancel()
|
||||
if err := sess.backend.Close(); err != nil {
|
||||
r.logger.Debugf("close session %s backend: %v", sess.addr, err)
|
||||
}
|
||||
}
|
||||
r.mu.Unlock()
|
||||
|
||||
if removed {
|
||||
r.logger.Debugf("UDP session %s ended (client→backend: %d bytes, backend→client: %d bytes)",
|
||||
sess.addr, sess.bytesIn.Load(), sess.bytesOut.Load())
|
||||
if r.observer != nil {
|
||||
r.observer.UDPSessionEnded(r.accountID)
|
||||
}
|
||||
r.logSessionEnd(sess)
|
||||
}
|
||||
}
|
||||
|
||||
// logSessionEnd sends an access log entry for a completed UDP session.
|
||||
func (r *Relay) logSessionEnd(sess *session) {
|
||||
if r.accessLog == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var sourceIP netip.Addr
|
||||
if ap, err := netip.ParseAddrPort(sess.addr.String()); err == nil {
|
||||
sourceIP = ap.Addr().Unmap()
|
||||
}
|
||||
|
||||
r.accessLog.LogL4(accesslog.L4Entry{
|
||||
AccountID: r.accountID,
|
||||
ServiceID: r.serviceID,
|
||||
Protocol: accesslog.ProtocolUDP,
|
||||
Host: r.domain,
|
||||
SourceIP: sourceIP,
|
||||
DurationMs: time.Unix(0, sess.lastSeen.Load()).Sub(sess.createdAt).Milliseconds(),
|
||||
BytesUpload: sess.bytesIn.Load(),
|
||||
BytesDownload: sess.bytesOut.Load(),
|
||||
})
|
||||
}
|
||||
|
||||
// Close stops the relay, waits for all session goroutines to exit,
|
||||
// and cleans up remaining sessions.
|
||||
func (r *Relay) Close() {
|
||||
r.cancel()
|
||||
if err := r.listener.Close(); err != nil {
|
||||
r.logger.Debugf("close UDP listener: %v", err)
|
||||
}
|
||||
|
||||
var closedSessions []*session
|
||||
r.mu.Lock()
|
||||
for key, sess := range r.sessions {
|
||||
if sess == nil {
|
||||
delete(r.sessions, key)
|
||||
continue
|
||||
}
|
||||
r.logger.Debugf("UDP session %s closed (client→backend: %d bytes, backend→client: %d bytes)",
|
||||
sess.addr, sess.bytesIn.Load(), sess.bytesOut.Load())
|
||||
sess.cancel()
|
||||
if err := sess.backend.Close(); err != nil {
|
||||
r.logger.Debugf("close session %s backend: %v", sess.addr, err)
|
||||
}
|
||||
delete(r.sessions, key)
|
||||
closedSessions = append(closedSessions, sess)
|
||||
}
|
||||
r.mu.Unlock()
|
||||
|
||||
for _, sess := range closedSessions {
|
||||
if r.observer != nil {
|
||||
r.observer.UDPSessionEnded(r.accountID)
|
||||
}
|
||||
r.logSessionEnd(sess)
|
||||
}
|
||||
|
||||
r.sessWg.Wait()
|
||||
}
|
||||
493
proxy/internal/udp/relay_test.go
Normal file
493
proxy/internal/udp/relay_test.go
Normal file
@@ -0,0 +1,493 @@
|
||||
package udp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
)
|
||||
|
||||
func TestRelay_BasicPacketExchange(t *testing.T) {
|
||||
// Set up a UDP backend that echoes packets.
|
||||
backend, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
go func() {
|
||||
buf := make([]byte, 65535)
|
||||
for {
|
||||
n, addr, err := backend.ReadFrom(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, _ = backend.WriteTo(buf[:n], addr)
|
||||
}
|
||||
}()
|
||||
|
||||
// Set up the relay's public-facing listener.
|
||||
listener, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer listener.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
logger := log.NewEntry(log.StandardLogger())
|
||||
backendAddr := backend.LocalAddr().String()
|
||||
|
||||
dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return net.Dial(network, address)
|
||||
}
|
||||
|
||||
relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: backendAddr, DialFunc: dialFunc})
|
||||
go relay.Serve()
|
||||
defer relay.Close()
|
||||
|
||||
// Create a client and send a packet to the relay.
|
||||
client, err := net.Dial("udp", listener.LocalAddr().String())
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
testData := []byte("hello UDP relay")
|
||||
_, err = client.Write(testData)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Read the echoed response.
|
||||
if err := client.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
buf := make([]byte, 1024)
|
||||
n, err := client.Read(buf)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testData, buf[:n], "should receive echoed packet")
|
||||
}
|
||||
|
||||
func TestRelay_MultipleClients(t *testing.T) {
|
||||
backend, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
go func() {
|
||||
buf := make([]byte, 65535)
|
||||
for {
|
||||
n, addr, err := backend.ReadFrom(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, _ = backend.WriteTo(buf[:n], addr)
|
||||
}
|
||||
}()
|
||||
|
||||
listener, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer listener.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
logger := log.NewEntry(log.StandardLogger())
|
||||
dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return net.Dial(network, address)
|
||||
}
|
||||
|
||||
relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: backend.LocalAddr().String(), DialFunc: dialFunc})
|
||||
go relay.Serve()
|
||||
defer relay.Close()
|
||||
|
||||
// Two clients, each should get their own session.
|
||||
for i, msg := range []string{"client-1", "client-2"} {
|
||||
client, err := net.Dial("udp", listener.LocalAddr().String())
|
||||
require.NoError(t, err, "client %d", i)
|
||||
defer client.Close()
|
||||
|
||||
_, err = client.Write([]byte(msg))
|
||||
require.NoError(t, err)
|
||||
|
||||
if err := client.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
buf := make([]byte, 1024)
|
||||
n, err := client.Read(buf)
|
||||
require.NoError(t, err, "client %d read", i)
|
||||
assert.Equal(t, msg, string(buf[:n]), "client %d should get own echo", i)
|
||||
}
|
||||
|
||||
// Verify two sessions were created.
|
||||
relay.mu.RLock()
|
||||
sessionCount := len(relay.sessions)
|
||||
relay.mu.RUnlock()
|
||||
assert.Equal(t, 2, sessionCount, "should have two sessions")
|
||||
}
|
||||
|
||||
func TestRelay_Close(t *testing.T) {
|
||||
listener, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
logger := log.NewEntry(log.StandardLogger())
|
||||
dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return net.Dial(network, address)
|
||||
}
|
||||
|
||||
relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: "127.0.0.1:9999", DialFunc: dialFunc})
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
relay.Serve()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
relay.Close()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("Serve did not return after Close")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRelay_SessionCleanup(t *testing.T) {
|
||||
backend, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
go func() {
|
||||
buf := make([]byte, 65535)
|
||||
for {
|
||||
n, addr, err := backend.ReadFrom(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, _ = backend.WriteTo(buf[:n], addr)
|
||||
}
|
||||
}()
|
||||
|
||||
listener, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer listener.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
logger := log.NewEntry(log.StandardLogger())
|
||||
dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return net.Dial(network, address)
|
||||
}
|
||||
|
||||
relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: backend.LocalAddr().String(), DialFunc: dialFunc})
|
||||
go relay.Serve()
|
||||
defer relay.Close()
|
||||
|
||||
// Create a session.
|
||||
client, err := net.Dial("udp", listener.LocalAddr().String())
|
||||
require.NoError(t, err)
|
||||
_, err = client.Write([]byte("hello"))
|
||||
require.NoError(t, err)
|
||||
|
||||
if err := client.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
buf := make([]byte, 1024)
|
||||
_, err = client.Read(buf)
|
||||
require.NoError(t, err)
|
||||
client.Close()
|
||||
|
||||
// Verify session exists.
|
||||
relay.mu.RLock()
|
||||
assert.Equal(t, 1, len(relay.sessions))
|
||||
relay.mu.RUnlock()
|
||||
|
||||
// Make session appear idle by setting lastSeen to the past.
|
||||
relay.mu.Lock()
|
||||
for _, sess := range relay.sessions {
|
||||
sess.lastSeen.Store(time.Now().Add(-2 * DefaultSessionTTL).UnixNano())
|
||||
}
|
||||
relay.mu.Unlock()
|
||||
|
||||
// Trigger cleanup manually.
|
||||
relay.cleanupIdleSessions()
|
||||
|
||||
relay.mu.RLock()
|
||||
assert.Equal(t, 0, len(relay.sessions), "idle sessions should be cleaned up")
|
||||
relay.mu.RUnlock()
|
||||
}
|
||||
|
||||
// TestRelay_CloseAndRecreate verifies that closing a relay and creating a new
|
||||
// one on the same port works cleanly (simulates port mapping modify cycle).
|
||||
func TestRelay_CloseAndRecreate(t *testing.T) {
|
||||
backend, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
go func() {
|
||||
buf := make([]byte, 65535)
|
||||
for {
|
||||
n, addr, err := backend.ReadFrom(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, _ = backend.WriteTo(buf[:n], addr)
|
||||
}
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
logger := log.NewEntry(log.StandardLogger())
|
||||
dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return net.Dial(network, address)
|
||||
}
|
||||
|
||||
// First relay.
|
||||
ln1, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
relay1 := New(ctx, RelayConfig{Logger: logger, Listener: ln1, Target: backend.LocalAddr().String(), DialFunc: dialFunc})
|
||||
go relay1.Serve()
|
||||
|
||||
client1, err := net.Dial("udp", ln1.LocalAddr().String())
|
||||
require.NoError(t, err)
|
||||
_, err = client1.Write([]byte("relay1"))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, client1.SetReadDeadline(time.Now().Add(2*time.Second)))
|
||||
buf := make([]byte, 1024)
|
||||
n, err := client1.Read(buf)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "relay1", string(buf[:n]))
|
||||
client1.Close()
|
||||
|
||||
// Close first relay.
|
||||
relay1.Close()
|
||||
|
||||
// Second relay on same port.
|
||||
port := ln1.LocalAddr().(*net.UDPAddr).Port
|
||||
ln2, err := net.ListenPacket("udp", fmt.Sprintf("127.0.0.1:%d", port))
|
||||
require.NoError(t, err)
|
||||
|
||||
relay2 := New(ctx, RelayConfig{Logger: logger, Listener: ln2, Target: backend.LocalAddr().String(), DialFunc: dialFunc})
|
||||
go relay2.Serve()
|
||||
defer relay2.Close()
|
||||
|
||||
client2, err := net.Dial("udp", ln2.LocalAddr().String())
|
||||
require.NoError(t, err)
|
||||
defer client2.Close()
|
||||
_, err = client2.Write([]byte("relay2"))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, client2.SetReadDeadline(time.Now().Add(2*time.Second)))
|
||||
n, err = client2.Read(buf)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "relay2", string(buf[:n]), "second relay should work on same port")
|
||||
}
|
||||
|
||||
func TestRelay_SessionLimit(t *testing.T) {
|
||||
backend, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
go func() {
|
||||
buf := make([]byte, 65535)
|
||||
for {
|
||||
n, addr, err := backend.ReadFrom(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, _ = backend.WriteTo(buf[:n], addr)
|
||||
}
|
||||
}()
|
||||
|
||||
listener, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer listener.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
logger := log.NewEntry(log.StandardLogger())
|
||||
dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return net.Dial(network, address)
|
||||
}
|
||||
|
||||
// Create a relay with a max of 2 sessions.
|
||||
relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: backend.LocalAddr().String(), DialFunc: dialFunc, MaxSessions: 2})
|
||||
go relay.Serve()
|
||||
defer relay.Close()
|
||||
|
||||
// Create 2 clients to fill up the session limit.
|
||||
for i := range 2 {
|
||||
client, err := net.Dial("udp", listener.LocalAddr().String())
|
||||
require.NoError(t, err, "client %d", i)
|
||||
defer client.Close()
|
||||
|
||||
_, err = client.Write([]byte("hello"))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, client.SetReadDeadline(time.Now().Add(2*time.Second)))
|
||||
buf := make([]byte, 1024)
|
||||
_, err = client.Read(buf)
|
||||
require.NoError(t, err, "client %d should get response", i)
|
||||
}
|
||||
|
||||
relay.mu.RLock()
|
||||
assert.Equal(t, 2, len(relay.sessions), "should have exactly 2 sessions")
|
||||
relay.mu.RUnlock()
|
||||
|
||||
// Third client should get its packet dropped (session creation fails).
|
||||
client3, err := net.Dial("udp", listener.LocalAddr().String())
|
||||
require.NoError(t, err)
|
||||
defer client3.Close()
|
||||
|
||||
_, err = client3.Write([]byte("should be dropped"))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, client3.SetReadDeadline(time.Now().Add(500*time.Millisecond)))
|
||||
buf := make([]byte, 1024)
|
||||
_, err = client3.Read(buf)
|
||||
assert.Error(t, err, "third client should time out because session was rejected")
|
||||
|
||||
relay.mu.RLock()
|
||||
assert.Equal(t, 2, len(relay.sessions), "session count should not exceed limit")
|
||||
relay.mu.RUnlock()
|
||||
}
|
||||
|
||||
// testObserver records UDP session lifecycle events for test assertions.
|
||||
type testObserver struct {
|
||||
mu sync.Mutex
|
||||
started int
|
||||
ended int
|
||||
rejected int
|
||||
dialErr int
|
||||
packets int
|
||||
bytes int
|
||||
}
|
||||
|
||||
func (o *testObserver) UDPSessionStarted(types.AccountID) { o.mu.Lock(); o.started++; o.mu.Unlock() }
|
||||
func (o *testObserver) UDPSessionEnded(types.AccountID) { o.mu.Lock(); o.ended++; o.mu.Unlock() }
|
||||
func (o *testObserver) UDPSessionDialError(types.AccountID) { o.mu.Lock(); o.dialErr++; o.mu.Unlock() }
|
||||
func (o *testObserver) UDPSessionRejected(types.AccountID) { o.mu.Lock(); o.rejected++; o.mu.Unlock() }
|
||||
func (o *testObserver) UDPPacketRelayed(_ types.RelayDirection, b int) {
|
||||
o.mu.Lock()
|
||||
o.packets++
|
||||
o.bytes += b
|
||||
o.mu.Unlock()
|
||||
}
|
||||
|
||||
func TestRelay_CloseFiresObserverEnded(t *testing.T) {
|
||||
backend, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
go func() {
|
||||
buf := make([]byte, 65535)
|
||||
for {
|
||||
n, addr, err := backend.ReadFrom(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, _ = backend.WriteTo(buf[:n], addr)
|
||||
}
|
||||
}()
|
||||
|
||||
listener, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer listener.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
logger := log.NewEntry(log.StandardLogger())
|
||||
dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return net.Dial(network, address)
|
||||
}
|
||||
|
||||
obs := &testObserver{}
|
||||
relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: backend.LocalAddr().String(), AccountID: "test-acct", DialFunc: dialFunc})
|
||||
relay.SetObserver(obs)
|
||||
go relay.Serve()
|
||||
|
||||
// Create two sessions.
|
||||
for i := range 2 {
|
||||
client, err := net.Dial("udp", listener.LocalAddr().String())
|
||||
require.NoError(t, err, "client %d", i)
|
||||
|
||||
_, err = client.Write([]byte("hello"))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, client.SetReadDeadline(time.Now().Add(2*time.Second)))
|
||||
buf := make([]byte, 1024)
|
||||
_, err = client.Read(buf)
|
||||
require.NoError(t, err)
|
||||
client.Close()
|
||||
}
|
||||
|
||||
obs.mu.Lock()
|
||||
assert.Equal(t, 2, obs.started, "should have 2 started events")
|
||||
obs.mu.Unlock()
|
||||
|
||||
// Close should fire UDPSessionEnded for all remaining sessions.
|
||||
relay.Close()
|
||||
|
||||
obs.mu.Lock()
|
||||
assert.Equal(t, 2, obs.ended, "Close should fire UDPSessionEnded for each session")
|
||||
obs.mu.Unlock()
|
||||
}
|
||||
|
||||
func TestRelay_SessionRateLimit(t *testing.T) {
|
||||
backend, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
go func() {
|
||||
buf := make([]byte, 65535)
|
||||
for {
|
||||
n, addr, err := backend.ReadFrom(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, _ = backend.WriteTo(buf[:n], addr)
|
||||
}
|
||||
}()
|
||||
|
||||
listener, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer listener.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
logger := log.NewEntry(log.StandardLogger())
|
||||
dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return net.Dial(network, address)
|
||||
}
|
||||
|
||||
obs := &testObserver{}
|
||||
// High max sessions (1000) but the relay uses a rate limiter internally
|
||||
// (default: 50/s burst 100). We exhaust the burst by creating sessions
|
||||
// rapidly, then verify that subsequent creates are rejected.
|
||||
relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: backend.LocalAddr().String(), AccountID: "test-acct", DialFunc: dialFunc, MaxSessions: 1000})
|
||||
relay.SetObserver(obs)
|
||||
go relay.Serve()
|
||||
defer relay.Close()
|
||||
|
||||
// Exhaust the burst by calling getOrCreateSession directly with
|
||||
// synthetic addresses. This is faster than real UDP round-trips.
|
||||
for i := range sessionCreateBurst + 20 {
|
||||
addr := &net.UDPAddr{IP: net.IPv4(10, 0, byte(i/256), byte(i%256)), Port: 10000 + i}
|
||||
_, _ = relay.getOrCreateSession(addr)
|
||||
}
|
||||
|
||||
obs.mu.Lock()
|
||||
rejected := obs.rejected
|
||||
obs.mu.Unlock()
|
||||
|
||||
assert.Greater(t, rejected, 0, "some sessions should be rate-limited")
|
||||
}
|
||||
Reference in New Issue
Block a user