diff --git a/relay/client/client.go b/relay/client/client.go index c8dfef617..2b2488cd1 100644 --- a/relay/client/client.go +++ b/relay/client/client.go @@ -175,8 +175,7 @@ func (c *Client) Connect() error { return nil } - err := c.connect() - if err != nil { + if err := c.connect(); err != nil { return err } @@ -266,8 +265,7 @@ func (c *Client) connect() error { } c.relayConn = conn - err = c.handShake() - if err != nil { + if err = c.handShake(); err != nil { cErr := conn.Close() if cErr != nil { c.log.Errorf("failed to close connection: %s", cErr) @@ -341,7 +339,7 @@ func (c *Client) readLoop(relayConn net.Conn) { c.log.Infof("start to Relay read loop exit") c.mu.Lock() if c.serviceIsRunning && !internallyStoppedFlag.isSet() { - c.log.Debugf("failed to read message from relay server: %s", errExit) + c.log.Errorf("failed to read message from relay server: %s", errExit) } c.mu.Unlock() break diff --git a/relay/client/dialer/quic/conn.go b/relay/client/dialer/quic/conn.go index 39e043b1e..408fa2cbc 100644 --- a/relay/client/dialer/quic/conn.go +++ b/relay/client/dialer/quic/conn.go @@ -2,11 +2,11 @@ package quic import ( "context" - "fmt" "net" "time" "github.com/quic-go/quic-go" + log "github.com/sirupsen/logrus" ) type QuicAddr struct { @@ -36,22 +36,21 @@ func NewConn(session quic.Connection, serverAddress string) net.Conn { } func (c *Conn) Read(b []byte) (n int, err error) { - // Use the QUIC stream's Read method directly dgram, err := c.session.ReceiveDatagram(c.ctx) if err != nil { - return 0, fmt.Errorf("failed to read from QUIC stream: %v", err) + log.Errorf("failed to read from QUIC session: %v", err) + return 0, err } - // Copy data to b, ensuring we don’t exceed the size of b n = copy(b, dgram) return n, nil } func (c *Conn) Write(b []byte) (int, error) { - // Use the QUIC stream's Write method directly err := c.session.SendDatagram(b) if err != nil { - return 0, fmt.Errorf("failed to write to QUIC stream: %v", err) + log.Errorf("failed to write to QUIC stream: %v", err) + return 0, err } return len(b), nil } diff --git a/relay/client/dialer/quic/quic.go b/relay/client/dialer/quic/quic.go index 201aa7ea6..40772be97 100644 --- a/relay/client/dialer/quic/quic.go +++ b/relay/client/dialer/quic/quic.go @@ -9,6 +9,7 @@ import ( "time" "github.com/quic-go/quic-go" + log "github.com/sirupsen/logrus" ) const ( @@ -35,9 +36,12 @@ func Dial(address string) (net.Conn, error) { EnableDatagrams: true, } + // todo add support for custom dialer + session, err := quic.DialAddr(ctx, quicURL, tlsConf, quicConfig) if err != nil { - return nil, fmt.Errorf("failed to dial QUIC server '%s': %v", quicURL, err) + log.Errorf("failed to dial to Relay server via QUIC '%s': %s", quicURL, err) + return nil, err } conn := NewConn(session, address) diff --git a/relay/client/dialer/ws/ws.go b/relay/client/dialer/ws/ws.go index 227d6953d..364676b88 100644 --- a/relay/client/dialer/ws/ws.go +++ b/relay/client/dialer/ws/ws.go @@ -32,8 +32,6 @@ func Dial(address string) (net.Conn, error) { } parsedURL.Path = ws.URLPath - log.Infof("------ Dialing to Relay server: %s", wsURL) - wsConn, resp, err := websocket.Dial(context.Background(), parsedURL.String(), opts) if err != nil { log.Errorf("failed to dial to Relay server '%s': %s", wsURL, err) diff --git a/relay/cmd/root.go b/relay/cmd/root.go index 7af536a61..d603ff73b 100644 --- a/relay/cmd/root.go +++ b/relay/cmd/root.go @@ -2,17 +2,10 @@ package cmd import ( "context" - "crypto/rand" - "crypto/rsa" "crypto/sha256" "crypto/tls" - "crypto/x509" - "crypto/x509/pkix" - "encoding/pem" "errors" "fmt" - "math/big" - "net" "net/http" "os" "os/signal" @@ -148,13 +141,6 @@ func execute(cmd *cobra.Command, args []string) error { hashedSecret := sha256.Sum256([]byte(cobraConfig.AuthSecret)) authenticator := auth.NewTimedHMACValidator(hashedSecret[:], 24*time.Hour) - tlsSupport = true - srvListenerCfg.TLSConfig, err = generateTestTLSConfig() - if err != nil { - log.Debugf("failed to generate test TLS config: %s", err) - return fmt.Errorf("failed to generate test TLS config: %s", err) - } - srv, err := server.NewServer(metricsServer.Meter, cobraConfig.ExposedAddress, tlsSupport, authenticator) if err != nil { log.Debugf("failed to create relay server: %v", err) @@ -227,57 +213,3 @@ func setupTLSCertManager(letsencryptDataDir string, letsencryptDomains ...string } return certManager.TLSConfig(), nil } - -// GenerateTestTLSConfig creates a self-signed certificate for testing -func generateTestTLSConfig() (*tls.Config, error) { - // Generate private key - privateKey, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - return nil, err - } - - // Create certificate template - template := x509.Certificate{ - SerialNumber: big.NewInt(1), - Subject: pkix.Name{ - Organization: []string{"Test Organization"}, - }, - NotBefore: time.Now(), - NotAfter: time.Now().Add(time.Hour * 24 * 180), // Valid for 180 days - KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, - ExtKeyUsage: []x509.ExtKeyUsage{ - x509.ExtKeyUsageServerAuth, - }, - BasicConstraintsValid: true, - DNSNames: []string{"localhost"}, - IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, - } - - // Create certificate - certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) - if err != nil { - return nil, err - } - - // Encode certificate and private key to PEM format - certPEM := pem.EncodeToMemory(&pem.Block{ - Type: "CERTIFICATE", - Bytes: certDER, - }) - - privateKeyPEM := pem.EncodeToMemory(&pem.Block{ - Type: "RSA PRIVATE KEY", - Bytes: x509.MarshalPKCS1PrivateKey(privateKey), - }) - - // Create TLS certificate - tlsCert, err := tls.X509KeyPair(certPEM, privateKeyPEM) - if err != nil { - return nil, err - } - - return &tls.Config{ - Certificates: []tls.Certificate{tlsCert}, - NextProtos: []string{"netbird-relay"}, // Your application protocol - }, nil -} diff --git a/relay/server/listener/quic/listener.go b/relay/server/listener/quic/listener.go index 3b55409f4..b6e01994f 100644 --- a/relay/server/listener/quic/listener.go +++ b/relay/server/listener/quic/listener.go @@ -3,6 +3,7 @@ package quic import ( "context" "crypto/tls" + "errors" "fmt" "net" @@ -37,32 +38,27 @@ func (l *Listener) Listen(acceptFn func(conn net.Conn)) error { for { session, err := listener.Accept(context.Background()) if err != nil { - // Check if the listener was closed intentionally - if err.Error() == "server closed" { + if errors.Is(err, quic.ErrServerClosed) { return nil } + log.Errorf("Failed to accept QUIC session: %v", err) continue } - // Handle each session in a separate goroutine - go l.handleSession(session) + log.Infof("QUIC client connected from: %s", session.RemoteAddr()) + conn := NewConn(session) + l.acceptFn(conn) } } -func (l *Listener) handleSession(session quic.Connection) { - conn := NewConn(session) - l.acceptFn(conn) -} - func (l *Listener) Shutdown(ctx context.Context) error { if l.listener == nil { return nil } log.Infof("stopping QUIC listener") - err := l.listener.Close() - if err != nil { + if err := l.listener.Close(); err != nil { return fmt.Errorf("listener shutdown failed: %v", err) } log.Infof("QUIC listener stopped") diff --git a/relay/server/listener/ws/listener.go b/relay/server/listener/ws/listener.go index 1ad57d27a..219cfc08c 100644 --- a/relay/server/listener/ws/listener.go +++ b/relay/server/listener/ws/listener.go @@ -23,6 +23,7 @@ type Listener struct { server *http.Server acceptFn func(conn net.Conn) + log *log.Entry } func (l *Listener) Listen(acceptFn func(conn net.Conn)) error { @@ -88,6 +89,8 @@ func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) { return } + log.Infof("WS client connected from: %s", rAddr) + conn := NewConn(wsConn, lAddr, rAddr) l.acceptFn(conn) } diff --git a/relay/server/server.go b/relay/server/server.go index 456dc1ea6..a09f6c16a 100644 --- a/relay/server/server.go +++ b/relay/server/server.go @@ -2,14 +2,26 @@ package server import ( "context" + "crypto/rand" + "crypto/rsa" "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "net" + "sync" + "time" + "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" "go.opentelemetry.io/otel/metric" + nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/relay/auth" "github.com/netbirdio/netbird/relay/server/listener" "github.com/netbirdio/netbird/relay/server/listener/quic" + "github.com/netbirdio/netbird/relay/server/listener/ws" ) // ListenerConfig is the configuration for the listener. @@ -24,8 +36,8 @@ type ListenerConfig struct { // It is the gate between the WebSocket listener and the Relay server logic. // In a new HTTP connection, the server will accept the connection and pass it to the Relay server via the Accept method. type Server struct { - relay *Relay - wSListener listener.Listener + relay *Relay + listeners []listener.Listener } // NewServer creates a new relay server instance. @@ -39,38 +51,120 @@ func NewServer(meter metric.Meter, exposedAddress string, tlsSupport bool, authV return nil, err } return &Server{ - relay: relay, + relay: relay, + listeners: make([]listener.Listener, 0, 2), }, nil } // Listen starts the relay server. func (r *Server) Listen(cfg ListenerConfig) error { - r.wSListener = &quic.Listener{ + wSListener := &ws.Listener{ Address: cfg.Address, TLSConfig: cfg.TLSConfig, } + r.listeners = append(r.listeners, wSListener) - wslErr := r.wSListener.Listen(r.relay.Accept) - if wslErr != nil { - log.Errorf("failed to bind ws server: %s", wslErr) + quicListener := &quic.Listener{ + Address: cfg.Address, } - return wslErr + if cfg.TLSConfig != nil { + quicListener.TLSConfig = cfg.TLSConfig + } else { + tlsConfig, err := generateTestTLSConfig() + if err != nil { + return err + } + quicListener.TLSConfig = tlsConfig + } + r.listeners = append(r.listeners, quicListener) + + errChan := make(chan error, len(r.listeners)) + wg := sync.WaitGroup{} + for _, l := range r.listeners { + wg.Add(1) + go func(listener listener.Listener) { + defer wg.Done() + errChan <- listener.Listen(r.relay.Accept) + }(l) + } + + wg.Wait() + close(errChan) + var multiErr *multierror.Error + for err := range errChan { + multiErr = multierror.Append(multiErr, err) + } + + return nberrors.FormatErrorOrNil(multiErr) } // Shutdown stops the relay server. If there are active connections, they will be closed gracefully. In case of a context, // the connections will be forcefully closed. -func (r *Server) Shutdown(ctx context.Context) (err error) { - // stop service new connections - if r.wSListener != nil { - err = r.wSListener.Shutdown(ctx) +func (r *Server) Shutdown(ctx context.Context) error { + var multiErr *multierror.Error + for _, l := range r.listeners { + if err := l.Shutdown(ctx); err != nil { + multiErr = multierror.Append(multiErr, err) + } } r.relay.Shutdown(ctx) - return + return nberrors.FormatErrorOrNil(multiErr) } // InstanceURL returns the instance URL of the relay server. func (r *Server) InstanceURL() string { return r.relay.instanceURL } + +// GenerateTestTLSConfig creates a self-signed certificate for testing +func generateTestTLSConfig() (*tls.Config, error) { + log.Infof("generating test TLS config") + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, err + } + + template := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{"Test Organization"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour * 24 * 180), // Valid for 180 days + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{ + x509.ExtKeyUsageServerAuth, + }, + BasicConstraintsValid: true, + DNSNames: []string{"localhost"}, + IPAddresses: []net.IP{net.ParseIP("192.168.0.10")}, + } + + // Create certificate + certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) + if err != nil { + return nil, err + } + + certPEM := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: certDER, + }) + + privateKeyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(privateKey), + }) + + tlsCert, err := tls.X509KeyPair(certPEM, privateKeyPEM) + if err != nil { + return nil, err + } + + return &tls.Config{ + Certificates: []tls.Certificate{tlsCert}, + NextProtos: []string{"netbird-relay"}, + }, nil +}