mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
[client, relay-server] Feature/relay notification (#4083)
- Clients now subscribe to peer status changes. - The server manages and maintains these subscriptions. - Replaced raw string peer IDs with a custom peer ID type for better type safety and clarity.
This commit is contained in:
@@ -6,7 +6,6 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/auth"
|
||||
"github.com/netbirdio/netbird/relay/messages"
|
||||
//nolint:staticcheck
|
||||
"github.com/netbirdio/netbird/relay/messages/address"
|
||||
@@ -14,6 +13,12 @@ import (
|
||||
authmsg "github.com/netbirdio/netbird/relay/messages/auth"
|
||||
)
|
||||
|
||||
type Validator interface {
|
||||
Validate(any) error
|
||||
// Deprecated: Use Validate instead.
|
||||
ValidateHelloMsgType(any) error
|
||||
}
|
||||
|
||||
// preparedMsg contains the marshalled success response messages
|
||||
type preparedMsg struct {
|
||||
responseHelloMsg []byte
|
||||
@@ -54,14 +59,14 @@ func marshalResponseHelloMsg(instanceURL string) ([]byte, error) {
|
||||
|
||||
type handshake struct {
|
||||
conn net.Conn
|
||||
validator auth.Validator
|
||||
validator Validator
|
||||
preparedMsg *preparedMsg
|
||||
|
||||
handshakeMethodAuth bool
|
||||
peerID string
|
||||
peerID *messages.PeerID
|
||||
}
|
||||
|
||||
func (h *handshake) handshakeReceive() ([]byte, error) {
|
||||
func (h *handshake) handshakeReceive() (*messages.PeerID, error) {
|
||||
buf := make([]byte, messages.MaxHandshakeSize)
|
||||
n, err := h.conn.Read(buf)
|
||||
if err != nil {
|
||||
@@ -80,17 +85,14 @@ func (h *handshake) handshakeReceive() ([]byte, error) {
|
||||
return nil, fmt.Errorf("determine message type from %s: %w", h.conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
var (
|
||||
bytePeerID []byte
|
||||
peerID string
|
||||
)
|
||||
var peerID *messages.PeerID
|
||||
switch msgType {
|
||||
//nolint:staticcheck
|
||||
case messages.MsgTypeHello:
|
||||
bytePeerID, peerID, err = h.handleHelloMsg(buf)
|
||||
peerID, err = h.handleHelloMsg(buf)
|
||||
case messages.MsgTypeAuth:
|
||||
h.handshakeMethodAuth = true
|
||||
bytePeerID, peerID, err = h.handleAuthMsg(buf)
|
||||
peerID, err = h.handleAuthMsg(buf)
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid message type %d from %s", msgType, h.conn.RemoteAddr())
|
||||
}
|
||||
@@ -98,7 +100,7 @@ func (h *handshake) handshakeReceive() ([]byte, error) {
|
||||
return nil, err
|
||||
}
|
||||
h.peerID = peerID
|
||||
return bytePeerID, nil
|
||||
return peerID, nil
|
||||
}
|
||||
|
||||
func (h *handshake) handshakeResponse() error {
|
||||
@@ -116,40 +118,37 @@ func (h *handshake) handshakeResponse() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *handshake) handleHelloMsg(buf []byte) ([]byte, string, error) {
|
||||
func (h *handshake) handleHelloMsg(buf []byte) (*messages.PeerID, error) {
|
||||
//nolint:staticcheck
|
||||
rawPeerID, authData, err := messages.UnmarshalHelloMsg(buf)
|
||||
peerID, authData, err := messages.UnmarshalHelloMsg(buf)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("unmarshal hello message: %w", err)
|
||||
return nil, fmt.Errorf("unmarshal hello message: %w", err)
|
||||
}
|
||||
|
||||
peerID := messages.HashIDToString(rawPeerID)
|
||||
log.Warnf("peer %s (%s) is using deprecated initial message type", peerID, h.conn.RemoteAddr())
|
||||
|
||||
authMsg, err := authmsg.UnmarshalMsg(authData)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("unmarshal auth message: %w", err)
|
||||
return nil, fmt.Errorf("unmarshal auth message: %w", err)
|
||||
}
|
||||
|
||||
//nolint:staticcheck
|
||||
if err := h.validator.ValidateHelloMsgType(authMsg.AdditionalData); err != nil {
|
||||
return nil, "", fmt.Errorf("validate %s (%s): %w", peerID, h.conn.RemoteAddr(), err)
|
||||
return nil, fmt.Errorf("validate %s (%s): %w", peerID, h.conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
return rawPeerID, peerID, nil
|
||||
return peerID, nil
|
||||
}
|
||||
|
||||
func (h *handshake) handleAuthMsg(buf []byte) ([]byte, string, error) {
|
||||
func (h *handshake) handleAuthMsg(buf []byte) (*messages.PeerID, error) {
|
||||
rawPeerID, authPayload, err := messages.UnmarshalAuthMsg(buf)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("unmarshal hello message: %w", err)
|
||||
return nil, fmt.Errorf("unmarshal hello message: %w", err)
|
||||
}
|
||||
|
||||
peerID := messages.HashIDToString(rawPeerID)
|
||||
|
||||
if err := h.validator.Validate(authPayload); err != nil {
|
||||
return nil, "", fmt.Errorf("validate %s (%s): %w", peerID, h.conn.RemoteAddr(), err)
|
||||
return nil, fmt.Errorf("validate %s (%s): %w", rawPeerID.String(), h.conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
return rawPeerID, peerID, nil
|
||||
return rawPeerID, nil
|
||||
}
|
||||
|
||||
@@ -12,43 +12,50 @@ import (
|
||||
"github.com/netbirdio/netbird/relay/healthcheck"
|
||||
"github.com/netbirdio/netbird/relay/messages"
|
||||
"github.com/netbirdio/netbird/relay/metrics"
|
||||
"github.com/netbirdio/netbird/relay/server/store"
|
||||
)
|
||||
|
||||
const (
|
||||
bufferSize = 8820
|
||||
bufferSize = messages.MaxMessageSize
|
||||
|
||||
errCloseConn = "failed to close connection to peer: %s"
|
||||
)
|
||||
|
||||
// Peer represents a peer connection
|
||||
type Peer struct {
|
||||
metrics *metrics.Metrics
|
||||
log *log.Entry
|
||||
idS string
|
||||
idB []byte
|
||||
conn net.Conn
|
||||
connMu sync.RWMutex
|
||||
store *Store
|
||||
metrics *metrics.Metrics
|
||||
log *log.Entry
|
||||
id messages.PeerID
|
||||
conn net.Conn
|
||||
connMu sync.RWMutex
|
||||
store *store.Store
|
||||
notifier *store.PeerNotifier
|
||||
|
||||
peersListener *store.Listener
|
||||
}
|
||||
|
||||
// NewPeer creates a new Peer instance and prepare custom logging
|
||||
func NewPeer(metrics *metrics.Metrics, id []byte, conn net.Conn, store *Store) *Peer {
|
||||
stringID := messages.HashIDToString(id)
|
||||
return &Peer{
|
||||
metrics: metrics,
|
||||
log: log.WithField("peer_id", stringID),
|
||||
idS: stringID,
|
||||
idB: id,
|
||||
conn: conn,
|
||||
store: store,
|
||||
func NewPeer(metrics *metrics.Metrics, id messages.PeerID, conn net.Conn, store *store.Store, notifier *store.PeerNotifier) *Peer {
|
||||
p := &Peer{
|
||||
metrics: metrics,
|
||||
log: log.WithField("peer_id", id.String()),
|
||||
id: id,
|
||||
conn: conn,
|
||||
store: store,
|
||||
notifier: notifier,
|
||||
}
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
// Work reads data from the connection
|
||||
// It manages the protocol (healthcheck, transport, close). Read the message and determine the message type and handle
|
||||
// the message accordingly.
|
||||
func (p *Peer) Work() {
|
||||
p.peersListener = p.notifier.NewListener(p.sendPeersOnline, p.sendPeersWentOffline)
|
||||
defer func() {
|
||||
p.notifier.RemoveListener(p.peersListener)
|
||||
|
||||
if err := p.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
p.log.Errorf(errCloseConn, err)
|
||||
}
|
||||
@@ -94,6 +101,10 @@ func (p *Peer) Work() {
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Peer) ID() messages.PeerID {
|
||||
return p.id
|
||||
}
|
||||
|
||||
func (p *Peer) handleMsgType(ctx context.Context, msgType messages.MsgType, hc *healthcheck.Sender, n int, msg []byte) {
|
||||
switch msgType {
|
||||
case messages.MsgTypeHealthCheck:
|
||||
@@ -107,6 +118,10 @@ func (p *Peer) handleMsgType(ctx context.Context, msgType messages.MsgType, hc *
|
||||
if err := p.conn.Close(); err != nil {
|
||||
log.Errorf(errCloseConn, err)
|
||||
}
|
||||
case messages.MsgTypeSubscribePeerState:
|
||||
p.handleSubscribePeerState(msg)
|
||||
case messages.MsgTypeUnsubscribePeerState:
|
||||
p.handleUnsubscribePeerState(msg)
|
||||
default:
|
||||
p.log.Warnf("received unexpected message type: %s", msgType)
|
||||
}
|
||||
@@ -145,7 +160,7 @@ func (p *Peer) Close() {
|
||||
|
||||
// String returns the peer ID
|
||||
func (p *Peer) String() string {
|
||||
return p.idS
|
||||
return p.id.String()
|
||||
}
|
||||
|
||||
func (p *Peer) writeWithTimeout(ctx context.Context, buf []byte) error {
|
||||
@@ -197,14 +212,14 @@ func (p *Peer) handleTransportMsg(msg []byte) {
|
||||
return
|
||||
}
|
||||
|
||||
stringPeerID := messages.HashIDToString(peerID)
|
||||
dp, ok := p.store.Peer(stringPeerID)
|
||||
item, ok := p.store.Peer(*peerID)
|
||||
if !ok {
|
||||
p.log.Debugf("peer not found: %s", stringPeerID)
|
||||
p.log.Debugf("peer not found: %s", peerID)
|
||||
return
|
||||
}
|
||||
dp := item.(*Peer)
|
||||
|
||||
err = messages.UpdateTransportMsg(msg, p.idB)
|
||||
err = messages.UpdateTransportMsg(msg, p.id)
|
||||
if err != nil {
|
||||
p.log.Errorf("failed to update transport message: %s", err)
|
||||
return
|
||||
@@ -217,3 +232,57 @@ func (p *Peer) handleTransportMsg(msg []byte) {
|
||||
}
|
||||
p.metrics.TransferBytesSent.Add(context.Background(), int64(n))
|
||||
}
|
||||
|
||||
func (p *Peer) handleSubscribePeerState(msg []byte) {
|
||||
peerIDs, err := messages.UnmarshalSubPeerStateMsg(msg)
|
||||
if err != nil {
|
||||
p.log.Errorf("failed to unmarshal open connection message: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
p.log.Debugf("received subscription message for %d peers", len(peerIDs))
|
||||
onlinePeers := p.peersListener.AddInterestedPeers(peerIDs)
|
||||
if len(onlinePeers) == 0 {
|
||||
return
|
||||
}
|
||||
p.log.Debugf("response with %d online peers", len(onlinePeers))
|
||||
p.sendPeersOnline(onlinePeers)
|
||||
}
|
||||
|
||||
func (p *Peer) handleUnsubscribePeerState(msg []byte) {
|
||||
peerIDs, err := messages.UnmarshalUnsubPeerStateMsg(msg)
|
||||
if err != nil {
|
||||
p.log.Errorf("failed to unmarshal open connection message: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
p.peersListener.RemoveInterestedPeer(peerIDs)
|
||||
}
|
||||
|
||||
func (p *Peer) sendPeersOnline(peers []messages.PeerID) {
|
||||
msgs, err := messages.MarshalPeersOnline(peers)
|
||||
if err != nil {
|
||||
p.log.Errorf("failed to marshal peer location message: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
for n, msg := range msgs {
|
||||
if _, err := p.Write(msg); err != nil {
|
||||
p.log.Errorf("failed to write %d. peers offline message: %s", n, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Peer) sendPeersWentOffline(peers []messages.PeerID) {
|
||||
msgs, err := messages.MarshalPeersWentOffline(peers)
|
||||
if err != nil {
|
||||
p.log.Errorf("failed to marshal peer location message: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
for n, msg := range msgs {
|
||||
if _, err := p.Write(msg); err != nil {
|
||||
p.log.Errorf("failed to write %d. peers offline message: %s", n, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,26 +4,55 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/auth"
|
||||
//nolint:staticcheck
|
||||
"github.com/netbirdio/netbird/relay/metrics"
|
||||
"github.com/netbirdio/netbird/relay/server/store"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Meter metric.Meter
|
||||
ExposedAddress string
|
||||
TLSSupport bool
|
||||
AuthValidator Validator
|
||||
|
||||
instanceURL string
|
||||
}
|
||||
|
||||
func (c *Config) validate() error {
|
||||
if c.Meter == nil {
|
||||
c.Meter = otel.Meter("")
|
||||
}
|
||||
if c.ExposedAddress == "" {
|
||||
return fmt.Errorf("exposed address is required")
|
||||
}
|
||||
|
||||
instanceURL, err := getInstanceURL(c.ExposedAddress, c.TLSSupport)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid url: %v", err)
|
||||
}
|
||||
c.instanceURL = instanceURL
|
||||
|
||||
if c.AuthValidator == nil {
|
||||
return fmt.Errorf("auth validator is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Relay represents the relay server
|
||||
type Relay struct {
|
||||
metrics *metrics.Metrics
|
||||
metricsCancel context.CancelFunc
|
||||
validator auth.Validator
|
||||
validator Validator
|
||||
|
||||
store *Store
|
||||
store *store.Store
|
||||
notifier *store.PeerNotifier
|
||||
instanceURL string
|
||||
preparedMsg *preparedMsg
|
||||
|
||||
@@ -31,40 +60,40 @@ type Relay struct {
|
||||
closeMu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewRelay creates a new Relay instance
|
||||
// NewRelay creates and returns a new Relay instance.
|
||||
//
|
||||
// Parameters:
|
||||
// meter: An instance of metric.Meter from the go.opentelemetry.io/otel/metric package. It is used to create and manage
|
||||
// metrics for the relay server.
|
||||
// exposedAddress: A string representing the address that the relay server is exposed on. The client will use this
|
||||
// address as the relay server's instance URL.
|
||||
// tlsSupport: A boolean indicating whether the relay server supports TLS (Transport Layer Security) or not. The
|
||||
// instance URL depends on this value.
|
||||
// validator: An instance of auth.Validator from the auth package. It is used to validate the authentication of the
|
||||
// peers.
|
||||
//
|
||||
// config: A Config struct that holds the configuration needed to initialize the relay server.
|
||||
// - Meter: A metric.Meter used for emitting metrics. If not set, a default no-op meter will be used.
|
||||
// - ExposedAddress: The external address clients use to reach this relay. Required.
|
||||
// - TLSSupport: A boolean indicating if the relay uses TLS. Affects the generated instance URL.
|
||||
// - AuthValidator: A Validator implementation used to authenticate peers. Required.
|
||||
//
|
||||
// Returns:
|
||||
// A pointer to a Relay instance and an error. If the Relay instance is successfully created, the error is nil.
|
||||
// Otherwise, the error contains the details of what went wrong.
|
||||
func NewRelay(meter metric.Meter, exposedAddress string, tlsSupport bool, validator auth.Validator) (*Relay, error) {
|
||||
//
|
||||
// A pointer to a Relay instance and an error. If initialization is successful, the error will be nil;
|
||||
// otherwise, it will contain the reason the relay could not be created (e.g., invalid configuration).
|
||||
func NewRelay(config Config) (*Relay, error) {
|
||||
if err := config.validate(); err != nil {
|
||||
return nil, fmt.Errorf("invalid config: %v", err)
|
||||
}
|
||||
|
||||
ctx, metricsCancel := context.WithCancel(context.Background())
|
||||
m, err := metrics.NewMetrics(ctx, meter)
|
||||
m, err := metrics.NewMetrics(ctx, config.Meter)
|
||||
if err != nil {
|
||||
metricsCancel()
|
||||
return nil, fmt.Errorf("creating app metrics: %v", err)
|
||||
}
|
||||
|
||||
peerStore := store.NewStore()
|
||||
r := &Relay{
|
||||
metrics: m,
|
||||
metricsCancel: metricsCancel,
|
||||
validator: validator,
|
||||
store: NewStore(),
|
||||
}
|
||||
|
||||
r.instanceURL, err = getInstanceURL(exposedAddress, tlsSupport)
|
||||
if err != nil {
|
||||
metricsCancel()
|
||||
return nil, fmt.Errorf("get instance URL: %v", err)
|
||||
validator: config.AuthValidator,
|
||||
instanceURL: config.instanceURL,
|
||||
store: peerStore,
|
||||
notifier: store.NewPeerNotifier(peerStore),
|
||||
}
|
||||
|
||||
r.preparedMsg, err = newPreparedMsg(r.instanceURL)
|
||||
@@ -76,32 +105,6 @@ func NewRelay(meter metric.Meter, exposedAddress string, tlsSupport bool, valida
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// getInstanceURL checks if user supplied a URL scheme otherwise adds to the
|
||||
// provided address according to TLS definition and parses the address before returning it
|
||||
func getInstanceURL(exposedAddress string, tlsSupported bool) (string, error) {
|
||||
addr := exposedAddress
|
||||
split := strings.Split(exposedAddress, "://")
|
||||
switch {
|
||||
case len(split) == 1 && tlsSupported:
|
||||
addr = "rels://" + exposedAddress
|
||||
case len(split) == 1 && !tlsSupported:
|
||||
addr = "rel://" + exposedAddress
|
||||
case len(split) > 2:
|
||||
return "", fmt.Errorf("invalid exposed address: %s", exposedAddress)
|
||||
}
|
||||
|
||||
parsedURL, err := url.ParseRequestURI(addr)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid exposed address: %v", err)
|
||||
}
|
||||
|
||||
if parsedURL.Scheme != "rel" && parsedURL.Scheme != "rels" {
|
||||
return "", fmt.Errorf("invalid scheme: %s", parsedURL.Scheme)
|
||||
}
|
||||
|
||||
return parsedURL.String(), nil
|
||||
}
|
||||
|
||||
// Accept start to handle a new peer connection
|
||||
func (r *Relay) Accept(conn net.Conn) {
|
||||
acceptTime := time.Now()
|
||||
@@ -125,14 +128,17 @@ func (r *Relay) Accept(conn net.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
peer := NewPeer(r.metrics, peerID, conn, r.store)
|
||||
peer := NewPeer(r.metrics, *peerID, conn, r.store, r.notifier)
|
||||
peer.log.Infof("peer connected from: %s", conn.RemoteAddr())
|
||||
storeTime := time.Now()
|
||||
r.store.AddPeer(peer)
|
||||
r.notifier.PeerCameOnline(peer.ID())
|
||||
|
||||
r.metrics.RecordPeerStoreTime(time.Since(storeTime))
|
||||
r.metrics.PeerConnected(peer.String())
|
||||
go func() {
|
||||
peer.Work()
|
||||
r.notifier.PeerWentOffline(peer.ID())
|
||||
r.store.DeletePeer(peer)
|
||||
peer.log.Debugf("relay connection closed")
|
||||
r.metrics.PeerDisconnected(peer.String())
|
||||
@@ -154,12 +160,12 @@ func (r *Relay) Shutdown(ctx context.Context) {
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
peers := r.store.Peers()
|
||||
for _, peer := range peers {
|
||||
for _, v := range peers {
|
||||
wg.Add(1)
|
||||
go func(p *Peer) {
|
||||
p.CloseGracefully(ctx)
|
||||
wg.Done()
|
||||
}(peer)
|
||||
}(v.(*Peer))
|
||||
}
|
||||
wg.Wait()
|
||||
r.metricsCancel()
|
||||
|
||||
@@ -6,15 +6,12 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/relay/auth"
|
||||
"github.com/netbirdio/netbird/relay/server/listener"
|
||||
"github.com/netbirdio/netbird/relay/server/listener/quic"
|
||||
"github.com/netbirdio/netbird/relay/server/listener/ws"
|
||||
quictls "github.com/netbirdio/netbird/relay/tls"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ListenerConfig is the configuration for the listener.
|
||||
@@ -33,13 +30,22 @@ type Server struct {
|
||||
listeners []listener.Listener
|
||||
}
|
||||
|
||||
// NewServer creates a new relay server instance.
|
||||
// meter: the OpenTelemetry meter
|
||||
// exposedAddress: this address will be used as the instance URL. It should be a domain:port format.
|
||||
// tlsSupport: if true, the server will support TLS
|
||||
// authValidator: the auth validator to use for the server
|
||||
func NewServer(meter metric.Meter, exposedAddress string, tlsSupport bool, authValidator auth.Validator) (*Server, error) {
|
||||
relay, err := NewRelay(meter, exposedAddress, tlsSupport, authValidator)
|
||||
// NewServer creates and returns a new relay server instance.
|
||||
//
|
||||
// Parameters:
|
||||
//
|
||||
// config: A Config struct containing the necessary configuration:
|
||||
// - Meter: An OpenTelemetry metric.Meter used for recording metrics. If nil, a default no-op meter is used.
|
||||
// - ExposedAddress: The public address (in domain:port format) used as the server's instance URL. Required.
|
||||
// - TLSSupport: A boolean indicating whether TLS is enabled for the server.
|
||||
// - AuthValidator: A Validator used to authenticate peers. Required.
|
||||
//
|
||||
// Returns:
|
||||
//
|
||||
// A pointer to a Server instance and an error. If the configuration is valid and initialization succeeds,
|
||||
// the returned error will be nil. Otherwise, the error will describe the problem.
|
||||
func NewServer(config Config) (*Server, error) {
|
||||
relay, err := NewRelay(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
121
relay/server/store/listener.go
Normal file
121
relay/server/store/listener.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/messages"
|
||||
)
|
||||
|
||||
type Listener struct {
|
||||
store *Store
|
||||
|
||||
onlineChan chan messages.PeerID
|
||||
offlineChan chan messages.PeerID
|
||||
interestedPeersForOffline map[messages.PeerID]struct{}
|
||||
interestedPeersForOnline map[messages.PeerID]struct{}
|
||||
mu sync.RWMutex
|
||||
|
||||
listenerCtx context.Context
|
||||
}
|
||||
|
||||
func newListener(store *Store) *Listener {
|
||||
l := &Listener{
|
||||
store: store,
|
||||
|
||||
onlineChan: make(chan messages.PeerID, 244), //244 is the message size limit in the relay protocol
|
||||
offlineChan: make(chan messages.PeerID, 244), //244 is the message size limit in the relay protocol
|
||||
interestedPeersForOffline: make(map[messages.PeerID]struct{}),
|
||||
interestedPeersForOnline: make(map[messages.PeerID]struct{}),
|
||||
}
|
||||
|
||||
return l
|
||||
}
|
||||
|
||||
func (l *Listener) AddInterestedPeers(peerIDs []messages.PeerID) []messages.PeerID {
|
||||
availablePeers := make([]messages.PeerID, 0)
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
for _, id := range peerIDs {
|
||||
l.interestedPeersForOnline[id] = struct{}{}
|
||||
l.interestedPeersForOffline[id] = struct{}{}
|
||||
}
|
||||
|
||||
// collect online peers to response back to the caller
|
||||
for _, id := range peerIDs {
|
||||
_, ok := l.store.Peer(id)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
availablePeers = append(availablePeers, id)
|
||||
}
|
||||
return availablePeers
|
||||
}
|
||||
|
||||
func (l *Listener) RemoveInterestedPeer(peerIDs []messages.PeerID) {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
for _, id := range peerIDs {
|
||||
delete(l.interestedPeersForOffline, id)
|
||||
delete(l.interestedPeersForOnline, id)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Listener) listenForEvents(ctx context.Context, onPeersComeOnline, onPeersWentOffline func([]messages.PeerID)) {
|
||||
l.listenerCtx = ctx
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case pID := <-l.onlineChan:
|
||||
peers := make([]messages.PeerID, 0)
|
||||
peers = append(peers, pID)
|
||||
|
||||
for len(l.onlineChan) > 0 {
|
||||
pID = <-l.onlineChan
|
||||
peers = append(peers, pID)
|
||||
}
|
||||
|
||||
onPeersComeOnline(peers)
|
||||
case pID := <-l.offlineChan:
|
||||
peers := make([]messages.PeerID, 0)
|
||||
peers = append(peers, pID)
|
||||
|
||||
for len(l.offlineChan) > 0 {
|
||||
pID = <-l.offlineChan
|
||||
peers = append(peers, pID)
|
||||
}
|
||||
|
||||
onPeersWentOffline(peers)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Listener) peerWentOffline(peerID messages.PeerID) {
|
||||
l.mu.RLock()
|
||||
defer l.mu.RUnlock()
|
||||
|
||||
if _, ok := l.interestedPeersForOffline[peerID]; ok {
|
||||
select {
|
||||
case l.offlineChan <- peerID:
|
||||
case <-l.listenerCtx.Done():
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Listener) peerComeOnline(peerID messages.PeerID) {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
if _, ok := l.interestedPeersForOnline[peerID]; ok {
|
||||
select {
|
||||
case l.onlineChan <- peerID:
|
||||
case <-l.listenerCtx.Done():
|
||||
}
|
||||
delete(l.interestedPeersForOnline, peerID)
|
||||
}
|
||||
}
|
||||
64
relay/server/store/notifier.go
Normal file
64
relay/server/store/notifier.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/messages"
|
||||
)
|
||||
|
||||
type PeerNotifier struct {
|
||||
store *Store
|
||||
|
||||
listeners map[*Listener]context.CancelFunc
|
||||
listenersMutex sync.RWMutex
|
||||
}
|
||||
|
||||
func NewPeerNotifier(store *Store) *PeerNotifier {
|
||||
pn := &PeerNotifier{
|
||||
store: store,
|
||||
listeners: make(map[*Listener]context.CancelFunc),
|
||||
}
|
||||
return pn
|
||||
}
|
||||
|
||||
func (pn *PeerNotifier) NewListener(onPeersComeOnline, onPeersWentOffline func([]messages.PeerID)) *Listener {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
listener := newListener(pn.store)
|
||||
go listener.listenForEvents(ctx, onPeersComeOnline, onPeersWentOffline)
|
||||
|
||||
pn.listenersMutex.Lock()
|
||||
pn.listeners[listener] = cancel
|
||||
pn.listenersMutex.Unlock()
|
||||
return listener
|
||||
}
|
||||
|
||||
func (pn *PeerNotifier) RemoveListener(listener *Listener) {
|
||||
pn.listenersMutex.Lock()
|
||||
defer pn.listenersMutex.Unlock()
|
||||
|
||||
cancel, ok := pn.listeners[listener]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
cancel()
|
||||
delete(pn.listeners, listener)
|
||||
}
|
||||
|
||||
func (pn *PeerNotifier) PeerWentOffline(peerID messages.PeerID) {
|
||||
pn.listenersMutex.RLock()
|
||||
defer pn.listenersMutex.RUnlock()
|
||||
|
||||
for listener := range pn.listeners {
|
||||
listener.peerWentOffline(peerID)
|
||||
}
|
||||
}
|
||||
|
||||
func (pn *PeerNotifier) PeerCameOnline(peerID messages.PeerID) {
|
||||
pn.listenersMutex.RLock()
|
||||
defer pn.listenersMutex.RUnlock()
|
||||
|
||||
for listener := range pn.listeners {
|
||||
listener.peerComeOnline(peerID)
|
||||
}
|
||||
}
|
||||
@@ -1,41 +1,48 @@
|
||||
package server
|
||||
package store
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/messages"
|
||||
)
|
||||
|
||||
type IPeer interface {
|
||||
Close()
|
||||
ID() messages.PeerID
|
||||
}
|
||||
|
||||
// Store is a thread-safe store of peers
|
||||
// It is used to store the peers that are connected to the relay server
|
||||
type Store struct {
|
||||
peers map[string]*Peer // consider to use [32]byte as key. The Peer(id string) would be faster
|
||||
peers map[messages.PeerID]IPeer
|
||||
peersLock sync.RWMutex
|
||||
}
|
||||
|
||||
// NewStore creates a new Store instance
|
||||
func NewStore() *Store {
|
||||
return &Store{
|
||||
peers: make(map[string]*Peer),
|
||||
peers: make(map[messages.PeerID]IPeer),
|
||||
}
|
||||
}
|
||||
|
||||
// AddPeer adds a peer to the store
|
||||
func (s *Store) AddPeer(peer *Peer) {
|
||||
func (s *Store) AddPeer(peer IPeer) {
|
||||
s.peersLock.Lock()
|
||||
defer s.peersLock.Unlock()
|
||||
odlPeer, ok := s.peers[peer.String()]
|
||||
odlPeer, ok := s.peers[peer.ID()]
|
||||
if ok {
|
||||
odlPeer.Close()
|
||||
}
|
||||
|
||||
s.peers[peer.String()] = peer
|
||||
s.peers[peer.ID()] = peer
|
||||
}
|
||||
|
||||
// DeletePeer deletes a peer from the store
|
||||
func (s *Store) DeletePeer(peer *Peer) {
|
||||
func (s *Store) DeletePeer(peer IPeer) {
|
||||
s.peersLock.Lock()
|
||||
defer s.peersLock.Unlock()
|
||||
|
||||
dp, ok := s.peers[peer.String()]
|
||||
dp, ok := s.peers[peer.ID()]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
@@ -43,11 +50,11 @@ func (s *Store) DeletePeer(peer *Peer) {
|
||||
return
|
||||
}
|
||||
|
||||
delete(s.peers, peer.String())
|
||||
delete(s.peers, peer.ID())
|
||||
}
|
||||
|
||||
// Peer returns a peer by its ID
|
||||
func (s *Store) Peer(id string) (*Peer, bool) {
|
||||
func (s *Store) Peer(id messages.PeerID) (IPeer, bool) {
|
||||
s.peersLock.RLock()
|
||||
defer s.peersLock.RUnlock()
|
||||
|
||||
@@ -56,11 +63,11 @@ func (s *Store) Peer(id string) (*Peer, bool) {
|
||||
}
|
||||
|
||||
// Peers returns all the peers in the store
|
||||
func (s *Store) Peers() []*Peer {
|
||||
func (s *Store) Peers() []IPeer {
|
||||
s.peersLock.RLock()
|
||||
defer s.peersLock.RUnlock()
|
||||
|
||||
peers := make([]*Peer, 0, len(s.peers))
|
||||
peers := make([]IPeer, 0, len(s.peers))
|
||||
for _, p := range s.peers {
|
||||
peers = append(peers, p)
|
||||
}
|
||||
49
relay/server/store/store_test.go
Normal file
49
relay/server/store/store_test.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/messages"
|
||||
)
|
||||
|
||||
type MocPeer struct {
|
||||
id messages.PeerID
|
||||
}
|
||||
|
||||
func (m *MocPeer) Close() {
|
||||
|
||||
}
|
||||
|
||||
func (m *MocPeer) ID() messages.PeerID {
|
||||
return m.id
|
||||
}
|
||||
|
||||
func TestStore_DeletePeer(t *testing.T) {
|
||||
s := NewStore()
|
||||
|
||||
pID := messages.HashID("peer_one")
|
||||
p := &MocPeer{id: pID}
|
||||
s.AddPeer(p)
|
||||
s.DeletePeer(p)
|
||||
if _, ok := s.Peer(pID); ok {
|
||||
t.Errorf("peer was not deleted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_DeleteDeprecatedPeer(t *testing.T) {
|
||||
s := NewStore()
|
||||
|
||||
pID1 := messages.HashID("peer_one")
|
||||
pID2 := messages.HashID("peer_one")
|
||||
|
||||
p1 := &MocPeer{id: pID1}
|
||||
p2 := &MocPeer{id: pID2}
|
||||
|
||||
s.AddPeer(p1)
|
||||
s.AddPeer(p2)
|
||||
s.DeletePeer(p1)
|
||||
|
||||
if _, ok := s.Peer(pID2); !ok {
|
||||
t.Errorf("second peer was deleted")
|
||||
}
|
||||
}
|
||||
@@ -1,85 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.opentelemetry.io/otel"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/metrics"
|
||||
)
|
||||
|
||||
type mockConn struct {
|
||||
}
|
||||
|
||||
func (m mockConn) Read(b []byte) (n int, err error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m mockConn) Write(b []byte) (n int, err error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m mockConn) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m mockConn) LocalAddr() net.Addr {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m mockConn) RemoteAddr() net.Addr {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m mockConn) SetDeadline(t time.Time) error {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m mockConn) SetReadDeadline(t time.Time) error {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m mockConn) SetWriteDeadline(t time.Time) error {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func TestStore_DeletePeer(t *testing.T) {
|
||||
s := NewStore()
|
||||
|
||||
m, _ := metrics.NewMetrics(context.Background(), otel.Meter(""))
|
||||
|
||||
p := NewPeer(m, []byte("peer_one"), nil, nil)
|
||||
s.AddPeer(p)
|
||||
s.DeletePeer(p)
|
||||
if _, ok := s.Peer(p.String()); ok {
|
||||
t.Errorf("peer was not deleted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_DeleteDeprecatedPeer(t *testing.T) {
|
||||
s := NewStore()
|
||||
|
||||
m, _ := metrics.NewMetrics(context.Background(), otel.Meter(""))
|
||||
|
||||
conn := &mockConn{}
|
||||
p1 := NewPeer(m, []byte("peer_id"), conn, nil)
|
||||
p2 := NewPeer(m, []byte("peer_id"), conn, nil)
|
||||
|
||||
s.AddPeer(p1)
|
||||
s.AddPeer(p2)
|
||||
s.DeletePeer(p1)
|
||||
|
||||
if _, ok := s.Peer(p2.String()); !ok {
|
||||
t.Errorf("second peer was deleted")
|
||||
}
|
||||
}
|
||||
33
relay/server/url.go
Normal file
33
relay/server/url.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// getInstanceURL checks if user supplied a URL scheme otherwise adds to the
|
||||
// provided address according to TLS definition and parses the address before returning it
|
||||
func getInstanceURL(exposedAddress string, tlsSupported bool) (string, error) {
|
||||
addr := exposedAddress
|
||||
split := strings.Split(exposedAddress, "://")
|
||||
switch {
|
||||
case len(split) == 1 && tlsSupported:
|
||||
addr = "rels://" + exposedAddress
|
||||
case len(split) == 1 && !tlsSupported:
|
||||
addr = "rel://" + exposedAddress
|
||||
case len(split) > 2:
|
||||
return "", fmt.Errorf("invalid exposed address: %s", exposedAddress)
|
||||
}
|
||||
|
||||
parsedURL, err := url.ParseRequestURI(addr)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid exposed address: %v", err)
|
||||
}
|
||||
|
||||
if parsedURL.Scheme != "rel" && parsedURL.Scheme != "rels" {
|
||||
return "", fmt.Errorf("invalid scheme: %s", parsedURL.Scheme)
|
||||
}
|
||||
|
||||
return parsedURL.String(), nil
|
||||
}
|
||||
Reference in New Issue
Block a user