mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
561 lines
15 KiB
Go
561 lines
15 KiB
Go
package udp
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"net/netip"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
log "github.com/sirupsen/logrus"
|
|
"golang.org/x/time/rate"
|
|
|
|
"github.com/netbirdio/netbird/proxy/internal/accesslog"
|
|
"github.com/netbirdio/netbird/proxy/internal/netutil"
|
|
"github.com/netbirdio/netbird/proxy/internal/restrict"
|
|
"github.com/netbirdio/netbird/proxy/internal/types"
|
|
)
|
|
|
|
const (
|
|
// DefaultSessionTTL is the default idle timeout for UDP sessions before cleanup.
|
|
DefaultSessionTTL = 30 * time.Second
|
|
// cleanupInterval is how often the cleaner goroutine runs.
|
|
cleanupInterval = time.Minute
|
|
// maxPacketSize is the maximum UDP packet size we'll handle.
|
|
maxPacketSize = 65535
|
|
// DefaultMaxSessions is the default cap on concurrent UDP sessions per relay.
|
|
DefaultMaxSessions = 1024
|
|
// sessionCreateRate limits new session creation per second.
|
|
sessionCreateRate = 50
|
|
// sessionCreateBurst is the burst allowance for session creation.
|
|
sessionCreateBurst = 100
|
|
// defaultDialTimeout is the fallback dial timeout for backend connections.
|
|
defaultDialTimeout = 30 * time.Second
|
|
)
|
|
|
|
// l4Logger sends layer-4 access log entries to the management server.
|
|
type l4Logger interface {
|
|
LogL4(entry accesslog.L4Entry)
|
|
}
|
|
|
|
// SessionObserver receives callbacks for UDP session lifecycle events.
|
|
// All methods must be safe for concurrent use.
|
|
type SessionObserver interface {
|
|
UDPSessionStarted(accountID types.AccountID)
|
|
UDPSessionEnded(accountID types.AccountID)
|
|
UDPSessionDialError(accountID types.AccountID)
|
|
UDPSessionRejected(accountID types.AccountID)
|
|
UDPPacketRelayed(direction types.RelayDirection, bytes int)
|
|
}
|
|
|
|
// clientAddr is a typed key for UDP session lookups.
|
|
type clientAddr string
|
|
|
|
// Relay listens for incoming UDP packets on a dedicated port and
|
|
// maintains per-client sessions that relay packets to a backend
|
|
// through the WireGuard tunnel.
|
|
type Relay struct {
|
|
logger *log.Entry
|
|
listener net.PacketConn
|
|
target string
|
|
domain string
|
|
accountID types.AccountID
|
|
serviceID types.ServiceID
|
|
dialFunc types.DialContextFunc
|
|
dialTimeout time.Duration
|
|
sessionTTL time.Duration
|
|
maxSessions int
|
|
filter *restrict.Filter
|
|
geo restrict.GeoResolver
|
|
|
|
mu sync.RWMutex
|
|
sessions map[clientAddr]*session
|
|
|
|
bufPool sync.Pool
|
|
sessLimiter *rate.Limiter
|
|
sessWg sync.WaitGroup
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
observer SessionObserver
|
|
accessLog l4Logger
|
|
}
|
|
|
|
type session struct {
|
|
backend net.Conn
|
|
addr net.Addr
|
|
createdAt time.Time
|
|
// lastSeen stores the last activity timestamp as unix nanoseconds.
|
|
lastSeen atomic.Int64
|
|
cancel context.CancelFunc
|
|
// bytesIn tracks total bytes received from the client.
|
|
bytesIn atomic.Int64
|
|
// bytesOut tracks total bytes sent back to the client.
|
|
bytesOut atomic.Int64
|
|
}
|
|
|
|
func (s *session) updateLastSeen() {
|
|
s.lastSeen.Store(time.Now().UnixNano())
|
|
}
|
|
|
|
func (s *session) idleDuration() time.Duration {
|
|
return time.Since(time.Unix(0, s.lastSeen.Load()))
|
|
}
|
|
|
|
// RelayConfig holds the configuration for a UDP relay.
|
|
type RelayConfig struct {
|
|
Logger *log.Entry
|
|
Listener net.PacketConn
|
|
Target string
|
|
Domain string
|
|
AccountID types.AccountID
|
|
ServiceID types.ServiceID
|
|
DialFunc types.DialContextFunc
|
|
DialTimeout time.Duration
|
|
SessionTTL time.Duration
|
|
MaxSessions int
|
|
AccessLog l4Logger
|
|
// Filter holds connection-level IP/geo restrictions. Nil means no restrictions.
|
|
Filter *restrict.Filter
|
|
// Geo is the geolocation lookup used for country-based restrictions.
|
|
Geo restrict.GeoResolver
|
|
}
|
|
|
|
// New creates a UDP relay for the given listener and backend target.
|
|
// MaxSessions caps the number of concurrent sessions; use 0 for DefaultMaxSessions.
|
|
// DialTimeout controls how long to wait for backend connections; use 0 for default.
|
|
// SessionTTL is the idle timeout before a session is reaped; use 0 for DefaultSessionTTL.
|
|
func New(parentCtx context.Context, cfg RelayConfig) *Relay {
|
|
maxSessions := cfg.MaxSessions
|
|
dialTimeout := cfg.DialTimeout
|
|
sessionTTL := cfg.SessionTTL
|
|
if maxSessions <= 0 {
|
|
maxSessions = DefaultMaxSessions
|
|
}
|
|
if dialTimeout <= 0 {
|
|
dialTimeout = defaultDialTimeout
|
|
}
|
|
if sessionTTL <= 0 {
|
|
sessionTTL = DefaultSessionTTL
|
|
}
|
|
ctx, cancel := context.WithCancel(parentCtx)
|
|
return &Relay{
|
|
logger: cfg.Logger,
|
|
listener: cfg.Listener,
|
|
target: cfg.Target,
|
|
domain: cfg.Domain,
|
|
accountID: cfg.AccountID,
|
|
serviceID: cfg.ServiceID,
|
|
accessLog: cfg.AccessLog,
|
|
dialFunc: cfg.DialFunc,
|
|
dialTimeout: dialTimeout,
|
|
sessionTTL: sessionTTL,
|
|
maxSessions: maxSessions,
|
|
filter: cfg.Filter,
|
|
geo: cfg.Geo,
|
|
sessions: make(map[clientAddr]*session),
|
|
bufPool: sync.Pool{
|
|
New: func() any {
|
|
buf := make([]byte, maxPacketSize)
|
|
return &buf
|
|
},
|
|
},
|
|
sessLimiter: rate.NewLimiter(sessionCreateRate, sessionCreateBurst),
|
|
ctx: ctx,
|
|
cancel: cancel,
|
|
}
|
|
}
|
|
|
|
// ServiceID returns the service ID associated with this relay.
|
|
func (r *Relay) ServiceID() types.ServiceID {
|
|
return r.serviceID
|
|
}
|
|
|
|
// SetObserver sets the session lifecycle observer. Must be called before Serve.
|
|
func (r *Relay) SetObserver(obs SessionObserver) {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
r.observer = obs
|
|
}
|
|
|
|
// getObserver returns the current session lifecycle observer.
|
|
func (r *Relay) getObserver() SessionObserver {
|
|
r.mu.RLock()
|
|
defer r.mu.RUnlock()
|
|
return r.observer
|
|
}
|
|
|
|
// Serve starts the relay loop. It blocks until the context is canceled
|
|
// or the listener is closed.
|
|
func (r *Relay) Serve() {
|
|
go r.cleanupLoop()
|
|
|
|
for {
|
|
bufp := r.bufPool.Get().(*[]byte)
|
|
buf := *bufp
|
|
|
|
n, addr, err := r.listener.ReadFrom(buf)
|
|
if err != nil {
|
|
r.bufPool.Put(bufp)
|
|
if r.ctx.Err() != nil || errors.Is(err, net.ErrClosed) {
|
|
return
|
|
}
|
|
r.logger.Debugf("UDP read: %v", err)
|
|
continue
|
|
}
|
|
|
|
data := buf[:n]
|
|
sess, err := r.getOrCreateSession(addr)
|
|
if err != nil {
|
|
r.bufPool.Put(bufp)
|
|
r.logger.Debugf("create UDP session for %s: %v", addr, err)
|
|
continue
|
|
}
|
|
|
|
sess.updateLastSeen()
|
|
|
|
nw, err := sess.backend.Write(data)
|
|
if err != nil {
|
|
r.bufPool.Put(bufp)
|
|
if !netutil.IsExpectedError(err) {
|
|
r.logger.Debugf("UDP write to backend for %s: %v", addr, err)
|
|
}
|
|
r.removeSession(sess)
|
|
continue
|
|
}
|
|
sess.bytesIn.Add(int64(nw))
|
|
|
|
if obs := r.getObserver(); obs != nil {
|
|
obs.UDPPacketRelayed(types.RelayDirectionClientToBackend, nw)
|
|
}
|
|
r.bufPool.Put(bufp)
|
|
}
|
|
}
|
|
|
|
// getOrCreateSession returns an existing session or creates a new one.
|
|
func (r *Relay) getOrCreateSession(addr net.Addr) (*session, error) {
|
|
key := clientAddr(addr.String())
|
|
|
|
r.mu.RLock()
|
|
sess, ok := r.sessions[key]
|
|
r.mu.RUnlock()
|
|
if ok && sess != nil {
|
|
return sess, nil
|
|
}
|
|
|
|
// Check before taking the write lock: if the relay is shutting down,
|
|
// don't create new sessions. This prevents orphaned goroutines when
|
|
// Serve() processes a packet that was already read before Close().
|
|
if r.ctx.Err() != nil {
|
|
return nil, r.ctx.Err()
|
|
}
|
|
|
|
if err := r.checkAccessRestrictions(addr); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
r.mu.Lock()
|
|
|
|
if sess, ok = r.sessions[key]; ok && sess != nil {
|
|
r.mu.Unlock()
|
|
return sess, nil
|
|
}
|
|
if ok {
|
|
// Another goroutine is dialing for this key, skip.
|
|
r.mu.Unlock()
|
|
return nil, fmt.Errorf("session dial in progress for %s", key)
|
|
}
|
|
|
|
if len(r.sessions) >= r.maxSessions {
|
|
r.mu.Unlock()
|
|
if obs := r.getObserver(); obs != nil {
|
|
obs.UDPSessionRejected(r.accountID)
|
|
}
|
|
return nil, fmt.Errorf("session limit reached (%d)", r.maxSessions)
|
|
}
|
|
|
|
if !r.sessLimiter.Allow() {
|
|
r.mu.Unlock()
|
|
if obs := r.getObserver(); obs != nil {
|
|
obs.UDPSessionRejected(r.accountID)
|
|
}
|
|
return nil, fmt.Errorf("session creation rate limited")
|
|
}
|
|
|
|
// Reserve the slot with a nil session so concurrent callers for the same
|
|
// key see it exists and wait. Release the lock before dialing.
|
|
r.sessions[key] = nil
|
|
r.mu.Unlock()
|
|
|
|
dialCtx, dialCancel := context.WithTimeout(r.ctx, r.dialTimeout)
|
|
backend, err := r.dialFunc(dialCtx, "udp", r.target)
|
|
dialCancel()
|
|
if err != nil {
|
|
r.mu.Lock()
|
|
delete(r.sessions, key)
|
|
r.mu.Unlock()
|
|
if obs := r.getObserver(); obs != nil {
|
|
obs.UDPSessionDialError(r.accountID)
|
|
}
|
|
return nil, fmt.Errorf("dial backend %s: %w", r.target, err)
|
|
}
|
|
|
|
sessCtx, sessCancel := context.WithCancel(r.ctx)
|
|
sess = &session{
|
|
backend: backend,
|
|
addr: addr,
|
|
createdAt: time.Now(),
|
|
cancel: sessCancel,
|
|
}
|
|
sess.updateLastSeen()
|
|
|
|
r.mu.Lock()
|
|
r.sessions[key] = sess
|
|
r.mu.Unlock()
|
|
|
|
if obs := r.getObserver(); obs != nil {
|
|
obs.UDPSessionStarted(r.accountID)
|
|
}
|
|
|
|
r.sessWg.Go(func() {
|
|
r.relayBackendToClient(sessCtx, sess)
|
|
})
|
|
|
|
r.logger.Debugf("UDP session created for %s", addr)
|
|
return sess, nil
|
|
}
|
|
|
|
func (r *Relay) checkAccessRestrictions(addr net.Addr) error {
|
|
if r.filter == nil {
|
|
return nil
|
|
}
|
|
clientIP, err := addrFromUDPAddr(addr)
|
|
if err != nil {
|
|
return fmt.Errorf("parse client address %s for restriction check: %w", addr, err)
|
|
}
|
|
if v := r.filter.Check(clientIP, r.geo); v != restrict.Allow {
|
|
r.logDeny(clientIP, v)
|
|
return fmt.Errorf("access restricted for %s", addr)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// relayBackendToClient reads packets from the backend and writes them
|
|
// back to the client through the public-facing listener.
|
|
func (r *Relay) relayBackendToClient(ctx context.Context, sess *session) {
|
|
bufp := r.bufPool.Get().(*[]byte)
|
|
defer r.bufPool.Put(bufp)
|
|
defer r.removeSession(sess)
|
|
|
|
for ctx.Err() == nil {
|
|
data, ok := r.readBackendPacket(sess, *bufp)
|
|
if !ok {
|
|
return
|
|
}
|
|
if data == nil {
|
|
continue
|
|
}
|
|
|
|
sess.updateLastSeen()
|
|
|
|
nw, err := r.listener.WriteTo(data, sess.addr)
|
|
if err != nil {
|
|
if !netutil.IsExpectedError(err) {
|
|
r.logger.Debugf("UDP write to client %s: %v", sess.addr, err)
|
|
}
|
|
return
|
|
}
|
|
sess.bytesOut.Add(int64(nw))
|
|
|
|
if obs := r.getObserver(); obs != nil {
|
|
obs.UDPPacketRelayed(types.RelayDirectionBackendToClient, nw)
|
|
}
|
|
}
|
|
}
|
|
|
|
// readBackendPacket reads one packet from the backend with an idle deadline.
|
|
// Returns (data, true) on success, (nil, true) on idle timeout that should
|
|
// retry, or (nil, false) when the session should be torn down.
|
|
func (r *Relay) readBackendPacket(sess *session, buf []byte) ([]byte, bool) {
|
|
if err := sess.backend.SetReadDeadline(time.Now().Add(r.sessionTTL)); err != nil {
|
|
r.logger.Debugf("set backend read deadline for %s: %v", sess.addr, err)
|
|
return nil, false
|
|
}
|
|
|
|
n, err := sess.backend.Read(buf)
|
|
if err != nil {
|
|
if netutil.IsTimeout(err) {
|
|
if sess.idleDuration() > r.sessionTTL {
|
|
return nil, false
|
|
}
|
|
return nil, true
|
|
}
|
|
if !netutil.IsExpectedError(err) {
|
|
r.logger.Debugf("UDP read from backend for %s: %v", sess.addr, err)
|
|
}
|
|
return nil, false
|
|
}
|
|
|
|
return buf[:n], true
|
|
}
|
|
|
|
// cleanupLoop periodically removes idle sessions.
|
|
func (r *Relay) cleanupLoop() {
|
|
ticker := time.NewTicker(cleanupInterval)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-r.ctx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
r.cleanupIdleSessions()
|
|
}
|
|
}
|
|
}
|
|
|
|
// cleanupIdleSessions closes sessions that have been idle for too long.
|
|
func (r *Relay) cleanupIdleSessions() {
|
|
var expired []*session
|
|
|
|
r.mu.Lock()
|
|
for key, sess := range r.sessions {
|
|
if sess == nil {
|
|
continue
|
|
}
|
|
idle := sess.idleDuration()
|
|
if idle > r.sessionTTL {
|
|
r.logger.Debugf("UDP session %s idle for %s, closing (client→backend: %d bytes, backend→client: %d bytes)",
|
|
sess.addr, idle, sess.bytesIn.Load(), sess.bytesOut.Load())
|
|
delete(r.sessions, key)
|
|
sess.cancel()
|
|
if err := sess.backend.Close(); err != nil {
|
|
r.logger.Debugf("close idle session %s backend: %v", sess.addr, err)
|
|
}
|
|
expired = append(expired, sess)
|
|
}
|
|
}
|
|
r.mu.Unlock()
|
|
|
|
obs := r.getObserver()
|
|
for _, sess := range expired {
|
|
if obs != nil {
|
|
obs.UDPSessionEnded(r.accountID)
|
|
}
|
|
r.logSessionEnd(sess)
|
|
}
|
|
}
|
|
|
|
// removeSession removes a session from the map if it still matches the
|
|
// given pointer. This is safe to call concurrently with cleanupIdleSessions
|
|
// because the identity check prevents double-close when both paths race.
|
|
func (r *Relay) removeSession(sess *session) {
|
|
r.mu.Lock()
|
|
key := clientAddr(sess.addr.String())
|
|
removed := r.sessions[key] == sess
|
|
if removed {
|
|
delete(r.sessions, key)
|
|
sess.cancel()
|
|
if err := sess.backend.Close(); err != nil {
|
|
r.logger.Debugf("close session %s backend: %v", sess.addr, err)
|
|
}
|
|
}
|
|
r.mu.Unlock()
|
|
|
|
if removed {
|
|
r.logger.Debugf("UDP session %s ended (client→backend: %d bytes, backend→client: %d bytes)",
|
|
sess.addr, sess.bytesIn.Load(), sess.bytesOut.Load())
|
|
if obs := r.getObserver(); obs != nil {
|
|
obs.UDPSessionEnded(r.accountID)
|
|
}
|
|
r.logSessionEnd(sess)
|
|
}
|
|
}
|
|
|
|
// logSessionEnd sends an access log entry for a completed UDP session.
|
|
func (r *Relay) logSessionEnd(sess *session) {
|
|
if r.accessLog == nil {
|
|
return
|
|
}
|
|
|
|
var sourceIP netip.Addr
|
|
if ap, err := netip.ParseAddrPort(sess.addr.String()); err == nil {
|
|
sourceIP = ap.Addr().Unmap()
|
|
}
|
|
|
|
r.accessLog.LogL4(accesslog.L4Entry{
|
|
AccountID: r.accountID,
|
|
ServiceID: r.serviceID,
|
|
Protocol: accesslog.ProtocolUDP,
|
|
Host: r.domain,
|
|
SourceIP: sourceIP,
|
|
DurationMs: time.Unix(0, sess.lastSeen.Load()).Sub(sess.createdAt).Milliseconds(),
|
|
BytesUpload: sess.bytesIn.Load(),
|
|
BytesDownload: sess.bytesOut.Load(),
|
|
})
|
|
}
|
|
|
|
// logDeny sends an access log entry for a denied UDP packet.
|
|
func (r *Relay) logDeny(clientIP netip.Addr, verdict restrict.Verdict) {
|
|
if r.accessLog == nil {
|
|
return
|
|
}
|
|
|
|
r.accessLog.LogL4(accesslog.L4Entry{
|
|
AccountID: r.accountID,
|
|
ServiceID: r.serviceID,
|
|
Protocol: accesslog.ProtocolUDP,
|
|
Host: r.domain,
|
|
SourceIP: clientIP,
|
|
DenyReason: verdict.String(),
|
|
})
|
|
}
|
|
|
|
// Close stops the relay, waits for all session goroutines to exit,
|
|
// and cleans up remaining sessions.
|
|
func (r *Relay) Close() {
|
|
r.cancel()
|
|
if err := r.listener.Close(); err != nil {
|
|
r.logger.Debugf("close UDP listener: %v", err)
|
|
}
|
|
|
|
var closedSessions []*session
|
|
r.mu.Lock()
|
|
for key, sess := range r.sessions {
|
|
if sess == nil {
|
|
delete(r.sessions, key)
|
|
continue
|
|
}
|
|
r.logger.Debugf("UDP session %s closed (client→backend: %d bytes, backend→client: %d bytes)",
|
|
sess.addr, sess.bytesIn.Load(), sess.bytesOut.Load())
|
|
sess.cancel()
|
|
if err := sess.backend.Close(); err != nil {
|
|
r.logger.Debugf("close session %s backend: %v", sess.addr, err)
|
|
}
|
|
delete(r.sessions, key)
|
|
closedSessions = append(closedSessions, sess)
|
|
}
|
|
r.mu.Unlock()
|
|
|
|
obs := r.getObserver()
|
|
for _, sess := range closedSessions {
|
|
if obs != nil {
|
|
obs.UDPSessionEnded(r.accountID)
|
|
}
|
|
r.logSessionEnd(sess)
|
|
}
|
|
|
|
r.sessWg.Wait()
|
|
}
|
|
|
|
// addrFromUDPAddr extracts a netip.Addr from a net.Addr.
|
|
func addrFromUDPAddr(addr net.Addr) (netip.Addr, error) {
|
|
ap, err := netip.ParseAddrPort(addr.String())
|
|
if err != nil {
|
|
return netip.Addr{}, err
|
|
}
|
|
return ap.Addr().Unmap(), nil
|
|
}
|