[relay] Replace net.Conn with context-aware Conn interface (#5770)

* [relay] Replace net.Conn with context-aware Conn interface for relay transports

Introduce a listener.Conn interface with context-based Read/Write methods,
replacing net.Conn throughout the relay server. This enables proper timeout
propagation (e.g. handshake timeout) without goroutine-based workarounds
and removes unused LocalAddr/SetDeadline methods from WS and QUIC conns.

* [relay] Refactor Peer context management to ensure proper cleanup

Integrate context creation (`context.WithCancel`) directly in `NewPeer` and remove redundant initialization in `Work`. Add `ctxCancel` calls to ensure context is properly canceled during `Close` operations.
This commit is contained in:
Zoltan Papp
2026-04-08 09:38:31 +02:00
committed by GitHub
parent d33cd4c95b
commit 96806bf55f
11 changed files with 103 additions and 143 deletions

View File

@@ -1,11 +1,13 @@
package server
import (
"context"
"fmt"
"net"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/server/listener"
"github.com/netbirdio/netbird/shared/relay/messages"
//nolint:staticcheck
"github.com/netbirdio/netbird/shared/relay/messages/address"
@@ -13,6 +15,12 @@ import (
authmsg "github.com/netbirdio/netbird/shared/relay/messages/auth"
)
const (
// handshakeTimeout bounds how long a connection may remain in the
// pre-authentication handshake phase before being closed.
handshakeTimeout = 10 * time.Second
)
type Validator interface {
Validate(any) error
// Deprecated: Use Validate instead.
@@ -58,7 +66,7 @@ func marshalResponseHelloMsg(instanceURL string) ([]byte, error) {
}
type handshake struct {
conn net.Conn
conn listener.Conn
validator Validator
preparedMsg *preparedMsg
@@ -66,9 +74,9 @@ type handshake struct {
peerID *messages.PeerID
}
func (h *handshake) handshakeReceive() (*messages.PeerID, error) {
func (h *handshake) handshakeReceive(ctx context.Context) (*messages.PeerID, error) {
buf := make([]byte, messages.MaxHandshakeSize)
n, err := h.conn.Read(buf)
n, err := h.conn.Read(ctx, buf)
if err != nil {
return nil, fmt.Errorf("read from %s: %w", h.conn.RemoteAddr(), err)
}
@@ -103,7 +111,7 @@ func (h *handshake) handshakeReceive() (*messages.PeerID, error) {
return peerID, nil
}
func (h *handshake) handshakeResponse() error {
func (h *handshake) handshakeResponse(ctx context.Context) error {
var responseMsg []byte
if h.handshakeMethodAuth {
responseMsg = h.preparedMsg.responseAuthMsg
@@ -111,7 +119,7 @@ func (h *handshake) handshakeResponse() error {
responseMsg = h.preparedMsg.responseHelloMsg
}
if _, err := h.conn.Write(responseMsg); err != nil {
if _, err := h.conn.Write(ctx, responseMsg); err != nil {
return fmt.Errorf("handshake response write to %s (%s): %w", h.peerID, h.conn.RemoteAddr(), err)
}