mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
Feature/embedded STUN (#5062)
This commit is contained in:
8
go.mod
8
go.mod
@@ -78,8 +78,8 @@ require (
|
|||||||
github.com/pion/logging v0.2.4
|
github.com/pion/logging v0.2.4
|
||||||
github.com/pion/randutil v0.1.0
|
github.com/pion/randutil v0.1.0
|
||||||
github.com/pion/stun/v2 v2.0.0
|
github.com/pion/stun/v2 v2.0.0
|
||||||
github.com/pion/stun/v3 v3.0.0
|
github.com/pion/stun/v3 v3.1.0
|
||||||
github.com/pion/transport/v3 v3.0.7
|
github.com/pion/transport/v3 v3.1.1
|
||||||
github.com/pion/turn/v3 v3.0.1
|
github.com/pion/turn/v3 v3.0.1
|
||||||
github.com/pkg/sftp v1.13.9
|
github.com/pkg/sftp v1.13.9
|
||||||
github.com/prometheus/client_golang v1.23.2
|
github.com/prometheus/client_golang v1.23.2
|
||||||
@@ -241,7 +241,7 @@ require (
|
|||||||
github.com/opencontainers/go-digest v1.0.0 // indirect
|
github.com/opencontainers/go-digest v1.0.0 // indirect
|
||||||
github.com/opencontainers/image-spec v1.1.0 // indirect
|
github.com/opencontainers/image-spec v1.1.0 // indirect
|
||||||
github.com/pion/dtls/v2 v2.2.10 // indirect
|
github.com/pion/dtls/v2 v2.2.10 // indirect
|
||||||
github.com/pion/dtls/v3 v3.0.7 // indirect
|
github.com/pion/dtls/v3 v3.0.9 // indirect
|
||||||
github.com/pion/mdns/v2 v2.0.7 // indirect
|
github.com/pion/mdns/v2 v2.0.7 // indirect
|
||||||
github.com/pion/transport/v2 v2.2.4 // indirect
|
github.com/pion/transport/v2 v2.2.4 // indirect
|
||||||
github.com/pion/turn/v4 v4.1.1 // indirect
|
github.com/pion/turn/v4 v4.1.1 // indirect
|
||||||
@@ -263,7 +263,7 @@ require (
|
|||||||
github.com/tklauser/numcpus v0.8.0 // indirect
|
github.com/tklauser/numcpus v0.8.0 // indirect
|
||||||
github.com/vishvananda/netns v0.0.5 // indirect
|
github.com/vishvananda/netns v0.0.5 // indirect
|
||||||
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
|
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
|
||||||
github.com/wlynxg/anet v0.0.3 // indirect
|
github.com/wlynxg/anet v0.0.5 // indirect
|
||||||
github.com/yuin/goldmark v1.7.8 // indirect
|
github.com/yuin/goldmark v1.7.8 // indirect
|
||||||
github.com/zeebo/blake3 v0.2.3 // indirect
|
github.com/zeebo/blake3 v0.2.3 // indirect
|
||||||
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
|
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
|
||||||
|
|||||||
16
go.sum
16
go.sum
@@ -444,8 +444,8 @@ github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203/go.mod h1:pxMtw7c
|
|||||||
github.com/pion/dtls/v2 v2.2.7/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ/F9s=
|
github.com/pion/dtls/v2 v2.2.7/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ/F9s=
|
||||||
github.com/pion/dtls/v2 v2.2.10 h1:u2Axk+FyIR1VFTPurktB+1zoEPGIW3bmyj3LEFrXjAA=
|
github.com/pion/dtls/v2 v2.2.10 h1:u2Axk+FyIR1VFTPurktB+1zoEPGIW3bmyj3LEFrXjAA=
|
||||||
github.com/pion/dtls/v2 v2.2.10/go.mod h1:d9SYc9fch0CqK90mRk1dC7AkzzpwJj6u2GU3u+9pqFE=
|
github.com/pion/dtls/v2 v2.2.10/go.mod h1:d9SYc9fch0CqK90mRk1dC7AkzzpwJj6u2GU3u+9pqFE=
|
||||||
github.com/pion/dtls/v3 v3.0.7 h1:bItXtTYYhZwkPFk4t1n3Kkf5TDrfj6+4wG+CZR8uI9Q=
|
github.com/pion/dtls/v3 v3.0.9 h1:4AijfFRm8mAjd1gfdlB1wzJF3fjjR/VPIpJgkEtvYmM=
|
||||||
github.com/pion/dtls/v3 v3.0.7/go.mod h1:uDlH5VPrgOQIw59irKYkMudSFprY9IEFCqz/eTz16f8=
|
github.com/pion/dtls/v3 v3.0.9/go.mod h1:abApPjgadS/ra1wvUzHLc3o2HvoxppAh+NZkyApL4Os=
|
||||||
github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms=
|
github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms=
|
||||||
github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8=
|
github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8=
|
||||||
github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so=
|
github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so=
|
||||||
@@ -455,14 +455,14 @@ github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA=
|
|||||||
github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8=
|
github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8=
|
||||||
github.com/pion/stun/v2 v2.0.0 h1:A5+wXKLAypxQri59+tmQKVs7+l6mMM+3d+eER9ifRU0=
|
github.com/pion/stun/v2 v2.0.0 h1:A5+wXKLAypxQri59+tmQKVs7+l6mMM+3d+eER9ifRU0=
|
||||||
github.com/pion/stun/v2 v2.0.0/go.mod h1:22qRSh08fSEttYUmJZGlriq9+03jtVmXNODgLccj8GQ=
|
github.com/pion/stun/v2 v2.0.0/go.mod h1:22qRSh08fSEttYUmJZGlriq9+03jtVmXNODgLccj8GQ=
|
||||||
github.com/pion/stun/v3 v3.0.0 h1:4h1gwhWLWuZWOJIJR9s2ferRO+W3zA/b6ijOI6mKzUw=
|
github.com/pion/stun/v3 v3.1.0 h1:bS1jjT3tGWZ4UPmIUeyalOylamTMTFg1OvXtY/r6seM=
|
||||||
github.com/pion/stun/v3 v3.0.0/go.mod h1:HvCN8txt8mwi4FBvS3EmDghW6aQJ24T+y+1TKjB5jyU=
|
github.com/pion/stun/v3 v3.1.0/go.mod h1:egmx1CUcfSSGJxQCOjtVlomfPqmQ58BibPyuOWNGQEU=
|
||||||
github.com/pion/transport/v2 v2.2.1/go.mod h1:cXXWavvCnFF6McHTft3DWS9iic2Mftcz1Aq29pGcU5g=
|
github.com/pion/transport/v2 v2.2.1/go.mod h1:cXXWavvCnFF6McHTft3DWS9iic2Mftcz1Aq29pGcU5g=
|
||||||
github.com/pion/transport/v2 v2.2.4 h1:41JJK6DZQYSeVLxILA2+F4ZkKb4Xd/tFJZRFZQ9QAlo=
|
github.com/pion/transport/v2 v2.2.4 h1:41JJK6DZQYSeVLxILA2+F4ZkKb4Xd/tFJZRFZQ9QAlo=
|
||||||
github.com/pion/transport/v2 v2.2.4/go.mod h1:q2U/tf9FEfnSBGSW6w5Qp5PFWRLRj3NjLhCCgpRK4p0=
|
github.com/pion/transport/v2 v2.2.4/go.mod h1:q2U/tf9FEfnSBGSW6w5Qp5PFWRLRj3NjLhCCgpRK4p0=
|
||||||
github.com/pion/transport/v3 v3.0.1/go.mod h1:UY7kiITrlMv7/IKgd5eTUcaahZx5oUN3l9SzK5f5xE0=
|
github.com/pion/transport/v3 v3.0.1/go.mod h1:UY7kiITrlMv7/IKgd5eTUcaahZx5oUN3l9SzK5f5xE0=
|
||||||
github.com/pion/transport/v3 v3.0.7 h1:iRbMH05BzSNwhILHoBoAPxoB9xQgOaJk+591KC9P1o0=
|
github.com/pion/transport/v3 v3.1.1 h1:Tr684+fnnKlhPceU+ICdrw6KKkTms+5qHMgw6bIkYOM=
|
||||||
github.com/pion/transport/v3 v3.0.7/go.mod h1:YleKiTZ4vqNxVwh77Z0zytYi7rXHl7j6uPLGhhz9rwo=
|
github.com/pion/transport/v3 v3.1.1/go.mod h1:+c2eewC5WJQHiAA46fkMMzoYZSuGzA/7E2FPrOYHctQ=
|
||||||
github.com/pion/turn/v3 v3.0.1 h1:wLi7BTQr6/Q20R0vt/lHbjv6y4GChFtC33nkYbasoT8=
|
github.com/pion/turn/v3 v3.0.1 h1:wLi7BTQr6/Q20R0vt/lHbjv6y4GChFtC33nkYbasoT8=
|
||||||
github.com/pion/turn/v3 v3.0.1/go.mod h1:MrJDKgqryDyWy1/4NT9TWfXWGMC7UHT6pJIv1+gMeNE=
|
github.com/pion/turn/v3 v3.0.1/go.mod h1:MrJDKgqryDyWy1/4NT9TWfXWGMC7UHT6pJIv1+gMeNE=
|
||||||
github.com/pion/turn/v4 v4.1.1 h1:9UnY2HB99tpDyz3cVVZguSxcqkJ1DsTSZ+8TGruh4fc=
|
github.com/pion/turn/v4 v4.1.1 h1:9UnY2HB99tpDyz3cVVZguSxcqkJ1DsTSZ+8TGruh4fc=
|
||||||
@@ -574,8 +574,8 @@ github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IU
|
|||||||
github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok=
|
github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok=
|
||||||
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
|
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
|
||||||
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
|
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
|
||||||
github.com/wlynxg/anet v0.0.3 h1:PvR53psxFXstc12jelG6f1Lv4MWqE0tI76/hHGjh9rg=
|
github.com/wlynxg/anet v0.0.5 h1:J3VJGi1gvo0JwZ/P1/Yc/8p63SoW98B5dHkYDmpgvvU=
|
||||||
github.com/wlynxg/anet v0.0.3/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA=
|
github.com/wlynxg/anet v0.0.5/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA=
|
||||||
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||||
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||||
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
|
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
@@ -22,6 +23,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/relay/server"
|
"github.com/netbirdio/netbird/relay/server"
|
||||||
"github.com/netbirdio/netbird/shared/relay/auth"
|
"github.com/netbirdio/netbird/shared/relay/auth"
|
||||||
"github.com/netbirdio/netbird/signal/metrics"
|
"github.com/netbirdio/netbird/signal/metrics"
|
||||||
|
"github.com/netbirdio/netbird/stun"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -43,6 +45,10 @@ type Config struct {
|
|||||||
LogLevel string
|
LogLevel string
|
||||||
LogFile string
|
LogFile string
|
||||||
HealthcheckListenAddress string
|
HealthcheckListenAddress string
|
||||||
|
// STUN server configuration
|
||||||
|
EnableSTUN bool
|
||||||
|
STUNPorts []int
|
||||||
|
STUNLogLevel string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c Config) Validate() error {
|
func (c Config) Validate() error {
|
||||||
@@ -52,6 +58,25 @@ func (c Config) Validate() error {
|
|||||||
if c.AuthSecret == "" {
|
if c.AuthSecret == "" {
|
||||||
return fmt.Errorf("auth secret is required")
|
return fmt.Errorf("auth secret is required")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate STUN configuration
|
||||||
|
if c.EnableSTUN {
|
||||||
|
if len(c.STUNPorts) == 0 {
|
||||||
|
return fmt.Errorf("--stun-ports is required when --enable-stun is set")
|
||||||
|
}
|
||||||
|
|
||||||
|
seen := make(map[int]bool)
|
||||||
|
for _, port := range c.STUNPorts {
|
||||||
|
if port <= 0 || port > 65535 {
|
||||||
|
return fmt.Errorf("invalid STUN port %d: must be between 1 and 65535", port)
|
||||||
|
}
|
||||||
|
if seen[port] {
|
||||||
|
return fmt.Errorf("duplicate STUN port %d", port)
|
||||||
|
}
|
||||||
|
seen[port] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -91,6 +116,9 @@ func init() {
|
|||||||
rootCmd.PersistentFlags().StringVar(&cobraConfig.LogLevel, "log-level", "info", "log level")
|
rootCmd.PersistentFlags().StringVar(&cobraConfig.LogLevel, "log-level", "info", "log level")
|
||||||
rootCmd.PersistentFlags().StringVar(&cobraConfig.LogFile, "log-file", "console", "log file")
|
rootCmd.PersistentFlags().StringVar(&cobraConfig.LogFile, "log-file", "console", "log file")
|
||||||
rootCmd.PersistentFlags().StringVarP(&cobraConfig.HealthcheckListenAddress, "health-listen-address", "H", ":9000", "listen address of healthcheck server")
|
rootCmd.PersistentFlags().StringVarP(&cobraConfig.HealthcheckListenAddress, "health-listen-address", "H", ":9000", "listen address of healthcheck server")
|
||||||
|
rootCmd.PersistentFlags().BoolVar(&cobraConfig.EnableSTUN, "enable-stun", false, "enable embedded STUN server")
|
||||||
|
rootCmd.PersistentFlags().IntSliceVar(&cobraConfig.STUNPorts, "stun-ports", []int{3478}, "ports for the embedded STUN server (can be specified multiple times or comma-separated)")
|
||||||
|
rootCmd.PersistentFlags().StringVar(&cobraConfig.STUNLogLevel, "stun-log-level", "info", "log level for STUN server (panic, fatal, error, warn, info, debug, trace)")
|
||||||
|
|
||||||
setFlagsFromEnvVars(rootCmd)
|
setFlagsFromEnvVars(rootCmd)
|
||||||
}
|
}
|
||||||
@@ -119,21 +147,14 @@ func execute(cmd *cobra.Command, args []string) error {
|
|||||||
return fmt.Errorf("failed to initialize log: %s", err)
|
return fmt.Errorf("failed to initialize log: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Resource creation phase (fail fast before starting any goroutines)
|
||||||
|
|
||||||
metricsServer, err := metrics.NewServer(cobraConfig.MetricsPort, "")
|
metricsServer, err := metrics.NewServer(cobraConfig.MetricsPort, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("setup metrics: %v", err)
|
log.Debugf("setup metrics: %v", err)
|
||||||
return fmt.Errorf("setup metrics: %v", err)
|
return fmt.Errorf("setup metrics: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
log.Infof("running metrics server: %s%s", metricsServer.Addr, metricsServer.Endpoint)
|
|
||||||
if err := metricsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
|
|
||||||
log.Fatalf("Failed to start metrics server: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
srvListenerCfg := server.ListenerConfig{
|
srvListenerCfg := server.ListenerConfig{
|
||||||
Address: cobraConfig.ListenAddress,
|
Address: cobraConfig.ListenAddress,
|
||||||
}
|
}
|
||||||
@@ -145,6 +166,12 @@ func execute(cmd *cobra.Command, args []string) error {
|
|||||||
}
|
}
|
||||||
srvListenerCfg.TLSConfig = tlsConfig
|
srvListenerCfg.TLSConfig = tlsConfig
|
||||||
|
|
||||||
|
// Create STUN listeners early to fail fast
|
||||||
|
stunListeners, err := createSTUNListeners()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
hashedSecret := sha256.Sum256([]byte(cobraConfig.AuthSecret))
|
hashedSecret := sha256.Sum256([]byte(cobraConfig.AuthSecret))
|
||||||
authenticator := auth.NewTimedHMACValidator(hashedSecret[:], 24*time.Hour)
|
authenticator := auth.NewTimedHMACValidator(hashedSecret[:], 24*time.Hour)
|
||||||
|
|
||||||
@@ -155,60 +182,145 @@ func execute(cmd *cobra.Command, args []string) error {
|
|||||||
TLSSupport: tlsSupport,
|
TLSSupport: tlsSupport,
|
||||||
}
|
}
|
||||||
|
|
||||||
srv, err := server.NewServer(cfg)
|
srv, err := createRelayServer(cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("failed to create relay server: %v", err)
|
cleanupSTUNListeners(stunListeners)
|
||||||
return fmt.Errorf("failed to create relay server: %v", err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
hCfg := healthcheck.Config{
|
||||||
|
ListenAddress: cobraConfig.HealthcheckListenAddress,
|
||||||
|
ServiceChecker: srv,
|
||||||
|
}
|
||||||
|
httpHealthcheck, err := createHealthCheck(hCfg)
|
||||||
|
if err != nil {
|
||||||
|
cleanupSTUNListeners(stunListeners)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var stunServer *stun.Server
|
||||||
|
if len(stunListeners) > 0 {
|
||||||
|
stunServer = stun.NewServer(stunListeners, cobraConfig.STUNLogLevel)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start all servers (only after all resources are successfully created)
|
||||||
|
startServers(&wg, metricsServer, srv, srvListenerCfg, httpHealthcheck, stunServer)
|
||||||
|
|
||||||
|
waitForExitSignal()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
err = shutdownServers(ctx, metricsServer, srv, httpHealthcheck, stunServer)
|
||||||
|
wg.Wait()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func startServers(wg *sync.WaitGroup, metricsServer *metrics.Metrics, srv *server.Server, srvListenerCfg server.ListenerConfig, httpHealthcheck *healthcheck.Server, stunServer *stun.Server) {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
log.Infof("running metrics server: %s%s", metricsServer.Addr, metricsServer.Endpoint)
|
||||||
|
if err := metricsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
|
||||||
|
log.Fatalf("failed to start metrics server: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
instanceURL := srv.InstanceURL()
|
instanceURL := srv.InstanceURL()
|
||||||
log.Infof("server will be available on: %s", instanceURL.String())
|
log.Infof("server will be available on: %s", instanceURL.String())
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
if err := srv.Listen(srvListenerCfg); err != nil {
|
if err := srv.Listen(srvListenerCfg); err != nil {
|
||||||
log.Fatalf("failed to bind server: %s", err)
|
log.Fatalf("failed to bind relay server: %s", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
hCfg := healthcheck.Config{
|
|
||||||
ListenAddress: cobraConfig.HealthcheckListenAddress,
|
|
||||||
ServiceChecker: srv,
|
|
||||||
}
|
|
||||||
httpHealthcheck, err := healthcheck.NewServer(hCfg)
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to create healthcheck server: %v", err)
|
|
||||||
return fmt.Errorf("failed to create healthcheck server: %v", err)
|
|
||||||
}
|
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
if err := httpHealthcheck.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
|
if err := httpHealthcheck.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
|
||||||
log.Fatalf("Failed to start healthcheck server: %v", err)
|
log.Fatalf("failed to start healthcheck server: %v", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// it will block until exit signal
|
if stunServer != nil {
|
||||||
waitForExitSignal()
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
if err := stunServer.Listen(); err != nil {
|
||||||
|
if errors.Is(err, stun.ErrServerClosed) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Errorf("STUN server error: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
func shutdownServers(ctx context.Context, metricsServer *metrics.Metrics, srv *server.Server, httpHealthcheck *healthcheck.Server, stunServer *stun.Server) error {
|
||||||
defer cancel()
|
var errs error
|
||||||
|
|
||||||
var shutDownErrors error
|
|
||||||
if err := httpHealthcheck.Shutdown(ctx); err != nil {
|
if err := httpHealthcheck.Shutdown(ctx); err != nil {
|
||||||
shutDownErrors = multierror.Append(shutDownErrors, fmt.Errorf("failed to close healthcheck server: %v", err))
|
errs = multierror.Append(errs, fmt.Errorf("failed to close healthcheck server: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
if stunServer != nil {
|
||||||
|
if err := stunServer.Shutdown(); err != nil {
|
||||||
|
errs = multierror.Append(errs, fmt.Errorf("failed to close STUN server: %w", err))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := srv.Shutdown(ctx); err != nil {
|
if err := srv.Shutdown(ctx); err != nil {
|
||||||
shutDownErrors = multierror.Append(shutDownErrors, fmt.Errorf("failed to close server: %s", err))
|
errs = multierror.Append(errs, fmt.Errorf("failed to close relay server: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("shutting down metrics server")
|
log.Infof("shutting down metrics server")
|
||||||
if err := metricsServer.Shutdown(ctx); err != nil {
|
if err := metricsServer.Shutdown(ctx); err != nil {
|
||||||
shutDownErrors = multierror.Append(shutDownErrors, fmt.Errorf("failed to close metrics server: %v", err))
|
errs = multierror.Append(errs, fmt.Errorf("failed to close metrics server: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
wg.Wait()
|
return errs
|
||||||
return shutDownErrors
|
}
|
||||||
|
|
||||||
|
func createHealthCheck(hCfg healthcheck.Config) (*healthcheck.Server, error) {
|
||||||
|
httpHealthcheck, err := healthcheck.NewServer(hCfg)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to create healthcheck server: %v", err)
|
||||||
|
return nil, fmt.Errorf("failed to create healthcheck server: %v", err)
|
||||||
|
}
|
||||||
|
return httpHealthcheck, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func createRelayServer(cfg server.Config) (*server.Server, error) {
|
||||||
|
srv, err := server.NewServer(cfg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create relay server: %v", err)
|
||||||
|
}
|
||||||
|
return srv, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func cleanupSTUNListeners(stunListeners []*net.UDPConn) {
|
||||||
|
for _, l := range stunListeners {
|
||||||
|
_ = l.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func createSTUNListeners() ([]*net.UDPConn, error) {
|
||||||
|
var stunListeners []*net.UDPConn
|
||||||
|
if cobraConfig.EnableSTUN {
|
||||||
|
for _, port := range cobraConfig.STUNPorts {
|
||||||
|
listener, err := net.ListenUDP("udp", &net.UDPAddr{Port: port})
|
||||||
|
if err != nil {
|
||||||
|
// Close already opened listeners on failure
|
||||||
|
cleanupSTUNListeners(stunListeners)
|
||||||
|
log.Debugf("failed to create STUN listener on port %d: %v", port, err)
|
||||||
|
return nil, fmt.Errorf("failed to create STUN listener on port %d: %v", port, err)
|
||||||
|
}
|
||||||
|
stunListeners = append(stunListeners, listener)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return stunListeners, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleTLSConfig(cfg *Config) (*tls.Config, bool, error) {
|
func handleTLSConfig(cfg *Config) (*tls.Config, bool, error) {
|
||||||
|
|||||||
170
stun/server.go
Normal file
170
stun/server.go
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
// Package stun provides an embedded STUN server for NAT traversal discovery.
|
||||||
|
package stun
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/formatter"
|
||||||
|
"github.com/pion/stun/v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrServerClosed is returned by Listen when the server is shut down gracefully.
|
||||||
|
var ErrServerClosed = errors.New("stun: server closed")
|
||||||
|
|
||||||
|
// ErrNoListeners is returned by Listen when no UDP connections were provided.
|
||||||
|
var ErrNoListeners = errors.New("stun: no listeners configured")
|
||||||
|
|
||||||
|
// Server implements a STUN server that responds to binding requests
|
||||||
|
// with the client's reflexive transport address.
|
||||||
|
type Server struct {
|
||||||
|
conns []*net.UDPConn
|
||||||
|
logger *log.Entry
|
||||||
|
logLevel log.Level
|
||||||
|
|
||||||
|
wg sync.WaitGroup
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewServer creates a new STUN server with the given UDP listeners.
|
||||||
|
// The caller is responsible for creating and providing the listeners.
|
||||||
|
// logLevel can be: panic, fatal, error, warn, info, debug, trace
|
||||||
|
func NewServer(conns []*net.UDPConn, logLevel string) *Server {
|
||||||
|
level, err := log.ParseLevel(logLevel)
|
||||||
|
if err != nil {
|
||||||
|
level = log.InfoLevel
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a separate logger with its own level setting
|
||||||
|
// This allows --stun-log-level to work independently of --log-level
|
||||||
|
stunLogger := log.New()
|
||||||
|
stunLogger.SetOutput(log.StandardLogger().Out)
|
||||||
|
stunLogger.SetLevel(level)
|
||||||
|
// Use the formatter package to set up formatter, ReportCaller, and context hook
|
||||||
|
formatter.SetTextFormatter(stunLogger)
|
||||||
|
|
||||||
|
logger := stunLogger.WithField("component", "stun-server")
|
||||||
|
logger.Infof("STUN server log level set to: %s", level.String())
|
||||||
|
|
||||||
|
return &Server{
|
||||||
|
conns: conns,
|
||||||
|
logger: logger,
|
||||||
|
logLevel: level,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Listen starts the STUN server and blocks until the server is shut down.
|
||||||
|
// Returns ErrServerClosed when shut down gracefully via Shutdown.
|
||||||
|
// Returns ErrNoListeners if no UDP connections were provided.
|
||||||
|
func (s *Server) Listen() error {
|
||||||
|
if len(s.conns) == 0 {
|
||||||
|
return ErrNoListeners
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start a read loop for each listener
|
||||||
|
for _, conn := range s.conns {
|
||||||
|
s.logger.Infof("STUN server listening on %s", conn.LocalAddr())
|
||||||
|
s.wg.Add(1)
|
||||||
|
go s.readLoop(conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.wg.Wait()
|
||||||
|
return ErrServerClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
// readLoop continuously reads UDP packets and handles STUN requests.
|
||||||
|
func (s *Server) readLoop(conn *net.UDPConn) {
|
||||||
|
defer s.wg.Done()
|
||||||
|
buf := make([]byte, 1500) // Standard MTU size
|
||||||
|
for {
|
||||||
|
n, remoteAddr, err := conn.ReadFromUDP(buf)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
// Check if the connection was closed externally
|
||||||
|
if errors.Is(err, net.ErrClosed) {
|
||||||
|
s.logger.Info("UDP connection closed, stopping read loop")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.logger.Warnf("failed to read UDP packet: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle packet in the same goroutine to avoid complexity
|
||||||
|
// STUN responses are small and fast
|
||||||
|
s.handlePacket(conn, buf[:n], remoteAddr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handlePacket processes a STUN request and sends a response.
|
||||||
|
func (s *Server) handlePacket(conn *net.UDPConn, data []byte, addr *net.UDPAddr) {
|
||||||
|
localPort := conn.LocalAddr().(*net.UDPAddr).Port
|
||||||
|
|
||||||
|
s.logger.Debugf("[port:%d] received %d bytes from %s", localPort, len(data), addr)
|
||||||
|
|
||||||
|
// Check if it's a STUN message
|
||||||
|
if !stun.IsMessage(data) {
|
||||||
|
s.logger.Debugf("[port:%d] not a STUN message (first bytes: %x)", localPort, data[:min(len(data), 8)])
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the STUN message
|
||||||
|
msg := &stun.Message{Raw: data}
|
||||||
|
if err := msg.Decode(); err != nil {
|
||||||
|
s.logger.Warnf("[port:%d] failed to decode STUN message from %s: %v", localPort, addr, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.logger.Debugf("[port:%d] received STUN %s from %s (tx=%x)", localPort, msg.Type, addr, msg.TransactionID[:8])
|
||||||
|
|
||||||
|
// Only handle binding requests
|
||||||
|
if msg.Type != stun.BindingRequest {
|
||||||
|
s.logger.Debugf("[port:%d] ignoring non-binding request: %s", localPort, msg.Type)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build the response
|
||||||
|
response, err := stun.Build(
|
||||||
|
stun.NewTransactionIDSetter(msg.TransactionID),
|
||||||
|
stun.BindingSuccess,
|
||||||
|
&stun.XORMappedAddress{
|
||||||
|
IP: addr.IP,
|
||||||
|
Port: addr.Port,
|
||||||
|
},
|
||||||
|
stun.Fingerprint,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Errorf("[port:%d] failed to build STUN response: %v", localPort, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send the response on the same connection it was received on
|
||||||
|
n, err := conn.WriteToUDP(response.Raw, addr)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Errorf("[port:%d] failed to send STUN response to %s: %v", localPort, addr, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.logger.Debugf("[port:%d] sent STUN BindingSuccess to %s (%d bytes) with XORMappedAddress %s:%d", localPort, addr, n, addr.IP, addr.Port)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shutdown gracefully stops the STUN server.
|
||||||
|
func (s *Server) Shutdown() error {
|
||||||
|
s.logger.Info("shutting down STUN server")
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
|
||||||
|
for _, conn := range s.conns {
|
||||||
|
if err := conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("close STUN UDP connection: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for all readLoops to finish
|
||||||
|
s.wg.Wait()
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
479
stun/server_test.go
Normal file
479
stun/server_test.go
Normal file
@@ -0,0 +1,479 @@
|
|||||||
|
package stun
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pion/stun/v3"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// createTestServer creates a STUN server listening on a random port for testing.
|
||||||
|
// Returns the server, the listener connection (caller must close), and the server address.
|
||||||
|
func createTestServer(t testing.TB) (*Server, *net.UDPConn, *net.UDPAddr) {
|
||||||
|
t.Helper()
|
||||||
|
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0})
|
||||||
|
require.NoError(t, err)
|
||||||
|
server := NewServer([]*net.UDPConn{conn}, "debug")
|
||||||
|
return server, conn, conn.LocalAddr().(*net.UDPAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// waitForServerReady polls the server with STUN binding requests until it responds.
|
||||||
|
// This avoids flaky tests on slow CI machines that relied on time.Sleep.
|
||||||
|
func waitForServerReady(t testing.TB, serverAddr *net.UDPAddr, timeout time.Duration) {
|
||||||
|
t.Helper()
|
||||||
|
deadline := time.Now().Add(timeout)
|
||||||
|
retryInterval := 10 * time.Millisecond
|
||||||
|
|
||||||
|
clientConn, err := net.DialUDP("udp", nil, serverAddr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer clientConn.Close()
|
||||||
|
|
||||||
|
buf := make([]byte, 1500)
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
msg, err := stun.Build(stun.TransactionID, stun.BindingRequest)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = clientConn.Write(msg.Raw)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_ = clientConn.SetReadDeadline(time.Now().Add(retryInterval))
|
||||||
|
n, err := clientConn.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
// Timeout or other error, retry
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
response := &stun.Message{Raw: buf[:n]}
|
||||||
|
if err := response.Decode(); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if response.Type == stun.BindingSuccess {
|
||||||
|
return // Server is ready
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Fatalf("server did not become ready within %v", timeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_BindingRequest(t *testing.T) {
|
||||||
|
// Start the STUN server on a random port
|
||||||
|
server, listener, serverAddr := createTestServer(t)
|
||||||
|
|
||||||
|
// Start server in background
|
||||||
|
serverErrCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
serverErrCh <- server.Listen()
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait for server to be ready
|
||||||
|
waitForServerReady(t, serverAddr, 2*time.Second)
|
||||||
|
|
||||||
|
// Create a UDP client
|
||||||
|
clientConn, err := net.DialUDP("udp", nil, serverAddr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer clientConn.Close()
|
||||||
|
|
||||||
|
// Build a STUN binding request
|
||||||
|
msg, err := stun.Build(stun.TransactionID, stun.BindingRequest)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Send the request
|
||||||
|
_, err = clientConn.Write(msg.Raw)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Read the response
|
||||||
|
buf := make([]byte, 1500)
|
||||||
|
_ = clientConn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||||
|
n, err := clientConn.Read(buf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Parse the response
|
||||||
|
response := &stun.Message{Raw: buf[:n]}
|
||||||
|
err = response.Decode()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify it's a binding success
|
||||||
|
assert.Equal(t, stun.BindingSuccess, response.Type)
|
||||||
|
|
||||||
|
// Extract the XOR-MAPPED-ADDRESS
|
||||||
|
var xorAddr stun.XORMappedAddress
|
||||||
|
err = xorAddr.GetFrom(response)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify the address matches our client's local address
|
||||||
|
clientAddr := clientConn.LocalAddr().(*net.UDPAddr)
|
||||||
|
assert.Equal(t, clientAddr.IP.String(), xorAddr.IP.String())
|
||||||
|
assert.Equal(t, clientAddr.Port, xorAddr.Port)
|
||||||
|
|
||||||
|
// Close listener first to unblock readLoop, then shutdown
|
||||||
|
_ = listener.Close()
|
||||||
|
err = server.Shutdown()
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_IgnoresNonSTUNPackets(t *testing.T) {
|
||||||
|
server, listener, serverAddr := createTestServer(t)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
_ = server.Listen()
|
||||||
|
}()
|
||||||
|
|
||||||
|
waitForServerReady(t, serverAddr, 2*time.Second)
|
||||||
|
|
||||||
|
clientConn, err := net.DialUDP("udp", nil, serverAddr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer clientConn.Close()
|
||||||
|
|
||||||
|
// Send non-STUN data
|
||||||
|
_, err = clientConn.Write([]byte("hello world"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Try to read response (should timeout since server ignores non-STUN)
|
||||||
|
buf := make([]byte, 1500)
|
||||||
|
_ = clientConn.SetReadDeadline(time.Now().Add(200 * time.Millisecond))
|
||||||
|
_, err = clientConn.Read(buf)
|
||||||
|
assert.Error(t, err) // Should be a timeout error
|
||||||
|
|
||||||
|
// Close listener first to unblock readLoop, then shutdown
|
||||||
|
_ = listener.Close()
|
||||||
|
_ = server.Shutdown()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_Shutdown(t *testing.T) {
|
||||||
|
server, listener, serverAddr := createTestServer(t)
|
||||||
|
|
||||||
|
serverDone := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
err := server.Listen()
|
||||||
|
assert.True(t, errors.Is(err, ErrServerClosed))
|
||||||
|
close(serverDone)
|
||||||
|
}()
|
||||||
|
|
||||||
|
waitForServerReady(t, serverAddr, 2*time.Second)
|
||||||
|
|
||||||
|
// Close listener first to unblock readLoop, then shutdown
|
||||||
|
_ = listener.Close()
|
||||||
|
|
||||||
|
err := server.Shutdown()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Wait for Listen to return
|
||||||
|
select {
|
||||||
|
case <-serverDone:
|
||||||
|
// Success
|
||||||
|
case <-time.After(3 * time.Second):
|
||||||
|
t.Fatal("server did not shutdown in time")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_MultipleRequests(t *testing.T) {
|
||||||
|
server, listener, serverAddr := createTestServer(t)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
_ = server.Listen()
|
||||||
|
}()
|
||||||
|
|
||||||
|
waitForServerReady(t, serverAddr, 2*time.Second)
|
||||||
|
|
||||||
|
// Create multiple clients and send requests
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
func() {
|
||||||
|
clientConn, err := net.DialUDP("udp", nil, serverAddr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer clientConn.Close()
|
||||||
|
|
||||||
|
msg, err := stun.Build(stun.TransactionID, stun.BindingRequest)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = clientConn.Write(msg.Raw)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
buf := make([]byte, 1500)
|
||||||
|
_ = clientConn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||||
|
n, err := clientConn.Read(buf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
response := &stun.Message{Raw: buf[:n]}
|
||||||
|
err = response.Decode()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, stun.BindingSuccess, response.Type)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close listener first to unblock readLoop, then shutdown
|
||||||
|
_ = listener.Close()
|
||||||
|
_ = server.Shutdown()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_ConcurrentClients(t *testing.T) {
|
||||||
|
numClients := 100
|
||||||
|
requestsPerClient := 5
|
||||||
|
maxStartDelay := 100 * time.Millisecond // Random delay before client starts
|
||||||
|
maxRequestDelay := 500 * time.Millisecond // Random delay between requests
|
||||||
|
|
||||||
|
// Remote server to test against via env var STUN_TEST_SERVER
|
||||||
|
// Example: STUN_TEST_SERVER=example.netbird.io:3478 go test -v ./stun/... -run ConcurrentClients
|
||||||
|
remoteServer := os.Getenv("STUN_TEST_SERVER")
|
||||||
|
|
||||||
|
var serverAddr *net.UDPAddr
|
||||||
|
var server *Server
|
||||||
|
var listener *net.UDPConn
|
||||||
|
|
||||||
|
if remoteServer != "" {
|
||||||
|
// Use remote server
|
||||||
|
var err error
|
||||||
|
serverAddr, err = net.ResolveUDPAddr("udp", remoteServer)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Logf("Testing against remote server: %s", remoteServer)
|
||||||
|
} else {
|
||||||
|
// Start local server
|
||||||
|
server, listener, serverAddr = createTestServer(t)
|
||||||
|
go func() {
|
||||||
|
_ = server.Listen()
|
||||||
|
}()
|
||||||
|
waitForServerReady(t, serverAddr, 2*time.Second)
|
||||||
|
t.Logf("Testing against local server: %s", serverAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
errorz := make(chan error, numClients*requestsPerClient)
|
||||||
|
successCount := make(chan int, numClients)
|
||||||
|
|
||||||
|
startTime := time.Now()
|
||||||
|
|
||||||
|
for i := 0; i < numClients; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(clientID int) {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
// Random delay before starting
|
||||||
|
time.Sleep(time.Duration(rand.Int63n(int64(maxStartDelay))))
|
||||||
|
|
||||||
|
clientConn, err := net.DialUDP("udp", nil, serverAddr)
|
||||||
|
if err != nil {
|
||||||
|
errorz <- fmt.Errorf("client %d: failed to dial: %w", clientID, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer clientConn.Close()
|
||||||
|
|
||||||
|
success := 0
|
||||||
|
for j := 0; j < requestsPerClient; j++ {
|
||||||
|
// Random delay between requests
|
||||||
|
if j > 0 {
|
||||||
|
time.Sleep(time.Duration(rand.Int63n(int64(maxRequestDelay))))
|
||||||
|
}
|
||||||
|
|
||||||
|
msg, err := stun.Build(stun.TransactionID, stun.BindingRequest)
|
||||||
|
if err != nil {
|
||||||
|
errorz <- fmt.Errorf("client %d: failed to build request: %w", clientID, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = clientConn.Write(msg.Raw)
|
||||||
|
if err != nil {
|
||||||
|
errorz <- fmt.Errorf("client %d: failed to write: %w", clientID, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, 1500)
|
||||||
|
_ = clientConn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||||
|
n, err := clientConn.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
errorz <- fmt.Errorf("client %d: failed to read: %w", clientID, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
response := &stun.Message{Raw: buf[:n]}
|
||||||
|
if err := response.Decode(); err != nil {
|
||||||
|
errorz <- fmt.Errorf("client %d: failed to decode: %w", clientID, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if response.Type != stun.BindingSuccess {
|
||||||
|
errorz <- fmt.Errorf("client %d: unexpected response type: %s", clientID, response.Type)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
success++
|
||||||
|
}
|
||||||
|
successCount <- success
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
close(errorz)
|
||||||
|
close(successCount)
|
||||||
|
|
||||||
|
elapsed := time.Since(startTime)
|
||||||
|
|
||||||
|
totalSuccess := 0
|
||||||
|
for count := range successCount {
|
||||||
|
totalSuccess += count
|
||||||
|
}
|
||||||
|
|
||||||
|
var errs []error
|
||||||
|
for err := range errorz {
|
||||||
|
errs = append(errs, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
totalRequests := numClients * requestsPerClient
|
||||||
|
t.Logf("Completed %d/%d requests in %v (%.2f req/s)",
|
||||||
|
totalSuccess, totalRequests, elapsed,
|
||||||
|
float64(totalSuccess)/elapsed.Seconds())
|
||||||
|
|
||||||
|
if len(errs) > 0 {
|
||||||
|
t.Logf("Errors (%d):", len(errs))
|
||||||
|
for i, err := range errs {
|
||||||
|
if i < 10 { // Only show first 10 errors
|
||||||
|
t.Logf(" - %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Require at least 95% success rate
|
||||||
|
successRate := float64(totalSuccess) / float64(totalRequests)
|
||||||
|
require.GreaterOrEqual(t, successRate, 0.95, "success rate too low: %.2f%%", successRate*100)
|
||||||
|
|
||||||
|
// Cleanup local server if used
|
||||||
|
if server != nil {
|
||||||
|
// Close listener first to unblock readLoop, then shutdown
|
||||||
|
_ = listener.Close()
|
||||||
|
_ = server.Shutdown()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_MultiplePorts(t *testing.T) {
|
||||||
|
// Create listeners on two random ports
|
||||||
|
conn1, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0})
|
||||||
|
require.NoError(t, err)
|
||||||
|
conn2, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
addr1 := conn1.LocalAddr().(*net.UDPAddr)
|
||||||
|
addr2 := conn2.LocalAddr().(*net.UDPAddr)
|
||||||
|
|
||||||
|
server := NewServer([]*net.UDPConn{conn1, conn2}, "debug")
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
_ = server.Listen()
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait for server to be ready (checking first port is sufficient)
|
||||||
|
waitForServerReady(t, addr1, 2*time.Second)
|
||||||
|
|
||||||
|
// Test requests on both ports
|
||||||
|
for _, serverAddr := range []*net.UDPAddr{addr1, addr2} {
|
||||||
|
func() {
|
||||||
|
clientConn, err := net.DialUDP("udp", nil, serverAddr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer clientConn.Close()
|
||||||
|
|
||||||
|
msg, err := stun.Build(stun.TransactionID, stun.BindingRequest)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = clientConn.Write(msg.Raw)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
buf := make([]byte, 1500)
|
||||||
|
_ = clientConn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||||
|
n, err := clientConn.Read(buf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
response := &stun.Message{Raw: buf[:n]}
|
||||||
|
err = response.Decode()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, stun.BindingSuccess, response.Type)
|
||||||
|
|
||||||
|
var xorAddr stun.XORMappedAddress
|
||||||
|
err = xorAddr.GetFrom(response)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
clientAddr := clientConn.LocalAddr().(*net.UDPAddr)
|
||||||
|
assert.Equal(t, clientAddr.Port, xorAddr.Port)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close listeners first to unblock readLoops, then shutdown
|
||||||
|
_ = conn1.Close()
|
||||||
|
_ = conn2.Close()
|
||||||
|
_ = server.Shutdown()
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkSTUNServer benchmarks the STUN server with concurrent clients
|
||||||
|
func BenchmarkSTUNServer(b *testing.B) {
|
||||||
|
server, listener, serverAddr := createTestServer(b)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
_ = server.Listen()
|
||||||
|
}()
|
||||||
|
|
||||||
|
waitForServerReady(b, serverAddr, 2*time.Second)
|
||||||
|
|
||||||
|
// Capture first error atomically - b.Fatal cannot be called from worker goroutines
|
||||||
|
var firstErr atomic.Pointer[error]
|
||||||
|
setErr := func(err error) {
|
||||||
|
firstErr.CompareAndSwap(nil, &err)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
// Stop work if an error has occurred
|
||||||
|
if firstErr.Load() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
clientConn, err := net.DialUDP("udp", nil, serverAddr)
|
||||||
|
if err != nil {
|
||||||
|
setErr(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer clientConn.Close()
|
||||||
|
|
||||||
|
buf := make([]byte, 1500)
|
||||||
|
|
||||||
|
for pb.Next() {
|
||||||
|
if firstErr.Load() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
msg, _ := stun.Build(stun.TransactionID, stun.BindingRequest)
|
||||||
|
_, _ = clientConn.Write(msg.Raw)
|
||||||
|
|
||||||
|
_ = clientConn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||||
|
n, err := clientConn.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
setErr(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response := &stun.Message{Raw: buf[:n]}
|
||||||
|
if err := response.Decode(); err != nil {
|
||||||
|
setErr(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.StopTimer()
|
||||||
|
|
||||||
|
// Fail after RunParallel completes
|
||||||
|
if errPtr := firstErr.Load(); errPtr != nil {
|
||||||
|
b.Fatal(*errPtr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close listener first to unblock readLoop, then shutdown
|
||||||
|
_ = listener.Close()
|
||||||
|
_ = server.Shutdown()
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user