mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
[relay] Replace net.Conn with context-aware Conn interface (#5770)
* [relay] Replace net.Conn with context-aware Conn interface for relay transports Introduce a listener.Conn interface with context-based Read/Write methods, replacing net.Conn throughout the relay server. This enables proper timeout propagation (e.g. handshake timeout) without goroutine-based workarounds and removes unused LocalAddr/SetDeadline methods from WS and QUIC conns. * [relay] Refactor Peer context management to ensure proper cleanup Integrate context creation (`context.WithCancel`) directly in `NewPeer` and remove redundant initialization in `Work`. Add `ctxCancel` calls to ensure context is properly canceled during `Close` operations.
This commit is contained in:
@@ -29,6 +29,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
"github.com/netbirdio/netbird/relay/healthcheck"
|
"github.com/netbirdio/netbird/relay/healthcheck"
|
||||||
relayServer "github.com/netbirdio/netbird/relay/server"
|
relayServer "github.com/netbirdio/netbird/relay/server"
|
||||||
|
"github.com/netbirdio/netbird/relay/server/listener"
|
||||||
"github.com/netbirdio/netbird/relay/server/listener/ws"
|
"github.com/netbirdio/netbird/relay/server/listener/ws"
|
||||||
sharedMetrics "github.com/netbirdio/netbird/shared/metrics"
|
sharedMetrics "github.com/netbirdio/netbird/shared/metrics"
|
||||||
"github.com/netbirdio/netbird/shared/relay/auth"
|
"github.com/netbirdio/netbird/shared/relay/auth"
|
||||||
@@ -523,7 +524,7 @@ func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*
|
|||||||
func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, relaySrv *relayServer.Server, meter metric.Meter, cfg *CombinedConfig) http.Handler {
|
func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, relaySrv *relayServer.Server, meter metric.Meter, cfg *CombinedConfig) http.Handler {
|
||||||
wsProxy := wsproxyserver.New(grpcServer, wsproxyserver.WithOTelMeter(meter))
|
wsProxy := wsproxyserver.New(grpcServer, wsproxyserver.WithOTelMeter(meter))
|
||||||
|
|
||||||
var relayAcceptFn func(conn net.Conn)
|
var relayAcceptFn func(conn listener.Conn)
|
||||||
if relaySrv != nil {
|
if relaySrv != nil {
|
||||||
relayAcceptFn = relaySrv.RelayAccept()
|
relayAcceptFn = relaySrv.RelayAccept()
|
||||||
}
|
}
|
||||||
@@ -563,7 +564,7 @@ func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, re
|
|||||||
}
|
}
|
||||||
|
|
||||||
// handleRelayWebSocket handles incoming WebSocket connections for the relay service
|
// handleRelayWebSocket handles incoming WebSocket connections for the relay service
|
||||||
func handleRelayWebSocket(w http.ResponseWriter, r *http.Request, acceptFn func(conn net.Conn), cfg *CombinedConfig) {
|
func handleRelayWebSocket(w http.ResponseWriter, r *http.Request, acceptFn func(conn listener.Conn), cfg *CombinedConfig) {
|
||||||
acceptOptions := &websocket.AcceptOptions{
|
acceptOptions := &websocket.AcceptOptions{
|
||||||
OriginPatterns: []string{"*"},
|
OriginPatterns: []string{"*"},
|
||||||
}
|
}
|
||||||
@@ -585,15 +586,9 @@ func handleRelayWebSocket(w http.ResponseWriter, r *http.Request, acceptFn func(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
lAddr, err := net.ResolveTCPAddr("tcp", cfg.Server.ListenAddress)
|
|
||||||
if err != nil {
|
|
||||||
_ = wsConn.Close(websocket.StatusInternalError, "internal error")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("Relay WS client connected from: %s", rAddr)
|
log.Debugf("Relay WS client connected from: %s", rAddr)
|
||||||
|
|
||||||
conn := ws.NewConn(wsConn, lAddr, rAddr)
|
conn := ws.NewConn(wsConn, rAddr)
|
||||||
acceptFn(conn)
|
acceptFn(conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,13 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/relay/server/listener"
|
||||||
"github.com/netbirdio/netbird/shared/relay/messages"
|
"github.com/netbirdio/netbird/shared/relay/messages"
|
||||||
//nolint:staticcheck
|
//nolint:staticcheck
|
||||||
"github.com/netbirdio/netbird/shared/relay/messages/address"
|
"github.com/netbirdio/netbird/shared/relay/messages/address"
|
||||||
@@ -13,6 +15,12 @@ import (
|
|||||||
authmsg "github.com/netbirdio/netbird/shared/relay/messages/auth"
|
authmsg "github.com/netbirdio/netbird/shared/relay/messages/auth"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// handshakeTimeout bounds how long a connection may remain in the
|
||||||
|
// pre-authentication handshake phase before being closed.
|
||||||
|
handshakeTimeout = 10 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
type Validator interface {
|
type Validator interface {
|
||||||
Validate(any) error
|
Validate(any) error
|
||||||
// Deprecated: Use Validate instead.
|
// Deprecated: Use Validate instead.
|
||||||
@@ -58,7 +66,7 @@ func marshalResponseHelloMsg(instanceURL string) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type handshake struct {
|
type handshake struct {
|
||||||
conn net.Conn
|
conn listener.Conn
|
||||||
validator Validator
|
validator Validator
|
||||||
preparedMsg *preparedMsg
|
preparedMsg *preparedMsg
|
||||||
|
|
||||||
@@ -66,9 +74,9 @@ type handshake struct {
|
|||||||
peerID *messages.PeerID
|
peerID *messages.PeerID
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *handshake) handshakeReceive() (*messages.PeerID, error) {
|
func (h *handshake) handshakeReceive(ctx context.Context) (*messages.PeerID, error) {
|
||||||
buf := make([]byte, messages.MaxHandshakeSize)
|
buf := make([]byte, messages.MaxHandshakeSize)
|
||||||
n, err := h.conn.Read(buf)
|
n, err := h.conn.Read(ctx, buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("read from %s: %w", h.conn.RemoteAddr(), err)
|
return nil, fmt.Errorf("read from %s: %w", h.conn.RemoteAddr(), err)
|
||||||
}
|
}
|
||||||
@@ -103,7 +111,7 @@ func (h *handshake) handshakeReceive() (*messages.PeerID, error) {
|
|||||||
return peerID, nil
|
return peerID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *handshake) handshakeResponse() error {
|
func (h *handshake) handshakeResponse(ctx context.Context) error {
|
||||||
var responseMsg []byte
|
var responseMsg []byte
|
||||||
if h.handshakeMethodAuth {
|
if h.handshakeMethodAuth {
|
||||||
responseMsg = h.preparedMsg.responseAuthMsg
|
responseMsg = h.preparedMsg.responseAuthMsg
|
||||||
@@ -111,7 +119,7 @@ func (h *handshake) handshakeResponse() error {
|
|||||||
responseMsg = h.preparedMsg.responseHelloMsg
|
responseMsg = h.preparedMsg.responseHelloMsg
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := h.conn.Write(responseMsg); err != nil {
|
if _, err := h.conn.Write(ctx, responseMsg); err != nil {
|
||||||
return fmt.Errorf("handshake response write to %s (%s): %w", h.peerID, h.conn.RemoteAddr(), err)
|
return fmt.Errorf("handshake response write to %s (%s): %w", h.peerID, h.conn.RemoteAddr(), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
14
relay/server/listener/conn.go
Normal file
14
relay/server/listener/conn.go
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
package listener
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Conn is the relay connection contract implemented by WS and QUIC transports.
|
||||||
|
type Conn interface {
|
||||||
|
Read(ctx context.Context, b []byte) (n int, err error)
|
||||||
|
Write(ctx context.Context, b []byte) (n int, err error)
|
||||||
|
RemoteAddr() net.Addr
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
@@ -1,14 +0,0 @@
|
|||||||
package listener
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/relay/protocol"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Listener interface {
|
|
||||||
Listen(func(conn net.Conn)) error
|
|
||||||
Shutdown(ctx context.Context) error
|
|
||||||
Protocol() protocol.Protocol
|
|
||||||
}
|
|
||||||
@@ -3,33 +3,26 @@ package quic
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/quic-go/quic-go"
|
"github.com/quic-go/quic-go"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Conn struct {
|
type Conn struct {
|
||||||
session *quic.Conn
|
session *quic.Conn
|
||||||
closed bool
|
closed bool
|
||||||
closedMu sync.Mutex
|
closedMu sync.Mutex
|
||||||
ctx context.Context
|
|
||||||
ctxCancel context.CancelFunc
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConn(session *quic.Conn) *Conn {
|
func NewConn(session *quic.Conn) *Conn {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
return &Conn{
|
return &Conn{
|
||||||
session: session,
|
session: session,
|
||||||
ctx: ctx,
|
|
||||||
ctxCancel: cancel,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) Read(b []byte) (n int, err error) {
|
func (c *Conn) Read(ctx context.Context, b []byte) (n int, err error) {
|
||||||
dgram, err := c.session.ReceiveDatagram(c.ctx)
|
dgram, err := c.session.ReceiveDatagram(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, c.remoteCloseErrHandling(err)
|
return 0, c.remoteCloseErrHandling(err)
|
||||||
}
|
}
|
||||||
@@ -38,33 +31,17 @@ func (c *Conn) Read(b []byte) (n int, err error) {
|
|||||||
return n, nil
|
return n, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) Write(b []byte) (int, error) {
|
func (c *Conn) Write(_ context.Context, b []byte) (int, error) {
|
||||||
if err := c.session.SendDatagram(b); err != nil {
|
if err := c.session.SendDatagram(b); err != nil {
|
||||||
return 0, c.remoteCloseErrHandling(err)
|
return 0, c.remoteCloseErrHandling(err)
|
||||||
}
|
}
|
||||||
return len(b), nil
|
return len(b), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) LocalAddr() net.Addr {
|
|
||||||
return c.session.LocalAddr()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Conn) RemoteAddr() net.Addr {
|
func (c *Conn) RemoteAddr() net.Addr {
|
||||||
return c.session.RemoteAddr()
|
return c.session.RemoteAddr()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) SetReadDeadline(t time.Time) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
|
||||||
return fmt.Errorf("SetWriteDeadline is not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Conn) SetDeadline(t time.Time) error {
|
|
||||||
return fmt.Errorf("SetDeadline is not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Conn) Close() error {
|
func (c *Conn) Close() error {
|
||||||
c.closedMu.Lock()
|
c.closedMu.Lock()
|
||||||
if c.closed {
|
if c.closed {
|
||||||
@@ -74,8 +51,6 @@ func (c *Conn) Close() error {
|
|||||||
c.closed = true
|
c.closed = true
|
||||||
c.closedMu.Unlock()
|
c.closedMu.Unlock()
|
||||||
|
|
||||||
c.ctxCancel() // Cancel the context
|
|
||||||
|
|
||||||
sessionErr := c.session.CloseWithError(0, "normal closure")
|
sessionErr := c.session.CloseWithError(0, "normal closure")
|
||||||
return sessionErr
|
return sessionErr
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,12 +5,12 @@ import (
|
|||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
|
|
||||||
"github.com/quic-go/quic-go"
|
"github.com/quic-go/quic-go"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/relay/protocol"
|
"github.com/netbirdio/netbird/relay/protocol"
|
||||||
|
relaylistener "github.com/netbirdio/netbird/relay/server/listener"
|
||||||
nbRelay "github.com/netbirdio/netbird/shared/relay"
|
nbRelay "github.com/netbirdio/netbird/shared/relay"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -25,7 +25,7 @@ type Listener struct {
|
|||||||
listener *quic.Listener
|
listener *quic.Listener
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
|
func (l *Listener) Listen(acceptFn func(conn relaylistener.Conn)) error {
|
||||||
quicCfg := &quic.Config{
|
quicCfg := &quic.Config{
|
||||||
EnableDatagrams: true,
|
EnableDatagrams: true,
|
||||||
InitialPacketSize: nbRelay.QUICInitialPacketSize,
|
InitialPacketSize: nbRelay.QUICInitialPacketSize,
|
||||||
|
|||||||
@@ -18,25 +18,21 @@ const (
|
|||||||
|
|
||||||
type Conn struct {
|
type Conn struct {
|
||||||
*websocket.Conn
|
*websocket.Conn
|
||||||
lAddr *net.TCPAddr
|
|
||||||
rAddr *net.TCPAddr
|
rAddr *net.TCPAddr
|
||||||
|
|
||||||
closed bool
|
closed bool
|
||||||
closedMu sync.Mutex
|
closedMu sync.Mutex
|
||||||
ctx context.Context
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConn(wsConn *websocket.Conn, lAddr, rAddr *net.TCPAddr) *Conn {
|
func NewConn(wsConn *websocket.Conn, rAddr *net.TCPAddr) *Conn {
|
||||||
return &Conn{
|
return &Conn{
|
||||||
Conn: wsConn,
|
Conn: wsConn,
|
||||||
lAddr: lAddr,
|
|
||||||
rAddr: rAddr,
|
rAddr: rAddr,
|
||||||
ctx: context.Background(),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) Read(b []byte) (n int, err error) {
|
func (c *Conn) Read(ctx context.Context, b []byte) (n int, err error) {
|
||||||
t, r, err := c.Reader(c.ctx)
|
t, r, err := c.Reader(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, c.ioErrHandling(err)
|
return 0, c.ioErrHandling(err)
|
||||||
}
|
}
|
||||||
@@ -56,34 +52,18 @@ func (c *Conn) Read(b []byte) (n int, err error) {
|
|||||||
// Write writes a binary message with the given payload.
|
// Write writes a binary message with the given payload.
|
||||||
// It does not block until fill the internal buffer.
|
// It does not block until fill the internal buffer.
|
||||||
// If the buffer filled up, wait until the buffer is drained or timeout.
|
// If the buffer filled up, wait until the buffer is drained or timeout.
|
||||||
func (c *Conn) Write(b []byte) (int, error) {
|
func (c *Conn) Write(ctx context.Context, b []byte) (int, error) {
|
||||||
ctx, ctxCancel := context.WithTimeout(c.ctx, writeTimeout)
|
ctx, ctxCancel := context.WithTimeout(ctx, writeTimeout)
|
||||||
defer ctxCancel()
|
defer ctxCancel()
|
||||||
|
|
||||||
err := c.Conn.Write(ctx, websocket.MessageBinary, b)
|
err := c.Conn.Write(ctx, websocket.MessageBinary, b)
|
||||||
return len(b), err
|
return len(b), err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) LocalAddr() net.Addr {
|
|
||||||
return c.lAddr
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Conn) RemoteAddr() net.Addr {
|
func (c *Conn) RemoteAddr() net.Addr {
|
||||||
return c.rAddr
|
return c.rAddr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) SetReadDeadline(t time.Time) error {
|
|
||||||
return fmt.Errorf("SetReadDeadline is not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
|
||||||
return fmt.Errorf("SetWriteDeadline is not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Conn) SetDeadline(t time.Time) error {
|
|
||||||
return fmt.Errorf("SetDeadline is not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Conn) Close() error {
|
func (c *Conn) Close() error {
|
||||||
c.closedMu.Lock()
|
c.closedMu.Lock()
|
||||||
c.closed = true
|
c.closed = true
|
||||||
|
|||||||
@@ -7,11 +7,13 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/coder/websocket"
|
"github.com/coder/websocket"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/relay/protocol"
|
"github.com/netbirdio/netbird/relay/protocol"
|
||||||
|
relaylistener "github.com/netbirdio/netbird/relay/server/listener"
|
||||||
"github.com/netbirdio/netbird/shared/relay"
|
"github.com/netbirdio/netbird/shared/relay"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -27,18 +29,19 @@ type Listener struct {
|
|||||||
TLSConfig *tls.Config
|
TLSConfig *tls.Config
|
||||||
|
|
||||||
server *http.Server
|
server *http.Server
|
||||||
acceptFn func(conn net.Conn)
|
acceptFn func(conn relaylistener.Conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
|
func (l *Listener) Listen(acceptFn func(conn relaylistener.Conn)) error {
|
||||||
l.acceptFn = acceptFn
|
l.acceptFn = acceptFn
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
mux.HandleFunc(URLPath, l.onAccept)
|
mux.HandleFunc(URLPath, l.onAccept)
|
||||||
|
|
||||||
l.server = &http.Server{
|
l.server = &http.Server{
|
||||||
Addr: l.Address,
|
Addr: l.Address,
|
||||||
Handler: mux,
|
Handler: mux,
|
||||||
TLSConfig: l.TLSConfig,
|
TLSConfig: l.TLSConfig,
|
||||||
|
ReadHeaderTimeout: 5 * time.Second,
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("WS server listening address: %s", l.Address)
|
log.Infof("WS server listening address: %s", l.Address)
|
||||||
@@ -93,18 +96,9 @@ func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
lAddr, err := net.ResolveTCPAddr("tcp", l.server.Addr)
|
|
||||||
if err != nil {
|
|
||||||
err = wsConn.Close(websocket.StatusInternalError, "internal error")
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to close ws connection: %s", err)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Infof("WS client connected from: %s", rAddr)
|
log.Infof("WS client connected from: %s", rAddr)
|
||||||
|
|
||||||
conn := NewConn(wsConn, lAddr, rAddr)
|
conn := NewConn(wsConn, rAddr)
|
||||||
l.acceptFn(conn)
|
l.acceptFn(conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/relay/metrics"
|
"github.com/netbirdio/netbird/relay/metrics"
|
||||||
|
"github.com/netbirdio/netbird/relay/server/listener"
|
||||||
"github.com/netbirdio/netbird/relay/server/store"
|
"github.com/netbirdio/netbird/relay/server/store"
|
||||||
"github.com/netbirdio/netbird/shared/relay/healthcheck"
|
"github.com/netbirdio/netbird/shared/relay/healthcheck"
|
||||||
"github.com/netbirdio/netbird/shared/relay/messages"
|
"github.com/netbirdio/netbird/shared/relay/messages"
|
||||||
@@ -26,11 +27,14 @@ type Peer struct {
|
|||||||
metrics *metrics.Metrics
|
metrics *metrics.Metrics
|
||||||
log *log.Entry
|
log *log.Entry
|
||||||
id messages.PeerID
|
id messages.PeerID
|
||||||
conn net.Conn
|
conn listener.Conn
|
||||||
connMu sync.RWMutex
|
connMu sync.RWMutex
|
||||||
store *store.Store
|
store *store.Store
|
||||||
notifier *store.PeerNotifier
|
notifier *store.PeerNotifier
|
||||||
|
|
||||||
|
ctx context.Context
|
||||||
|
ctxCancel context.CancelFunc
|
||||||
|
|
||||||
peersListener *store.Listener
|
peersListener *store.Listener
|
||||||
|
|
||||||
// between the online peer collection step and the notification sending should not be sent offline notifications from another thread
|
// between the online peer collection step and the notification sending should not be sent offline notifications from another thread
|
||||||
@@ -38,14 +42,17 @@ type Peer struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewPeer creates a new Peer instance and prepare custom logging
|
// NewPeer creates a new Peer instance and prepare custom logging
|
||||||
func NewPeer(metrics *metrics.Metrics, id messages.PeerID, conn net.Conn, store *store.Store, notifier *store.PeerNotifier) *Peer {
|
func NewPeer(metrics *metrics.Metrics, id messages.PeerID, conn listener.Conn, store *store.Store, notifier *store.PeerNotifier) *Peer {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
p := &Peer{
|
p := &Peer{
|
||||||
metrics: metrics,
|
metrics: metrics,
|
||||||
log: log.WithField("peer_id", id.String()),
|
log: log.WithField("peer_id", id.String()),
|
||||||
id: id,
|
id: id,
|
||||||
conn: conn,
|
conn: conn,
|
||||||
store: store,
|
store: store,
|
||||||
notifier: notifier,
|
notifier: notifier,
|
||||||
|
ctx: ctx,
|
||||||
|
ctxCancel: cancel,
|
||||||
}
|
}
|
||||||
|
|
||||||
return p
|
return p
|
||||||
@@ -57,6 +64,7 @@ func NewPeer(metrics *metrics.Metrics, id messages.PeerID, conn net.Conn, store
|
|||||||
func (p *Peer) Work() {
|
func (p *Peer) Work() {
|
||||||
p.peersListener = p.notifier.NewListener(p.sendPeersOnline, p.sendPeersWentOffline)
|
p.peersListener = p.notifier.NewListener(p.sendPeersOnline, p.sendPeersWentOffline)
|
||||||
defer func() {
|
defer func() {
|
||||||
|
p.ctxCancel()
|
||||||
p.notifier.RemoveListener(p.peersListener)
|
p.notifier.RemoveListener(p.peersListener)
|
||||||
|
|
||||||
if err := p.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
if err := p.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||||
@@ -64,8 +72,7 @@ func (p *Peer) Work() {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx := p.ctx
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
hc := healthcheck.NewSender(p.log)
|
hc := healthcheck.NewSender(p.log)
|
||||||
go hc.StartHealthCheck(ctx)
|
go hc.StartHealthCheck(ctx)
|
||||||
@@ -73,7 +80,7 @@ func (p *Peer) Work() {
|
|||||||
|
|
||||||
buf := make([]byte, bufferSize)
|
buf := make([]byte, bufferSize)
|
||||||
for {
|
for {
|
||||||
n, err := p.conn.Read(buf)
|
n, err := p.conn.Read(ctx, buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !errors.Is(err, net.ErrClosed) {
|
if !errors.Is(err, net.ErrClosed) {
|
||||||
p.log.Errorf("failed to read message: %s", err)
|
p.log.Errorf("failed to read message: %s", err)
|
||||||
@@ -131,10 +138,10 @@ func (p *Peer) handleMsgType(ctx context.Context, msgType messages.MsgType, hc *
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Write writes data to the connection
|
// Write writes data to the connection
|
||||||
func (p *Peer) Write(b []byte) (int, error) {
|
func (p *Peer) Write(ctx context.Context, b []byte) (int, error) {
|
||||||
p.connMu.RLock()
|
p.connMu.RLock()
|
||||||
defer p.connMu.RUnlock()
|
defer p.connMu.RUnlock()
|
||||||
return p.conn.Write(b)
|
return p.conn.Write(ctx, b)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CloseGracefully closes the connection with the peer gracefully. Send a close message to the client and close the
|
// CloseGracefully closes the connection with the peer gracefully. Send a close message to the client and close the
|
||||||
@@ -147,6 +154,7 @@ func (p *Peer) CloseGracefully(ctx context.Context) {
|
|||||||
p.log.Errorf("failed to send close message to peer: %s", p.String())
|
p.log.Errorf("failed to send close message to peer: %s", p.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
p.ctxCancel()
|
||||||
if err := p.conn.Close(); err != nil {
|
if err := p.conn.Close(); err != nil {
|
||||||
p.log.Errorf(errCloseConn, err)
|
p.log.Errorf(errCloseConn, err)
|
||||||
}
|
}
|
||||||
@@ -156,6 +164,7 @@ func (p *Peer) Close() {
|
|||||||
p.connMu.Lock()
|
p.connMu.Lock()
|
||||||
defer p.connMu.Unlock()
|
defer p.connMu.Unlock()
|
||||||
|
|
||||||
|
p.ctxCancel()
|
||||||
if err := p.conn.Close(); err != nil {
|
if err := p.conn.Close(); err != nil {
|
||||||
p.log.Errorf(errCloseConn, err)
|
p.log.Errorf(errCloseConn, err)
|
||||||
}
|
}
|
||||||
@@ -170,26 +179,15 @@ func (p *Peer) writeWithTimeout(ctx context.Context, buf []byte) error {
|
|||||||
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
writeDone := make(chan struct{})
|
_, err := p.conn.Write(ctx, buf)
|
||||||
var err error
|
return err
|
||||||
go func() {
|
|
||||||
_, err = p.conn.Write(buf)
|
|
||||||
close(writeDone)
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return ctx.Err()
|
|
||||||
case <-writeDone:
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Peer) handleHealthcheckEvents(ctx context.Context, hc *healthcheck.Sender) {
|
func (p *Peer) handleHealthcheckEvents(ctx context.Context, hc *healthcheck.Sender) {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-hc.HealthCheck:
|
case <-hc.HealthCheck:
|
||||||
_, err := p.Write(messages.MarshalHealthcheck())
|
_, err := p.Write(ctx, messages.MarshalHealthcheck())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
p.log.Errorf("failed to send healthcheck message: %s", err)
|
p.log.Errorf("failed to send healthcheck message: %s", err)
|
||||||
return
|
return
|
||||||
@@ -228,12 +226,12 @@ func (p *Peer) handleTransportMsg(msg []byte) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
n, err := dp.Write(msg)
|
n, err := dp.Write(dp.ctx, msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
p.log.Errorf("failed to write transport message to: %s", dp.String())
|
p.log.Errorf("failed to write transport message to: %s", dp.String())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
p.metrics.TransferBytesSent.Add(context.Background(), int64(n))
|
p.metrics.TransferBytesSent.Add(p.ctx, int64(n))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Peer) handleSubscribePeerState(msg []byte) {
|
func (p *Peer) handleSubscribePeerState(msg []byte) {
|
||||||
@@ -276,7 +274,7 @@ func (p *Peer) sendPeersOnline(peers []messages.PeerID) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for n, msg := range msgs {
|
for n, msg := range msgs {
|
||||||
if _, err := p.Write(msg); err != nil {
|
if _, err := p.Write(p.ctx, msg); err != nil {
|
||||||
p.log.Errorf("failed to write %d. peers offline message: %s", n, err)
|
p.log.Errorf("failed to write %d. peers offline message: %s", n, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -293,7 +291,7 @@ func (p *Peer) sendPeersWentOffline(peers []messages.PeerID) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for n, msg := range msgs {
|
for n, msg := range msgs {
|
||||||
if _, err := p.Write(msg); err != nil {
|
if _, err := p.Write(p.ctx, msg); err != nil {
|
||||||
p.log.Errorf("failed to write %d. peers offline message: %s", n, err)
|
p.log.Errorf("failed to write %d. peers offline message: %s", n, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package server
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"net/url"
|
"net/url"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@@ -13,11 +12,20 @@ import (
|
|||||||
"go.opentelemetry.io/otel/metric"
|
"go.opentelemetry.io/otel/metric"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/relay/healthcheck/peerid"
|
"github.com/netbirdio/netbird/relay/healthcheck/peerid"
|
||||||
|
"github.com/netbirdio/netbird/relay/protocol"
|
||||||
|
"github.com/netbirdio/netbird/relay/server/listener"
|
||||||
|
|
||||||
//nolint:staticcheck
|
//nolint:staticcheck
|
||||||
"github.com/netbirdio/netbird/relay/metrics"
|
"github.com/netbirdio/netbird/relay/metrics"
|
||||||
"github.com/netbirdio/netbird/relay/server/store"
|
"github.com/netbirdio/netbird/relay/server/store"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type Listener interface {
|
||||||
|
Listen(func(conn listener.Conn)) error
|
||||||
|
Shutdown(ctx context.Context) error
|
||||||
|
Protocol() protocol.Protocol
|
||||||
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Meter metric.Meter
|
Meter metric.Meter
|
||||||
ExposedAddress string
|
ExposedAddress string
|
||||||
@@ -109,7 +117,7 @@ func NewRelay(config Config) (*Relay, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Accept start to handle a new peer connection
|
// Accept start to handle a new peer connection
|
||||||
func (r *Relay) Accept(conn net.Conn) {
|
func (r *Relay) Accept(conn listener.Conn) {
|
||||||
acceptTime := time.Now()
|
acceptTime := time.Now()
|
||||||
r.closeMu.RLock()
|
r.closeMu.RLock()
|
||||||
defer r.closeMu.RUnlock()
|
defer r.closeMu.RUnlock()
|
||||||
@@ -117,12 +125,15 @@ func (r *Relay) Accept(conn net.Conn) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
hsCtx, hsCancel := context.WithTimeout(context.Background(), handshakeTimeout)
|
||||||
|
defer hsCancel()
|
||||||
|
|
||||||
h := handshake{
|
h := handshake{
|
||||||
conn: conn,
|
conn: conn,
|
||||||
validator: r.validator,
|
validator: r.validator,
|
||||||
preparedMsg: r.preparedMsg,
|
preparedMsg: r.preparedMsg,
|
||||||
}
|
}
|
||||||
peerID, err := h.handshakeReceive()
|
peerID, err := h.handshakeReceive(hsCtx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if peerid.IsHealthCheck(peerID) {
|
if peerid.IsHealthCheck(peerID) {
|
||||||
log.Debugf("health check connection from %s", conn.RemoteAddr())
|
log.Debugf("health check connection from %s", conn.RemoteAddr())
|
||||||
@@ -154,7 +165,7 @@ func (r *Relay) Accept(conn net.Conn) {
|
|||||||
r.metrics.PeerDisconnected(peer.String())
|
r.metrics.PeerDisconnected(peer.String())
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err := h.handshakeResponse(); err != nil {
|
if err := h.handshakeResponse(hsCtx); err != nil {
|
||||||
log.Errorf("failed to send handshake response, close peer: %s", err)
|
log.Errorf("failed to send handshake response, close peer: %s", err)
|
||||||
peer.Close()
|
peer.Close()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package server
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"net"
|
|
||||||
"net/url"
|
"net/url"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
@@ -31,7 +30,7 @@ type ListenerConfig struct {
|
|||||||
// In a new HTTP connection, the server will accept the connection and pass it to the Relay server via the Accept method.
|
// In a new HTTP connection, the server will accept the connection and pass it to the Relay server via the Accept method.
|
||||||
type Server struct {
|
type Server struct {
|
||||||
relay *Relay
|
relay *Relay
|
||||||
listeners []listener.Listener
|
listeners []Listener
|
||||||
listenerMux sync.Mutex
|
listenerMux sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -56,7 +55,7 @@ func NewServer(config Config) (*Server, error) {
|
|||||||
}
|
}
|
||||||
return &Server{
|
return &Server{
|
||||||
relay: relay,
|
relay: relay,
|
||||||
listeners: make([]listener.Listener, 0, 2),
|
listeners: make([]Listener, 0, 2),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -86,7 +85,7 @@ func (r *Server) Listen(cfg ListenerConfig) error {
|
|||||||
wg := sync.WaitGroup{}
|
wg := sync.WaitGroup{}
|
||||||
for _, l := range r.listeners {
|
for _, l := range r.listeners {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(listener listener.Listener) {
|
go func(listener Listener) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
errChan <- listener.Listen(r.relay.Accept)
|
errChan <- listener.Listen(r.relay.Accept)
|
||||||
}(l)
|
}(l)
|
||||||
@@ -139,6 +138,6 @@ func (r *Server) InstanceURL() url.URL {
|
|||||||
// RelayAccept returns the relay's Accept function for handling incoming connections.
|
// RelayAccept returns the relay's Accept function for handling incoming connections.
|
||||||
// This allows external HTTP handlers to route connections to the relay without
|
// This allows external HTTP handlers to route connections to the relay without
|
||||||
// starting the relay's own listeners.
|
// starting the relay's own listeners.
|
||||||
func (r *Server) RelayAccept() func(conn net.Conn) {
|
func (r *Server) RelayAccept() func(conn listener.Conn) {
|
||||||
return r.relay.Accept
|
return r.relay.Accept
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user