[relay] Replace net.Conn with context-aware Conn interface (#5770)

* [relay] Replace net.Conn with context-aware Conn interface for relay transports

Introduce a listener.Conn interface with context-based Read/Write methods,
replacing net.Conn throughout the relay server. This enables proper timeout
propagation (e.g. handshake timeout) without goroutine-based workarounds
and removes unused LocalAddr/SetDeadline methods from WS and QUIC conns.

* [relay] Refactor Peer context management to ensure proper cleanup

Integrate context creation (`context.WithCancel`) directly in `NewPeer` and remove redundant initialization in `Work`. Add `ctxCancel` calls to ensure context is properly canceled during `Close` operations.
This commit is contained in:
Zoltan Papp
2026-04-08 09:38:31 +02:00
committed by GitHub
parent d33cd4c95b
commit 96806bf55f
11 changed files with 103 additions and 143 deletions

View File

@@ -29,6 +29,7 @@ import (
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/relay/healthcheck" "github.com/netbirdio/netbird/relay/healthcheck"
relayServer "github.com/netbirdio/netbird/relay/server" relayServer "github.com/netbirdio/netbird/relay/server"
"github.com/netbirdio/netbird/relay/server/listener"
"github.com/netbirdio/netbird/relay/server/listener/ws" "github.com/netbirdio/netbird/relay/server/listener/ws"
sharedMetrics "github.com/netbirdio/netbird/shared/metrics" sharedMetrics "github.com/netbirdio/netbird/shared/metrics"
"github.com/netbirdio/netbird/shared/relay/auth" "github.com/netbirdio/netbird/shared/relay/auth"
@@ -523,7 +524,7 @@ func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*
func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, relaySrv *relayServer.Server, meter metric.Meter, cfg *CombinedConfig) http.Handler { func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, relaySrv *relayServer.Server, meter metric.Meter, cfg *CombinedConfig) http.Handler {
wsProxy := wsproxyserver.New(grpcServer, wsproxyserver.WithOTelMeter(meter)) wsProxy := wsproxyserver.New(grpcServer, wsproxyserver.WithOTelMeter(meter))
var relayAcceptFn func(conn net.Conn) var relayAcceptFn func(conn listener.Conn)
if relaySrv != nil { if relaySrv != nil {
relayAcceptFn = relaySrv.RelayAccept() relayAcceptFn = relaySrv.RelayAccept()
} }
@@ -563,7 +564,7 @@ func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, re
} }
// handleRelayWebSocket handles incoming WebSocket connections for the relay service // handleRelayWebSocket handles incoming WebSocket connections for the relay service
func handleRelayWebSocket(w http.ResponseWriter, r *http.Request, acceptFn func(conn net.Conn), cfg *CombinedConfig) { func handleRelayWebSocket(w http.ResponseWriter, r *http.Request, acceptFn func(conn listener.Conn), cfg *CombinedConfig) {
acceptOptions := &websocket.AcceptOptions{ acceptOptions := &websocket.AcceptOptions{
OriginPatterns: []string{"*"}, OriginPatterns: []string{"*"},
} }
@@ -585,15 +586,9 @@ func handleRelayWebSocket(w http.ResponseWriter, r *http.Request, acceptFn func(
return return
} }
lAddr, err := net.ResolveTCPAddr("tcp", cfg.Server.ListenAddress)
if err != nil {
_ = wsConn.Close(websocket.StatusInternalError, "internal error")
return
}
log.Debugf("Relay WS client connected from: %s", rAddr) log.Debugf("Relay WS client connected from: %s", rAddr)
conn := ws.NewConn(wsConn, lAddr, rAddr) conn := ws.NewConn(wsConn, rAddr)
acceptFn(conn) acceptFn(conn)
} }

View File

