mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
[relay] Feature/relay integration (#2244)
This update adds new relay integration for NetBird clients. The new relay is based on web sockets and listens on a single port. - Adds new relay implementation with websocket with single port relaying mechanism - refactor peer connection logic, allowing upgrade and downgrade from/to P2P connection - peer connections are faster since it connects first to relay and then upgrades to P2P - maintains compatibility with old clients by not using the new relay - updates infrastructure scripts with new relay service
This commit is contained in:
11
relay/server/listener/listener.go
Normal file
11
relay/server/listener/listener.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package listener
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
)
|
||||
|
||||
type Listener interface {
|
||||
Listen(func(conn net.Conn)) error
|
||||
Shutdown(ctx context.Context) error
|
||||
}
|
||||
114
relay/server/listener/ws/conn.go
Normal file
114
relay/server/listener/ws/conn.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"nhooyr.io/websocket"
|
||||
)
|
||||
|
||||
const (
|
||||
writeTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
type Conn struct {
|
||||
*websocket.Conn
|
||||
lAddr *net.TCPAddr
|
||||
rAddr *net.TCPAddr
|
||||
|
||||
closed bool
|
||||
closedMu sync.Mutex
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func NewConn(wsConn *websocket.Conn, lAddr, rAddr *net.TCPAddr) *Conn {
|
||||
return &Conn{
|
||||
Conn: wsConn,
|
||||
lAddr: lAddr,
|
||||
rAddr: rAddr,
|
||||
ctx: context.Background(),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) Read(b []byte) (n int, err error) {
|
||||
t, r, err := c.Reader(c.ctx)
|
||||
if err != nil {
|
||||
return 0, c.ioErrHandling(err)
|
||||
}
|
||||
|
||||
if t != websocket.MessageBinary {
|
||||
log.Errorf("unexpected message type: %d", t)
|
||||
return 0, fmt.Errorf("unexpected message type")
|
||||
}
|
||||
|
||||
n, err = r.Read(b)
|
||||
if err != nil {
|
||||
return 0, c.ioErrHandling(err)
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Write writes a binary message with the given payload.
|
||||
// It does not block until fill the internal buffer.
|
||||
// If the buffer filled up, wait until the buffer is drained or timeout.
|
||||
func (c *Conn) Write(b []byte) (int, error) {
|
||||
ctx, ctxCancel := context.WithTimeout(c.ctx, writeTimeout)
|
||||
defer ctxCancel()
|
||||
|
||||
err := c.Conn.Write(ctx, websocket.MessageBinary, b)
|
||||
return len(b), err
|
||||
}
|
||||
|
||||
func (c *Conn) LocalAddr() net.Addr {
|
||||
return c.lAddr
|
||||
}
|
||||
|
||||
func (c *Conn) RemoteAddr() net.Addr {
|
||||
return c.rAddr
|
||||
}
|
||||
|
||||
func (c *Conn) SetReadDeadline(t time.Time) error {
|
||||
return fmt.Errorf("SetReadDeadline is not implemented")
|
||||
}
|
||||
|
||||
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
||||
return fmt.Errorf("SetWriteDeadline is not implemented")
|
||||
}
|
||||
|
||||
func (c *Conn) SetDeadline(t time.Time) error {
|
||||
return fmt.Errorf("SetDeadline is not implemented")
|
||||
}
|
||||
|
||||
func (c *Conn) Close() error {
|
||||
c.closedMu.Lock()
|
||||
c.closed = true
|
||||
c.closedMu.Unlock()
|
||||
return c.Conn.CloseNow()
|
||||
}
|
||||
|
||||
func (c *Conn) isClosed() bool {
|
||||
c.closedMu.Lock()
|
||||
defer c.closedMu.Unlock()
|
||||
return c.closed
|
||||
}
|
||||
|
||||
func (c *Conn) ioErrHandling(err error) error {
|
||||
if c.isClosed() {
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
var wErr *websocket.CloseError
|
||||
if !errors.As(err, &wErr) {
|
||||
return err
|
||||
}
|
||||
if wErr.Code == websocket.StatusNormalClosure {
|
||||
return io.EOF
|
||||
}
|
||||
return err
|
||||
}
|
||||
92
relay/server/listener/ws/listener.go
Normal file
92
relay/server/listener/ws/listener.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"nhooyr.io/websocket"
|
||||
)
|
||||
|
||||
// URLPath is the path for the websocket connection.
|
||||
const URLPath = "/relay"
|
||||
|
||||
type Listener struct {
|
||||
// Address is the address to listen on.
|
||||
Address string
|
||||
// TLSConfig is the TLS configuration for the server.
|
||||
TLSConfig *tls.Config
|
||||
|
||||
server *http.Server
|
||||
acceptFn func(conn net.Conn)
|
||||
}
|
||||
|
||||
func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
|
||||
l.acceptFn = acceptFn
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc(URLPath, l.onAccept)
|
||||
|
||||
l.server = &http.Server{
|
||||
Addr: l.Address,
|
||||
Handler: mux,
|
||||
TLSConfig: l.TLSConfig,
|
||||
}
|
||||
|
||||
log.Infof("WS server listening address: %s", l.Address)
|
||||
var err error
|
||||
if l.TLSConfig != nil {
|
||||
err = l.server.ListenAndServeTLS("", "")
|
||||
} else {
|
||||
err = l.server.ListenAndServe()
|
||||
}
|
||||
if errors.Is(err, http.ErrServerClosed) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (l *Listener) Shutdown(ctx context.Context) error {
|
||||
if l.server == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Infof("stop WS listener")
|
||||
if err := l.server.Shutdown(ctx); err != nil {
|
||||
return fmt.Errorf("server shutdown failed: %v", err)
|
||||
}
|
||||
log.Infof("WS listener stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) {
|
||||
wsConn, err := websocket.Accept(w, r, nil)
|
||||
if err != nil {
|
||||
log.Errorf("failed to accept ws connection from %s: %s", r.RemoteAddr, err)
|
||||
return
|
||||
}
|
||||
|
||||
rAddr, err := net.ResolveTCPAddr("tcp", r.RemoteAddr)
|
||||
if err != nil {
|
||||
err = wsConn.Close(websocket.StatusInternalError, "internal error")
|
||||
if err != nil {
|
||||
log.Errorf("failed to close ws connection: %s", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
lAddr, err := net.ResolveTCPAddr("tcp", l.server.Addr)
|
||||
if err != nil {
|
||||
err = wsConn.Close(websocket.StatusInternalError, "internal error")
|
||||
if err != nil {
|
||||
log.Errorf("failed to close ws connection: %s", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
conn := NewConn(wsConn, lAddr, rAddr)
|
||||
l.acceptFn(conn)
|
||||
}
|
||||
203
relay/server/peer.go
Normal file
203
relay/server/peer.go
Normal file
@@ -0,0 +1,203 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/healthcheck"
|
||||
"github.com/netbirdio/netbird/relay/messages"
|
||||
"github.com/netbirdio/netbird/relay/metrics"
|
||||
)
|
||||
|
||||
const (
|
||||
bufferSize = 8820
|
||||
)
|
||||
|
||||
// Peer represents a peer connection
|
||||
type Peer struct {
|
||||
metrics *metrics.Metrics
|
||||
log *log.Entry
|
||||
idS string
|
||||
idB []byte
|
||||
conn net.Conn
|
||||
connMu sync.RWMutex
|
||||
store *Store
|
||||
}
|
||||
|
||||
// NewPeer creates a new Peer instance and prepare custom logging
|
||||
func NewPeer(metrics *metrics.Metrics, id []byte, conn net.Conn, store *Store) *Peer {
|
||||
stringID := messages.HashIDToString(id)
|
||||
return &Peer{
|
||||
metrics: metrics,
|
||||
log: log.WithField("peer_id", stringID),
|
||||
idS: stringID,
|
||||
idB: id,
|
||||
conn: conn,
|
||||
store: store,
|
||||
}
|
||||
}
|
||||
|
||||
// Work reads data from the connection
|
||||
// It manages the protocol (healthcheck, transport, close). Read the message and determine the message type and handle
|
||||
// the message accordingly.
|
||||
func (p *Peer) Work() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
hc := healthcheck.NewSender()
|
||||
go hc.StartHealthCheck(ctx)
|
||||
go p.handleHealthcheckEvents(ctx, hc)
|
||||
|
||||
buf := make([]byte, bufferSize)
|
||||
for {
|
||||
n, err := p.conn.Read(buf)
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
p.log.Errorf("failed to read message: %s", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if n == 0 {
|
||||
p.log.Errorf("received empty message")
|
||||
return
|
||||
}
|
||||
|
||||
msg := buf[:n]
|
||||
|
||||
_, err = messages.ValidateVersion(msg)
|
||||
if err != nil {
|
||||
p.log.Warnf("failed to validate protocol version: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
msgType, err := messages.DetermineClientMessageType(msg[messages.SizeOfVersionByte:])
|
||||
if err != nil {
|
||||
p.log.Errorf("failed to determine message type: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
p.handleMsgType(ctx, msgType, hc, n, msg)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Peer) handleMsgType(ctx context.Context, msgType messages.MsgType, hc *healthcheck.Sender, n int, msg []byte) {
|
||||
switch msgType {
|
||||
case messages.MsgTypeHealthCheck:
|
||||
hc.OnHCResponse()
|
||||
case messages.MsgTypeTransport:
|
||||
p.metrics.TransferBytesRecv.Add(ctx, int64(n))
|
||||
p.metrics.PeerActivity(p.String())
|
||||
p.handleTransportMsg(msg)
|
||||
case messages.MsgTypeClose:
|
||||
p.log.Infof("peer exited gracefully")
|
||||
if err := p.conn.Close(); err != nil {
|
||||
log.Errorf("failed to close connection to peer: %s", err)
|
||||
}
|
||||
default:
|
||||
p.log.Warnf("received unexpected message type: %s", msgType)
|
||||
}
|
||||
}
|
||||
|
||||
// Write writes data to the connection
|
||||
func (p *Peer) Write(b []byte) (int, error) {
|
||||
p.connMu.RLock()
|
||||
defer p.connMu.RUnlock()
|
||||
return p.conn.Write(b)
|
||||
}
|
||||
|
||||
// CloseGracefully closes the connection with the peer gracefully. Send a close message to the client and close the
|
||||
// connection.
|
||||
func (p *Peer) CloseGracefully(ctx context.Context) {
|
||||
p.connMu.Lock()
|
||||
err := p.writeWithTimeout(ctx, messages.MarshalCloseMsg())
|
||||
if err != nil {
|
||||
p.log.Errorf("failed to send close message to peer: %s", p.String())
|
||||
}
|
||||
|
||||
err = p.conn.Close()
|
||||
if err != nil {
|
||||
p.log.Errorf("failed to close connection to peer: %s", err)
|
||||
}
|
||||
|
||||
defer p.connMu.Unlock()
|
||||
}
|
||||
|
||||
// String returns the peer ID
|
||||
func (p *Peer) String() string {
|
||||
return p.idS
|
||||
}
|
||||
|
||||
func (p *Peer) writeWithTimeout(ctx context.Context, buf []byte) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
writeDone := make(chan struct{})
|
||||
var err error
|
||||
go func() {
|
||||
_, err = p.conn.Write(buf)
|
||||
close(writeDone)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-writeDone:
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Peer) handleHealthcheckEvents(ctx context.Context, hc *healthcheck.Sender) {
|
||||
for {
|
||||
select {
|
||||
case <-hc.HealthCheck:
|
||||
_, err := p.Write(messages.MarshalHealthcheck())
|
||||
if err != nil {
|
||||
p.log.Errorf("failed to send healthcheck message: %s", err)
|
||||
return
|
||||
}
|
||||
case <-hc.Timeout:
|
||||
p.log.Errorf("peer healthcheck timeout")
|
||||
err := p.conn.Close()
|
||||
if err != nil {
|
||||
p.log.Errorf("failed to close connection to peer: %s", err)
|
||||
}
|
||||
return
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Peer) handleTransportMsg(msg []byte) {
|
||||
peerID, err := messages.UnmarshalTransportID(msg[messages.SizeOfProtoHeader:])
|
||||
if err != nil {
|
||||
p.log.Errorf("failed to unmarshal transport message: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
stringPeerID := messages.HashIDToString(peerID)
|
||||
dp, ok := p.store.Peer(stringPeerID)
|
||||
if !ok {
|
||||
p.log.Errorf("peer not found: %s", stringPeerID)
|
||||
return
|
||||
}
|
||||
|
||||
err = messages.UpdateTransportMsg(msg[messages.SizeOfProtoHeader:], p.idB)
|
||||
if err != nil {
|
||||
p.log.Errorf("failed to update transport message: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
n, err := dp.Write(msg)
|
||||
if err != nil {
|
||||
p.log.Errorf("failed to write transport message to: %s", dp.String())
|
||||
return
|
||||
}
|
||||
p.metrics.TransferBytesSent.Add(context.Background(), int64(n))
|
||||
}
|
||||
206
relay/server/relay.go
Normal file
206
relay/server/relay.go
Normal file
@@ -0,0 +1,206 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/auth"
|
||||
"github.com/netbirdio/netbird/relay/messages"
|
||||
"github.com/netbirdio/netbird/relay/messages/address"
|
||||
authmsg "github.com/netbirdio/netbird/relay/messages/auth"
|
||||
"github.com/netbirdio/netbird/relay/metrics"
|
||||
)
|
||||
|
||||
// Relay represents the relay server
|
||||
type Relay struct {
|
||||
metrics *metrics.Metrics
|
||||
metricsCancel context.CancelFunc
|
||||
validator auth.Validator
|
||||
|
||||
store *Store
|
||||
instanceURL string
|
||||
|
||||
closed bool
|
||||
closeMu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewRelay creates a new Relay instance
|
||||
//
|
||||
// Parameters:
|
||||
// meter: An instance of metric.Meter from the go.opentelemetry.io/otel/metric package. It is used to create and manage
|
||||
// metrics for the relay server.
|
||||
// exposedAddress: A string representing the address that the relay server is exposed on. The client will use this
|
||||
// address as the relay server's instance URL.
|
||||
// tlsSupport: A boolean indicating whether the relay server supports TLS (Transport Layer Security) or not. The
|
||||
// instance URL depends on this value.
|
||||
// validator: An instance of auth.Validator from the auth package. It is used to validate the authentication of the
|
||||
// peers.
|
||||
//
|
||||
// Returns:
|
||||
// A pointer to a Relay instance and an error. If the Relay instance is successfully created, the error is nil.
|
||||
// Otherwise, the error contains the details of what went wrong.
|
||||
func NewRelay(meter metric.Meter, exposedAddress string, tlsSupport bool, validator auth.Validator) (*Relay, error) {
|
||||
ctx, metricsCancel := context.WithCancel(context.Background())
|
||||
m, err := metrics.NewMetrics(ctx, meter)
|
||||
if err != nil {
|
||||
metricsCancel()
|
||||
return nil, fmt.Errorf("creating app metrics: %v", err)
|
||||
}
|
||||
|
||||
r := &Relay{
|
||||
metrics: m,
|
||||
metricsCancel: metricsCancel,
|
||||
validator: validator,
|
||||
store: NewStore(),
|
||||
}
|
||||
|
||||
r.instanceURL, err = getInstanceURL(exposedAddress, tlsSupport)
|
||||
if err != nil {
|
||||
metricsCancel()
|
||||
return nil, fmt.Errorf("get instance URL: %v", err)
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// getInstanceURL checks if user supplied a URL scheme otherwise adds to the
|
||||
// provided address according to TLS definition and parses the address before returning it
|
||||
func getInstanceURL(exposedAddress string, tlsSupported bool) (string, error) {
|
||||
addr := exposedAddress
|
||||
split := strings.Split(exposedAddress, "://")
|
||||
switch {
|
||||
case len(split) == 1 && tlsSupported:
|
||||
addr = "rels://" + exposedAddress
|
||||
case len(split) == 1 && !tlsSupported:
|
||||
addr = "rel://" + exposedAddress
|
||||
case len(split) > 2:
|
||||
return "", fmt.Errorf("invalid exposed address: %s", exposedAddress)
|
||||
}
|
||||
|
||||
parsedURL, err := url.ParseRequestURI(addr)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid exposed address: %v", err)
|
||||
}
|
||||
|
||||
if parsedURL.Scheme != "rel" && parsedURL.Scheme != "rels" {
|
||||
return "", fmt.Errorf("invalid scheme: %s", parsedURL.Scheme)
|
||||
}
|
||||
|
||||
return parsedURL.String(), nil
|
||||
}
|
||||
|
||||
// Accept start to handle a new peer connection
|
||||
func (r *Relay) Accept(conn net.Conn) {
|
||||
r.closeMu.RLock()
|
||||
defer r.closeMu.RUnlock()
|
||||
if r.closed {
|
||||
return
|
||||
}
|
||||
|
||||
peerID, err := r.handshake(conn)
|
||||
if err != nil {
|
||||
log.Errorf("failed to handshake: %s", err)
|
||||
cErr := conn.Close()
|
||||
if cErr != nil {
|
||||
log.Errorf("failed to close connection, %s: %s", conn.RemoteAddr(), cErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
peer := NewPeer(r.metrics, peerID, conn, r.store)
|
||||
peer.log.Infof("peer connected from: %s", conn.RemoteAddr())
|
||||
r.store.AddPeer(peer)
|
||||
r.metrics.PeerConnected(peer.String())
|
||||
go func() {
|
||||
peer.Work()
|
||||
r.store.DeletePeer(peer)
|
||||
peer.log.Debugf("relay connection closed")
|
||||
r.metrics.PeerDisconnected(peer.String())
|
||||
}()
|
||||
}
|
||||
|
||||
// Shutdown closes the relay server
|
||||
// It closes the connection with all peers in gracefully and stops accepting new connections.
|
||||
func (r *Relay) Shutdown(ctx context.Context) {
|
||||
log.Infof("close connection with all peers")
|
||||
r.closeMu.Lock()
|
||||
wg := sync.WaitGroup{}
|
||||
peers := r.store.Peers()
|
||||
for _, peer := range peers {
|
||||
wg.Add(1)
|
||||
go func(p *Peer) {
|
||||
p.CloseGracefully(ctx)
|
||||
wg.Done()
|
||||
}(peer)
|
||||
}
|
||||
wg.Wait()
|
||||
r.metricsCancel()
|
||||
r.closeMu.Unlock()
|
||||
}
|
||||
|
||||
// InstanceURL returns the instance URL of the relay server
|
||||
func (r *Relay) InstanceURL() string {
|
||||
return r.instanceURL
|
||||
}
|
||||
|
||||
func (r *Relay) handshake(conn net.Conn) ([]byte, error) {
|
||||
buf := make([]byte, messages.MaxHandshakeSize)
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read from %s: %w", conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
_, err = messages.ValidateVersion(buf[:n])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("validate version from %s: %w", conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
msgType, err := messages.DetermineClientMessageType(buf[messages.SizeOfVersionByte:n])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("determine message type from %s: %w", conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
if msgType != messages.MsgTypeHello {
|
||||
return nil, fmt.Errorf("invalid message type from %s", conn.RemoteAddr())
|
||||
}
|
||||
|
||||
peerID, authData, err := messages.UnmarshalHelloMsg(buf[messages.SizeOfProtoHeader:n])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unmarshal hello message: %w", err)
|
||||
}
|
||||
|
||||
authMsg, err := authmsg.UnmarshalMsg(authData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unmarshal auth message: %w", err)
|
||||
}
|
||||
|
||||
if err := r.validator.Validate(sha256.New, authMsg.AdditionalData); err != nil {
|
||||
return nil, fmt.Errorf("validate %s (%s): %w", peerID, conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
addr := &address.Address{URL: r.instanceURL}
|
||||
addrData, err := addr.Marshal()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal addressc to %s (%s): %w", peerID, conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
msg, err := messages.MarshalHelloResponse(addrData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
_, err = conn.Write(msg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("write to %s (%s): %w", peerID, conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
return peerID, nil
|
||||
}
|
||||
36
relay/server/relay_test.go
Normal file
36
relay/server/relay_test.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package server
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestGetInstanceURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
exposedAddress string
|
||||
tlsSupported bool
|
||||
expectedURL string
|
||||
expectError bool
|
||||
}{
|
||||
{"Valid address with TLS", "example.com", true, "rels://example.com", false},
|
||||
{"Valid address without TLS", "example.com", false, "rel://example.com", false},
|
||||
{"Valid address with scheme", "rel://example.com", false, "rel://example.com", false},
|
||||
{"Valid address with non TLS scheme and TLS true", "rel://example.com", true, "rel://example.com", false},
|
||||
{"Valid address with TLS scheme", "rels://example.com", true, "rels://example.com", false},
|
||||
{"Valid address with TLS scheme and TLS false", "rels://example.com", false, "rels://example.com", false},
|
||||
{"Valid address with TLS scheme and custom port", "rels://example.com:9300", true, "rels://example.com:9300", false},
|
||||
{"Invalid address with multiple schemes", "rel://rels://example.com", false, "", true},
|
||||
{"Invalid address with unsupported scheme", "http://example.com", false, "", true},
|
||||
{"Invalid address format", "://example.com", false, "", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
url, err := getInstanceURL(tt.exposedAddress, tt.tlsSupported)
|
||||
if (err != nil) != tt.expectError {
|
||||
t.Errorf("expected error: %v, got: %v", tt.expectError, err)
|
||||
}
|
||||
if url != tt.expectedURL {
|
||||
t.Errorf("expected URL: %s, got: %s", tt.expectedURL, url)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
76
relay/server/server.go
Normal file
76
relay/server/server.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/auth"
|
||||
"github.com/netbirdio/netbird/relay/server/listener"
|
||||
"github.com/netbirdio/netbird/relay/server/listener/ws"
|
||||
)
|
||||
|
||||
// ListenerConfig is the configuration for the listener.
|
||||
// Address: the address to bind the listener to. It could be an address behind a reverse proxy.
|
||||
// TLSConfig: the TLS configuration for the listener.
|
||||
type ListenerConfig struct {
|
||||
Address string
|
||||
TLSConfig *tls.Config
|
||||
}
|
||||
|
||||
// Server is the main entry point for the relay server.
|
||||
// It is the gate between the WebSocket listener and the Relay server logic.
|
||||
// In a new HTTP connection, the server will accept the connection and pass it to the Relay server via the Accept method.
|
||||
type Server struct {
|
||||
relay *Relay
|
||||
wSListener listener.Listener
|
||||
}
|
||||
|
||||
// NewServer creates a new relay server instance.
|
||||
// meter: the OpenTelemetry meter
|
||||
// exposedAddress: this address will be used as the instance URL. It should be a domain:port format.
|
||||
// tlsSupport: if true, the server will support TLS
|
||||
// authValidator: the auth validator to use for the server
|
||||
func NewServer(meter metric.Meter, exposedAddress string, tlsSupport bool, authValidator auth.Validator) (*Server, error) {
|
||||
relay, err := NewRelay(meter, exposedAddress, tlsSupport, authValidator)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Server{
|
||||
relay: relay,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Listen starts the relay server.
|
||||
func (r *Server) Listen(cfg ListenerConfig) error {
|
||||
r.wSListener = &ws.Listener{
|
||||
Address: cfg.Address,
|
||||
TLSConfig: cfg.TLSConfig,
|
||||
}
|
||||
|
||||
wslErr := r.wSListener.Listen(r.relay.Accept)
|
||||
if wslErr != nil {
|
||||
log.Errorf("failed to bind ws server: %s", wslErr)
|
||||
}
|
||||
|
||||
return wslErr
|
||||
}
|
||||
|
||||
// Shutdown stops the relay server. If there are active connections, they will be closed gracefully. In case of a context,
|
||||
// the connections will be forcefully closed.
|
||||
func (r *Server) Shutdown(ctx context.Context) (err error) {
|
||||
// stop service new connections
|
||||
if r.wSListener != nil {
|
||||
err = r.wSListener.Shutdown(ctx)
|
||||
}
|
||||
|
||||
r.relay.Shutdown(ctx)
|
||||
return
|
||||
}
|
||||
|
||||
// InstanceURL returns the instance URL of the relay server.
|
||||
func (r *Server) InstanceURL() string {
|
||||
return r.relay.instanceURL
|
||||
}
|
||||
64
relay/server/store.go
Normal file
64
relay/server/store.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Store is a thread-safe store of peers
|
||||
// It is used to store the peers that are connected to the relay server
|
||||
type Store struct {
|
||||
peers map[string]*Peer // consider to use [32]byte as key. The Peer(id string) would be faster
|
||||
peersLock sync.RWMutex
|
||||
}
|
||||
|
||||
// NewStore creates a new Store instance
|
||||
func NewStore() *Store {
|
||||
return &Store{
|
||||
peers: make(map[string]*Peer),
|
||||
}
|
||||
}
|
||||
|
||||
// AddPeer adds a peer to the store
|
||||
// todo: consider to close peer conn if the peer already exists
|
||||
func (s *Store) AddPeer(peer *Peer) {
|
||||
s.peersLock.Lock()
|
||||
defer s.peersLock.Unlock()
|
||||
s.peers[peer.String()] = peer
|
||||
}
|
||||
|
||||
// DeletePeer deletes a peer from the store
|
||||
func (s *Store) DeletePeer(peer *Peer) {
|
||||
s.peersLock.Lock()
|
||||
defer s.peersLock.Unlock()
|
||||
|
||||
dp, ok := s.peers[peer.String()]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if dp != peer {
|
||||
return
|
||||
}
|
||||
|
||||
delete(s.peers, peer.String())
|
||||
}
|
||||
|
||||
// Peer returns a peer by its ID
|
||||
func (s *Store) Peer(id string) (*Peer, bool) {
|
||||
s.peersLock.RLock()
|
||||
defer s.peersLock.RUnlock()
|
||||
|
||||
p, ok := s.peers[id]
|
||||
return p, ok
|
||||
}
|
||||
|
||||
// Peers returns all the peers in the store
|
||||
func (s *Store) Peers() []*Peer {
|
||||
s.peersLock.RLock()
|
||||
defer s.peersLock.RUnlock()
|
||||
|
||||
peers := make([]*Peer, 0, len(s.peers))
|
||||
for _, p := range s.peers {
|
||||
peers = append(peers, p)
|
||||
}
|
||||
return peers
|
||||
}
|
||||
40
relay/server/store_test.go
Normal file
40
relay/server/store_test.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"go.opentelemetry.io/otel"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/metrics"
|
||||
)
|
||||
|
||||
func TestStore_DeletePeer(t *testing.T) {
|
||||
s := NewStore()
|
||||
|
||||
m, _ := metrics.NewMetrics(context.Background(), otel.Meter(""))
|
||||
|
||||
p := NewPeer(m, []byte("peer_one"), nil, nil)
|
||||
s.AddPeer(p)
|
||||
s.DeletePeer(p)
|
||||
if _, ok := s.Peer(p.String()); ok {
|
||||
t.Errorf("peer was not deleted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_DeleteDeprecatedPeer(t *testing.T) {
|
||||
s := NewStore()
|
||||
|
||||
m, _ := metrics.NewMetrics(context.Background(), otel.Meter(""))
|
||||
|
||||
p1 := NewPeer(m, []byte("peer_id"), nil, nil)
|
||||
p2 := NewPeer(m, []byte("peer_id"), nil, nil)
|
||||
|
||||
s.AddPeer(p1)
|
||||
s.AddPeer(p2)
|
||||
s.DeletePeer(p1)
|
||||
|
||||
if _, ok := s.Peer(p2.String()); !ok {
|
||||
t.Errorf("second peer was deleted")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user