diff --git a/go.mod b/go.mod index cf55b9260..773869cb5 100644 --- a/go.mod +++ b/go.mod @@ -78,8 +78,8 @@ require ( github.com/pion/logging v0.2.4 github.com/pion/randutil v0.1.0 github.com/pion/stun/v2 v2.0.0 - github.com/pion/stun/v3 v3.0.0 - github.com/pion/transport/v3 v3.0.7 + github.com/pion/stun/v3 v3.1.0 + github.com/pion/transport/v3 v3.1.1 github.com/pion/turn/v3 v3.0.1 github.com/pkg/sftp v1.13.9 github.com/prometheus/client_golang v1.23.2 @@ -241,7 +241,7 @@ require ( github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.0 // indirect github.com/pion/dtls/v2 v2.2.10 // indirect - github.com/pion/dtls/v3 v3.0.7 // indirect + github.com/pion/dtls/v3 v3.0.9 // indirect github.com/pion/mdns/v2 v2.0.7 // indirect github.com/pion/transport/v2 v2.2.4 // indirect github.com/pion/turn/v4 v4.1.1 // indirect @@ -263,7 +263,7 @@ require ( github.com/tklauser/numcpus v0.8.0 // indirect github.com/vishvananda/netns v0.0.5 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect - github.com/wlynxg/anet v0.0.3 // indirect + github.com/wlynxg/anet v0.0.5 // indirect github.com/yuin/goldmark v1.7.8 // indirect github.com/zeebo/blake3 v0.2.3 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect diff --git a/go.sum b/go.sum index e89e0ef12..4ea00b399 100644 --- a/go.sum +++ b/go.sum @@ -444,8 +444,8 @@ github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203/go.mod h1:pxMtw7c github.com/pion/dtls/v2 v2.2.7/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ/F9s= github.com/pion/dtls/v2 v2.2.10 h1:u2Axk+FyIR1VFTPurktB+1zoEPGIW3bmyj3LEFrXjAA= github.com/pion/dtls/v2 v2.2.10/go.mod h1:d9SYc9fch0CqK90mRk1dC7AkzzpwJj6u2GU3u+9pqFE= -github.com/pion/dtls/v3 v3.0.7 h1:bItXtTYYhZwkPFk4t1n3Kkf5TDrfj6+4wG+CZR8uI9Q= -github.com/pion/dtls/v3 v3.0.7/go.mod h1:uDlH5VPrgOQIw59irKYkMudSFprY9IEFCqz/eTz16f8= +github.com/pion/dtls/v3 v3.0.9 h1:4AijfFRm8mAjd1gfdlB1wzJF3fjjR/VPIpJgkEtvYmM= +github.com/pion/dtls/v3 v3.0.9/go.mod h1:abApPjgadS/ra1wvUzHLc3o2HvoxppAh+NZkyApL4Os= github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms= github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8= github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so= @@ -455,14 +455,14 @@ github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA= github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8= github.com/pion/stun/v2 v2.0.0 h1:A5+wXKLAypxQri59+tmQKVs7+l6mMM+3d+eER9ifRU0= github.com/pion/stun/v2 v2.0.0/go.mod h1:22qRSh08fSEttYUmJZGlriq9+03jtVmXNODgLccj8GQ= -github.com/pion/stun/v3 v3.0.0 h1:4h1gwhWLWuZWOJIJR9s2ferRO+W3zA/b6ijOI6mKzUw= -github.com/pion/stun/v3 v3.0.0/go.mod h1:HvCN8txt8mwi4FBvS3EmDghW6aQJ24T+y+1TKjB5jyU= +github.com/pion/stun/v3 v3.1.0 h1:bS1jjT3tGWZ4UPmIUeyalOylamTMTFg1OvXtY/r6seM= +github.com/pion/stun/v3 v3.1.0/go.mod h1:egmx1CUcfSSGJxQCOjtVlomfPqmQ58BibPyuOWNGQEU= github.com/pion/transport/v2 v2.2.1/go.mod h1:cXXWavvCnFF6McHTft3DWS9iic2Mftcz1Aq29pGcU5g= github.com/pion/transport/v2 v2.2.4 h1:41JJK6DZQYSeVLxILA2+F4ZkKb4Xd/tFJZRFZQ9QAlo= github.com/pion/transport/v2 v2.2.4/go.mod h1:q2U/tf9FEfnSBGSW6w5Qp5PFWRLRj3NjLhCCgpRK4p0= github.com/pion/transport/v3 v3.0.1/go.mod h1:UY7kiITrlMv7/IKgd5eTUcaahZx5oUN3l9SzK5f5xE0= -github.com/pion/transport/v3 v3.0.7 h1:iRbMH05BzSNwhILHoBoAPxoB9xQgOaJk+591KC9P1o0= -github.com/pion/transport/v3 v3.0.7/go.mod h1:YleKiTZ4vqNxVwh77Z0zytYi7rXHl7j6uPLGhhz9rwo= +github.com/pion/transport/v3 v3.1.1 h1:Tr684+fnnKlhPceU+ICdrw6KKkTms+5qHMgw6bIkYOM= +github.com/pion/transport/v3 v3.1.1/go.mod h1:+c2eewC5WJQHiAA46fkMMzoYZSuGzA/7E2FPrOYHctQ= github.com/pion/turn/v3 v3.0.1 h1:wLi7BTQr6/Q20R0vt/lHbjv6y4GChFtC33nkYbasoT8= github.com/pion/turn/v3 v3.0.1/go.mod h1:MrJDKgqryDyWy1/4NT9TWfXWGMC7UHT6pJIv1+gMeNE= github.com/pion/turn/v4 v4.1.1 h1:9UnY2HB99tpDyz3cVVZguSxcqkJ1DsTSZ+8TGruh4fc= @@ -574,8 +574,8 @@ github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IU github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok= github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= -github.com/wlynxg/anet v0.0.3 h1:PvR53psxFXstc12jelG6f1Lv4MWqE0tI76/hHGjh9rg= -github.com/wlynxg/anet v0.0.3/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA= +github.com/wlynxg/anet v0.0.5 h1:J3VJGi1gvo0JwZ/P1/Yc/8p63SoW98B5dHkYDmpgvvU= +github.com/wlynxg/anet v0.0.5/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= diff --git a/relay/cmd/root.go b/relay/cmd/root.go index e7dadcfdf..20c565c3d 100644 --- a/relay/cmd/root.go +++ b/relay/cmd/root.go @@ -6,6 +6,7 @@ import ( "crypto/tls" "errors" "fmt" + "net" "net/http" "os" "os/signal" @@ -22,6 +23,7 @@ import ( "github.com/netbirdio/netbird/relay/server" "github.com/netbirdio/netbird/shared/relay/auth" "github.com/netbirdio/netbird/signal/metrics" + "github.com/netbirdio/netbird/stun" "github.com/netbirdio/netbird/util" ) @@ -43,6 +45,10 @@ type Config struct { LogLevel string LogFile string HealthcheckListenAddress string + // STUN server configuration + EnableSTUN bool + STUNPorts []int + STUNLogLevel string } func (c Config) Validate() error { @@ -52,6 +58,25 @@ func (c Config) Validate() error { if c.AuthSecret == "" { return fmt.Errorf("auth secret is required") } + + // Validate STUN configuration + if c.EnableSTUN { + if len(c.STUNPorts) == 0 { + return fmt.Errorf("--stun-ports is required when --enable-stun is set") + } + + seen := make(map[int]bool) + for _, port := range c.STUNPorts { + if port <= 0 || port > 65535 { + return fmt.Errorf("invalid STUN port %d: must be between 1 and 65535", port) + } + if seen[port] { + return fmt.Errorf("duplicate STUN port %d", port) + } + seen[port] = true + } + } + return nil } @@ -91,6 +116,9 @@ func init() { rootCmd.PersistentFlags().StringVar(&cobraConfig.LogLevel, "log-level", "info", "log level") rootCmd.PersistentFlags().StringVar(&cobraConfig.LogFile, "log-file", "console", "log file") rootCmd.PersistentFlags().StringVarP(&cobraConfig.HealthcheckListenAddress, "health-listen-address", "H", ":9000", "listen address of healthcheck server") + rootCmd.PersistentFlags().BoolVar(&cobraConfig.EnableSTUN, "enable-stun", false, "enable embedded STUN server") + rootCmd.PersistentFlags().IntSliceVar(&cobraConfig.STUNPorts, "stun-ports", []int{3478}, "ports for the embedded STUN server (can be specified multiple times or comma-separated)") + rootCmd.PersistentFlags().StringVar(&cobraConfig.STUNLogLevel, "stun-log-level", "info", "log level for STUN server (panic, fatal, error, warn, info, debug, trace)") setFlagsFromEnvVars(rootCmd) } @@ -119,21 +147,14 @@ func execute(cmd *cobra.Command, args []string) error { return fmt.Errorf("failed to initialize log: %s", err) } + // Resource creation phase (fail fast before starting any goroutines) + metricsServer, err := metrics.NewServer(cobraConfig.MetricsPort, "") if err != nil { log.Debugf("setup metrics: %v", err) return fmt.Errorf("setup metrics: %v", err) } - wg.Add(1) - go func() { - defer wg.Done() - log.Infof("running metrics server: %s%s", metricsServer.Addr, metricsServer.Endpoint) - if err := metricsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) { - log.Fatalf("Failed to start metrics server: %v", err) - } - }() - srvListenerCfg := server.ListenerConfig{ Address: cobraConfig.ListenAddress, } @@ -145,6 +166,12 @@ func execute(cmd *cobra.Command, args []string) error { } srvListenerCfg.TLSConfig = tlsConfig + // Create STUN listeners early to fail fast + stunListeners, err := createSTUNListeners() + if err != nil { + return err + } + hashedSecret := sha256.Sum256([]byte(cobraConfig.AuthSecret)) authenticator := auth.NewTimedHMACValidator(hashedSecret[:], 24*time.Hour) @@ -155,60 +182,145 @@ func execute(cmd *cobra.Command, args []string) error { TLSSupport: tlsSupport, } - srv, err := server.NewServer(cfg) + srv, err := createRelayServer(cfg) if err != nil { - log.Debugf("failed to create relay server: %v", err) - return fmt.Errorf("failed to create relay server: %v", err) + cleanupSTUNListeners(stunListeners) + return err } + + hCfg := healthcheck.Config{ + ListenAddress: cobraConfig.HealthcheckListenAddress, + ServiceChecker: srv, + } + httpHealthcheck, err := createHealthCheck(hCfg) + if err != nil { + cleanupSTUNListeners(stunListeners) + return err + } + + var stunServer *stun.Server + if len(stunListeners) > 0 { + stunServer = stun.NewServer(stunListeners, cobraConfig.STUNLogLevel) + } + + // Start all servers (only after all resources are successfully created) + startServers(&wg, metricsServer, srv, srvListenerCfg, httpHealthcheck, stunServer) + + waitForExitSignal() + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + err = shutdownServers(ctx, metricsServer, srv, httpHealthcheck, stunServer) + wg.Wait() + return err +} + +func startServers(wg *sync.WaitGroup, metricsServer *metrics.Metrics, srv *server.Server, srvListenerCfg server.ListenerConfig, httpHealthcheck *healthcheck.Server, stunServer *stun.Server) { + wg.Add(1) + go func() { + defer wg.Done() + log.Infof("running metrics server: %s%s", metricsServer.Addr, metricsServer.Endpoint) + if err := metricsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) { + log.Fatalf("failed to start metrics server: %v", err) + } + }() + instanceURL := srv.InstanceURL() log.Infof("server will be available on: %s", instanceURL.String()) wg.Add(1) go func() { defer wg.Done() if err := srv.Listen(srvListenerCfg); err != nil { - log.Fatalf("failed to bind server: %s", err) + log.Fatalf("failed to bind relay server: %s", err) } }() - hCfg := healthcheck.Config{ - ListenAddress: cobraConfig.HealthcheckListenAddress, - ServiceChecker: srv, - } - httpHealthcheck, err := healthcheck.NewServer(hCfg) - if err != nil { - log.Debugf("failed to create healthcheck server: %v", err) - return fmt.Errorf("failed to create healthcheck server: %v", err) - } wg.Add(1) go func() { defer wg.Done() if err := httpHealthcheck.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) { - log.Fatalf("Failed to start healthcheck server: %v", err) + log.Fatalf("failed to start healthcheck server: %v", err) } }() - // it will block until exit signal - waitForExitSignal() + if stunServer != nil { + wg.Add(1) + go func() { + defer wg.Done() + if err := stunServer.Listen(); err != nil { + if errors.Is(err, stun.ErrServerClosed) { + return + } + log.Errorf("STUN server error: %v", err) + } + }() + } +} - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() +func shutdownServers(ctx context.Context, metricsServer *metrics.Metrics, srv *server.Server, httpHealthcheck *healthcheck.Server, stunServer *stun.Server) error { + var errs error - var shutDownErrors error if err := httpHealthcheck.Shutdown(ctx); err != nil { - shutDownErrors = multierror.Append(shutDownErrors, fmt.Errorf("failed to close healthcheck server: %v", err)) + errs = multierror.Append(errs, fmt.Errorf("failed to close healthcheck server: %w", err)) + } + + if stunServer != nil { + if err := stunServer.Shutdown(); err != nil { + errs = multierror.Append(errs, fmt.Errorf("failed to close STUN server: %w", err)) + } } if err := srv.Shutdown(ctx); err != nil { - shutDownErrors = multierror.Append(shutDownErrors, fmt.Errorf("failed to close server: %s", err)) + errs = multierror.Append(errs, fmt.Errorf("failed to close relay server: %w", err)) } log.Infof("shutting down metrics server") if err := metricsServer.Shutdown(ctx); err != nil { - shutDownErrors = multierror.Append(shutDownErrors, fmt.Errorf("failed to close metrics server: %v", err)) + errs = multierror.Append(errs, fmt.Errorf("failed to close metrics server: %w", err)) } - wg.Wait() - return shutDownErrors + return errs +} + +func createHealthCheck(hCfg healthcheck.Config) (*healthcheck.Server, error) { + httpHealthcheck, err := healthcheck.NewServer(hCfg) + if err != nil { + log.Debugf("failed to create healthcheck server: %v", err) + return nil, fmt.Errorf("failed to create healthcheck server: %v", err) + } + return httpHealthcheck, nil +} + +func createRelayServer(cfg server.Config) (*server.Server, error) { + srv, err := server.NewServer(cfg) + if err != nil { + return nil, fmt.Errorf("failed to create relay server: %v", err) + } + return srv, nil +} + +func cleanupSTUNListeners(stunListeners []*net.UDPConn) { + for _, l := range stunListeners { + _ = l.Close() + } +} + +func createSTUNListeners() ([]*net.UDPConn, error) { + var stunListeners []*net.UDPConn + if cobraConfig.EnableSTUN { + for _, port := range cobraConfig.STUNPorts { + listener, err := net.ListenUDP("udp", &net.UDPAddr{Port: port}) + if err != nil { + // Close already opened listeners on failure + cleanupSTUNListeners(stunListeners) + log.Debugf("failed to create STUN listener on port %d: %v", port, err) + return nil, fmt.Errorf("failed to create STUN listener on port %d: %v", port, err) + } + stunListeners = append(stunListeners, listener) + } + } + return stunListeners, nil } func handleTLSConfig(cfg *Config) (*tls.Config, bool, error) { diff --git a/stun/server.go b/stun/server.go new file mode 100644 index 000000000..be5717d48 --- /dev/null +++ b/stun/server.go @@ -0,0 +1,170 @@ +// Package stun provides an embedded STUN server for NAT traversal discovery. +package stun + +import ( + "errors" + "fmt" + "net" + "sync" + + "github.com/hashicorp/go-multierror" + nberrors "github.com/netbirdio/netbird/client/errors" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/formatter" + "github.com/pion/stun/v3" +) + +// ErrServerClosed is returned by Listen when the server is shut down gracefully. +var ErrServerClosed = errors.New("stun: server closed") + +// ErrNoListeners is returned by Listen when no UDP connections were provided. +var ErrNoListeners = errors.New("stun: no listeners configured") + +// Server implements a STUN server that responds to binding requests +// with the client's reflexive transport address. +type Server struct { + conns []*net.UDPConn + logger *log.Entry + logLevel log.Level + + wg sync.WaitGroup +} + +// NewServer creates a new STUN server with the given UDP listeners. +// The caller is responsible for creating and providing the listeners. +// logLevel can be: panic, fatal, error, warn, info, debug, trace +func NewServer(conns []*net.UDPConn, logLevel string) *Server { + level, err := log.ParseLevel(logLevel) + if err != nil { + level = log.InfoLevel + } + + // Create a separate logger with its own level setting + // This allows --stun-log-level to work independently of --log-level + stunLogger := log.New() + stunLogger.SetOutput(log.StandardLogger().Out) + stunLogger.SetLevel(level) + // Use the formatter package to set up formatter, ReportCaller, and context hook + formatter.SetTextFormatter(stunLogger) + + logger := stunLogger.WithField("component", "stun-server") + logger.Infof("STUN server log level set to: %s", level.String()) + + return &Server{ + conns: conns, + logger: logger, + logLevel: level, + } +} + +// Listen starts the STUN server and blocks until the server is shut down. +// Returns ErrServerClosed when shut down gracefully via Shutdown. +// Returns ErrNoListeners if no UDP connections were provided. +func (s *Server) Listen() error { + if len(s.conns) == 0 { + return ErrNoListeners + } + + // Start a read loop for each listener + for _, conn := range s.conns { + s.logger.Infof("STUN server listening on %s", conn.LocalAddr()) + s.wg.Add(1) + go s.readLoop(conn) + } + + s.wg.Wait() + return ErrServerClosed +} + +// readLoop continuously reads UDP packets and handles STUN requests. +func (s *Server) readLoop(conn *net.UDPConn) { + defer s.wg.Done() + buf := make([]byte, 1500) // Standard MTU size + for { + n, remoteAddr, err := conn.ReadFromUDP(buf) + + if err != nil { + // Check if the connection was closed externally + if errors.Is(err, net.ErrClosed) { + s.logger.Info("UDP connection closed, stopping read loop") + return + } + s.logger.Warnf("failed to read UDP packet: %v", err) + continue + } + + // Handle packet in the same goroutine to avoid complexity + // STUN responses are small and fast + s.handlePacket(conn, buf[:n], remoteAddr) + } +} + +// handlePacket processes a STUN request and sends a response. +func (s *Server) handlePacket(conn *net.UDPConn, data []byte, addr *net.UDPAddr) { + localPort := conn.LocalAddr().(*net.UDPAddr).Port + + s.logger.Debugf("[port:%d] received %d bytes from %s", localPort, len(data), addr) + + // Check if it's a STUN message + if !stun.IsMessage(data) { + s.logger.Debugf("[port:%d] not a STUN message (first bytes: %x)", localPort, data[:min(len(data), 8)]) + return + } + + // Parse the STUN message + msg := &stun.Message{Raw: data} + if err := msg.Decode(); err != nil { + s.logger.Warnf("[port:%d] failed to decode STUN message from %s: %v", localPort, addr, err) + return + } + + s.logger.Debugf("[port:%d] received STUN %s from %s (tx=%x)", localPort, msg.Type, addr, msg.TransactionID[:8]) + + // Only handle binding requests + if msg.Type != stun.BindingRequest { + s.logger.Debugf("[port:%d] ignoring non-binding request: %s", localPort, msg.Type) + return + } + + // Build the response + response, err := stun.Build( + stun.NewTransactionIDSetter(msg.TransactionID), + stun.BindingSuccess, + &stun.XORMappedAddress{ + IP: addr.IP, + Port: addr.Port, + }, + stun.Fingerprint, + ) + if err != nil { + s.logger.Errorf("[port:%d] failed to build STUN response: %v", localPort, err) + return + } + + // Send the response on the same connection it was received on + n, err := conn.WriteToUDP(response.Raw, addr) + if err != nil { + s.logger.Errorf("[port:%d] failed to send STUN response to %s: %v", localPort, addr, err) + return + } + + s.logger.Debugf("[port:%d] sent STUN BindingSuccess to %s (%d bytes) with XORMappedAddress %s:%d", localPort, addr, n, addr.IP, addr.Port) +} + +// Shutdown gracefully stops the STUN server. +func (s *Server) Shutdown() error { + s.logger.Info("shutting down STUN server") + + var merr *multierror.Error + + for _, conn := range s.conns { + if err := conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { + merr = multierror.Append(merr, fmt.Errorf("close STUN UDP connection: %w", err)) + } + } + + // Wait for all readLoops to finish + s.wg.Wait() + return nberrors.FormatErrorOrNil(merr) +} diff --git a/stun/server_test.go b/stun/server_test.go new file mode 100644 index 000000000..4fd949863 --- /dev/null +++ b/stun/server_test.go @@ -0,0 +1,479 @@ +package stun + +import ( + "errors" + "fmt" + "math/rand" + "net" + "os" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/pion/stun/v3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// createTestServer creates a STUN server listening on a random port for testing. +// Returns the server, the listener connection (caller must close), and the server address. +func createTestServer(t testing.TB) (*Server, *net.UDPConn, *net.UDPAddr) { + t.Helper() + conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) + require.NoError(t, err) + server := NewServer([]*net.UDPConn{conn}, "debug") + return server, conn, conn.LocalAddr().(*net.UDPAddr) +} + +// waitForServerReady polls the server with STUN binding requests until it responds. +// This avoids flaky tests on slow CI machines that relied on time.Sleep. +func waitForServerReady(t testing.TB, serverAddr *net.UDPAddr, timeout time.Duration) { + t.Helper() + deadline := time.Now().Add(timeout) + retryInterval := 10 * time.Millisecond + + clientConn, err := net.DialUDP("udp", nil, serverAddr) + require.NoError(t, err) + defer clientConn.Close() + + buf := make([]byte, 1500) + for time.Now().Before(deadline) { + msg, err := stun.Build(stun.TransactionID, stun.BindingRequest) + require.NoError(t, err) + + _, err = clientConn.Write(msg.Raw) + require.NoError(t, err) + + _ = clientConn.SetReadDeadline(time.Now().Add(retryInterval)) + n, err := clientConn.Read(buf) + if err != nil { + // Timeout or other error, retry + continue + } + + response := &stun.Message{Raw: buf[:n]} + if err := response.Decode(); err != nil { + continue + } + + if response.Type == stun.BindingSuccess { + return // Server is ready + } + } + + t.Fatalf("server did not become ready within %v", timeout) +} + +func TestServer_BindingRequest(t *testing.T) { + // Start the STUN server on a random port + server, listener, serverAddr := createTestServer(t) + + // Start server in background + serverErrCh := make(chan error, 1) + go func() { + serverErrCh <- server.Listen() + }() + + // Wait for server to be ready + waitForServerReady(t, serverAddr, 2*time.Second) + + // Create a UDP client + clientConn, err := net.DialUDP("udp", nil, serverAddr) + require.NoError(t, err) + defer clientConn.Close() + + // Build a STUN binding request + msg, err := stun.Build(stun.TransactionID, stun.BindingRequest) + require.NoError(t, err) + + // Send the request + _, err = clientConn.Write(msg.Raw) + require.NoError(t, err) + + // Read the response + buf := make([]byte, 1500) + _ = clientConn.SetReadDeadline(time.Now().Add(2 * time.Second)) + n, err := clientConn.Read(buf) + require.NoError(t, err) + + // Parse the response + response := &stun.Message{Raw: buf[:n]} + err = response.Decode() + require.NoError(t, err) + + // Verify it's a binding success + assert.Equal(t, stun.BindingSuccess, response.Type) + + // Extract the XOR-MAPPED-ADDRESS + var xorAddr stun.XORMappedAddress + err = xorAddr.GetFrom(response) + require.NoError(t, err) + + // Verify the address matches our client's local address + clientAddr := clientConn.LocalAddr().(*net.UDPAddr) + assert.Equal(t, clientAddr.IP.String(), xorAddr.IP.String()) + assert.Equal(t, clientAddr.Port, xorAddr.Port) + + // Close listener first to unblock readLoop, then shutdown + _ = listener.Close() + err = server.Shutdown() + require.NoError(t, err) +} + +func TestServer_IgnoresNonSTUNPackets(t *testing.T) { + server, listener, serverAddr := createTestServer(t) + + go func() { + _ = server.Listen() + }() + + waitForServerReady(t, serverAddr, 2*time.Second) + + clientConn, err := net.DialUDP("udp", nil, serverAddr) + require.NoError(t, err) + defer clientConn.Close() + + // Send non-STUN data + _, err = clientConn.Write([]byte("hello world")) + require.NoError(t, err) + + // Try to read response (should timeout since server ignores non-STUN) + buf := make([]byte, 1500) + _ = clientConn.SetReadDeadline(time.Now().Add(200 * time.Millisecond)) + _, err = clientConn.Read(buf) + assert.Error(t, err) // Should be a timeout error + + // Close listener first to unblock readLoop, then shutdown + _ = listener.Close() + _ = server.Shutdown() +} + +func TestServer_Shutdown(t *testing.T) { + server, listener, serverAddr := createTestServer(t) + + serverDone := make(chan struct{}) + go func() { + err := server.Listen() + assert.True(t, errors.Is(err, ErrServerClosed)) + close(serverDone) + }() + + waitForServerReady(t, serverAddr, 2*time.Second) + + // Close listener first to unblock readLoop, then shutdown + _ = listener.Close() + + err := server.Shutdown() + require.NoError(t, err) + + // Wait for Listen to return + select { + case <-serverDone: + // Success + case <-time.After(3 * time.Second): + t.Fatal("server did not shutdown in time") + } +} + +func TestServer_MultipleRequests(t *testing.T) { + server, listener, serverAddr := createTestServer(t) + + go func() { + _ = server.Listen() + }() + + waitForServerReady(t, serverAddr, 2*time.Second) + + // Create multiple clients and send requests + for i := 0; i < 5; i++ { + func() { + clientConn, err := net.DialUDP("udp", nil, serverAddr) + require.NoError(t, err) + defer clientConn.Close() + + msg, err := stun.Build(stun.TransactionID, stun.BindingRequest) + require.NoError(t, err) + + _, err = clientConn.Write(msg.Raw) + require.NoError(t, err) + + buf := make([]byte, 1500) + _ = clientConn.SetReadDeadline(time.Now().Add(2 * time.Second)) + n, err := clientConn.Read(buf) + require.NoError(t, err) + + response := &stun.Message{Raw: buf[:n]} + err = response.Decode() + require.NoError(t, err) + + assert.Equal(t, stun.BindingSuccess, response.Type) + }() + } + + // Close listener first to unblock readLoop, then shutdown + _ = listener.Close() + _ = server.Shutdown() +} + +func TestServer_ConcurrentClients(t *testing.T) { + numClients := 100 + requestsPerClient := 5 + maxStartDelay := 100 * time.Millisecond // Random delay before client starts + maxRequestDelay := 500 * time.Millisecond // Random delay between requests + + // Remote server to test against via env var STUN_TEST_SERVER + // Example: STUN_TEST_SERVER=example.netbird.io:3478 go test -v ./stun/... -run ConcurrentClients + remoteServer := os.Getenv("STUN_TEST_SERVER") + + var serverAddr *net.UDPAddr + var server *Server + var listener *net.UDPConn + + if remoteServer != "" { + // Use remote server + var err error + serverAddr, err = net.ResolveUDPAddr("udp", remoteServer) + require.NoError(t, err) + t.Logf("Testing against remote server: %s", remoteServer) + } else { + // Start local server + server, listener, serverAddr = createTestServer(t) + go func() { + _ = server.Listen() + }() + waitForServerReady(t, serverAddr, 2*time.Second) + t.Logf("Testing against local server: %s", serverAddr) + } + + var wg sync.WaitGroup + errorz := make(chan error, numClients*requestsPerClient) + successCount := make(chan int, numClients) + + startTime := time.Now() + + for i := 0; i < numClients; i++ { + wg.Add(1) + go func(clientID int) { + defer wg.Done() + + // Random delay before starting + time.Sleep(time.Duration(rand.Int63n(int64(maxStartDelay)))) + + clientConn, err := net.DialUDP("udp", nil, serverAddr) + if err != nil { + errorz <- fmt.Errorf("client %d: failed to dial: %w", clientID, err) + return + } + defer clientConn.Close() + + success := 0 + for j := 0; j < requestsPerClient; j++ { + // Random delay between requests + if j > 0 { + time.Sleep(time.Duration(rand.Int63n(int64(maxRequestDelay)))) + } + + msg, err := stun.Build(stun.TransactionID, stun.BindingRequest) + if err != nil { + errorz <- fmt.Errorf("client %d: failed to build request: %w", clientID, err) + continue + } + + _, err = clientConn.Write(msg.Raw) + if err != nil { + errorz <- fmt.Errorf("client %d: failed to write: %w", clientID, err) + continue + } + + buf := make([]byte, 1500) + _ = clientConn.SetReadDeadline(time.Now().Add(5 * time.Second)) + n, err := clientConn.Read(buf) + if err != nil { + errorz <- fmt.Errorf("client %d: failed to read: %w", clientID, err) + continue + } + + response := &stun.Message{Raw: buf[:n]} + if err := response.Decode(); err != nil { + errorz <- fmt.Errorf("client %d: failed to decode: %w", clientID, err) + continue + } + + if response.Type != stun.BindingSuccess { + errorz <- fmt.Errorf("client %d: unexpected response type: %s", clientID, response.Type) + continue + } + + success++ + } + successCount <- success + }(i) + } + + wg.Wait() + close(errorz) + close(successCount) + + elapsed := time.Since(startTime) + + totalSuccess := 0 + for count := range successCount { + totalSuccess += count + } + + var errs []error + for err := range errorz { + errs = append(errs, err) + } + + totalRequests := numClients * requestsPerClient + t.Logf("Completed %d/%d requests in %v (%.2f req/s)", + totalSuccess, totalRequests, elapsed, + float64(totalSuccess)/elapsed.Seconds()) + + if len(errs) > 0 { + t.Logf("Errors (%d):", len(errs)) + for i, err := range errs { + if i < 10 { // Only show first 10 errors + t.Logf(" - %v", err) + } + } + } + + // Require at least 95% success rate + successRate := float64(totalSuccess) / float64(totalRequests) + require.GreaterOrEqual(t, successRate, 0.95, "success rate too low: %.2f%%", successRate*100) + + // Cleanup local server if used + if server != nil { + // Close listener first to unblock readLoop, then shutdown + _ = listener.Close() + _ = server.Shutdown() + } +} + +func TestServer_MultiplePorts(t *testing.T) { + // Create listeners on two random ports + conn1, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) + require.NoError(t, err) + conn2, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) + require.NoError(t, err) + + addr1 := conn1.LocalAddr().(*net.UDPAddr) + addr2 := conn2.LocalAddr().(*net.UDPAddr) + + server := NewServer([]*net.UDPConn{conn1, conn2}, "debug") + + go func() { + _ = server.Listen() + }() + + // Wait for server to be ready (checking first port is sufficient) + waitForServerReady(t, addr1, 2*time.Second) + + // Test requests on both ports + for _, serverAddr := range []*net.UDPAddr{addr1, addr2} { + func() { + clientConn, err := net.DialUDP("udp", nil, serverAddr) + require.NoError(t, err) + defer clientConn.Close() + + msg, err := stun.Build(stun.TransactionID, stun.BindingRequest) + require.NoError(t, err) + + _, err = clientConn.Write(msg.Raw) + require.NoError(t, err) + + buf := make([]byte, 1500) + _ = clientConn.SetReadDeadline(time.Now().Add(2 * time.Second)) + n, err := clientConn.Read(buf) + require.NoError(t, err) + + response := &stun.Message{Raw: buf[:n]} + err = response.Decode() + require.NoError(t, err) + + assert.Equal(t, stun.BindingSuccess, response.Type) + + var xorAddr stun.XORMappedAddress + err = xorAddr.GetFrom(response) + require.NoError(t, err) + + clientAddr := clientConn.LocalAddr().(*net.UDPAddr) + assert.Equal(t, clientAddr.Port, xorAddr.Port) + }() + } + + // Close listeners first to unblock readLoops, then shutdown + _ = conn1.Close() + _ = conn2.Close() + _ = server.Shutdown() +} + +// BenchmarkSTUNServer benchmarks the STUN server with concurrent clients +func BenchmarkSTUNServer(b *testing.B) { + server, listener, serverAddr := createTestServer(b) + + go func() { + _ = server.Listen() + }() + + waitForServerReady(b, serverAddr, 2*time.Second) + + // Capture first error atomically - b.Fatal cannot be called from worker goroutines + var firstErr atomic.Pointer[error] + setErr := func(err error) { + firstErr.CompareAndSwap(nil, &err) + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + // Stop work if an error has occurred + if firstErr.Load() != nil { + return + } + + clientConn, err := net.DialUDP("udp", nil, serverAddr) + if err != nil { + setErr(err) + return + } + defer clientConn.Close() + + buf := make([]byte, 1500) + + for pb.Next() { + if firstErr.Load() != nil { + return + } + + msg, _ := stun.Build(stun.TransactionID, stun.BindingRequest) + _, _ = clientConn.Write(msg.Raw) + + _ = clientConn.SetReadDeadline(time.Now().Add(2 * time.Second)) + n, err := clientConn.Read(buf) + if err != nil { + setErr(err) + return + } + + response := &stun.Message{Raw: buf[:n]} + if err := response.Decode(); err != nil { + setErr(err) + return + } + } + }) + + b.StopTimer() + + // Fail after RunParallel completes + if errPtr := firstErr.Load(); errPtr != nil { + b.Fatal(*errPtr) + } + + // Close listener first to unblock readLoop, then shutdown + _ = listener.Close() + _ = server.Shutdown() +}