@@ -1,11 +1,13 @@
package server package server
import ( import (
"context"
"fmt" "fmt"
"net" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/server/listener"
"github.com/netbirdio/netbird/shared/relay/messages" "github.com/netbirdio/netbird/shared/relay/messages"
//nolint:staticcheck //nolint:staticcheck
"github.com/netbirdio/netbird/shared/relay/messages/address" "github.com/netbirdio/netbird/shared/relay/messages/address"
@@ -13,6 +15,12 @@ import (
authmsg "github.com/netbirdio/netbird/shared/relay/messages/auth" authmsg "github.com/netbirdio/netbird/shared/relay/messages/auth"
) )
const (
// handshakeTimeout bounds how long a connection may remain in the
// pre-authentication handshake phase before being closed.
handshakeTimeout = 10 * time.Second
)
type Validator interface { type Validator interface {
Validate(any) error Validate(any) error
// Deprecated: Use Validate instead. // Deprecated: Use Validate instead.
@@ -58,7 +66,7 @@ func marshalResponseHelloMsg(instanceURL string) ([]byte, error) {
} }
type handshake struct { type handshake struct {
conn net.Conn conn listener.Conn
validator Validator validator Validator
preparedMsg *preparedMsg preparedMsg *preparedMsg
@@ -66,9 +74,9 @@ type handshake struct {
peerID *messages.PeerID peerID *messages.PeerID
} }
func (h *handshake) handshakeReceive() (*messages.PeerID, error) { func (h *handshake) handshakeReceive(ctx context.Context) (*messages.PeerID, error) {
buf := make([]byte, messages.MaxHandshakeSize) buf := make([]byte, messages.MaxHandshakeSize)
n, err := h.conn.Read(buf) n, err := h.conn.Read(ctx, buf)
if err != nil { if err != nil {
return nil, fmt.Errorf("read from %s: %w", h.conn.RemoteAddr(), err) return nil, fmt.Errorf("read from %s: %w", h.conn.RemoteAddr(), err)
} }
@@ -103,7 +111,7 @@ func (h *handshake) handshakeReceive() (*messages.PeerID, error) {
return peerID, nil return peerID, nil
} }
func (h *handshake) handshakeResponse() error { func (h *handshake) handshakeResponse(ctx context.Context) error {
var responseMsg []byte var responseMsg []byte
if h.handshakeMethodAuth { if h.handshakeMethodAuth {
responseMsg = h.preparedMsg.responseAuthMsg responseMsg = h.preparedMsg.responseAuthMsg
@@ -111,7 +119,7 @@ func (h *handshake) handshakeResponse() error {
responseMsg = h.preparedMsg.responseHelloMsg responseMsg = h.preparedMsg.responseHelloMsg
} }
if _, err := h.conn.Write(responseMsg); err != nil { if _, err := h.conn.Write(ctx, responseMsg); err != nil {
return fmt.Errorf("handshake response write to %s (%s): %w", h.peerID, h.conn.RemoteAddr(), err) return fmt.Errorf("handshake response write to %s (%s): %w", h.peerID, h.conn.RemoteAddr(), err)
} }

View File

@@ -0,0 +1,14 @@
package listener
import (
"context"
"net"
)
// Conn is the relay connection contract implemented by WS and QUIC transports.
type Conn interface {
Read(ctx context.Context, b []byte) (n int, err error)
Write(ctx context.Context, b []byte) (n int, err error)
RemoteAddr() net.Addr
Close() error
}

View File

@@ -1,14 +0,0 @@
package listener
import (
"context"
"net"
"github.com/netbirdio/netbird/relay/protocol"
)
type Listener interface {
Listen(func(conn net.Conn)) error
Shutdown(ctx context.Context) error
Protocol() protocol.Protocol
}

View File

@@ -3,33 +3,26 @@ package quic
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"net" "net"
"sync" "sync"
"time"
"github.com/quic-go/quic-go" "github.com/quic-go/quic-go"
) )
type Conn struct { type Conn struct {
session *quic.Conn session *quic.Conn
closed bool closed bool
closedMu sync.Mutex closedMu sync.Mutex
ctx context.Context
ctxCancel context.CancelFunc
} }
func NewConn(session *quic.Conn) *Conn { func NewConn(session *quic.Conn) *Conn {
ctx, cancel := context.WithCancel(context.Background())
return &Conn{ return &Conn{
session: session, session: session,
ctx: ctx,
ctxCancel: cancel,
} }
} }
func (c *Conn) Read(b []byte) (n int, err error) { func (c *Conn) Read(ctx context.Context, b []byte) (n int, err error) {
dgram, err := c.session.ReceiveDatagram(c.ctx) dgram, err := c.session.ReceiveDatagram(ctx)
if err != nil { if err != nil {
return 0, c.remoteCloseErrHandling(err) return 0, c.remoteCloseErrHandling(err)
} }
@@ -38,33 +31,17 @@ func (c *Conn) Read(b []byte) (n int, err error) {
return n, nil return n, nil
} }
func (c *Conn) Write(b []byte) (int, error) { func (c *Conn) Write(_ context.Context, b []byte) (int, error) {
if err := c.session.SendDatagram(b); err != nil { if err := c.session.SendDatagram(b); err != nil {
return 0, c.remoteCloseErrHandling(err) return 0, c.remoteCloseErrHandling(err)
} }
return len(b), nil return len(b), nil
} }
func (c *Conn) LocalAddr() net.Addr {
return c.session.LocalAddr()
}
func (c *Conn) RemoteAddr() net.Addr { func (c *Conn) RemoteAddr() net.Addr {
return c.session.RemoteAddr() return c.session.RemoteAddr()
} }
func (c *Conn) SetReadDeadline(t time.Time) error {
return nil
}
func (c *Conn) SetWriteDeadline(t time.Time) error {
return fmt.Errorf("SetWriteDeadline is not implemented")
}
func (c *Conn) SetDeadline(t time.Time) error {
return fmt.Errorf("SetDeadline is not implemented")
}
func (c *Conn) Close() error { func (c *Conn) Close() error {
c.closedMu.Lock() c.closedMu.Lock()
if c.closed { if c.closed {
@@ -74,8 +51,6 @@ func (c *Conn) Close() error {
c.closed = true c.closed = true
c.closedMu.Unlock() c.closedMu.Unlock()
c.ctxCancel() // Cancel the context
sessionErr := c.session.CloseWithError(0, "normal closure") sessionErr := c.session.CloseWithError(0, "normal closure")
return sessionErr return sessionErr
} }

View File

@@ -5,12 +5,12 @@ import (
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
"net"
"github.com/quic-go/quic-go" "github.com/quic-go/quic-go"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/protocol" "github.com/netbirdio/netbird/relay/protocol"
relaylistener "github.com/netbirdio/netbird/relay/server/listener"
nbRelay "github.com/netbirdio/netbird/shared/relay" nbRelay "github.com/netbirdio/netbird/shared/relay"
) )
@@ -25,7 +25,7 @@ type Listener struct {
listener *quic.Listener listener *quic.Listener
} }
func (l *Listener) Listen(acceptFn func(conn net.Conn)) error { func (l *Listener) Listen(acceptFn func(conn relaylistener.Conn)) error {
quicCfg := &quic.Config{ quicCfg := &quic.Config{
EnableDatagrams: true, EnableDatagrams: true,
InitialPacketSize: nbRelay.QUICInitialPacketSize, InitialPacketSize: nbRelay.QUICInitialPacketSize,

View File

@@ -18,25 +18,21 @@ const (
type Conn struct { type Conn struct {
*websocket.Conn *websocket.Conn
lAddr *net.TCPAddr
rAddr *net.TCPAddr rAddr *net.TCPAddr
closed bool closed bool
closedMu sync.Mutex closedMu sync.Mutex
ctx context.Context
} }
func NewConn(wsConn *websocket.Conn, lAddr, rAddr *net.TCPAddr) *Conn { func NewConn(wsConn *websocket.Conn, rAddr *net.TCPAddr) *Conn {
return &Conn{ return &Conn{
Conn: wsConn, Conn: wsConn,
lAddr: lAddr,
rAddr: rAddr, rAddr: rAddr,
ctx: context.Background(),
} }
} }
func (c *Conn) Read(b []byte) (n int, err error) { func (c *Conn) Read(ctx context.Context, b []byte) (n int, err error) {
t, r, err := c.Reader(c.ctx) t, r, err := c.Reader(ctx)
if err != nil { if err != nil {
return 0, c.ioErrHandling(err) return 0, c.ioErrHandling(err)
} }
@@ -56,34 +52,18 @@ func (c *Conn) Read(b []byte) (n int, err error) {
// Write writes a binary message with the given payload. // Write writes a binary message with the given payload.
// It does not block until fill the internal buffer. // It does not block until fill the internal buffer.
// If the buffer filled up, wait until the buffer is drained or timeout. // If the buffer filled up, wait until the buffer is drained or timeout.
func (c *Conn) Write(b []byte) (int, error) { func (c *Conn) Write(ctx context.Context, b []byte) (int, error) {
ctx, ctxCancel := context.WithTimeout(c.ctx, writeTimeout) ctx, ctxCancel := context.WithTimeout(ctx, writeTimeout)
defer ctxCancel() defer ctxCancel()
err := c.Conn.Write(ctx, websocket.MessageBinary, b) err := c.Conn.Write(ctx, websocket.MessageBinary, b)
return len(b), err return len(b), err
} }
func (c *Conn) LocalAddr() net.Addr {
return c.lAddr
}
func (c *Conn) RemoteAddr() net.Addr { func (c *Conn) RemoteAddr() net.Addr {
return c.rAddr return c.rAddr
} }
func (c *Conn) SetReadDeadline(t time.Time) error {
return fmt.Errorf("SetReadDeadline is not implemented")
}
func (c *Conn) SetWriteDeadline(t time.Time) error {
return fmt.Errorf("SetWriteDeadline is not implemented")
}
func (c *Conn) SetDeadline(t time.Time) error {
return fmt.Errorf("SetDeadline is not implemented")
}
func (c *Conn) Close() error { func (c *Conn) Close() error {
c.closedMu.Lock() c.closedMu.Lock()
c.closed = true c.closed = true

View File

@@ -7,11 +7,13 @@ import (
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
"time"
"github.com/coder/websocket" "github.com/coder/websocket"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/protocol" "github.com/netbirdio/netbird/relay/protocol"
relaylistener "github.com/netbirdio/netbird/relay/server/listener"
"github.com/netbirdio/netbird/shared/relay" "github.com/netbirdio/netbird/shared/relay"
) )
@@ -27,18 +29,19 @@ type Listener struct {
TLSConfig *tls.Config TLSConfig *tls.Config
server *http.Server server *http.Server
acceptFn func(conn net.Conn) acceptFn func(conn relaylistener.Conn)
} }
func (l *Listener) Listen(acceptFn func(conn net.Conn)) error { func (l *Listener) Listen(acceptFn func(conn relaylistener.Conn)) error {
l.acceptFn = acceptFn l.acceptFn = acceptFn
mux := http.NewServeMux() mux := http.NewServeMux()
mux.HandleFunc(URLPath, l.onAccept) mux.HandleFunc(URLPath, l.onAccept)
l.server = &http.Server{ l.server = &http.Server{
Addr: l.Address, Addr: l.Address,
Handler: mux, Handler: mux,
TLSConfig: l.TLSConfig, TLSConfig: l.TLSConfig,
ReadHeaderTimeout: 5 * time.Second,
} }
log.Infof("WS server listening address: %s", l.Address) log.Infof("WS server listening address: %s", l.Address)
@@ -93,18 +96,9 @@ func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) {
return return
} }
lAddr, err := net.ResolveTCPAddr("tcp", l.server.Addr)
if err != nil {
err = wsConn.Close(websocket.StatusInternalError, "internal error")
if err != nil {
log.Errorf("failed to close ws connection: %s", err)
}
return
}
log.Infof("WS client connected from: %s", rAddr) log.Infof("WS client connected from: %s", rAddr)
conn := NewConn(wsConn, lAddr, rAddr) conn := NewConn(wsConn, rAddr)
l.acceptFn(conn) l.acceptFn(conn)
} }

View File

@@ -10,6 +10,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/metrics" "github.com/netbirdio/netbird/relay/metrics"
"github.com/netbirdio/netbird/relay/server/listener"
"github.com/netbirdio/netbird/relay/server/store" "github.com/netbirdio/netbird/relay/server/store"
"github.com/netbirdio/netbird/shared/relay/healthcheck" "github.com/netbirdio/netbird/shared/relay/healthcheck"
"github.com/netbirdio/netbird/shared/relay/messages" "github.com/netbirdio/netbird/shared/relay/messages"
@@ -26,11 +27,14 @@ type Peer struct {
metrics *metrics.Metrics metrics *metrics.Metrics
log *log.Entry log *log.Entry
id messages.PeerID id messages.PeerID
conn net.Conn conn listener.Conn
connMu sync.RWMutex connMu sync.RWMutex
store *store.Store store *store.Store
notifier *store.PeerNotifier notifier *store.PeerNotifier
ctx context.Context
ctxCancel context.CancelFunc
peersListener *store.Listener peersListener *store.Listener
// between the online peer collection step and the notification sending should not be sent offline notifications from another thread // between the online peer collection step and the notification sending should not be sent offline notifications from another thread
@@ -38,14 +42,17 @@ type Peer struct {
} }
// NewPeer creates a new Peer instance and prepare custom logging // NewPeer creates a new Peer instance and prepare custom logging
func NewPeer(metrics *metrics.Metrics, id messages.PeerID, conn net.Conn, store *store.Store, notifier *store.PeerNotifier) *Peer { func NewPeer(metrics *metrics.Metrics, id messages.PeerID, conn listener.Conn, store *store.Store, notifier *store.PeerNotifier) *Peer {
ctx, cancel := context.WithCancel(context.Background())
p := &Peer{ p := &Peer{
metrics: metrics, metrics: metrics,
log: log.WithField("peer_id", id.String()), log: log.WithField("peer_id", id.String()),
id: id, id: id,
conn: conn, conn: conn,
store: store, store: store,
notifier: notifier, notifier: notifier,
ctx: ctx,
ctxCancel: cancel,
} }
return p return p
@@ -57,6 +64,7 @@ func NewPeer(metrics *metrics.Metrics, id messages.PeerID, conn net.Conn, store
func (p *Peer) Work() { func (p *Peer) Work() {
p.peersListener = p.notifier.NewListener(p.sendPeersOnline, p.sendPeersWentOffline) p.peersListener = p.notifier.NewListener(p.sendPeersOnline, p.sendPeersWentOffline)
defer func() { defer func() {
p.ctxCancel()
p.notifier.RemoveListener(p.peersListener) p.notifier.RemoveListener(p.peersListener)
if err := p.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { if err := p.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
@@ -64,8 +72,7 @@ func (p *Peer) Work() {
} }
}() }()
ctx, cancel := context.WithCancel(context.Background()) ctx := p.ctx
defer cancel()
hc := healthcheck.NewSender(p.log) hc := healthcheck.NewSender(p.log)
go hc.StartHealthCheck(ctx) go hc.StartHealthCheck(ctx)
@@ -73,7 +80,7 @@ func (p *Peer) Work() {
buf := make([]byte, bufferSize) buf := make([]byte, bufferSize)
for { for {
n, err := p.conn.Read(buf) n, err := p.conn.Read(ctx, buf)
if err != nil { if err != nil {
if !errors.Is(err, net.ErrClosed) { if !errors.Is(err, net.ErrClosed) {
p.log.Errorf("failed to read message: %s", err) p.log.Errorf("failed to read message: %s", err)
@@ -131,10 +138,10 @@ func (p *Peer) handleMsgType(ctx context.Context, msgType messages.MsgType, hc *
} }
// Write writes data to the connection // Write writes data to the connection
func (p *Peer) Write(b []byte) (int, error) { func (p *Peer) Write(ctx context.Context, b []byte) (int, error) {
p.connMu.RLock() p.connMu.RLock()
defer p.connMu.RUnlock() defer p.connMu.RUnlock()
return p.conn.Write(b) return p.conn.Write(ctx, b)
} }
// CloseGracefully closes the connection with the peer gracefully. Send a close message to the client and close the // CloseGracefully closes the connection with the peer gracefully. Send a close message to the client and close the
@@ -147,6 +154,7 @@ func (p *Peer) CloseGracefully(ctx context.Context) {
p.log.Errorf("failed to send close message to peer: %s", p.String()) p.log.Errorf("failed to send close message to peer: %s", p.String())
} }
p.ctxCancel()
if err := p.conn.Close(); err != nil { if err := p.conn.Close(); err != nil {
p.log.Errorf(errCloseConn, err) p.log.Errorf(errCloseConn, err)
} }
@@ -156,6 +164,7 @@ func (p *Peer) Close() {
p.connMu.Lock() p.connMu.Lock()
defer p.connMu.Unlock() defer p.connMu.Unlock()
p.ctxCancel()
if err := p.conn.Close(); err != nil { if err := p.conn.Close(); err != nil {
p.log.Errorf(errCloseConn, err) p.log.Errorf(errCloseConn, err)
} }
@@ -170,26 +179,15 @@ func (p *Peer) writeWithTimeout(ctx context.Context, buf []byte) error {
ctx, cancel := context.WithTimeout(ctx, 3*time.Second) ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel() defer cancel()
writeDone := make(chan struct{}) _, err := p.conn.Write(ctx, buf)
var err error return err
go func() {
_, err = p.conn.Write(buf)
close(writeDone)
}()
select {
case <-ctx.Done():
return ctx.Err()
case <-writeDone:
return err
}
} }
func (p *Peer) handleHealthcheckEvents(ctx context.Context, hc *healthcheck.Sender) { func (p *Peer) handleHealthcheckEvents(ctx context.Context, hc *healthcheck.Sender) {
for { for {
select { select {
case <-hc.HealthCheck: case <-hc.HealthCheck:
_, err := p.Write(messages.MarshalHealthcheck()) _, err := p.Write(ctx, messages.MarshalHealthcheck())
if err != nil { if err != nil {
p.log.Errorf("failed to send healthcheck message: %s", err) p.log.Errorf("failed to send healthcheck message: %s", err)
return return
@@ -228,12 +226,12 @@ func (p *Peer) handleTransportMsg(msg []byte) {
return return
} }
n, err := dp.Write(msg) n, err := dp.Write(dp.ctx, msg)
if err != nil { if err != nil {
p.log.Errorf("failed to write transport message to: %s", dp.String()) p.log.Errorf("failed to write transport message to: %s", dp.String())
return return
} }
p.metrics.TransferBytesSent.Add(context.Background(), int64(n)) p.metrics.TransferBytesSent.Add(p.ctx, int64(n))
} }
func (p *Peer) handleSubscribePeerState(msg []byte) { func (p *Peer) handleSubscribePeerState(msg []byte) {
@@ -276,7 +274,7 @@ func (p *Peer) sendPeersOnline(peers []messages.PeerID) {
} }
for n, msg := range msgs { for n, msg := range msgs {
if _, err := p.Write(msg); err != nil { if _, err := p.Write(p.ctx, msg); err != nil {
p.log.Errorf("failed to write %d. peers offline message: %s", n, err) p.log.Errorf("failed to write %d. peers offline message: %s", n, err)
} }
} }
@@ -293,7 +291,7 @@ func (p *Peer) sendPeersWentOffline(peers []messages.PeerID) {
} }
for n, msg := range msgs { for n, msg := range msgs {
if _, err := p.Write(msg); err != nil { if _, err := p.Write(p.ctx, msg); err != nil {
p.log.Errorf("failed to write %d. peers offline message: %s", n, err) p.log.Errorf("failed to write %d. peers offline message: %s", n, err)
} }
} }

View File

@@ -3,7 +3,6 @@ package server
import ( import (
"context" "context"
"fmt" "fmt"
"net"
"net/url" "net/url"
"sync" "sync"
"time" "time"
@@ -13,11 +12,20 @@ import (
"go.opentelemetry.io/otel/metric" "go.opentelemetry.io/otel/metric"
"github.com/netbirdio/netbird/relay/healthcheck/peerid" "github.com/netbirdio/netbird/relay/healthcheck/peerid"
"github.com/netbirdio/netbird/relay/protocol"
"github.com/netbirdio/netbird/relay/server/listener"
//nolint:staticcheck //nolint:staticcheck
"github.com/netbirdio/netbird/relay/metrics" "github.com/netbirdio/netbird/relay/metrics"
"github.com/netbirdio/netbird/relay/server/store" "github.com/netbirdio/netbird/relay/server/store"
) )
type Listener interface {
Listen(func(conn listener.Conn)) error
Shutdown(ctx context.Context) error
Protocol() protocol.Protocol
}
type Config struct { type Config struct {
Meter metric.Meter Meter metric.Meter
ExposedAddress string ExposedAddress string
@@ -109,7 +117,7 @@ func NewRelay(config Config) (*Relay, error) {
} }
// Accept start to handle a new peer connection // Accept start to handle a new peer connection
func (r *Relay) Accept(conn net.Conn) { func (r *Relay) Accept(conn listener.Conn) {
acceptTime := time.Now() acceptTime := time.Now()
r.closeMu.RLock() r.closeMu.RLock()
defer r.closeMu.RUnlock() defer r.closeMu.RUnlock()
@@ -117,12 +125,15 @@ func (r *Relay) Accept(conn net.Conn) {
return return
} }
hsCtx, hsCancel := context.WithTimeout(context.Background(), handshakeTimeout)
defer hsCancel()
h := handshake{ h := handshake{
conn: conn, conn: conn,
validator: r.validator, validator: r.validator,
preparedMsg: r.preparedMsg, preparedMsg: r.preparedMsg,
} }
peerID, err := h.handshakeReceive() peerID, err := h.handshakeReceive(hsCtx)
if err != nil { if err != nil {
if peerid.IsHealthCheck(peerID) { if peerid.IsHealthCheck(peerID) {
log.Debugf("health check connection from %s", conn.RemoteAddr()) log.Debugf("health check connection from %s", conn.RemoteAddr())
@@ -154,7 +165,7 @@ func (r *Relay) Accept(conn net.Conn) {
r.metrics.PeerDisconnected(peer.String()) r.metrics.PeerDisconnected(peer.String())
}() }()
if err := h.handshakeResponse(); err != nil { if err := h.handshakeResponse(hsCtx); err != nil {
log.Errorf("failed to send handshake response, close peer: %s", err) log.Errorf("failed to send handshake response, close peer: %s", err)
peer.Close() peer.Close()
} }

View File

@@ -3,7 +3,6 @@ package server
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"net"
"net/url" "net/url"
"sync" "sync"
@@ -31,7 +30,7 @@ type ListenerConfig struct {
// In a new HTTP connection, the server will accept the connection and pass it to the Relay server via the Accept method. // In a new HTTP connection, the server will accept the connection and pass it to the Relay server via the Accept method.
type Server struct { type Server struct {
relay *Relay relay *Relay
listeners []listener.Listener listeners []Listener
listenerMux sync.Mutex listenerMux sync.Mutex
} }
@@ -56,7 +55,7 @@ func NewServer(config Config) (*Server, error) {
} }
return &Server{ return &Server{
relay: relay, relay: relay,
listeners: make([]listener.Listener, 0, 2), listeners: make([]Listener, 0, 2),
}, nil }, nil
} }
@@ -86,7 +85,7 @@ func (r *Server) Listen(cfg ListenerConfig) error {
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
for _, l := range r.listeners { for _, l := range r.listeners {
wg.Add(1) wg.Add(1)
go func(listener listener.Listener) { go func(listener Listener) {
defer wg.Done() defer wg.Done()
errChan <- listener.Listen(r.relay.Accept) errChan <- listener.Listen(r.relay.Accept)
}(l) }(l)
@@ -139,6 +138,6 @@ func (r *Server) InstanceURL() url.URL {
// RelayAccept returns the relay's Accept function for handling incoming connections. // RelayAccept returns the relay's Accept function for handling incoming connections.
// This allows external HTTP handlers to route connections to the relay without // This allows external HTTP handlers to route connections to the relay without
// starting the relay's own listeners. // starting the relay's own listeners.
func (r *Server) RelayAccept() func(conn net.Conn) { func (r *Server) RelayAccept() func(conn listener.Conn) {
return r.relay.Accept return r.relay.Accept
} }