[relay] Feature/relay integration (#2244)

This update adds new relay integration for NetBird clients. The new relay is based on web sockets and listens on a single port.

- Adds new relay implementation with websocket with single port relaying mechanism
- refactor peer connection logic, allowing upgrade and downgrade from/to P2P connection
- peer connections are faster since it connects first to relay and then upgrades to P2P
- maintains compatibility with old clients by not using the new relay
- updates infrastructure scripts with new relay service
This commit is contained in:
Zoltan Papp
2024-09-08 12:06:14 +02:00
committed by GitHub
parent fcac02a92f
commit 0c039274a4
120 changed files with 9879 additions and 1940 deletions

4
relay/Dockerfile Normal file
View File

@@ -0,0 +1,4 @@
FROM gcr.io/distroless/base:debug
ENTRYPOINT [ "/go/bin/netbird-relay" ]
ENV NB_LOG_FILE=console
COPY netbird-relay /go/bin/netbird-relay

View File

@@ -0,0 +1,12 @@
package allow
import "hash"
// Auth is a Validator that allows all connections.
// Used this for testing purposes only.
type Auth struct {
}
func (a *Auth) Validate(func() hash.Hash, any) error {
return nil
}

26
relay/auth/doc.go Normal file
View File

@@ -0,0 +1,26 @@
/*
Package auth manages the authentication process with the relay server.
Key Components:
Validator: The Validator interface defines the Validate method. Any type that provides this method can be used as a
Validator.
Methods:
Validate(func() hash.Hash, any): This method is defined in the Validator interface and is used to validate the authentication.
Usage:
To create a new AllowAllAuth validator, simply instantiate it:
validator := &allow.Auth{}
To validate the authentication, use the Validate method:
err := validator.Validate(sha256.New, any)
This package provides a simple and effective way to manage authentication with the relay server, ensuring that the
peers are authenticated properly.
*/
package auth

8
relay/auth/hmac/doc.go Normal file
View File

@@ -0,0 +1,8 @@
/*
This package uses a similar HMAC method for authentication with the TURN server. The Management server provides the
tokens for the peers. The peers manage these tokens in the token store. The token store is a simple thread safe store
that keeps the tokens in memory. These tokens are used to authenticate the peers with the Relay server in the hello
message.
*/
package hmac

36
relay/auth/hmac/store.go Normal file
View File

@@ -0,0 +1,36 @@
package hmac
import (
"sync"
log "github.com/sirupsen/logrus"
)
// TokenStore is a simple in-memory store for token
// With this can update the token in thread safe way
type TokenStore struct {
mu sync.Mutex
token []byte
}
func (a *TokenStore) UpdateToken(token *Token) error {
a.mu.Lock()
defer a.mu.Unlock()
if token == nil {
return nil
}
t, err := marshalToken(*token)
if err != nil {
log.Debugf("failed to marshal token: %s", err)
return err
}
a.token = t
return nil
}
func (a *TokenStore) TokenBinary() []byte {
a.mu.Lock()
defer a.mu.Unlock()
return a.token
}

105
relay/auth/hmac/token.go Normal file
View File

@@ -0,0 +1,105 @@
package hmac
import (
"bytes"
"crypto/hmac"
"encoding/base64"
"encoding/gob"
"fmt"
"hash"
"strconv"
"time"
log "github.com/sirupsen/logrus"
)
type Token struct {
Payload string
Signature string
}
func marshalToken(token Token) ([]byte, error) {
var buffer bytes.Buffer
encoder := gob.NewEncoder(&buffer)
err := encoder.Encode(token)
if err != nil {
log.Debugf("failed to marshal token: %s", err)
return nil, fmt.Errorf("failed to marshal token: %w", err)
}
return buffer.Bytes(), nil
}
func unmarshalToken(payload []byte) (Token, error) {
var creds Token
buffer := bytes.NewBuffer(payload)
decoder := gob.NewDecoder(buffer)
err := decoder.Decode(&creds)
return creds, err
}
// TimedHMAC generates a token with TTL and uses a pre-shared secret known to the relay server
type TimedHMAC struct {
secret string
timeToLive time.Duration
}
// NewTimedHMAC creates a new TimedHMAC instance
func NewTimedHMAC(secret string, timeToLive time.Duration) *TimedHMAC {
return &TimedHMAC{
secret: secret,
timeToLive: timeToLive,
}
}
// GenerateToken generates new time-based secret token - basically Payload is a unix timestamp and Signature is a HMAC
// hash of a timestamp with a preshared TURN secret
func (m *TimedHMAC) GenerateToken(algo func() hash.Hash) (*Token, error) {
timeAuth := time.Now().Add(m.timeToLive).Unix()
timeStamp := strconv.FormatInt(timeAuth, 10)
checksum, err := m.generate(algo, timeStamp)
if err != nil {
return nil, err
}
return &Token{
Payload: timeStamp,
Signature: base64.StdEncoding.EncodeToString(checksum),
}, nil
}
// Validate checks if the token is valid
func (m *TimedHMAC) Validate(algo func() hash.Hash, token Token) error {
expectedMAC, err := m.generate(algo, token.Payload)
if err != nil {
return err
}
expectedSignature := base64.StdEncoding.EncodeToString(expectedMAC)
if !hmac.Equal([]byte(expectedSignature), []byte(token.Signature)) {
return fmt.Errorf("signature mismatch")
}
timeAuthInt, err := strconv.ParseInt(token.Payload, 10, 64)
if err != nil {
return fmt.Errorf("invalid payload: %w", err)
}
if time.Now().Unix() > timeAuthInt {
return fmt.Errorf("expired token")
}
return nil
}
func (m *TimedHMAC) generate(algo func() hash.Hash, payload string) ([]byte, error) {
mac := hmac.New(algo, []byte(m.secret))
_, err := mac.Write([]byte(payload))
if err != nil {
log.Debugf("failed to generate token: %s", err)
return nil, fmt.Errorf("failed to generate token: %w", err)
}
return mac.Sum(nil), nil
}

View File

@@ -0,0 +1,105 @@
package hmac
import (
"crypto/sha1"
"crypto/sha256"
"encoding/base64"
"strconv"
"testing"
"time"
)
func TestGenerateCredentials(t *testing.T) {
secret := "secret"
timeToLive := 1 * time.Hour
v := NewTimedHMAC(secret, timeToLive)
creds, err := v.GenerateToken(sha1.New)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if creds.Payload == "" {
t.Fatalf("expected non-empty payload")
}
_, err = strconv.ParseInt(creds.Payload, 10, 64)
if err != nil {
t.Fatalf("expected payload to be a valid unix timestamp, got %v", err)
}
_, err = base64.StdEncoding.DecodeString(creds.Signature)
if err != nil {
t.Fatalf("expected signature to be base64 encoded, got %v", err)
}
}
func TestValidateCredentials(t *testing.T) {
secret := "supersecret"
timeToLive := 1 * time.Hour
manager := NewTimedHMAC(secret, timeToLive)
// Test valid token
creds, err := manager.GenerateToken(sha1.New)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if err := manager.Validate(sha1.New, *creds); err != nil {
t.Fatalf("expected valid token: %s", err)
}
}
func TestInvalidSignature(t *testing.T) {
secret := "supersecret"
timeToLive := 1 * time.Hour
manager := NewTimedHMAC(secret, timeToLive)
creds, err := manager.GenerateToken(sha256.New)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
invalidCreds := &Token{
Payload: creds.Payload,
Signature: "invalidsignature",
}
if err = manager.Validate(sha1.New, *invalidCreds); err == nil {
t.Fatalf("expected invalid token due to signature mismatch")
}
}
func TestExpired(t *testing.T) {
secret := "supersecret"
v := NewTimedHMAC(secret, -1*time.Hour)
expiredCreds, err := v.GenerateToken(sha256.New)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if err = v.Validate(sha1.New, *expiredCreds); err == nil {
t.Fatalf("expected invalid token due to expiration")
}
}
func TestInvalidPayload(t *testing.T) {
secret := "supersecret"
timeToLive := 1 * time.Hour
v := NewTimedHMAC(secret, timeToLive)
creds, err := v.GenerateToken(sha256.New)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
// Test invalid payload
invalidPayloadCreds := &Token{
Payload: "invalidtimestamp",
Signature: creds.Signature,
}
if err = v.Validate(sha1.New, *invalidPayloadCreds); err == nil {
t.Fatalf("expected invalid token due to invalid payload")
}
}

View File

@@ -0,0 +1,33 @@
package hmac
import (
"fmt"
"hash"
"time"
log "github.com/sirupsen/logrus"
)
type TimedHMACValidator struct {
*TimedHMAC
}
func NewTimedHMACValidator(secret string, duration time.Duration) *TimedHMACValidator {
ta := NewTimedHMAC(secret, duration)
return &TimedHMACValidator{
ta,
}
}
func (a *TimedHMACValidator) Validate(algo func() hash.Hash, credentials any) error {
b, ok := credentials.([]byte)
if !ok {
return fmt.Errorf("invalid credentials type")
}
c, err := unmarshalToken(b)
if err != nil {
log.Debugf("failed to unmarshal token: %s", err)
return err
}
return a.TimedHMAC.Validate(algo, c)
}

8
relay/auth/validator.go Normal file
View File

@@ -0,0 +1,8 @@
package auth
import "hash"
// Validator is an interface that defines the Validate method.
type Validator interface {
Validate(func() hash.Hash, any) error
}

13
relay/client/addr.go Normal file
View File

@@ -0,0 +1,13 @@
package client
type RelayAddr struct {
addr string
}
func (a RelayAddr) Network() string {
return "relay"
}
func (a RelayAddr) String() string {
return a.addr
}

553
relay/client/client.go Normal file
View File

@@ -0,0 +1,553 @@
package client
import (
"context"
"fmt"
"io"
"net"
"sync"
"time"
log "github.com/sirupsen/logrus"
auth "github.com/netbirdio/netbird/relay/auth/hmac"
"github.com/netbirdio/netbird/relay/client/dialer/ws"
"github.com/netbirdio/netbird/relay/healthcheck"
"github.com/netbirdio/netbird/relay/messages"
"github.com/netbirdio/netbird/relay/messages/address"
auth2 "github.com/netbirdio/netbird/relay/messages/auth"
)
const (
bufferSize = 8820
serverResponseTimeout = 8 * time.Second
)
var (
ErrConnAlreadyExists = fmt.Errorf("connection already exists")
)
type internalStopFlag struct {
sync.Mutex
stop bool
}
func newInternalStopFlag() *internalStopFlag {
return &internalStopFlag{}
}
func (isf *internalStopFlag) set() {
isf.Lock()
defer isf.Unlock()
isf.stop = true
}
func (isf *internalStopFlag) isSet() bool {
isf.Lock()
defer isf.Unlock()
return isf.stop
}
// Msg carry the payload from the server to the client. With this struct, the net.Conn can free the buffer.
type Msg struct {
Payload []byte
bufPool *sync.Pool
bufPtr *[]byte
}
func (m *Msg) Free() {
m.bufPool.Put(m.bufPtr)
}
type connContainer struct {
conn *Conn
messages chan Msg
msgChanLock sync.Mutex
closed bool // flag to check if channel is closed
}
func newConnContainer(conn *Conn, messages chan Msg) *connContainer {
return &connContainer{
conn: conn,
messages: messages,
}
}
func (cc *connContainer) writeMsg(msg Msg) {
cc.msgChanLock.Lock()
defer cc.msgChanLock.Unlock()
if cc.closed {
return
}
cc.messages <- msg
}
func (cc *connContainer) close() {
cc.msgChanLock.Lock()
defer cc.msgChanLock.Unlock()
if cc.closed {
return
}
close(cc.messages)
cc.closed = true
}
// Client is a client for the relay server. It is responsible for establishing a connection to the relay server and
// managing connections to other peers. All exported functions are safe to call concurrently. After close the connection,
// the client can be reused by calling Connect again. When the client is closed, all connections are closed too.
// While the Connect is in progress, the OpenConn function will block until the connection is established with relay server.
type Client struct {
log *log.Entry
parentCtx context.Context
connectionURL string
authTokenStore *auth.TokenStore
hashedID []byte
bufPool *sync.Pool
relayConn net.Conn
conns map[string]*connContainer
serviceIsRunning bool
mu sync.Mutex // protect serviceIsRunning and conns
readLoopMutex sync.Mutex
wgReadLoop sync.WaitGroup
instanceURL *RelayAddr
muInstanceURL sync.Mutex
onDisconnectListener func()
listenerMutex sync.Mutex
}
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client {
hashedID, hashedStringId := messages.HashID(peerID)
return &Client{
log: log.WithField("client_id", hashedStringId),
parentCtx: ctx,
connectionURL: serverURL,
authTokenStore: authTokenStore,
hashedID: hashedID,
bufPool: &sync.Pool{
New: func() any {
buf := make([]byte, bufferSize)
return &buf
},
},
conns: make(map[string]*connContainer),
}
}
// Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs.
func (c *Client) Connect() error {
c.log.Infof("connecting to relay server: %s", c.connectionURL)
c.readLoopMutex.Lock()
defer c.readLoopMutex.Unlock()
c.mu.Lock()
defer c.mu.Unlock()
if c.serviceIsRunning {
return nil
}
err := c.connect()
if err != nil {
return err
}
c.serviceIsRunning = true
c.wgReadLoop.Add(1)
go c.readLoop(c.relayConn)
c.log.Infof("relay connection established with: %s", c.connectionURL)
return nil
}
// OpenConn create a new net.Conn for the destination peer ID. In case if the connection is in progress
// to the relay server, the function will block until the connection is established or timed out. Otherwise,
// it will return immediately.
// todo: what should happen if call with the same peerID with multiple times?
func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) {
c.mu.Lock()
defer c.mu.Unlock()
if !c.serviceIsRunning {
return nil, fmt.Errorf("relay connection is not established")
}
hashedID, hashedStringID := messages.HashID(dstPeerID)
_, ok := c.conns[hashedStringID]
if ok {
return nil, ErrConnAlreadyExists
}
log.Infof("open connection to peer: %s", hashedStringID)
msgChannel := make(chan Msg, 2)
conn := NewConn(c, hashedID, hashedStringID, msgChannel, c.instanceURL)
c.conns[hashedStringID] = newConnContainer(conn, msgChannel)
return conn, nil
}
// ServerInstanceURL returns the address of the relay server. It could change after the close and reopen the connection.
func (c *Client) ServerInstanceURL() (string, error) {
c.muInstanceURL.Lock()
defer c.muInstanceURL.Unlock()
if c.instanceURL == nil {
return "", fmt.Errorf("relay connection is not established")
}
return c.instanceURL.String(), nil
}
// SetOnDisconnectListener sets a function that will be called when the connection to the relay server is closed.
func (c *Client) SetOnDisconnectListener(fn func()) {
c.listenerMutex.Lock()
defer c.listenerMutex.Unlock()
c.onDisconnectListener = fn
}
// HasConns returns true if there are connections.
func (c *Client) HasConns() bool {
c.mu.Lock()
defer c.mu.Unlock()
return len(c.conns) > 0
}
// Close closes the connection to the relay server and all connections to other peers.
func (c *Client) Close() error {
return c.close(true)
}
func (c *Client) connect() error {
conn, err := ws.Dial(c.connectionURL)
if err != nil {
return err
}
c.relayConn = conn
err = c.handShake()
if err != nil {
cErr := conn.Close()
if cErr != nil {
log.Errorf("failed to close connection: %s", cErr)
}
return err
}
return nil
}
func (c *Client) handShake() error {
authMsg := &auth2.Msg{
AuthAlgorithm: auth2.AlgoHMACSHA256,
AdditionalData: c.authTokenStore.TokenBinary(),
}
authData, err := authMsg.Marshal()
if err != nil {
return fmt.Errorf("marshal auth message: %w", err)
}
msg, err := messages.MarshalHelloMsg(c.hashedID, authData)
if err != nil {
log.Errorf("failed to marshal hello message: %s", err)
return err
}
_, err = c.relayConn.Write(msg)
if err != nil {
log.Errorf("failed to send hello message: %s", err)
return err
}
buf := make([]byte, messages.MaxHandshakeSize)
n, err := c.readWithTimeout(buf)
if err != nil {
log.Errorf("failed to read hello response: %s", err)
return err
}
_, err = messages.ValidateVersion(buf[:n])
if err != nil {
return fmt.Errorf("validate version: %w", err)
}
msgType, err := messages.DetermineServerMessageType(buf[messages.SizeOfVersionByte:n])
if err != nil {
log.Errorf("failed to determine message type: %s", err)
return err
}
if msgType != messages.MsgTypeHelloResponse {
log.Errorf("unexpected message type: %s", msgType)
return fmt.Errorf("unexpected message type")
}
additionalData, err := messages.UnmarshalHelloResponse(buf[messages.SizeOfProtoHeader:n])
if err != nil {
return err
}
addr, err := address.Unmarshal(additionalData)
if err != nil {
return fmt.Errorf("unmarshal address: %w", err)
}
c.muInstanceURL.Lock()
c.instanceURL = &RelayAddr{addr: addr.URL}
c.muInstanceURL.Unlock()
return nil
}
func (c *Client) readLoop(relayConn net.Conn) {
internallyStoppedFlag := newInternalStopFlag()
hc := healthcheck.NewReceiver()
go c.listenForStopEvents(hc, relayConn, internallyStoppedFlag)
var (
errExit error
n int
)
for {
bufPtr := c.bufPool.Get().(*[]byte)
buf := *bufPtr
n, errExit = relayConn.Read(buf)
if errExit != nil {
c.mu.Lock()
if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
c.log.Debugf("failed to read message from relay server: %s", errExit)
}
c.mu.Unlock()
break
}
_, err := messages.ValidateVersion(buf[:n])
if err != nil {
c.log.Errorf("failed to validate protocol version: %s", err)
c.bufPool.Put(bufPtr)
continue
}
msgType, err := messages.DetermineServerMessageType(buf[messages.SizeOfVersionByte:n])
if err != nil {
c.log.Errorf("failed to determine message type: %s", err)
c.bufPool.Put(bufPtr)
continue
}
if !c.handleMsg(msgType, buf[messages.SizeOfProtoHeader:n], bufPtr, hc, internallyStoppedFlag) {
break
}
}
hc.Stop()
c.muInstanceURL.Lock()
c.instanceURL = nil
c.muInstanceURL.Unlock()
c.notifyDisconnected()
c.wgReadLoop.Done()
_ = c.close(false)
}
func (c *Client) handleMsg(msgType messages.MsgType, buf []byte, bufPtr *[]byte, hc *healthcheck.Receiver, internallyStoppedFlag *internalStopFlag) (continueLoop bool) {
switch msgType {
case messages.MsgTypeHealthCheck:
c.handleHealthCheck(hc, internallyStoppedFlag)
c.bufPool.Put(bufPtr)
case messages.MsgTypeTransport:
return c.handleTransportMsg(buf, bufPtr, internallyStoppedFlag)
case messages.MsgTypeClose:
log.Debugf("relay connection close by server")
c.bufPool.Put(bufPtr)
return false
}
return true
}
func (c *Client) handleHealthCheck(hc *healthcheck.Receiver, internallyStoppedFlag *internalStopFlag) {
msg := messages.MarshalHealthcheck()
_, wErr := c.relayConn.Write(msg)
if wErr != nil {
if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
c.log.Errorf("failed to send heartbeat: %s", wErr)
}
}
hc.Heartbeat()
}
func (c *Client) handleTransportMsg(buf []byte, bufPtr *[]byte, internallyStoppedFlag *internalStopFlag) bool {
peerID, payload, err := messages.UnmarshalTransportMsg(buf)
if err != nil {
if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
c.log.Errorf("failed to parse transport message: %v", err)
}
c.bufPool.Put(bufPtr)
return true
}
stringID := messages.HashIDToString(peerID)
c.mu.Lock()
if !c.serviceIsRunning {
c.mu.Unlock()
c.bufPool.Put(bufPtr)
return false
}
container, ok := c.conns[stringID]
c.mu.Unlock()
if !ok {
c.log.Errorf("peer not found: %s", stringID)
c.bufPool.Put(bufPtr)
return true
}
msg := Msg{
bufPool: c.bufPool,
bufPtr: bufPtr,
Payload: payload,
}
container.writeMsg(msg)
return true
}
func (c *Client) writeTo(connReference *Conn, id string, dstID []byte, payload []byte) (int, error) {
c.mu.Lock()
conn, ok := c.conns[id]
c.mu.Unlock()
if !ok {
return 0, io.EOF
}
if conn.conn != connReference {
return 0, io.EOF
}
// todo: use buffer pool instead of create new transport msg.
msg, err := messages.MarshalTransportMsg(dstID, payload)
if err != nil {
log.Errorf("failed to marshal transport message: %s", err)
return 0, err
}
// the write always return with 0 length because the underling does not support the size feedback.
_, err = c.relayConn.Write(msg)
if err != nil {
log.Errorf("failed to write transport message: %s", err)
}
return len(payload), err
}
func (c *Client) listenForStopEvents(hc *healthcheck.Receiver, conn net.Conn, internalStopFlag *internalStopFlag) {
for {
select {
case _, ok := <-hc.OnTimeout:
if !ok {
return
}
c.log.Errorf("health check timeout")
internalStopFlag.set()
_ = conn.Close() // ignore the err because the readLoop will handle it
return
case <-c.parentCtx.Done():
err := c.close(true)
if err != nil {
log.Errorf("failed to teardown connection: %s", err)
}
return
}
}
}
func (c *Client) closeAllConns() {
for _, container := range c.conns {
container.close()
}
c.conns = make(map[string]*connContainer)
}
func (c *Client) closeConn(connReference *Conn, id string) error {
c.mu.Lock()
defer c.mu.Unlock()
container, ok := c.conns[id]
if !ok {
return fmt.Errorf("connection already closed")
}
if container.conn != connReference {
return fmt.Errorf("conn reference mismatch")
}
container.close()
delete(c.conns, id)
return nil
}
func (c *Client) close(gracefullyExit bool) error {
c.readLoopMutex.Lock()
defer c.readLoopMutex.Unlock()
c.mu.Lock()
var err error
if !c.serviceIsRunning {
c.mu.Unlock()
return nil
}
c.serviceIsRunning = false
c.closeAllConns()
if gracefullyExit {
c.writeCloseMsg()
}
err = c.relayConn.Close()
c.mu.Unlock()
c.wgReadLoop.Wait()
c.log.Infof("relay connection closed with: %s", c.connectionURL)
return err
}
func (c *Client) notifyDisconnected() {
c.listenerMutex.Lock()
defer c.listenerMutex.Unlock()
if c.onDisconnectListener == nil {
return
}
go c.onDisconnectListener()
}
func (c *Client) writeCloseMsg() {
msg := messages.MarshalCloseMsg()
_, err := c.relayConn.Write(msg)
if err != nil {
c.log.Errorf("failed to send close message: %s", err)
}
}
func (c *Client) readWithTimeout(buf []byte) (int, error) {
ctx, cancel := context.WithTimeout(c.parentCtx, serverResponseTimeout)
defer cancel()
readDone := make(chan struct{})
var (
n int
err error
)
go func() {
n, err = c.relayConn.Read(buf)
close(readDone)
}()
select {
case <-ctx.Done():
return 0, fmt.Errorf("read operation timed out")
case <-readDone:
return n, err
}
}

631
relay/client/client_test.go Normal file
View File

@@ -0,0 +1,631 @@
package client
import (
"context"
"net"
"os"
"testing"
"time"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/relay/auth/allow"
"github.com/netbirdio/netbird/relay/auth/hmac"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/relay/server"
)
var (
av = &allow.Auth{}
hmacTokenStore = &hmac.TokenStore{}
serverListenAddr = "127.0.0.1:1234"
serverURL = "rel://127.0.0.1:1234"
)
func TestMain(m *testing.M) {
_ = util.InitLog("error", "console")
code := m.Run()
os.Exit(code)
}
func TestClient(t *testing.T) {
ctx := context.Background()
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
listenCfg := server.ListenerConfig{Address: serverListenAddr}
err := srv.Listen(listenCfg)
if err != nil {
errChan <- err
}
}()
defer func() {
err := srv.Shutdown(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
// wait for server to start
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
t.Log("alice connecting to server")
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect()
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
defer clientAlice.Close()
t.Log("placeholder connecting to server")
clientPlaceHolder := NewClient(ctx, serverURL, hmacTokenStore, "clientPlaceHolder")
err = clientPlaceHolder.Connect()
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
defer clientPlaceHolder.Close()
t.Log("Bob connecting to server")
clientBob := NewClient(ctx, serverURL, hmacTokenStore, "bob")
err = clientBob.Connect()
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
defer clientBob.Close()
t.Log("Alice open connection to Bob")
connAliceToBob, err := clientAlice.OpenConn("bob")
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
t.Log("Bob open connection to Alice")
connBobToAlice, err := clientBob.OpenConn("alice")
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
payload := "hello bob, I am alice"
_, err = connAliceToBob.Write([]byte(payload))
if err != nil {
t.Fatalf("failed to write to channel: %s", err)
}
log.Debugf("alice sent message to bob")
buf := make([]byte, 65535)
n, err := connBobToAlice.Read(buf)
if err != nil {
t.Fatalf("failed to read from channel: %s", err)
}
log.Debugf("on new message from alice to bob")
if payload != string(buf[:n]) {
t.Fatalf("expected %s, got %s", payload, string(buf[:n]))
}
}
func TestRegistration(t *testing.T) {
ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
err := srv.Listen(srvCfg)
if err != nil {
errChan <- err
}
}()
// wait for server to start
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect()
if err != nil {
_ = srv.Shutdown(ctx)
t.Fatalf("failed to connect to server: %s", err)
}
err = clientAlice.Close()
if err != nil {
t.Errorf("failed to close conn: %s", err)
}
err = srv.Shutdown(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}
func TestRegistrationTimeout(t *testing.T) {
ctx := context.Background()
fakeUDPListener, err := net.ListenUDP("udp", &net.UDPAddr{
Port: 1234,
IP: net.ParseIP("0.0.0.0"),
})
if err != nil {
t.Fatalf("failed to bind UDP server: %s", err)
}
defer func(fakeUDPListener *net.UDPConn) {
_ = fakeUDPListener.Close()
}(fakeUDPListener)
fakeTCPListener, err := net.ListenTCP("tcp", &net.TCPAddr{
Port: 1234,
IP: net.ParseIP("0.0.0.0"),
})
if err != nil {
t.Fatalf("failed to bind TCP server: %s", err)
}
defer func(fakeTCPListener *net.TCPListener) {
_ = fakeTCPListener.Close()
}(fakeTCPListener)
clientAlice := NewClient(ctx, "127.0.0.1:1234", hmacTokenStore, "alice")
err = clientAlice.Connect()
if err == nil {
t.Errorf("failed to connect to server: %s", err)
}
log.Debugf("%s", err)
err = clientAlice.Close()
if err != nil {
t.Errorf("failed to close conn: %s", err)
}
}
func TestEcho(t *testing.T) {
ctx := context.Background()
idAlice := "alice"
idBob := "bob"
srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
err := srv.Listen(srvCfg)
if err != nil {
errChan <- err
}
}()
defer func() {
err := srv.Shutdown(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
// wait for servers to start
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, idAlice)
err = clientAlice.Connect()
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
defer func() {
err := clientAlice.Close()
if err != nil {
t.Errorf("failed to close Alice client: %s", err)
}
}()
clientBob := NewClient(ctx, serverURL, hmacTokenStore, idBob)
err = clientBob.Connect()
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
defer func() {
err := clientBob.Close()
if err != nil {
t.Errorf("failed to close Bob client: %s", err)
}
}()
connAliceToBob, err := clientAlice.OpenConn(idBob)
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
connBobToAlice, err := clientBob.OpenConn(idAlice)
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
payload := "hello bob, I am alice"
_, err = connAliceToBob.Write([]byte(payload))
if err != nil {
t.Fatalf("failed to write to channel: %s", err)
}
buf := make([]byte, 65535)
n, err := connBobToAlice.Read(buf)
if err != nil {
t.Fatalf("failed to read from channel: %s", err)
}
_, err = connBobToAlice.Write(buf[:n])
if err != nil {
t.Fatalf("failed to write to channel: %s", err)
}
n, err = connAliceToBob.Read(buf)
if err != nil {
t.Fatalf("failed to read from channel: %s", err)
}
if payload != string(buf[:n]) {
t.Fatalf("expected %s, got %s", payload, string(buf[:n]))
}
}
func TestBindToUnavailabePeer(t *testing.T) {
ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
err := srv.Listen(srvCfg)
if err != nil {
errChan <- err
}
}()
defer func() {
log.Infof("closing server")
err := srv.Shutdown(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
// wait for servers to start
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect()
if err != nil {
t.Errorf("failed to connect to server: %s", err)
}
_, err = clientAlice.OpenConn("bob")
if err != nil {
t.Errorf("failed to bind channel: %s", err)
}
log.Infof("closing client")
err = clientAlice.Close()
if err != nil {
t.Errorf("failed to close client: %s", err)
}
}
func TestBindReconnect(t *testing.T) {
ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
err := srv.Listen(srvCfg)
if err != nil {
errChan <- err
}
}()
defer func() {
log.Infof("closing server")
err := srv.Shutdown(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
// wait for servers to start
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect()
if err != nil {
t.Errorf("failed to connect to server: %s", err)
}
_, err = clientAlice.OpenConn("bob")
if err != nil {
t.Errorf("failed to bind channel: %s", err)
}
clientBob := NewClient(ctx, serverURL, hmacTokenStore, "bob")
err = clientBob.Connect()
if err != nil {
t.Errorf("failed to connect to server: %s", err)
}
chBob, err := clientBob.OpenConn("alice")
if err != nil {
t.Errorf("failed to bind channel: %s", err)
}
log.Infof("closing client Alice")
err = clientAlice.Close()
if err != nil {
t.Errorf("failed to close client: %s", err)
}
clientAlice = NewClient(ctx, serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect()
if err != nil {
t.Errorf("failed to connect to server: %s", err)
}
chAlice, err := clientAlice.OpenConn("bob")
if err != nil {
t.Errorf("failed to bind channel: %s", err)
}
testString := "hello alice, I am bob"
_, err = chBob.Write([]byte(testString))
if err != nil {
t.Errorf("failed to write to channel: %s", err)
}
buf := make([]byte, 65535)
n, err := chAlice.Read(buf)
if err != nil {
t.Errorf("failed to read from channel: %s", err)
}
if testString != string(buf[:n]) {
t.Errorf("expected %s, got %s", testString, string(buf[:n]))
}
log.Infof("closing client")
err = clientAlice.Close()
if err != nil {
t.Errorf("failed to close client: %s", err)
}
}
func TestCloseConn(t *testing.T) {
ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
err := srv.Listen(srvCfg)
if err != nil {
errChan <- err
}
}()
defer func() {
log.Infof("closing server")
err := srv.Shutdown(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
// wait for servers to start
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect()
if err != nil {
t.Errorf("failed to connect to server: %s", err)
}
conn, err := clientAlice.OpenConn("bob")
if err != nil {
t.Errorf("failed to bind channel: %s", err)
}
log.Infof("closing connection")
err = conn.Close()
if err != nil {
t.Errorf("failed to close connection: %s", err)
}
_, err = conn.Read(make([]byte, 1))
if err == nil {
t.Errorf("unexpected reading from closed connection")
}
_, err = conn.Write([]byte("hello"))
if err == nil {
t.Errorf("unexpected writing from closed connection")
}
}
func TestCloseRelayConn(t *testing.T) {
ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
err := srv.Listen(srvCfg)
if err != nil {
errChan <- err
}
}()
defer func() {
err := srv.Shutdown(ctx)
if err != nil {
log.Errorf("failed to close server: %s", err)
}
}()
// wait for servers to start
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect()
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
conn, err := clientAlice.OpenConn("bob")
if err != nil {
t.Errorf("failed to bind channel: %s", err)
}
_ = clientAlice.relayConn.Close()
_, err = conn.Read(make([]byte, 1))
if err == nil {
t.Errorf("unexpected reading from closed connection")
}
_, err = clientAlice.OpenConn("bob")
if err == nil {
t.Errorf("unexpected opening connection to closed server")
}
}
func TestCloseByServer(t *testing.T) {
ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv1, err := server.NewServer(otel.Meter(""), serverURL, false, av)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
err := srv1.Listen(srvCfg)
if err != nil {
errChan <- err
}
}()
// wait for servers to start
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
idAlice := "alice"
log.Debugf("connect by alice")
relayClient := NewClient(ctx, serverURL, hmacTokenStore, idAlice)
err = relayClient.Connect()
if err != nil {
log.Fatalf("failed to connect to server: %s", err)
}
disconnected := make(chan struct{})
relayClient.SetOnDisconnectListener(func() {
log.Infof("client disconnected")
close(disconnected)
})
err = srv1.Shutdown(ctx)
if err != nil {
t.Fatalf("failed to close server: %s", err)
}
select {
case <-disconnected:
case <-time.After(3 * time.Second):
log.Fatalf("timeout waiting for client to disconnect")
}
_, err = relayClient.OpenConn("bob")
if err == nil {
t.Errorf("unexpected opening connection to closed server")
}
}
func TestCloseByClient(t *testing.T) {
ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
err := srv.Listen(srvCfg)
if err != nil {
errChan <- err
}
}()
// wait for servers to start
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
idAlice := "alice"
log.Debugf("connect by alice")
relayClient := NewClient(ctx, serverURL, hmacTokenStore, idAlice)
err = relayClient.Connect()
if err != nil {
log.Fatalf("failed to connect to server: %s", err)
}
err = relayClient.Close()
if err != nil {
t.Errorf("failed to close client: %s", err)
}
_, err = relayClient.OpenConn("bob")
if err == nil {
t.Errorf("unexpected opening connection to closed server")
}
err = srv.Shutdown(ctx)
if err != nil {
t.Fatalf("failed to close server: %s", err)
}
}
func waitForServerToStart(errChan chan error) error {
select {
case err := <-errChan:
if err != nil {
return err
}
case <-time.After(300 * time.Millisecond):
return nil
}
return nil
}

76
relay/client/conn.go Normal file
View File

@@ -0,0 +1,76 @@
package client
import (
"io"
"net"
"time"
)
// Conn represent a connection to a relayed remote peer.
type Conn struct {
client *Client
dstID []byte
dstStringID string
messageChan chan Msg
instanceURL *RelayAddr
}
// NewConn creates a new connection to a relayed remote peer.
// client: the client instance, it used to send messages to the destination peer
// dstID: the destination peer ID
// dstStringID: the destination peer ID in string format
// messageChan: the channel where the messages will be received
// instanceURL: the relay instance URL, it used to get the proper server instance address for the remote peer
func NewConn(client *Client, dstID []byte, dstStringID string, messageChan chan Msg, instanceURL *RelayAddr) *Conn {
c := &Conn{
client: client,
dstID: dstID,
dstStringID: dstStringID,
messageChan: messageChan,
instanceURL: instanceURL,
}
return c
}
func (c *Conn) Write(p []byte) (n int, err error) {
return c.client.writeTo(c, c.dstStringID, c.dstID, p)
}
func (c *Conn) Read(b []byte) (n int, err error) {
msg, ok := <-c.messageChan
if !ok {
return 0, io.EOF
}
n = copy(b, msg.Payload)
msg.Free()
return n, nil
}
func (c *Conn) Close() error {
return c.client.closeConn(c, c.dstStringID)
}
func (c *Conn) LocalAddr() net.Addr {
return c.client.relayConn.LocalAddr()
}
func (c *Conn) RemoteAddr() net.Addr {
return c.instanceURL
}
func (c *Conn) SetDeadline(t time.Time) error {
//TODO implement me
panic("SetDeadline is not implemented")
}
func (c *Conn) SetReadDeadline(t time.Time) error {
//TODO implement me
panic("SetReadDeadline is not implemented")
}
func (c *Conn) SetWriteDeadline(t time.Time) error {
//TODO implement me
panic("SetReadDeadline is not implemented")
}

View File

@@ -0,0 +1,13 @@
package ws
type WebsocketAddr struct {
addr string
}
func (a WebsocketAddr) Network() string {
return "websocket"
}
func (a WebsocketAddr) String() string {
return a.addr
}

View File

@@ -0,0 +1,66 @@
package ws
import (
"context"
"fmt"
"net"
"time"
"nhooyr.io/websocket"
)
type Conn struct {
ctx context.Context
*websocket.Conn
remoteAddr WebsocketAddr
}
func NewConn(wsConn *websocket.Conn, serverAddress string) net.Conn {
return &Conn{
ctx: context.Background(),
Conn: wsConn,
remoteAddr: WebsocketAddr{serverAddress},
}
}
func (c *Conn) Read(b []byte) (n int, err error) {
t, ioReader, err := c.Conn.Reader(c.ctx)
if err != nil {
return 0, err
}
if t != websocket.MessageBinary {
return 0, fmt.Errorf("unexpected message type")
}
return ioReader.Read(b)
}
func (c *Conn) Write(b []byte) (n int, err error) {
err = c.Conn.Write(c.ctx, websocket.MessageBinary, b)
return 0, err
}
func (c *Conn) RemoteAddr() net.Addr {
return c.remoteAddr
}
func (c *Conn) LocalAddr() net.Addr {
return WebsocketAddr{addr: "unknown"}
}
func (c *Conn) SetReadDeadline(t time.Time) error {
return fmt.Errorf("SetReadDeadline is not implemented")
}
func (c *Conn) SetWriteDeadline(t time.Time) error {
return fmt.Errorf("SetWriteDeadline is not implemented")
}
func (c *Conn) SetDeadline(t time.Time) error {
return fmt.Errorf("SetDeadline is not implemented")
}
func (c *Conn) Close() error {
return c.Conn.CloseNow()
}

View File

@@ -0,0 +1,67 @@
package ws
import (
"context"
"fmt"
"net"
"net/http"
"net/url"
"strings"
log "github.com/sirupsen/logrus"
"nhooyr.io/websocket"
"github.com/netbirdio/netbird/relay/server/listener/ws"
nbnet "github.com/netbirdio/netbird/util/net"
)
func Dial(address string) (net.Conn, error) {
wsURL, err := prepareURL(address)
if err != nil {
return nil, err
}
opts := &websocket.DialOptions{
HTTPClient: httpClientNbDialer(),
}
parsedURL, err := url.Parse(wsURL)
if err != nil {
return nil, err
}
parsedURL.Path = ws.URLPath
wsConn, resp, err := websocket.Dial(context.Background(), parsedURL.String(), opts)
if err != nil {
log.Errorf("failed to dial to Relay server '%s': %s", wsURL, err)
return nil, err
}
if resp.Body != nil {
_ = resp.Body.Close()
}
conn := NewConn(wsConn, address)
return conn, nil
}
func prepareURL(address string) (string, error) {
if !strings.HasPrefix(address, "rel:") && !strings.HasPrefix(address, "rels:") {
return "", fmt.Errorf("unsupported scheme: %s", address)
}
return strings.Replace(address, "rel", "ws", 1), nil
}
func httpClientNbDialer() *http.Client {
customDialer := nbnet.NewDialer()
customTransport := &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return customDialer.DialContext(ctx, network, addr)
},
}
return &http.Client{
Transport: customTransport,
}
}

12
relay/client/doc.go Normal file
View File

@@ -0,0 +1,12 @@
/*
Package client contains the implementation of the Relay client.
The Relay client is responsible for establishing a connection with the Relay server and sending and receiving messages,
Keep persistent connection with the Relay server and handle the connection issues.
It uses the WebSocket protocol for communication and optionally supports TLS (Transport Layer Security).
If a peer wants to communicate with a peer on a different relay server, the manager will establish a new connection to
the relay server. The connection with these relay servers will be closed if there is no active connection. The peers
negotiate the common relay instance via signaling service.
*/
package client

48
relay/client/guard.go Normal file
View File

@@ -0,0 +1,48 @@
package client
import (
"context"
"time"
log "github.com/sirupsen/logrus"
)
var (
reconnectingTimeout = 5 * time.Second
)
// Guard manage the reconnection tries to the Relay server in case of disconnection event.
type Guard struct {
ctx context.Context
relayClient *Client
}
// NewGuard creates a new guard for the relay client.
func NewGuard(context context.Context, relayClient *Client) *Guard {
g := &Guard{
ctx: context,
relayClient: relayClient,
}
return g
}
// OnDisconnected is called when the relay client is disconnected from the relay server. It will trigger the reconnection
// todo prevent multiple reconnection instances. In the current usage it should not happen, but it is better to prevent
func (g *Guard) OnDisconnected() {
ticker := time.NewTicker(reconnectingTimeout)
defer ticker.Stop()
for {
select {
case <-ticker.C:
err := g.relayClient.Connect()
if err != nil {
log.Errorf("failed to reconnect to relay server: %s", err)
continue
}
return
case <-g.ctx.Done():
return
}
}
}

365
relay/client/manager.go Normal file
View File

@@ -0,0 +1,365 @@
package client
import (
"container/list"
"context"
"errors"
"fmt"
"net"
"reflect"
"sync"
"time"
log "github.com/sirupsen/logrus"
relayAuth "github.com/netbirdio/netbird/relay/auth/hmac"
)
var (
relayCleanupInterval = 60 * time.Second
connectionTimeout = 30 * time.Second
maxConcurrentServers = 7
ErrRelayClientNotConnected = fmt.Errorf("relay client not connected")
)
// RelayTrack hold the relay clients for the foreign relay servers.
// With the mutex can ensure we can open new connection in case the relay connection has been established with
// the relay server.
type RelayTrack struct {
sync.RWMutex
relayClient *Client
}
func NewRelayTrack() *RelayTrack {
return &RelayTrack{}
}
type OnServerCloseListener func()
// ManagerService is the interface for the relay manager.
type ManagerService interface {
Serve() error
OpenConn(serverAddress, peerKey string) (net.Conn, error)
AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error
RelayInstanceAddress() (string, error)
ServerURLs() []string
HasRelayAddress() bool
UpdateToken(token *relayAuth.Token) error
}
// Manager is a manager for the relay client instances. It establishes one persistent connection to the given relay URL
// and automatically reconnect to them in case disconnection.
// The manager also manage temporary relay connection. If a client wants to communicate with a client on a
// different relay server, the manager will establish a new connection to the relay server. The connection with these
// relay servers will be closed if there is no active connection. Periodically the manager will check if there is any
// unused relay connection and close it.
type Manager struct {
ctx context.Context
serverURLs []string
peerID string
tokenStore *relayAuth.TokenStore
relayClient *Client
reconnectGuard *Guard
relayClients map[string]*RelayTrack
relayClientsMutex sync.RWMutex
onDisconnectedListeners map[string]*list.List
listenerLock sync.Mutex
}
// NewManager creates a new manager instance.
// The serverURL address can be empty. In this case, the manager will not serve.
func NewManager(ctx context.Context, serverURLs []string, peerID string) *Manager {
return &Manager{
ctx: ctx,
serverURLs: serverURLs,
peerID: peerID,
tokenStore: &relayAuth.TokenStore{},
relayClients: make(map[string]*RelayTrack),
onDisconnectedListeners: make(map[string]*list.List),
}
}
// Serve starts the manager. It will establish a connection to the relay server and start the relay cleanup loop for
// the unused relay connections. The manager will automatically reconnect to the relay server in case of disconnection.
func (m *Manager) Serve() error {
if m.relayClient != nil {
return fmt.Errorf("manager already serving")
}
log.Debugf("starting relay client manager with %v relay servers", m.serverURLs)
totalServers := len(m.serverURLs)
successChan := make(chan *Client, 1)
errChan := make(chan error, len(m.serverURLs))
ctx, cancel := context.WithTimeout(m.ctx, connectionTimeout)
defer cancel()
sem := make(chan struct{}, maxConcurrentServers)
for _, url := range m.serverURLs {
sem <- struct{}{}
go func(url string) {
defer func() { <-sem }()
m.connect(m.ctx, url, successChan, errChan)
}(url)
}
var errCount int
for {
select {
case client := <-successChan:
log.Infof("Successfully connected to relay server: %s", client.connectionURL)
m.relayClient = client
m.reconnectGuard = NewGuard(m.ctx, m.relayClient)
m.relayClient.SetOnDisconnectListener(func() {
m.onServerDisconnected(client.connectionURL)
})
m.startCleanupLoop()
return nil
case err := <-errChan:
errCount++
log.Warnf("Connection attempt failed: %v", err)
if errCount == totalServers {
return errors.New("failed to connect to any relay server: all attempts failed")
}
case <-ctx.Done():
return fmt.Errorf("failed to connect to any relay server: %w", ctx.Err())
}
}
}
func (m *Manager) connect(ctx context.Context, serverURL string, successChan chan<- *Client, errChan chan<- error) {
// TODO: abort the connection if another connection was successful
relayClient := NewClient(ctx, serverURL, m.tokenStore, m.peerID)
if err := relayClient.Connect(); err != nil {
errChan <- fmt.Errorf("failed to connect to %s: %w", serverURL, err)
return
}
select {
case successChan <- relayClient:
// This client was the first to connect successfully
default:
if err := relayClient.Close(); err != nil {
log.Debugf("failed to close relay client: %s", err)
}
}
}
// OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be
// established via the relay server. If the peer is on a different relay server, the manager will establish a new
// connection to the relay server. It returns back with a net.Conn what represent the remote peer connection.
func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) {
if m.relayClient == nil {
return nil, ErrRelayClientNotConnected
}
foreign, err := m.isForeignServer(serverAddress)
if err != nil {
return nil, err
}
var (
netConn net.Conn
)
if !foreign {
log.Debugf("open peer connection via permanent server: %s", peerKey)
netConn, err = m.relayClient.OpenConn(peerKey)
} else {
log.Debugf("open peer connection via foreign server: %s", serverAddress)
netConn, err = m.openConnVia(serverAddress, peerKey)
}
if err != nil {
return nil, err
}
return netConn, err
}
// AddCloseListener adds a listener to the given server instance address. The listener will be called if the connection
// closed.
func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error {
foreign, err := m.isForeignServer(serverAddress)
if err != nil {
return err
}
var listenerAddr string
if foreign {
listenerAddr = serverAddress
} else {
listenerAddr = m.relayClient.connectionURL
}
m.addListener(listenerAddr, onClosedListener)
return nil
}
// RelayInstanceAddress returns the address of the permanent relay server. It could change if the network connection is
// lost. This address will be sent to the target peer to choose the common relay server for the communication.
func (m *Manager) RelayInstanceAddress() (string, error) {
if m.relayClient == nil {
return "", ErrRelayClientNotConnected
}
return m.relayClient.ServerInstanceURL()
}
// ServerURLs returns the addresses of the relay servers.
func (m *Manager) ServerURLs() []string {
return m.serverURLs
}
// HasRelayAddress returns true if the manager is serving. With this method can check if the peer can communicate with
// Relay service.
func (m *Manager) HasRelayAddress() bool {
return len(m.serverURLs) > 0
}
// UpdateToken updates the token in the token store.
func (m *Manager) UpdateToken(token *relayAuth.Token) error {
return m.tokenStore.UpdateToken(token)
}
func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) {
// check if already has a connection to the desired relay server
m.relayClientsMutex.RLock()
rt, ok := m.relayClients[serverAddress]
if ok {
rt.RLock()
m.relayClientsMutex.RUnlock()
defer rt.RUnlock()
return rt.relayClient.OpenConn(peerKey)
}
m.relayClientsMutex.RUnlock()
// if not, establish a new connection but check it again (because changed the lock type) before starting the
// connection
m.relayClientsMutex.Lock()
rt, ok = m.relayClients[serverAddress]
if ok {
rt.RLock()
m.relayClientsMutex.Unlock()
defer rt.RUnlock()
return rt.relayClient.OpenConn(peerKey)
}
// create a new relay client and store it in the relayClients map
rt = NewRelayTrack()
rt.Lock()
m.relayClients[serverAddress] = rt
m.relayClientsMutex.Unlock()
relayClient := NewClient(m.ctx, serverAddress, m.tokenStore, m.peerID)
err := relayClient.Connect()
if err != nil {
rt.Unlock()
m.relayClientsMutex.Lock()
delete(m.relayClients, serverAddress)
m.relayClientsMutex.Unlock()
return nil, err
}
// if connection closed then delete the relay client from the list
relayClient.SetOnDisconnectListener(func() {
m.onServerDisconnected(serverAddress)
})
rt.relayClient = relayClient
rt.Unlock()
conn, err := relayClient.OpenConn(peerKey)
if err != nil {
return nil, err
}
return conn, nil
}
func (m *Manager) onServerDisconnected(serverAddress string) {
if serverAddress == m.relayClient.connectionURL {
go m.reconnectGuard.OnDisconnected()
}
m.notifyOnDisconnectListeners(serverAddress)
}
func (m *Manager) isForeignServer(address string) (bool, error) {
rAddr, err := m.relayClient.ServerInstanceURL()
if err != nil {
return false, fmt.Errorf("relay client not connected")
}
return rAddr != address, nil
}
func (m *Manager) startCleanupLoop() {
if m.ctx.Err() != nil {
return
}
ticker := time.NewTicker(relayCleanupInterval)
go func() {
defer ticker.Stop()
for {
select {
case <-m.ctx.Done():
return
case <-ticker.C:
m.cleanUpUnusedRelays()
}
}
}()
}
func (m *Manager) cleanUpUnusedRelays() {
m.relayClientsMutex.Lock()
defer m.relayClientsMutex.Unlock()
for addr, rt := range m.relayClients {
rt.Lock()
if rt.relayClient.HasConns() {
rt.Unlock()
continue
}
rt.relayClient.SetOnDisconnectListener(nil)
go func() {
_ = rt.relayClient.Close()
}()
log.Debugf("clean up unused relay server connection: %s", addr)
delete(m.relayClients, addr)
rt.Unlock()
}
}
func (m *Manager) addListener(serverAddress string, onClosedListener OnServerCloseListener) {
m.listenerLock.Lock()
defer m.listenerLock.Unlock()
l, ok := m.onDisconnectedListeners[serverAddress]
if !ok {
l = list.New()
}
for e := l.Front(); e != nil; e = e.Next() {
if reflect.ValueOf(e.Value).Pointer() == reflect.ValueOf(onClosedListener).Pointer() {
return
}
}
l.PushBack(onClosedListener)
m.onDisconnectedListeners[serverAddress] = l
}
func (m *Manager) notifyOnDisconnectListeners(serverAddress string) {
m.listenerLock.Lock()
defer m.listenerLock.Unlock()
l, ok := m.onDisconnectedListeners[serverAddress]
if !ok {
return
}
for e := l.Front(); e != nil; e = e.Next() {
go e.Value.(OnServerCloseListener)()
}
delete(m.onDisconnectedListeners, serverAddress)
}

View File

@@ -0,0 +1,432 @@
package client
import (
"context"
"testing"
"time"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/relay/server"
)
func TestEmptyURL(t *testing.T) {
mgr := NewManager(context.Background(), nil, "alice")
err := mgr.Serve()
if err == nil {
t.Errorf("expected error, got nil")
}
}
func TestForeignConn(t *testing.T) {
ctx := context.Background()
srvCfg1 := server.ListenerConfig{
Address: "localhost:1234",
}
srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
err := srv1.Listen(srvCfg1)
if err != nil {
errChan <- err
}
}()
defer func() {
err := srv1.Shutdown(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
srvCfg2 := server.ListenerConfig{
Address: "localhost:2234",
}
srv2, err := server.NewServer(otel.Meter(""), srvCfg2.Address, false, av)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan2 := make(chan error, 1)
go func() {
err := srv2.Listen(srvCfg2)
if err != nil {
errChan2 <- err
}
}()
defer func() {
err := srv2.Shutdown(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
if err := waitForServerToStart(errChan2); err != nil {
t.Fatalf("failed to start server: %s", err)
}
idAlice := "alice"
log.Debugf("connect by alice")
mCtx, cancel := context.WithCancel(ctx)
defer cancel()
clientAlice := NewManager(mCtx, toURL(srvCfg1), idAlice)
err = clientAlice.Serve()
if err != nil {
t.Fatalf("failed to serve manager: %s", err)
}
idBob := "bob"
log.Debugf("connect by bob")
clientBob := NewManager(mCtx, toURL(srvCfg2), idBob)
err = clientBob.Serve()
if err != nil {
t.Fatalf("failed to serve manager: %s", err)
}
bobsSrvAddr, err := clientBob.RelayInstanceAddress()
if err != nil {
t.Fatalf("failed to get relay address: %s", err)
}
connAliceToBob, err := clientAlice.OpenConn(bobsSrvAddr, idBob)
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
connBobToAlice, err := clientBob.OpenConn(bobsSrvAddr, idAlice)
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
payload := "hello bob, I am alice"
_, err = connAliceToBob.Write([]byte(payload))
if err != nil {
t.Fatalf("failed to write to channel: %s", err)
}
buf := make([]byte, 65535)
n, err := connBobToAlice.Read(buf)
if err != nil {
t.Fatalf("failed to read from channel: %s", err)
}
_, err = connBobToAlice.Write(buf[:n])
if err != nil {
t.Fatalf("failed to write to channel: %s", err)
}
n, err = connAliceToBob.Read(buf)
if err != nil {
t.Fatalf("failed to read from channel: %s", err)
}
if payload != string(buf[:n]) {
t.Fatalf("expected %s, got %s", payload, string(buf[:n]))
}
}
func TestForeginConnClose(t *testing.T) {
ctx := context.Background()
srvCfg1 := server.ListenerConfig{
Address: "localhost:1234",
}
srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
err := srv1.Listen(srvCfg1)
if err != nil {
errChan <- err
}
}()
defer func() {
err := srv1.Shutdown(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
srvCfg2 := server.ListenerConfig{
Address: "localhost:2234",
}
srv2, err := server.NewServer(otel.Meter(""), srvCfg2.Address, false, av)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan2 := make(chan error, 1)
go func() {
err := srv2.Listen(srvCfg2)
if err != nil {
errChan2 <- err
}
}()
defer func() {
err := srv2.Shutdown(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
if err := waitForServerToStart(errChan2); err != nil {
t.Fatalf("failed to start server: %s", err)
}
idAlice := "alice"
log.Debugf("connect by alice")
mCtx, cancel := context.WithCancel(ctx)
defer cancel()
mgr := NewManager(mCtx, toURL(srvCfg1), idAlice)
err = mgr.Serve()
if err != nil {
t.Fatalf("failed to serve manager: %s", err)
}
conn, err := mgr.OpenConn(toURL(srvCfg2)[0], "anotherpeer")
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
err = conn.Close()
if err != nil {
t.Fatalf("failed to close connection: %s", err)
}
}
func TestForeginAutoClose(t *testing.T) {
ctx := context.Background()
relayCleanupInterval = 1 * time.Second
srvCfg1 := server.ListenerConfig{
Address: "localhost:1234",
}
srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
t.Log("binding server 1.")
err := srv1.Listen(srvCfg1)
if err != nil {
errChan <- err
}
}()
defer func() {
t.Logf("closing server 1.")
err := srv1.Shutdown(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
t.Logf("server 1. closed")
}()
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
srvCfg2 := server.ListenerConfig{
Address: "localhost:2234",
}
srv2, err := server.NewServer(otel.Meter(""), srvCfg2.Address, false, av)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan2 := make(chan error, 1)
go func() {
t.Log("binding server 2.")
err := srv2.Listen(srvCfg2)
if err != nil {
errChan2 <- err
}
}()
defer func() {
t.Logf("closing server 2.")
err := srv2.Shutdown(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
t.Logf("server 2 closed.")
}()
if err := waitForServerToStart(errChan2); err != nil {
t.Fatalf("failed to start server: %s", err)
}
idAlice := "alice"
t.Log("connect to server 1.")
mCtx, cancel := context.WithCancel(ctx)
defer cancel()
mgr := NewManager(mCtx, toURL(srvCfg1), idAlice)
err = mgr.Serve()
if err != nil {
t.Fatalf("failed to serve manager: %s", err)
}
t.Log("open connection to another peer")
conn, err := mgr.OpenConn(toURL(srvCfg2)[0], "anotherpeer")
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
t.Log("close conn")
err = conn.Close()
if err != nil {
t.Fatalf("failed to close connection: %s", err)
}
t.Logf("waiting for relay cleanup: %s", relayCleanupInterval+1*time.Second)
time.Sleep(relayCleanupInterval + 1*time.Second)
if len(mgr.relayClients) != 0 {
t.Errorf("expected 0, got %d", len(mgr.relayClients))
}
t.Logf("closing manager")
}
func TestAutoReconnect(t *testing.T) {
ctx := context.Background()
reconnectingTimeout = 2 * time.Second
srvCfg := server.ListenerConfig{
Address: "localhost:1234",
}
srv, err := server.NewServer(otel.Meter(""), srvCfg.Address, false, av)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
err := srv.Listen(srvCfg)
if err != nil {
errChan <- err
}
}()
defer func() {
err := srv.Shutdown(ctx)
if err != nil {
log.Errorf("failed to close server: %s", err)
}
}()
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
mCtx, cancel := context.WithCancel(ctx)
defer cancel()
clientAlice := NewManager(mCtx, toURL(srvCfg), "alice")
err = clientAlice.Serve()
if err != nil {
t.Fatalf("failed to serve manager: %s", err)
}
ra, err := clientAlice.RelayInstanceAddress()
if err != nil {
t.Errorf("failed to get relay address: %s", err)
}
conn, err := clientAlice.OpenConn(ra, "bob")
if err != nil {
t.Errorf("failed to bind channel: %s", err)
}
t.Log("closing client relay connection")
// todo figure out moc server
_ = clientAlice.relayClient.relayConn.Close()
t.Log("start test reading")
_, err = conn.Read(make([]byte, 1))
if err == nil {
t.Errorf("unexpected reading from closed connection")
}
log.Infof("waiting for reconnection")
time.Sleep(reconnectingTimeout + 1*time.Second)
log.Infof("reopent the connection")
_, err = clientAlice.OpenConn(ra, "bob")
if err != nil {
t.Errorf("failed to open channel: %s", err)
}
}
func TestNotifierDoubleAdd(t *testing.T) {
ctx := context.Background()
srvCfg1 := server.ListenerConfig{
Address: "localhost:1234",
}
srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
err := srv1.Listen(srvCfg1)
if err != nil {
errChan <- err
}
}()
defer func() {
err := srv1.Shutdown(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
idAlice := "alice"
log.Debugf("connect by alice")
mCtx, cancel := context.WithCancel(ctx)
defer cancel()
clientAlice := NewManager(mCtx, toURL(srvCfg1), idAlice)
err = clientAlice.Serve()
if err != nil {
t.Fatalf("failed to serve manager: %s", err)
}
conn1, err := clientAlice.OpenConn(clientAlice.ServerURLs()[0], "idBob")
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
fnCloseListener := OnServerCloseListener(func() {
log.Infof("close listener")
})
err = clientAlice.AddCloseListener(clientAlice.ServerURLs()[0], fnCloseListener)
if err != nil {
t.Fatalf("failed to add close listener: %s", err)
}
err = clientAlice.AddCloseListener(clientAlice.ServerURLs()[0], fnCloseListener)
if err != nil {
t.Fatalf("failed to add close listener: %s", err)
}
err = conn1.Close()
if err != nil {
t.Errorf("failed to close connection: %s", err)
}
}
func toURL(address server.ListenerConfig) []string {
return []string{"rel://" + address.Address}
}

35
relay/cmd/env.go Normal file
View File

@@ -0,0 +1,35 @@
package cmd
import (
"os"
"strings"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
)
// setFlagsFromEnvVars reads and updates flag values from environment variables with prefix NB_
func setFlagsFromEnvVars(cmd *cobra.Command) {
flags := cmd.PersistentFlags()
flags.VisitAll(func(f *pflag.Flag) {
newEnvVar := flagNameToEnvVar(f.Name, "NB_")
value, present := os.LookupEnv(newEnvVar)
if !present {
return
}
err := flags.Set(f.Name, value)
if err != nil {
log.Infof("unable to configure flag %s using variable %s, err: %v", f.Name, newEnvVar, err)
}
})
}
// flagNameToEnvVar converts flag name to environment var name adding a prefix,
// replacing dashes and making all uppercase (e.g. setup-keys is converted to NB_SETUP_KEYS according to the input prefix)
func flagNameToEnvVar(cmdFlag string, prefix string) string {
parsed := strings.ReplaceAll(cmdFlag, "-", "_")
upper := strings.ToUpper(parsed)
return prefix + upper
}

214
relay/cmd/root.go Normal file
View File

@@ -0,0 +1,214 @@
package cmd
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/encryption"
auth "github.com/netbirdio/netbird/relay/auth/hmac"
"github.com/netbirdio/netbird/relay/server"
"github.com/netbirdio/netbird/signal/metrics"
"github.com/netbirdio/netbird/util"
)
const (
metricsPort = 9090
)
type Config struct {
ListenAddress string
// in HA every peer connect to a common domain, the instance domain has been distributed during the p2p connection
// it is a domain:port or ip:port
ExposedAddress string
LetsencryptEmail string
LetsencryptDataDir string
LetsencryptDomains []string
// in case of using Route 53 for DNS challenge the credentials should be provided in the environment variables or
// in the AWS credentials file
LetsencryptAWSRoute53 bool
TlsCertFile string
TlsKeyFile string
AuthSecret string
LogLevel string
LogFile string
}
func (c Config) Validate() error {
if c.ExposedAddress == "" {
return fmt.Errorf("exposed address is required")
}
if c.AuthSecret == "" {
return fmt.Errorf("auth secret is required")
}
return nil
}
func (c Config) HasCertConfig() bool {
return c.TlsCertFile != "" && c.TlsKeyFile != ""
}
func (c Config) HasLetsEncrypt() bool {
return c.LetsencryptDataDir != "" && c.LetsencryptDomains != nil && len(c.LetsencryptDomains) > 0
}
var (
cobraConfig *Config
rootCmd = &cobra.Command{
Use: "relay",
Short: "Relay service",
Long: "Relay service for Netbird agents",
SilenceUsage: true,
SilenceErrors: true,
RunE: execute,
}
)
func init() {
_ = util.InitLog("trace", "console")
cobraConfig = &Config{}
rootCmd.PersistentFlags().StringVarP(&cobraConfig.ListenAddress, "listen-address", "l", ":443", "listen address")
rootCmd.PersistentFlags().StringVarP(&cobraConfig.ExposedAddress, "exposed-address", "e", "", "instance domain address (or ip) and port, it will be distributes between peers")
rootCmd.PersistentFlags().StringVarP(&cobraConfig.LetsencryptDataDir, "letsencrypt-data-dir", "d", "", "a directory to store Let's Encrypt data. Required if Let's Encrypt is enabled.")
rootCmd.PersistentFlags().StringSliceVarP(&cobraConfig.LetsencryptDomains, "letsencrypt-domains", "a", nil, "list of domains to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
rootCmd.PersistentFlags().StringVar(&cobraConfig.LetsencryptEmail, "letsencrypt-email", "", "email address to use for Let's Encrypt certificate registration")
rootCmd.PersistentFlags().BoolVar(&cobraConfig.LetsencryptAWSRoute53, "letsencrypt-aws-route53", false, "use AWS Route 53 for Let's Encrypt DNS challenge")
rootCmd.PersistentFlags().StringVarP(&cobraConfig.TlsCertFile, "tls-cert-file", "c", "", "")
rootCmd.PersistentFlags().StringVarP(&cobraConfig.TlsKeyFile, "tls-key-file", "k", "", "")
rootCmd.PersistentFlags().StringVarP(&cobraConfig.AuthSecret, "auth-secret", "s", "", "auth secret")
rootCmd.PersistentFlags().StringVar(&cobraConfig.LogLevel, "log-level", "info", "log level")
rootCmd.PersistentFlags().StringVar(&cobraConfig.LogFile, "log-file", "console", "log file")
setFlagsFromEnvVars(rootCmd)
}
func Execute() error {
return rootCmd.Execute()
}
func waitForExitSignal() {
osSigs := make(chan os.Signal, 1)
signal.Notify(osSigs, syscall.SIGINT, syscall.SIGTERM)
<-osSigs
}
func execute(cmd *cobra.Command, args []string) error {
err := cobraConfig.Validate()
if err != nil {
log.Debugf("invalid config: %s", err)
return fmt.Errorf("invalid config: %s", err)
}
err = util.InitLog(cobraConfig.LogLevel, cobraConfig.LogFile)
if err != nil {
log.Debugf("failed to initialize log: %s", err)
return fmt.Errorf("failed to initialize log: %s", err)
}
metricsServer, err := metrics.NewServer(metricsPort, "")
if err != nil {
log.Debugf("setup metrics: %v", err)
return fmt.Errorf("setup metrics: %v", err)
}
go func() {
log.Infof("running metrics server: %s%s", metricsServer.Addr, metricsServer.Endpoint)
if err := metricsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
log.Fatalf("Failed to start metrics server: %v", err)
}
}()
srvListenerCfg := server.ListenerConfig{
Address: cobraConfig.ListenAddress,
}
tlsConfig, tlsSupport, err := handleTLSConfig(cobraConfig)
if err != nil {
log.Debugf("failed to setup TLS config: %s", err)
return fmt.Errorf("failed to setup TLS config: %s", err)
}
srvListenerCfg.TLSConfig = tlsConfig
authenticator := auth.NewTimedHMACValidator(cobraConfig.AuthSecret, 24*time.Hour)
srv, err := server.NewServer(metricsServer.Meter, cobraConfig.ExposedAddress, tlsSupport, authenticator)
if err != nil {
log.Debugf("failed to create relay server: %v", err)
return fmt.Errorf("failed to create relay server: %v", err)
}
log.Infof("server will be available on: %s", srv.InstanceURL())
go func() {
if err := srv.Listen(srvListenerCfg); err != nil {
log.Fatalf("failed to bind server: %s", err)
}
}()
// it will block until exit signal
waitForExitSignal()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
var shutDownErrors error
if err := srv.Shutdown(ctx); err != nil {
shutDownErrors = multierror.Append(shutDownErrors, fmt.Errorf("failed to close server: %s", err))
}
log.Infof("shutting down metrics server")
if err := metricsServer.Shutdown(ctx); err != nil {
shutDownErrors = multierror.Append(shutDownErrors, fmt.Errorf("failed to close metrics server: %v", err))
}
return shutDownErrors
}
func handleTLSConfig(cfg *Config) (*tls.Config, bool, error) {
if cfg.LetsencryptAWSRoute53 {
log.Debugf("using Let's Encrypt DNS resolver with Route 53 support")
r53 := encryption.Route53TLS{
DataDir: cfg.LetsencryptDataDir,
Email: cfg.LetsencryptEmail,
Domains: cfg.LetsencryptDomains,
}
tlsCfg, err := r53.GetCertificate()
if err != nil {
return nil, false, fmt.Errorf("%s", err)
}
return tlsCfg, true, nil
}
if cfg.HasLetsEncrypt() {
log.Infof("setting up TLS with Let's Encrypt.")
tlsCfg, err := setupTLSCertManager(cfg.LetsencryptDataDir, cfg.LetsencryptDomains...)
if err != nil {
return nil, false, fmt.Errorf("%s", err)
}
return tlsCfg, true, nil
}
if cfg.HasCertConfig() {
log.Debugf("using file based TLS config")
tlsCfg, err := encryption.LoadTLSConfig(cfg.TlsCertFile, cfg.TlsKeyFile)
if err != nil {
return nil, false, fmt.Errorf("%s", err)
}
return tlsCfg, true, nil
}
return nil, false, nil
}
func setupTLSCertManager(letsencryptDataDir string, letsencryptDomains ...string) (*tls.Config, error) {
certManager, err := encryption.CreateCertManager(letsencryptDataDir, letsencryptDomains...)
if err != nil {
return nil, fmt.Errorf("failed creating LetsEncrypt cert manager: %v", err)
}
return certManager.TLSConfig(), nil
}

14
relay/doc.go Normal file
View File

@@ -0,0 +1,14 @@
//Package main
/*
The `relay` package contains the implementation of the Relay server and client. The Relay server can be used to relay
messages between peers on a single network channel. In this implementation the transport layer is the WebSocket
protocol.
Between the server and client communication has been design a custom protocol and message format. These messages are
transported over the WebSocket connection. Optionally the server can use TLS to secure the communication.
The service can support multiple Relay server instances. For this purpose the peers must know the server instance URL.
This URL will be sent to the target peer to choose the common Relay server for the communication via Signal service.
*/
package main

17
relay/healthcheck/doc.go Normal file
View File

@@ -0,0 +1,17 @@
/*
The `healthcheck` package is responsible for managing the health checks between the client and the relay server. It
ensures that the connection between the client and the server are alive and functioning properly.
The `Sender` struct is responsible for sending health check signals to the receiver. The receiver listens for these
signals and sends a new signal back to the sender to acknowledge that the signal has been received. If the sender does
not receive an acknowledgment signal within a certain time frame, it will send a timeout signal via timeout channel
and stop working.
The `Receiver` struct is responsible for receiving the health check signals from the sender. If the receiver does not
receive a signal within a certain time frame, it will send a timeout signal via the OnTimeout channel and stop working.
In the Relay usage the signal is sent to the peer in message type Healthcheck. In case of timeout the connection is
closed and the peer is removed from the relay.
*/
package healthcheck

View File

@@ -0,0 +1,82 @@
package healthcheck
import (
"context"
"time"
)
var (
heartbeatTimeout = healthCheckInterval + 3*time.Second
)
// Receiver is a healthcheck receiver
// It will listen for heartbeat and check if the heartbeat is not received in a certain time
// If the heartbeat is not received in a certain time, it will send a timeout signal and stop to work
// The heartbeat timeout is a bit longer than the sender's healthcheck interval
type Receiver struct {
OnTimeout chan struct{}
ctx context.Context
ctxCancel context.CancelFunc
heartbeat chan struct{}
alive bool
}
// NewReceiver creates a new healthcheck receiver and start the timer in the background
func NewReceiver() *Receiver {
ctx, ctxCancel := context.WithCancel(context.Background())
r := &Receiver{
OnTimeout: make(chan struct{}, 1),
ctx: ctx,
ctxCancel: ctxCancel,
heartbeat: make(chan struct{}, 1),
}
go r.waitForHealthcheck()
return r
}
// Heartbeat acknowledge the heartbeat has been received
func (r *Receiver) Heartbeat() {
select {
case r.heartbeat <- struct{}{}:
default:
}
}
// Stop check the timeout and do not send new notifications
func (r *Receiver) Stop() {
r.ctxCancel()
}
func (r *Receiver) waitForHealthcheck() {
ticker := time.NewTicker(heartbeatTimeout)
defer ticker.Stop()
defer r.ctxCancel()
defer close(r.OnTimeout)
for {
select {
case <-r.heartbeat:
r.alive = true
case <-ticker.C:
if r.alive {
r.alive = false
continue
}
r.notifyTimeout()
return
case <-r.ctx.Done():
return
}
}
}
func (r *Receiver) notifyTimeout() {
select {
case r.OnTimeout <- struct{}{}:
default:
}
}

View File

@@ -0,0 +1,42 @@
package healthcheck
import (
"testing"
"time"
)
func TestNewReceiver(t *testing.T) {
heartbeatTimeout = 5 * time.Second
r := NewReceiver()
select {
case <-r.OnTimeout:
t.Error("unexpected timeout")
case <-time.After(1 * time.Second):
}
}
func TestNewReceiverNotReceive(t *testing.T) {
heartbeatTimeout = 1 * time.Second
r := NewReceiver()
select {
case <-r.OnTimeout:
case <-time.After(2 * time.Second):
t.Error("timeout not received")
}
}
func TestNewReceiverAck(t *testing.T) {
heartbeatTimeout = 2 * time.Second
r := NewReceiver()
r.Heartbeat()
select {
case <-r.OnTimeout:
t.Error("unexpected timeout")
case <-time.After(3 * time.Second):
}
}

View File

@@ -0,0 +1,68 @@
package healthcheck
import (
"context"
"time"
)
var (
healthCheckInterval = 25 * time.Second
healthCheckTimeout = 5 * time.Second
)
// Sender is a healthcheck sender
// It will send healthcheck signal to the receiver
// If the receiver does not receive the signal in a certain time, it will send a timeout signal and stop to work
// It will also stop if the context is canceled
type Sender struct {
// HealthCheck is a channel to send health check signal to the peer
HealthCheck chan struct{}
// Timeout is a channel to the health check signal is not received in a certain time
Timeout chan struct{}
ack chan struct{}
}
// NewSender creates a new healthcheck sender
func NewSender() *Sender {
hc := &Sender{
HealthCheck: make(chan struct{}, 1),
Timeout: make(chan struct{}, 1),
ack: make(chan struct{}, 1),
}
return hc
}
// OnHCResponse sends an acknowledgment signal to the sender
func (hc *Sender) OnHCResponse() {
select {
case hc.ack <- struct{}{}:
default:
}
}
func (hc *Sender) StartHealthCheck(ctx context.Context) {
ticker := time.NewTicker(healthCheckInterval)
defer ticker.Stop()
timeoutTimer := time.NewTimer(healthCheckInterval + healthCheckTimeout)
defer timeoutTimer.Stop()
defer close(hc.HealthCheck)
defer close(hc.Timeout)
for {
select {
case <-ticker.C:
hc.HealthCheck <- struct{}{}
case <-timeoutTimer.C:
hc.Timeout <- struct{}{}
return
case <-hc.ack:
timeoutTimer.Reset(healthCheckInterval + healthCheckTimeout)
case <-ctx.Done():
return
}
}
}

View File

@@ -0,0 +1,103 @@
package healthcheck
import (
"context"
"os"
"testing"
"time"
)
func TestMain(m *testing.M) {
// override the health check interval to speed up the test
healthCheckInterval = 2 * time.Second
healthCheckTimeout = 100 * time.Millisecond
code := m.Run()
os.Exit(code)
}
func TestNewHealthPeriod(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
hc := NewSender()
go hc.StartHealthCheck(ctx)
iterations := 0
for i := 0; i < 3; i++ {
select {
case <-hc.HealthCheck:
iterations++
hc.OnHCResponse()
case <-hc.Timeout:
t.Fatalf("health check is timed out")
case <-time.After(healthCheckInterval + 100*time.Millisecond):
t.Fatalf("health check not received")
}
}
}
func TestNewHealthFailed(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
hc := NewSender()
go hc.StartHealthCheck(ctx)
select {
case <-hc.Timeout:
case <-time.After(healthCheckInterval + healthCheckTimeout + 100*time.Millisecond):
t.Fatalf("health check is not timed out")
}
}
func TestNewHealthcheckStop(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
hc := NewSender()
go hc.StartHealthCheck(ctx)
time.Sleep(100 * time.Millisecond)
cancel()
select {
case _, ok := <-hc.HealthCheck:
if ok {
t.Fatalf("health check on received")
}
case _, ok := <-hc.Timeout:
if ok {
t.Fatalf("health check on received")
}
case <-ctx.Done():
// expected
case <-time.After(10 * time.Second):
t.Fatalf("is not exited")
}
}
func TestTimeoutReset(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
hc := NewSender()
go hc.StartHealthCheck(ctx)
iterations := 0
for i := 0; i < 3; i++ {
select {
case <-hc.HealthCheck:
iterations++
hc.OnHCResponse()
case <-hc.Timeout:
t.Fatalf("health check is timed out")
case <-time.After(healthCheckInterval + 100*time.Millisecond):
t.Fatalf("health check not received")
}
}
select {
case <-hc.HealthCheck:
case <-hc.Timeout:
// expected
case <-ctx.Done():
t.Fatalf("context is done")
case <-time.After(10 * time.Second):
t.Fatalf("is not exited")
}
}

13
relay/main.go Normal file
View File

@@ -0,0 +1,13 @@
package main
import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/cmd"
)
func main() {
if err := cmd.Execute(); err != nil {
log.Fatalf("failed to execute command: %v", err)
}
}

View File

@@ -0,0 +1,30 @@
package address
import (
"bytes"
"encoding/gob"
"fmt"
)
type Address struct {
URL string
}
func (addr *Address) Marshal() ([]byte, error) {
var buf bytes.Buffer
enc := gob.NewEncoder(&buf)
if err := enc.Encode(addr); err != nil {
return nil, fmt.Errorf("encode Address: %w", err)
}
return buf.Bytes(), nil
}
func Unmarshal(data []byte) (*Address, error) {
var addr Address
buf := bytes.NewBuffer(data)
dec := gob.NewDecoder(buf)
if err := dec.Decode(&addr); err != nil {
return nil, fmt.Errorf("decode Address: %w", err)
}
return &addr, nil
}

View File

@@ -0,0 +1,51 @@
package auth
import (
"bytes"
"encoding/gob"
"fmt"
)
type Algorithm int
const (
AlgoUnknown Algorithm = iota
AlgoHMACSHA256
AlgoHMACSHA512
)
func (a Algorithm) String() string {
switch a {
case AlgoHMACSHA256:
return "HMAC-SHA256"
case AlgoHMACSHA512:
return "HMAC-SHA512"
default:
return "Unknown"
}
}
type Msg struct {
AuthAlgorithm Algorithm
AdditionalData []byte
}
func (msg *Msg) Marshal() ([]byte, error) {
var buf bytes.Buffer
enc := gob.NewEncoder(&buf)
if err := enc.Encode(msg); err != nil {
return nil, fmt.Errorf("encode Msg: %w", err)
}
return buf.Bytes(), nil
}
func UnmarshalMsg(data []byte) (*Msg, error) {
var msg *Msg
buf := bytes.NewBuffer(data)
dec := gob.NewDecoder(buf)
if err := dec.Decode(&msg); err != nil {
return nil, fmt.Errorf("decode Msg: %w", err)
}
return msg, nil
}

5
relay/messages/doc.go Normal file
View File

@@ -0,0 +1,5 @@
/*
Package messages provides the message types that are used to communicate between the relay and the client.
This package is used to determine the type of message that is being sent and received between the relay and the client.
*/
package messages

31
relay/messages/id.go Normal file
View File

@@ -0,0 +1,31 @@
package messages
import (
"crypto/sha256"
"encoding/base64"
"fmt"
)
const (
prefixLength = 4
IDSize = prefixLength + sha256.Size
)
var (
prefix = []byte("sha-") // 4 bytes
)
// HashID generates a sha256 hash from the peerID and returns the hash and the human-readable string
func HashID(peerID string) ([]byte, string) {
idHash := sha256.Sum256([]byte(peerID))
idHashString := string(prefix) + base64.StdEncoding.EncodeToString(idHash[:])
var prefixedHash []byte
prefixedHash = append(prefixedHash, prefix...)
prefixedHash = append(prefixedHash, idHash[:]...)
return prefixedHash, idHashString
}
// HashIDToString converts a hash to a human-readable string
func HashIDToString(idHash []byte) string {
return fmt.Sprintf("%s%s", idHash[:prefixLength], base64.StdEncoding.EncodeToString(idHash[prefixLength:]))
}

13
relay/messages/id_test.go Normal file
View File

@@ -0,0 +1,13 @@
package messages
import (
"testing"
)
func TestHashID(t *testing.T) {
hashedID, hashedStringId := HashID("alice")
enc := HashIDToString(hashedID)
if enc != hashedStringId {
t.Errorf("expected %s, got %s", hashedStringId, enc)
}
}

239
relay/messages/message.go Normal file
View File

@@ -0,0 +1,239 @@
package messages
import (
"bytes"
"errors"
"fmt"
)
const (
MsgTypeUnknown MsgType = 0
MsgTypeHello MsgType = 1
MsgTypeHelloResponse MsgType = 2
MsgTypeTransport MsgType = 3
MsgTypeClose MsgType = 4
MsgTypeHealthCheck MsgType = 5
SizeOfVersionByte = 1
SizeOfMsgType = 1
SizeOfProtoHeader = SizeOfVersionByte + SizeOfMsgType
sizeOfMagicByte = 4
headerSizeTransport = IDSize
headerSizeHello = sizeOfMagicByte + IDSize
headerSizeHelloResp = 0
MaxHandshakeSize = 8192
CurrentProtocolVersion = 1
)
var (
ErrInvalidMessageLength = errors.New("invalid message length")
ErrUnsupportedVersion = errors.New("unsupported version")
magicHeader = []byte{0x21, 0x12, 0xA4, 0x42}
healthCheckMsg = []byte{byte(CurrentProtocolVersion), byte(MsgTypeHealthCheck)}
)
type MsgType byte
func (m MsgType) String() string {
switch m {
case MsgTypeHello:
return "hello"
case MsgTypeHelloResponse:
return "hello response"
case MsgTypeTransport:
return "transport"
case MsgTypeClose:
return "close"
case MsgTypeHealthCheck:
return "health check"
default:
return "unknown"
}
}
type HelloResponse struct {
InstanceAddress string
}
// ValidateVersion checks if the given version is supported by the protocol
func ValidateVersion(msg []byte) (int, error) {
if len(msg) < SizeOfVersionByte {
return 0, ErrInvalidMessageLength
}
version := int(msg[0])
if version != CurrentProtocolVersion {
return 0, fmt.Errorf("%d: %w", version, ErrUnsupportedVersion)
}
return version, nil
}
// DetermineClientMessageType determines the message type from the first the message
func DetermineClientMessageType(msg []byte) (MsgType, error) {
if len(msg) < SizeOfMsgType {
return 0, ErrInvalidMessageLength
}
msgType := MsgType(msg[0])
switch msgType {
case
MsgTypeHello,
MsgTypeTransport,
MsgTypeClose,
MsgTypeHealthCheck:
return msgType, nil
default:
return MsgTypeUnknown, fmt.Errorf("invalid msg type %d", msgType)
}
}
// DetermineServerMessageType determines the message type from the first the message
func DetermineServerMessageType(msg []byte) (MsgType, error) {
if len(msg) < SizeOfMsgType {
return 0, ErrInvalidMessageLength
}
msgType := MsgType(msg[0])
switch msgType {
case
MsgTypeHelloResponse,
MsgTypeTransport,
MsgTypeClose,
MsgTypeHealthCheck:
return msgType, nil
default:
return MsgTypeUnknown, fmt.Errorf("invalid msg type %d", msgType)
}
}
// MarshalHelloMsg initial hello message
// The Hello message is the first message sent by a client after establishing a connection with the Relay server. This
// message is used to authenticate the client with the server. The authentication is done using an HMAC method.
// The protocol does not limit to use HMAC, it can be any other method. If the authentication failed the server will
// close the network connection without any response.
func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) {
if len(peerID) != IDSize {
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
}
msg := make([]byte, SizeOfProtoHeader+sizeOfMagicByte, SizeOfProtoHeader+headerSizeHello+len(additions))
msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeHello)
copy(msg[SizeOfProtoHeader:SizeOfProtoHeader+sizeOfMagicByte], magicHeader)
msg = append(msg, peerID...)
msg = append(msg, additions...)
return msg, nil
}
// UnmarshalHelloMsg extracts peerID and the additional data from the hello message. The Additional data is used to
// authenticate the client with the server.
func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) {
if len(msg) < headerSizeHello {
return nil, nil, ErrInvalidMessageLength
}
if !bytes.Equal(msg[:sizeOfMagicByte], magicHeader) {
return nil, nil, errors.New("invalid magic header")
}
return msg[sizeOfMagicByte:headerSizeHello], msg[headerSizeHello:], nil
}
// MarshalHelloResponse creates a response message to the hello message.
// In case of success connection the server response with a Hello Response message. This message contains the server's
// instance URL. This URL will be used by choose the common Relay server in case if the peers are in different Relay
// servers.
func MarshalHelloResponse(additionalData []byte) ([]byte, error) {
msg := make([]byte, SizeOfProtoHeader, SizeOfProtoHeader+headerSizeHelloResp+len(additionalData))
msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeHelloResponse)
msg = append(msg, additionalData...)
return msg, nil
}
// UnmarshalHelloResponse extracts the additional data from the hello response message.
func UnmarshalHelloResponse(msg []byte) ([]byte, error) {
if len(msg) < headerSizeHelloResp {
return nil, ErrInvalidMessageLength
}
return msg, nil
}
// MarshalCloseMsg creates a close message.
// The close message is used to close the connection gracefully between the client and the server. The server and the
// client can send this message. After receiving this message, the server or client will close the connection.
func MarshalCloseMsg() []byte {
msg := make([]byte, SizeOfProtoHeader)
msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeClose)
return msg
}
// MarshalTransportMsg creates a transport message.
// The transport message is used to exchange data between peers. The message contains the data to be exchanged and the
// destination peer hashed ID.
func MarshalTransportMsg(peerID []byte, payload []byte) ([]byte, error) {
if len(peerID) != IDSize {
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
}
msg := make([]byte, SizeOfProtoHeader+headerSizeTransport, SizeOfProtoHeader+headerSizeTransport+len(payload))
msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeTransport)
copy(msg[SizeOfProtoHeader:], peerID)
msg = append(msg, payload...)
return msg, nil
}
// UnmarshalTransportMsg extracts the peerID and the payload from the transport message.
func UnmarshalTransportMsg(buf []byte) ([]byte, []byte, error) {
if len(buf) < headerSizeTransport {
return nil, nil, ErrInvalidMessageLength
}
return buf[:headerSizeTransport], buf[headerSizeTransport:], nil
}
// UnmarshalTransportID extracts the peerID from the transport message.
func UnmarshalTransportID(buf []byte) ([]byte, error) {
if len(buf) < headerSizeTransport {
return nil, ErrInvalidMessageLength
}
return buf[:headerSizeTransport], nil
}
// UpdateTransportMsg updates the peerID in the transport message.
// With this function the server can reuse the given byte slice to update the peerID in the transport message. So do
// need to allocate a new byte slice.
func UpdateTransportMsg(msg []byte, peerID []byte) error {
if len(msg) < len(peerID) {
return ErrInvalidMessageLength
}
copy(msg, peerID)
return nil
}
// MarshalHealthcheck creates a health check message.
// Health check message is sent by the server periodically. The client will respond with a health check response
// message. If the client does not respond to the health check message, the server will close the connection.
func MarshalHealthcheck() []byte {
return healthCheckMsg
}

View File

@@ -0,0 +1,43 @@
package messages
import (
"testing"
)
func TestMarshalHelloMsg(t *testing.T) {
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
bHello, err := MarshalHelloMsg(peerID, nil)
if err != nil {
t.Fatalf("error: %v", err)
}
receivedPeerID, _, err := UnmarshalHelloMsg(bHello[SizeOfProtoHeader:])
if err != nil {
t.Fatalf("error: %v", err)
}
if string(receivedPeerID) != string(peerID) {
t.Errorf("expected %s, got %s", peerID, receivedPeerID)
}
}
func TestMarshalTransportMsg(t *testing.T) {
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
payload := []byte("payload")
msg, err := MarshalTransportMsg(peerID, payload)
if err != nil {
t.Fatalf("error: %v", err)
}
id, respPayload, err := UnmarshalTransportMsg(msg[SizeOfProtoHeader:])
if err != nil {
t.Fatalf("error: %v", err)
}
if string(id) != string(peerID) {
t.Errorf("expected %s, got %s", peerID, id)
}
if string(respPayload) != string(payload) {
t.Errorf("expected %s, got %s", payload, respPayload)
}
}

136
relay/metrics/realy.go Normal file
View File

@@ -0,0 +1,136 @@
package metrics
import (
"context"
"sync"
"time"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/metric"
)
const (
idleTimeout = 30 * time.Second
)
type Metrics struct {
metric.Meter
TransferBytesSent metric.Int64Counter
TransferBytesRecv metric.Int64Counter
peers metric.Int64UpDownCounter
peerActivityChan chan string
peerLastActive map[string]time.Time
mutexActivity sync.Mutex
ctx context.Context
}
func NewMetrics(ctx context.Context, meter metric.Meter) (*Metrics, error) {
bytesSent, err := meter.Int64Counter("relay_transfer_sent_bytes_total")
if err != nil {
return nil, err
}
bytesRecv, err := meter.Int64Counter("relay_transfer_received_bytes_total")
if err != nil {
return nil, err
}
peers, err := meter.Int64UpDownCounter("relay_peers")
if err != nil {
return nil, err
}
peersActive, err := meter.Int64ObservableGauge("relay_peers_active")
if err != nil {
return nil, err
}
peersIdle, err := meter.Int64ObservableGauge("relay_peers_idle")
if err != nil {
return nil, err
}
m := &Metrics{
Meter: meter,
TransferBytesSent: bytesSent,
TransferBytesRecv: bytesRecv,
peers: peers,
ctx: ctx,
peerActivityChan: make(chan string, 10),
peerLastActive: make(map[string]time.Time),
}
_, err = meter.RegisterCallback(
func(ctx context.Context, o metric.Observer) error {
active, idle := m.calculateActiveIdleConnections()
o.ObserveInt64(peersActive, active)
o.ObserveInt64(peersIdle, idle)
return nil
},
peersActive, peersIdle,
)
if err != nil {
return nil, err
}
go m.readPeerActivity()
return m, nil
}
// PeerConnected increments the number of connected peers and increments number of idle connections
func (m *Metrics) PeerConnected(id string) {
m.peers.Add(m.ctx, 1)
m.mutexActivity.Lock()
defer m.mutexActivity.Unlock()
m.peerLastActive[id] = time.Time{}
}
// PeerDisconnected decrements the number of connected peers and decrements number of idle or active connections
func (m *Metrics) PeerDisconnected(id string) {
m.peers.Add(m.ctx, -1)
m.mutexActivity.Lock()
defer m.mutexActivity.Unlock()
delete(m.peerLastActive, id)
}
// PeerActivity increases the active connections
func (m *Metrics) PeerActivity(peerID string) {
select {
case m.peerActivityChan <- peerID:
default:
log.Errorf("peer activity channel is full, dropping activity metrics for peer %s", peerID)
}
}
func (m *Metrics) calculateActiveIdleConnections() (int64, int64) {
active, idle := int64(0), int64(0)
m.mutexActivity.Lock()
defer m.mutexActivity.Unlock()
for _, lastActive := range m.peerLastActive {
if time.Since(lastActive) > idleTimeout {
idle++
} else {
active++
}
}
return active, idle
}
func (m *Metrics) readPeerActivity() {
for {
select {
case peerID := <-m.peerActivityChan:
m.mutexActivity.Lock()
m.peerLastActive[peerID] = time.Now()
m.mutexActivity.Unlock()
case <-m.ctx.Done():
return
}
}
}

View File

@@ -0,0 +1,11 @@
package listener
import (
"context"
"net"
)
type Listener interface {
Listen(func(conn net.Conn)) error
Shutdown(ctx context.Context) error
}

View File

@@ -0,0 +1,114 @@
package ws
import (
"context"
"errors"
"fmt"
"io"
"net"
"sync"
"time"
log "github.com/sirupsen/logrus"
"nhooyr.io/websocket"
)
const (
writeTimeout = 10 * time.Second
)
type Conn struct {
*websocket.Conn
lAddr *net.TCPAddr
rAddr *net.TCPAddr
closed bool
closedMu sync.Mutex
ctx context.Context
}
func NewConn(wsConn *websocket.Conn, lAddr, rAddr *net.TCPAddr) *Conn {
return &Conn{
Conn: wsConn,
lAddr: lAddr,
rAddr: rAddr,
ctx: context.Background(),
}
}
func (c *Conn) Read(b []byte) (n int, err error) {
t, r, err := c.Reader(c.ctx)
if err != nil {
return 0, c.ioErrHandling(err)
}
if t != websocket.MessageBinary {
log.Errorf("unexpected message type: %d", t)
return 0, fmt.Errorf("unexpected message type")
}
n, err = r.Read(b)
if err != nil {
return 0, c.ioErrHandling(err)
}
return n, err
}
// Write writes a binary message with the given payload.
// It does not block until fill the internal buffer.
// If the buffer filled up, wait until the buffer is drained or timeout.
func (c *Conn) Write(b []byte) (int, error) {
ctx, ctxCancel := context.WithTimeout(c.ctx, writeTimeout)
defer ctxCancel()
err := c.Conn.Write(ctx, websocket.MessageBinary, b)
return len(b), err
}
func (c *Conn) LocalAddr() net.Addr {
return c.lAddr
}
func (c *Conn) RemoteAddr() net.Addr {
return c.rAddr
}
func (c *Conn) SetReadDeadline(t time.Time) error {
return fmt.Errorf("SetReadDeadline is not implemented")
}
func (c *Conn) SetWriteDeadline(t time.Time) error {
return fmt.Errorf("SetWriteDeadline is not implemented")
}
func (c *Conn) SetDeadline(t time.Time) error {
return fmt.Errorf("SetDeadline is not implemented")
}
func (c *Conn) Close() error {
c.closedMu.Lock()
c.closed = true
c.closedMu.Unlock()
return c.Conn.CloseNow()
}
func (c *Conn) isClosed() bool {
c.closedMu.Lock()
defer c.closedMu.Unlock()
return c.closed
}
func (c *Conn) ioErrHandling(err error) error {
if c.isClosed() {
return io.EOF
}
var wErr *websocket.CloseError
if !errors.As(err, &wErr) {
return err
}
if wErr.Code == websocket.StatusNormalClosure {
return io.EOF
}
return err
}

View File

@@ -0,0 +1,92 @@
package ws
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
log "github.com/sirupsen/logrus"
"nhooyr.io/websocket"
)
// URLPath is the path for the websocket connection.
const URLPath = "/relay"
type Listener struct {
// Address is the address to listen on.
Address string
// TLSConfig is the TLS configuration for the server.
TLSConfig *tls.Config
server *http.Server
acceptFn func(conn net.Conn)
}
func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
l.acceptFn = acceptFn
mux := http.NewServeMux()
mux.HandleFunc(URLPath, l.onAccept)
l.server = &http.Server{
Addr: l.Address,
Handler: mux,
TLSConfig: l.TLSConfig,
}
log.Infof("WS server listening address: %s", l.Address)
var err error
if l.TLSConfig != nil {
err = l.server.ListenAndServeTLS("", "")
} else {
err = l.server.ListenAndServe()
}
if errors.Is(err, http.ErrServerClosed) {
return nil
}
return err
}
func (l *Listener) Shutdown(ctx context.Context) error {
if l.server == nil {
return nil
}
log.Infof("stop WS listener")
if err := l.server.Shutdown(ctx); err != nil {
return fmt.Errorf("server shutdown failed: %v", err)
}
log.Infof("WS listener stopped")
return nil
}
func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) {
wsConn, err := websocket.Accept(w, r, nil)
if err != nil {
log.Errorf("failed to accept ws connection from %s: %s", r.RemoteAddr, err)
return
}
rAddr, err := net.ResolveTCPAddr("tcp", r.RemoteAddr)
if err != nil {
err = wsConn.Close(websocket.StatusInternalError, "internal error")
if err != nil {
log.Errorf("failed to close ws connection: %s", err)
}
return
}
lAddr, err := net.ResolveTCPAddr("tcp", l.server.Addr)
if err != nil {
err = wsConn.Close(websocket.StatusInternalError, "internal error")
if err != nil {
log.Errorf("failed to close ws connection: %s", err)
}
return
}
conn := NewConn(wsConn, lAddr, rAddr)
l.acceptFn(conn)
}

203
relay/server/peer.go Normal file
View File

@@ -0,0 +1,203 @@
package server
import (
"context"
"io"
"net"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/healthcheck"
"github.com/netbirdio/netbird/relay/messages"
"github.com/netbirdio/netbird/relay/metrics"
)
const (
bufferSize = 8820
)
// Peer represents a peer connection
type Peer struct {
metrics *metrics.Metrics
log *log.Entry
idS string
idB []byte
conn net.Conn
connMu sync.RWMutex
store *Store
}
// NewPeer creates a new Peer instance and prepare custom logging
func NewPeer(metrics *metrics.Metrics, id []byte, conn net.Conn, store *Store) *Peer {
stringID := messages.HashIDToString(id)
return &Peer{
metrics: metrics,
log: log.WithField("peer_id", stringID),
idS: stringID,
idB: id,
conn: conn,
store: store,
}
}
// Work reads data from the connection
// It manages the protocol (healthcheck, transport, close). Read the message and determine the message type and handle
// the message accordingly.
func (p *Peer) Work() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
hc := healthcheck.NewSender()
go hc.StartHealthCheck(ctx)
go p.handleHealthcheckEvents(ctx, hc)
buf := make([]byte, bufferSize)
for {
n, err := p.conn.Read(buf)
if err != nil {
if err != io.EOF {
p.log.Errorf("failed to read message: %s", err)
}
return
}
if n == 0 {
p.log.Errorf("received empty message")
return
}
msg := buf[:n]
_, err = messages.ValidateVersion(msg)
if err != nil {
p.log.Warnf("failed to validate protocol version: %s", err)
return
}
msgType, err := messages.DetermineClientMessageType(msg[messages.SizeOfVersionByte:])
if err != nil {
p.log.Errorf("failed to determine message type: %s", err)
return
}
p.handleMsgType(ctx, msgType, hc, n, msg)
}
}
func (p *Peer) handleMsgType(ctx context.Context, msgType messages.MsgType, hc *healthcheck.Sender, n int, msg []byte) {
switch msgType {
case messages.MsgTypeHealthCheck:
hc.OnHCResponse()
case messages.MsgTypeTransport:
p.metrics.TransferBytesRecv.Add(ctx, int64(n))
p.metrics.PeerActivity(p.String())
p.handleTransportMsg(msg)
case messages.MsgTypeClose:
p.log.Infof("peer exited gracefully")
if err := p.conn.Close(); err != nil {
log.Errorf("failed to close connection to peer: %s", err)
}
default:
p.log.Warnf("received unexpected message type: %s", msgType)
}
}
// Write writes data to the connection
func (p *Peer) Write(b []byte) (int, error) {
p.connMu.RLock()
defer p.connMu.RUnlock()
return p.conn.Write(b)
}
// CloseGracefully closes the connection with the peer gracefully. Send a close message to the client and close the
// connection.
func (p *Peer) CloseGracefully(ctx context.Context) {
p.connMu.Lock()
err := p.writeWithTimeout(ctx, messages.MarshalCloseMsg())
if err != nil {
p.log.Errorf("failed to send close message to peer: %s", p.String())
}
err = p.conn.Close()
if err != nil {
p.log.Errorf("failed to close connection to peer: %s", err)
}
defer p.connMu.Unlock()
}
// String returns the peer ID
func (p *Peer) String() string {
return p.idS
}
func (p *Peer) writeWithTimeout(ctx context.Context, buf []byte) error {
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
writeDone := make(chan struct{})
var err error
go func() {
_, err = p.conn.Write(buf)
close(writeDone)
}()
select {
case <-ctx.Done():
return ctx.Err()
case <-writeDone:
return err
}
}
func (p *Peer) handleHealthcheckEvents(ctx context.Context, hc *healthcheck.Sender) {
for {
select {
case <-hc.HealthCheck:
_, err := p.Write(messages.MarshalHealthcheck())
if err != nil {
p.log.Errorf("failed to send healthcheck message: %s", err)
return
}
case <-hc.Timeout:
p.log.Errorf("peer healthcheck timeout")
err := p.conn.Close()
if err != nil {
p.log.Errorf("failed to close connection to peer: %s", err)
}
return
case <-ctx.Done():
return
}
}
}
func (p *Peer) handleTransportMsg(msg []byte) {
peerID, err := messages.UnmarshalTransportID(msg[messages.SizeOfProtoHeader:])
if err != nil {
p.log.Errorf("failed to unmarshal transport message: %s", err)
return
}
stringPeerID := messages.HashIDToString(peerID)
dp, ok := p.store.Peer(stringPeerID)
if !ok {
p.log.Errorf("peer not found: %s", stringPeerID)
return
}
err = messages.UpdateTransportMsg(msg[messages.SizeOfProtoHeader:], p.idB)
if err != nil {
p.log.Errorf("failed to update transport message: %s", err)
return
}
n, err := dp.Write(msg)
if err != nil {
p.log.Errorf("failed to write transport message to: %s", dp.String())
return
}
p.metrics.TransferBytesSent.Add(context.Background(), int64(n))
}

206
relay/server/relay.go Normal file
View File

@@ -0,0 +1,206 @@
package server
import (
"context"
"crypto/sha256"
"fmt"
"net"
"net/url"
"strings"
"sync"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/metric"
"github.com/netbirdio/netbird/relay/auth"
"github.com/netbirdio/netbird/relay/messages"
"github.com/netbirdio/netbird/relay/messages/address"
authmsg "github.com/netbirdio/netbird/relay/messages/auth"
"github.com/netbirdio/netbird/relay/metrics"
)
// Relay represents the relay server
type Relay struct {
metrics *metrics.Metrics
metricsCancel context.CancelFunc
validator auth.Validator
store *Store
instanceURL string
closed bool
closeMu sync.RWMutex
}
// NewRelay creates a new Relay instance
//
// Parameters:
// meter: An instance of metric.Meter from the go.opentelemetry.io/otel/metric package. It is used to create and manage
// metrics for the relay server.
// exposedAddress: A string representing the address that the relay server is exposed on. The client will use this
// address as the relay server's instance URL.
// tlsSupport: A boolean indicating whether the relay server supports TLS (Transport Layer Security) or not. The
// instance URL depends on this value.
// validator: An instance of auth.Validator from the auth package. It is used to validate the authentication of the
// peers.
//
// Returns:
// A pointer to a Relay instance and an error. If the Relay instance is successfully created, the error is nil.
// Otherwise, the error contains the details of what went wrong.
func NewRelay(meter metric.Meter, exposedAddress string, tlsSupport bool, validator auth.Validator) (*Relay, error) {
ctx, metricsCancel := context.WithCancel(context.Background())
m, err := metrics.NewMetrics(ctx, meter)
if err != nil {
metricsCancel()
return nil, fmt.Errorf("creating app metrics: %v", err)
}
r := &Relay{
metrics: m,
metricsCancel: metricsCancel,
validator: validator,
store: NewStore(),
}
r.instanceURL, err = getInstanceURL(exposedAddress, tlsSupport)
if err != nil {
metricsCancel()
return nil, fmt.Errorf("get instance URL: %v", err)
}
return r, nil
}
// getInstanceURL checks if user supplied a URL scheme otherwise adds to the
// provided address according to TLS definition and parses the address before returning it
func getInstanceURL(exposedAddress string, tlsSupported bool) (string, error) {
addr := exposedAddress
split := strings.Split(exposedAddress, "://")
switch {
case len(split) == 1 && tlsSupported:
addr = "rels://" + exposedAddress
case len(split) == 1 && !tlsSupported:
addr = "rel://" + exposedAddress
case len(split) > 2:
return "", fmt.Errorf("invalid exposed address: %s", exposedAddress)
}
parsedURL, err := url.ParseRequestURI(addr)
if err != nil {
return "", fmt.Errorf("invalid exposed address: %v", err)
}
if parsedURL.Scheme != "rel" && parsedURL.Scheme != "rels" {
return "", fmt.Errorf("invalid scheme: %s", parsedURL.Scheme)
}
return parsedURL.String(), nil
}
// Accept start to handle a new peer connection
func (r *Relay) Accept(conn net.Conn) {
r.closeMu.RLock()
defer r.closeMu.RUnlock()
if r.closed {
return
}
peerID, err := r.handshake(conn)
if err != nil {
log.Errorf("failed to handshake: %s", err)
cErr := conn.Close()
if cErr != nil {
log.Errorf("failed to close connection, %s: %s", conn.RemoteAddr(), cErr)
}
return
}
peer := NewPeer(r.metrics, peerID, conn, r.store)
peer.log.Infof("peer connected from: %s", conn.RemoteAddr())
r.store.AddPeer(peer)
r.metrics.PeerConnected(peer.String())
go func() {
peer.Work()
r.store.DeletePeer(peer)
peer.log.Debugf("relay connection closed")
r.metrics.PeerDisconnected(peer.String())
}()
}
// Shutdown closes the relay server
// It closes the connection with all peers in gracefully and stops accepting new connections.
func (r *Relay) Shutdown(ctx context.Context) {
log.Infof("close connection with all peers")
r.closeMu.Lock()
wg := sync.WaitGroup{}
peers := r.store.Peers()
for _, peer := range peers {
wg.Add(1)
go func(p *Peer) {
p.CloseGracefully(ctx)
wg.Done()
}(peer)
}
wg.Wait()
r.metricsCancel()
r.closeMu.Unlock()
}
// InstanceURL returns the instance URL of the relay server
func (r *Relay) InstanceURL() string {
return r.instanceURL
}
func (r *Relay) handshake(conn net.Conn) ([]byte, error) {
buf := make([]byte, messages.MaxHandshakeSize)
n, err := conn.Read(buf)
if err != nil {
return nil, fmt.Errorf("read from %s: %w", conn.RemoteAddr(), err)
}
_, err = messages.ValidateVersion(buf[:n])
if err != nil {
return nil, fmt.Errorf("validate version from %s: %w", conn.RemoteAddr(), err)
}
msgType, err := messages.DetermineClientMessageType(buf[messages.SizeOfVersionByte:n])
if err != nil {
return nil, fmt.Errorf("determine message type from %s: %w", conn.RemoteAddr(), err)
}
if msgType != messages.MsgTypeHello {
return nil, fmt.Errorf("invalid message type from %s", conn.RemoteAddr())
}
peerID, authData, err := messages.UnmarshalHelloMsg(buf[messages.SizeOfProtoHeader:n])
if err != nil {
return nil, fmt.Errorf("unmarshal hello message: %w", err)
}
authMsg, err := authmsg.UnmarshalMsg(authData)
if err != nil {
return nil, fmt.Errorf("unmarshal auth message: %w", err)
}
if err := r.validator.Validate(sha256.New, authMsg.AdditionalData); err != nil {
return nil, fmt.Errorf("validate %s (%s): %w", peerID, conn.RemoteAddr(), err)
}
addr := &address.Address{URL: r.instanceURL}
addrData, err := addr.Marshal()
if err != nil {
return nil, fmt.Errorf("marshal addressc to %s (%s): %w", peerID, conn.RemoteAddr(), err)
}
msg, err := messages.MarshalHelloResponse(addrData)
if err != nil {
return nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, conn.RemoteAddr(), err)
}
_, err = conn.Write(msg)
if err != nil {
return nil, fmt.Errorf("write to %s (%s): %w", peerID, conn.RemoteAddr(), err)
}
return peerID, nil
}

View File

@@ -0,0 +1,36 @@
package server
import "testing"
func TestGetInstanceURL(t *testing.T) {
tests := []struct {
name string
exposedAddress string
tlsSupported bool
expectedURL string
expectError bool
}{
{"Valid address with TLS", "example.com", true, "rels://example.com", false},
{"Valid address without TLS", "example.com", false, "rel://example.com", false},
{"Valid address with scheme", "rel://example.com", false, "rel://example.com", false},
{"Valid address with non TLS scheme and TLS true", "rel://example.com", true, "rel://example.com", false},
{"Valid address with TLS scheme", "rels://example.com", true, "rels://example.com", false},
{"Valid address with TLS scheme and TLS false", "rels://example.com", false, "rels://example.com", false},
{"Valid address with TLS scheme and custom port", "rels://example.com:9300", true, "rels://example.com:9300", false},
{"Invalid address with multiple schemes", "rel://rels://example.com", false, "", true},
{"Invalid address with unsupported scheme", "http://example.com", false, "", true},
{"Invalid address format", "://example.com", false, "", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
url, err := getInstanceURL(tt.exposedAddress, tt.tlsSupported)
if (err != nil) != tt.expectError {
t.Errorf("expected error: %v, got: %v", tt.expectError, err)
}
if url != tt.expectedURL {
t.Errorf("expected URL: %s, got: %s", tt.expectedURL, url)
}
})
}
}

76
relay/server/server.go Normal file
View File

@@ -0,0 +1,76 @@
package server
import (
"context"
"crypto/tls"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/metric"
"github.com/netbirdio/netbird/relay/auth"
"github.com/netbirdio/netbird/relay/server/listener"
"github.com/netbirdio/netbird/relay/server/listener/ws"
)
// ListenerConfig is the configuration for the listener.
// Address: the address to bind the listener to. It could be an address behind a reverse proxy.
// TLSConfig: the TLS configuration for the listener.
type ListenerConfig struct {
Address string
TLSConfig *tls.Config
}
// Server is the main entry point for the relay server.
// It is the gate between the WebSocket listener and the Relay server logic.
// In a new HTTP connection, the server will accept the connection and pass it to the Relay server via the Accept method.
type Server struct {
relay *Relay
wSListener listener.Listener
}
// NewServer creates a new relay server instance.
// meter: the OpenTelemetry meter
// exposedAddress: this address will be used as the instance URL. It should be a domain:port format.
// tlsSupport: if true, the server will support TLS
// authValidator: the auth validator to use for the server
func NewServer(meter metric.Meter, exposedAddress string, tlsSupport bool, authValidator auth.Validator) (*Server, error) {
relay, err := NewRelay(meter, exposedAddress, tlsSupport, authValidator)
if err != nil {
return nil, err
}
return &Server{
relay: relay,
}, nil
}
// Listen starts the relay server.
func (r *Server) Listen(cfg ListenerConfig) error {
r.wSListener = &ws.Listener{
Address: cfg.Address,
TLSConfig: cfg.TLSConfig,
}
wslErr := r.wSListener.Listen(r.relay.Accept)
if wslErr != nil {
log.Errorf("failed to bind ws server: %s", wslErr)
}
return wslErr
}
// Shutdown stops the relay server. If there are active connections, they will be closed gracefully. In case of a context,
// the connections will be forcefully closed.
func (r *Server) Shutdown(ctx context.Context) (err error) {
// stop service new connections
if r.wSListener != nil {
err = r.wSListener.Shutdown(ctx)
}
r.relay.Shutdown(ctx)
return
}
// InstanceURL returns the instance URL of the relay server.
func (r *Server) InstanceURL() string {
return r.relay.instanceURL
}

64
relay/server/store.go Normal file
View File

@@ -0,0 +1,64 @@
package server
import (
"sync"
)
// Store is a thread-safe store of peers
// It is used to store the peers that are connected to the relay server
type Store struct {
peers map[string]*Peer // consider to use [32]byte as key. The Peer(id string) would be faster
peersLock sync.RWMutex
}
// NewStore creates a new Store instance
func NewStore() *Store {
return &Store{
peers: make(map[string]*Peer),
}
}
// AddPeer adds a peer to the store
// todo: consider to close peer conn if the peer already exists
func (s *Store) AddPeer(peer *Peer) {
s.peersLock.Lock()
defer s.peersLock.Unlock()
s.peers[peer.String()] = peer
}
// DeletePeer deletes a peer from the store
func (s *Store) DeletePeer(peer *Peer) {
s.peersLock.Lock()
defer s.peersLock.Unlock()
dp, ok := s.peers[peer.String()]
if !ok {
return
}
if dp != peer {
return
}
delete(s.peers, peer.String())
}
// Peer returns a peer by its ID
func (s *Store) Peer(id string) (*Peer, bool) {
s.peersLock.RLock()
defer s.peersLock.RUnlock()
p, ok := s.peers[id]
return p, ok
}
// Peers returns all the peers in the store
func (s *Store) Peers() []*Peer {
s.peersLock.RLock()
defer s.peersLock.RUnlock()
peers := make([]*Peer, 0, len(s.peers))
for _, p := range s.peers {
peers = append(peers, p)
}
return peers
}

View File

@@ -0,0 +1,40 @@
package server
import (
"context"
"testing"
"go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/relay/metrics"
)
func TestStore_DeletePeer(t *testing.T) {
s := NewStore()
m, _ := metrics.NewMetrics(context.Background(), otel.Meter(""))
p := NewPeer(m, []byte("peer_one"), nil, nil)
s.AddPeer(p)
s.DeletePeer(p)
if _, ok := s.Peer(p.String()); ok {
t.Errorf("peer was not deleted")
}
}
func TestStore_DeleteDeprecatedPeer(t *testing.T) {
s := NewStore()
m, _ := metrics.NewMetrics(context.Background(), otel.Meter(""))
p1 := NewPeer(m, []byte("peer_id"), nil, nil)
p2 := NewPeer(m, []byte("peer_id"), nil, nil)
s.AddPeer(p1)
s.AddPeer(p2)
s.DeletePeer(p1)
if _, ok := s.Peer(p2.String()); !ok {
t.Errorf("second peer was deleted")
}
}

View File

@@ -0,0 +1,386 @@
package test
import (
"context"
"crypto/rand"
"fmt"
"net"
"os"
"sync"
"testing"
"time"
"github.com/pion/logging"
"github.com/pion/turn/v3"
"go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/relay/auth/allow"
"github.com/netbirdio/netbird/relay/auth/hmac"
"github.com/netbirdio/netbird/relay/client"
"github.com/netbirdio/netbird/relay/server"
"github.com/netbirdio/netbird/util"
)
var (
av = &allow.Auth{}
hmacTokenStore = &hmac.TokenStore{}
pairs = []int{1, 5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100}
dataSize = 1024 * 1024 * 10
)
func TestMain(m *testing.M) {
_ = util.InitLog("error", "console")
code := m.Run()
os.Exit(code)
}
func TestRelayDataTransfer(t *testing.T) {
t.SkipNow() // skip this test on CI because it is a benchmark test
testData, err := seedRandomData(dataSize)
if err != nil {
t.Fatalf("failed to seed random data: %s", err)
}
for _, peerPairs := range pairs {
t.Run(fmt.Sprintf("peerPairs-%d", peerPairs), func(t *testing.T) {
transfer(t, testData, peerPairs)
})
}
}
// TestTurnDataTransfer run turn server:
// docker run --rm --name coturn -d --network=host coturn/coturn --user test:test
func TestTurnDataTransfer(t *testing.T) {
t.SkipNow() // skip this test on CI because it is a benchmark test
testData, err := seedRandomData(dataSize)
if err != nil {
t.Fatalf("failed to seed random data: %s", err)
}
for _, peerPairs := range pairs {
t.Run(fmt.Sprintf("peerPairs-%d", peerPairs), func(t *testing.T) {
runTurnTest(t, testData, peerPairs)
})
}
}
func transfer(t *testing.T, testData []byte, peerPairs int) {
t.Helper()
ctx := context.Background()
port := 35000 + peerPairs
serverAddress := fmt.Sprintf("127.0.0.1:%d", port)
serverConnURL := fmt.Sprintf("rel://%s", serverAddress)
srv, err := server.NewServer(otel.Meter(""), serverConnURL, false, av)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
listenCfg := server.ListenerConfig{Address: serverAddress}
err := srv.Listen(listenCfg)
if err != nil {
errChan <- err
}
}()
defer func() {
err := srv.Shutdown(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
// wait for server to start
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
clientsSender := make([]*client.Client, peerPairs)
for i := 0; i < cap(clientsSender); i++ {
c := client.NewClient(ctx, serverConnURL, hmacTokenStore, "sender-"+fmt.Sprint(i))
err := c.Connect()
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
clientsSender[i] = c
}
clientsReceiver := make([]*client.Client, peerPairs)
for i := 0; i < cap(clientsReceiver); i++ {
c := client.NewClient(ctx, serverConnURL, hmacTokenStore, "receiver-"+fmt.Sprint(i))
err := c.Connect()
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
clientsReceiver[i] = c
}
connsSender := make([]net.Conn, 0, peerPairs)
connsReceiver := make([]net.Conn, 0, peerPairs)
for i := 0; i < len(clientsSender); i++ {
conn, err := clientsSender[i].OpenConn("receiver-" + fmt.Sprint(i))
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
connsSender = append(connsSender, conn)
conn, err = clientsReceiver[i].OpenConn("sender-" + fmt.Sprint(i))
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
connsReceiver = append(connsReceiver, conn)
}
var transferDuration []time.Duration
wg := sync.WaitGroup{}
var writeErr error
var readErr error
for i := 0; i < len(connsSender); i++ {
wg.Add(2)
start := time.Now()
go func(i int) {
defer wg.Done()
pieceSize := 1024
testDataLen := len(testData)
for j := 0; j < testDataLen; j += pieceSize {
end := j + pieceSize
if end > testDataLen {
end = testDataLen
}
_, writeErr = connsSender[i].Write(testData[j:end])
if writeErr != nil {
return
}
}
}(i)
go func(i int, start time.Time) {
defer wg.Done()
buf := make([]byte, 8192)
rcv := 0
var n int
for receivedSize := 0; receivedSize < len(testData); {
n, readErr = connsReceiver[i].Read(buf)
if readErr != nil {
return
}
receivedSize += n
rcv += n
}
transferDuration = append(transferDuration, time.Since(start))
}(i, start)
}
wg.Wait()
if writeErr != nil {
t.Fatalf("failed to write to channel: %s", err)
}
if readErr != nil {
t.Fatalf("failed to read from channel: %s", err)
}
// calculate the megabytes per second from the average transferDuration against the dataSize
var totalDuration time.Duration
for _, d := range transferDuration {
totalDuration += d
}
avgDuration := totalDuration / time.Duration(len(transferDuration))
mbps := float64(len(testData)) / avgDuration.Seconds() / 1024 / 1024
t.Logf("average transfer duration: %s", avgDuration)
t.Logf("average transfer speed: %.2f MB/s", mbps)
for i := 0; i < len(connsSender); i++ {
err := connsSender[i].Close()
if err != nil {
t.Errorf("failed to close connection: %s", err)
}
err = connsReceiver[i].Close()
if err != nil {
t.Errorf("failed to close connection: %s", err)
}
}
}
func runTurnTest(t *testing.T, testData []byte, maxPairs int) {
t.Helper()
var transferDuration []time.Duration
var wg sync.WaitGroup
for i := 0; i < maxPairs; i++ {
wg.Add(1)
go func() {
defer wg.Done()
d := runTurnDataTransfer(t, testData)
transferDuration = append(transferDuration, d)
}()
}
wg.Wait()
var totalDuration time.Duration
for _, d := range transferDuration {
totalDuration += d
}
avgDuration := totalDuration / time.Duration(len(transferDuration))
mbps := float64(len(testData)) / avgDuration.Seconds() / 1024 / 1024
t.Logf("average transfer duration: %s", avgDuration)
t.Logf("average transfer speed: %.2f MB/s", mbps)
}
func runTurnDataTransfer(t *testing.T, testData []byte) time.Duration {
t.Helper()
testDataLen := len(testData)
relayAddress := "192.168.0.10:3478"
conn, err := net.Dial("tcp", relayAddress)
if err != nil {
t.Fatal(err)
}
defer func(conn net.Conn) {
_ = conn.Close()
}(conn)
turnClient, err := getTurnClient(t, relayAddress, conn)
if err != nil {
t.Fatal(err)
}
defer turnClient.Close()
relayConn, err := turnClient.Allocate()
if err != nil {
t.Fatal(err)
}
defer func(relayConn net.PacketConn) {
_ = relayConn.Close()
}(relayConn)
receiverConn, err := net.Dial("udp", relayConn.LocalAddr().String())
if err != nil {
t.Fatal(err)
}
defer func(receiverConn net.Conn) {
_ = receiverConn.Close()
}(receiverConn)
var (
tb int
start time.Time
timerInit bool
readDone = make(chan struct{})
ack = make([]byte, 1)
)
go func() {
defer func() {
readDone <- struct{}{}
}()
buff := make([]byte, 8192)
for {
n, e := receiverConn.Read(buff)
if e != nil {
return
}
if !timerInit {
start = time.Now()
timerInit = true
}
tb += n
_, _ = receiverConn.Write(ack)
if tb >= testDataLen {
return
}
}
}()
pieceSize := 1024
ackBuff := make([]byte, 1)
pipelineSize := 10
for j := 0; j < testDataLen; j += pieceSize {
end := j + pieceSize
if end > testDataLen {
end = testDataLen
}
_, err := relayConn.WriteTo(testData[j:end], receiverConn.LocalAddr())
if err != nil {
t.Fatalf("failed to write to channel: %s", err)
}
if pipelineSize == 0 {
_, _, _ = relayConn.ReadFrom(ackBuff)
} else {
pipelineSize--
}
}
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
select {
case <-readDone:
if tb != testDataLen {
t.Fatalf("failed to read all data: %d/%d", tb, testDataLen)
}
case <-ctx.Done():
t.Fatal("timeout")
}
return time.Since(start)
}
func getTurnClient(t *testing.T, address string, conn net.Conn) (*turn.Client, error) {
t.Helper()
// Dial TURN Server
addrStr := fmt.Sprintf("%s:%d", address, 443)
fac := logging.NewDefaultLoggerFactory()
//fac.DefaultLogLevel = logging.LogLevelTrace
// Start a new TURN Client and wrap our net.Conn in a STUNConn
// This allows us to simulate datagram based communication over a net.Conn
cfg := &turn.ClientConfig{
TURNServerAddr: address,
Conn: turn.NewSTUNConn(conn),
Username: "test",
Password: "test",
LoggerFactory: fac,
}
client, err := turn.NewClient(cfg)
if err != nil {
return nil, fmt.Errorf("failed to create TURN client for server %s: %s", addrStr, err)
}
// Start listening on the conn provided.
err = client.Listen()
if err != nil {
client.Close()
return nil, fmt.Errorf("failed to listen on TURN client for server %s: %s", addrStr, err)
}
return client, nil
}
func seedRandomData(size int) ([]byte, error) {
token := make([]byte, size)
_, err := rand.Read(token)
if err != nil {
return nil, err
}
return token, nil
}
func waitForServerToStart(errChan chan error) error {
select {
case err := <-errChan:
if err != nil {
return err
}
case <-time.After(300 * time.Millisecond):
return nil
}
return nil
}

258
relay/testec2/main.go Normal file
View File

@@ -0,0 +1,258 @@
//go:build linux || darwin
package main
import (
"crypto/rand"
"flag"
"fmt"
"net"
"os"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/util"
)
const (
errMsgFailedReadTCP = "failed to read from tcp: %s"
)
var (
dataSize = 1024 * 1024 * 50 // 50MB
pairs = []int{1, 5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100}
signalListenAddress = ":8081"
relaySrvAddress string
turnSrvAddress string
signalURL string
udpListener string // used for TURN test
)
type testResult struct {
numOfPairs int
duration time.Duration
speed float64
}
func (tr testResult) Speed() string {
speed := tr.speed
var unit string
switch {
case speed < 1024:
unit = "B/s"
case speed < 1048576:
speed /= 1024
unit = "KB/s"
case speed < 1073741824:
speed /= 1048576
unit = "MB/s"
default:
speed /= 1073741824
unit = "GB/s"
}
return fmt.Sprintf("%.2f %s", speed, unit)
}
func seedRandomData(size int) ([]byte, error) {
token := make([]byte, size)
_, err := rand.Read(token)
if err != nil {
return nil, err
}
return token, nil
}
func avg(transferDuration []time.Duration) (time.Duration, float64) {
var totalDuration time.Duration
for _, d := range transferDuration {
totalDuration += d
}
avgDuration := totalDuration / time.Duration(len(transferDuration))
bps := float64(dataSize) / avgDuration.Seconds()
return avgDuration, bps
}
func RelayReceiverMain() []testResult {
testResults := make([]testResult, 0, len(pairs))
for _, p := range pairs {
tr := testResult{numOfPairs: p}
td := relayReceive(relaySrvAddress, p)
tr.duration, tr.speed = avg(td)
testResults = append(testResults, tr)
}
return testResults
}
func RelaySenderMain() {
log.Infof("starting sender")
log.Infof("starting seed phase")
testData, err := seedRandomData(dataSize)
if err != nil {
log.Fatalf("failed to seed random data: %s", err)
}
log.Infof("data size: %d", len(testData))
for n, p := range pairs {
log.Infof("running test with %d pairs", p)
relayTransfer(relaySrvAddress, testData, p)
// grant time to prepare new receivers
if n < len(pairs)-1 {
time.Sleep(3 * time.Second)
}
}
}
// TRUNSenderMain is the sender
// - allocate turn clients
// - send relayed addresses to signal server in batch
// - wait for signal server to send back addresses in a map
// - send test data to each address in parallel
func TRUNSenderMain() {
log.Infof("starting TURN sender test")
log.Infof("starting seed random data: %d", dataSize)
testData, err := seedRandomData(dataSize)
if err != nil {
log.Fatalf("failed to seed random data: %s", err)
}
ss := SignalClient{signalURL}
for _, p := range pairs {
log.Infof("running test with %d pairs", p)
turnSender := &TurnSender{}
createTurnConns(p, turnSender)
log.Infof("send addresses via signal server: %d", len(turnSender.addresses))
clientAddresses, err := ss.SendAddress(turnSender.addresses)
if err != nil {
log.Fatalf("failed to send address: %s", err)
}
log.Infof("received addresses: %v", clientAddresses.Address)
createSenderDevices(turnSender, clientAddresses)
log.Infof("waiting for tcpListeners to be ready")
time.Sleep(2 * time.Second)
tcpConns := make([]net.Conn, 0, len(turnSender.devices))
for i := range turnSender.devices {
addr := fmt.Sprintf("10.0.%d.2:9999", i)
log.Infof("dialing: %s", addr)
tcpConn, err := net.Dial("tcp", addr)
if err != nil {
log.Fatalf("failed to dial tcp: %s", err)
}
tcpConns = append(tcpConns, tcpConn)
}
log.Infof("start test data transfer for %d pairs", p)
testDataLen := len(testData)
wg := sync.WaitGroup{}
wg.Add(len(tcpConns))
for i, tcpConn := range tcpConns {
log.Infof("sending test data to device: %d", i)
go runTurnWriting(tcpConn, testData, testDataLen, &wg)
}
wg.Wait()
for _, d := range turnSender.devices {
_ = d.Close()
}
log.Infof("test finished with %d pairs", p)
}
}
func TURNReaderMain() []testResult {
log.Infof("starting TURN receiver test")
si := NewSignalService()
go func() {
log.Infof("starting signal server")
err := si.Listen(signalListenAddress)
if err != nil {
log.Errorf("failed to listen: %s", err)
}
}()
testResults := make([]testResult, 0, len(pairs))
for range pairs {
addresses := <-si.AddressesChan
instanceNumber := len(addresses)
log.Infof("received addresses: %d", instanceNumber)
turnReceiver := &TurnReceiver{}
err := createDevices(addresses, turnReceiver)
if err != nil {
log.Fatalf("%s", err)
}
// send client addresses back via signal server
si.ClientAddressChan <- turnReceiver.clientAddresses
durations := make(chan time.Duration, instanceNumber)
for _, device := range turnReceiver.devices {
go runTurnReading(device, durations)
}
durationsList := make([]time.Duration, 0, instanceNumber)
for d := range durations {
durationsList = append(durationsList, d)
if len(durationsList) == instanceNumber {
close(durations)
}
}
avgDuration, avgSpeed := avg(durationsList)
ts := testResult{
numOfPairs: len(durationsList),
duration: avgDuration,
speed: avgSpeed,
}
testResults = append(testResults, ts)
for _, d := range turnReceiver.devices {
_ = d.Close()
}
}
return testResults
}
func main() {
var mode string
_ = util.InitLog("debug", "console")
flag.StringVar(&mode, "mode", "sender", "sender or receiver mode")
flag.Parse()
relaySrvAddress = os.Getenv("TEST_RELAY_SERVER") // rel://ip:port
turnSrvAddress = os.Getenv("TEST_TURN_SERVER") // ip:3478
signalURL = os.Getenv("TEST_SIGNAL_URL") // http://receiver_ip:8081
udpListener = os.Getenv("TEST_UDP_LISTENER") // IP:0
if mode == "receiver" {
relayResult := RelayReceiverMain()
turnResults := TURNReaderMain()
for i := 0; i < len(turnResults); i++ {
log.Infof("pairs: %d,\tRelay speed:\t%s,\trelay duration:\t%s", relayResult[i].numOfPairs, relayResult[i].Speed(), relayResult[i].duration)
log.Infof("pairs: %d,\tTURN speed:\t%s,\tturn duration:\t%s", turnResults[i].numOfPairs, turnResults[i].Speed(), turnResults[i].duration)
}
} else {
RelaySenderMain()
// grant time for receiver to start
time.Sleep(3 * time.Second)
TRUNSenderMain()
}
}

176
relay/testec2/relay.go Normal file
View File

@@ -0,0 +1,176 @@
//go:build linux || darwin
package main
import (
"context"
"fmt"
"net"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/auth/hmac"
"github.com/netbirdio/netbird/relay/client"
)
var (
hmacTokenStore = &hmac.TokenStore{}
)
func relayTransfer(serverConnURL string, testData []byte, peerPairs int) {
connsSender := prepareConnsSender(serverConnURL, peerPairs)
defer func() {
for i := 0; i < len(connsSender); i++ {
err := connsSender[i].Close()
if err != nil {
log.Errorf("failed to close connection: %s", err)
}
}
}()
wg := sync.WaitGroup{}
wg.Add(len(connsSender))
for _, conn := range connsSender {
go func(conn net.Conn) {
defer wg.Done()
runWriter(conn, testData)
}(conn)
}
wg.Wait()
}
func runWriter(conn net.Conn, testData []byte) {
si := NewStartInidication(time.Now(), len(testData))
_, err := conn.Write(si)
if err != nil {
log.Errorf("failed to write to channel: %s", err)
return
}
log.Infof("sent start indication")
pieceSize := 1024
testDataLen := len(testData)
for j := 0; j < testDataLen; j += pieceSize {
end := j + pieceSize
if end > testDataLen {
end = testDataLen
}
_, writeErr := conn.Write(testData[j:end])
if writeErr != nil {
log.Errorf("failed to write to channel: %s", writeErr)
return
}
}
}
func prepareConnsSender(serverConnURL string, peerPairs int) []net.Conn {
ctx := context.Background()
clientsSender := make([]*client.Client, peerPairs)
for i := 0; i < cap(clientsSender); i++ {
c := client.NewClient(ctx, serverConnURL, hmacTokenStore, "sender-"+fmt.Sprint(i))
if err := c.Connect(); err != nil {
log.Fatalf("failed to connect to server: %s", err)
}
clientsSender[i] = c
}
connsSender := make([]net.Conn, 0, peerPairs)
for i := 0; i < len(clientsSender); i++ {
conn, err := clientsSender[i].OpenConn("receiver-" + fmt.Sprint(i))
if err != nil {
log.Fatalf("failed to bind channel: %s", err)
}
connsSender = append(connsSender, conn)
}
return connsSender
}
func relayReceive(serverConnURL string, peerPairs int) []time.Duration {
connsReceiver := prepareConnsReceiver(serverConnURL, peerPairs)
defer func() {
for i := 0; i < len(connsReceiver); i++ {
if err := connsReceiver[i].Close(); err != nil {
log.Errorf("failed to close connection: %s", err)
}
}
}()
durations := make(chan time.Duration, len(connsReceiver))
wg := sync.WaitGroup{}
for _, conn := range connsReceiver {
wg.Add(1)
go func(conn net.Conn) {
defer wg.Done()
duration := runReader(conn)
durations <- duration
}(conn)
}
wg.Wait()
durationsList := make([]time.Duration, 0, len(connsReceiver))
for d := range durations {
durationsList = append(durationsList, d)
if len(durationsList) == len(connsReceiver) {
close(durations)
}
}
return durationsList
}
func runReader(conn net.Conn) time.Duration {
buf := make([]byte, 8192)
n, readErr := conn.Read(buf)
if readErr != nil {
log.Errorf("failed to read from channel: %s", readErr)
return 0
}
si := DecodeStartIndication(buf[:n])
log.Infof("received start indication: %v", si)
receivedSize, err := conn.Read(buf)
if err != nil {
log.Fatalf("failed to read from relay: %s", err)
}
now := time.Now()
rcv := 0
for receivedSize < si.TransferSize {
n, readErr = conn.Read(buf)
if readErr != nil {
log.Errorf("failed to read from channel: %s", readErr)
return 0
}
receivedSize += n
rcv += n
}
return time.Since(now)
}
func prepareConnsReceiver(serverConnURL string, peerPairs int) []net.Conn {
clientsReceiver := make([]*client.Client, peerPairs)
for i := 0; i < cap(clientsReceiver); i++ {
c := client.NewClient(context.Background(), serverConnURL, hmacTokenStore, "receiver-"+fmt.Sprint(i))
err := c.Connect()
if err != nil {
log.Fatalf("failed to connect to server: %s", err)
}
clientsReceiver[i] = c
}
connsReceiver := make([]net.Conn, 0, peerPairs)
for i := 0; i < len(clientsReceiver); i++ {
conn, err := clientsReceiver[i].OpenConn("sender-" + fmt.Sprint(i))
if err != nil {
log.Fatalf("failed to bind channel: %s", err)
}
connsReceiver = append(connsReceiver, conn)
}
return connsReceiver
}

91
relay/testec2/signal.go Normal file
View File

@@ -0,0 +1,91 @@
//go:build linux || darwin
package main
import (
"bytes"
"encoding/json"
"net/http"
log "github.com/sirupsen/logrus"
)
type PeerAddr struct {
Address []string
}
type ClientPeerAddr struct {
Address map[string]string
}
type Signal struct {
AddressesChan chan []string
ClientAddressChan chan map[string]string
}
func NewSignalService() *Signal {
return &Signal{
AddressesChan: make(chan []string),
ClientAddressChan: make(chan map[string]string),
}
}
func (rs *Signal) Listen(listenAddr string) error {
http.HandleFunc("/", rs.onNewAddresses)
return http.ListenAndServe(listenAddr, nil)
}
func (rs *Signal) onNewAddresses(w http.ResponseWriter, r *http.Request) {
var msg PeerAddr
err := json.NewDecoder(r.Body).Decode(&msg)
if err != nil {
log.Errorf("Error decoding message: %v", err)
}
log.Infof("received addresses: %d", len(msg.Address))
rs.AddressesChan <- msg.Address
clientAddresses := <-rs.ClientAddressChan
respMsg := ClientPeerAddr{
Address: clientAddresses,
}
data, err := json.Marshal(respMsg)
if err != nil {
log.Errorf("Error marshalling message: %v", err)
return
}
_, err = w.Write(data)
if err != nil {
log.Errorf("Error writing response: %v", err)
}
}
type SignalClient struct {
SignalURL string
}
func (ss SignalClient) SendAddress(addresses []string) (*ClientPeerAddr, error) {
msg := PeerAddr{
Address: addresses,
}
data, err := json.Marshal(msg)
if err != nil {
return nil, err
}
response, err := http.Post(ss.SignalURL, "application/json", bytes.NewBuffer(data))
if err != nil {
return nil, err
}
defer response.Body.Close()
log.Debugf("wait for signal response")
var respPeerAddress ClientPeerAddr
err = json.NewDecoder(response.Body).Decode(&respPeerAddress)
if err != nil {
return nil, err
}
return &respPeerAddress, nil
}

View File

@@ -0,0 +1,39 @@
//go:build linux || darwin
package main
import (
"bytes"
"encoding/gob"
"time"
log "github.com/sirupsen/logrus"
)
type StartIndication struct {
Started time.Time
TransferSize int
}
func NewStartInidication(started time.Time, transferSize int) []byte {
si := StartIndication{
Started: started,
TransferSize: transferSize,
}
var data bytes.Buffer
err := gob.NewEncoder(&data).Encode(si)
if err != nil {
log.Fatal("encode error:", err)
}
return data.Bytes()
}
func DecodeStartIndication(data []byte) StartIndication {
var si StartIndication
err := gob.NewDecoder(bytes.NewReader(data)).Decode(&si)
if err != nil {
log.Fatal("decode error:", err)
}
return si
}

View File

@@ -0,0 +1,72 @@
//go:build linux || darwin
package tun
import (
"net"
"sync/atomic"
log "github.com/sirupsen/logrus"
)
type Proxy struct {
Device *Device
PConn net.PacketConn
DstAddr net.Addr
shutdownFlag atomic.Bool
}
func (p *Proxy) Start() {
go p.readFromDevice()
go p.readFromConn()
}
func (p *Proxy) Close() {
p.shutdownFlag.Store(true)
}
func (p *Proxy) readFromDevice() {
buf := make([]byte, 1500)
for {
n, err := p.Device.Read(buf)
if err != nil {
if p.shutdownFlag.Load() {
return
}
log.Errorf("failed to read from device: %s", err)
return
}
_, err = p.PConn.WriteTo(buf[:n], p.DstAddr)
if err != nil {
if p.shutdownFlag.Load() {
return
}
log.Errorf("failed to write to conn: %s", err)
return
}
}
}
func (p *Proxy) readFromConn() {
buf := make([]byte, 1500)
for {
n, _, err := p.PConn.ReadFrom(buf)
if err != nil {
if p.shutdownFlag.Load() {
return
}
log.Errorf("failed to read from conn: %s", err)
return
}
_, err = p.Device.Write(buf[:n])
if err != nil {
if p.shutdownFlag.Load() {
return
}
log.Errorf("failed to write to device: %s", err)
return
}
}
}

110
relay/testec2/tun/tun.go Normal file
View File

@@ -0,0 +1,110 @@
//go:build linux || darwin
package tun
import (
"net"
log "github.com/sirupsen/logrus"
"github.com/songgao/water"
"github.com/vishvananda/netlink"
)
type Device struct {
Name string
IP string
PConn net.PacketConn
DstAddr net.Addr
iFace *water.Interface
proxy *Proxy
}
func (d *Device) Up() error {
cfg := water.Config{
DeviceType: water.TUN,
PlatformSpecificParams: water.PlatformSpecificParams{
Name: d.Name,
},
}
iFace, err := water.New(cfg)
if err != nil {
return err
}
d.iFace = iFace
err = d.assignIP()
if err != nil {
return err
}
err = d.bringUp()
if err != nil {
return err
}
d.proxy = &Proxy{
Device: d,
PConn: d.PConn,
DstAddr: d.DstAddr,
}
d.proxy.Start()
return nil
}
func (d *Device) Close() error {
if d.proxy != nil {
d.proxy.Close()
}
if d.iFace != nil {
return d.iFace.Close()
}
return nil
}
func (d *Device) Read(b []byte) (int, error) {
return d.iFace.Read(b)
}
func (d *Device) Write(b []byte) (int, error) {
return d.iFace.Write(b)
}
func (d *Device) assignIP() error {
iface, err := netlink.LinkByName(d.Name)
if err != nil {
log.Errorf("failed to get TUN device: %v", err)
return err
}
ip := net.IPNet{
IP: net.ParseIP(d.IP),
Mask: net.CIDRMask(24, 32),
}
addr := &netlink.Addr{
IPNet: &ip,
}
err = netlink.AddrAdd(iface, addr)
if err != nil {
log.Errorf("failed to add IP address: %v", err)
return err
}
return nil
}
func (d *Device) bringUp() error {
iface, err := netlink.LinkByName(d.Name)
if err != nil {
log.Errorf("failed to get device: %v", err)
return err
}
// Bring the interface up
err = netlink.LinkSetUp(iface)
if err != nil {
log.Errorf("failed to set device up: %v", err)
return err
}
return nil
}

181
relay/testec2/turn.go Normal file
View File

@@ -0,0 +1,181 @@
//go:build linux || darwin
package main
import (
"fmt"
"net"
"sync"
"time"
"github.com/netbirdio/netbird/relay/testec2/tun"
log "github.com/sirupsen/logrus"
)
type TurnReceiver struct {
conns []*net.UDPConn
clientAddresses map[string]string
devices []*tun.Device
}
type TurnSender struct {
turnConns map[string]*TurnConn
addresses []string
devices []*tun.Device
}
func runTurnWriting(tcpConn net.Conn, testData []byte, testDataLen int, wg *sync.WaitGroup) {
defer wg.Done()
defer tcpConn.Close()
log.Infof("start to sending test data: %s", tcpConn.RemoteAddr())
si := NewStartInidication(time.Now(), testDataLen)
_, err := tcpConn.Write(si)
if err != nil {
log.Errorf("failed to write to tcp: %s", err)
return
}
pieceSize := 1024
for j := 0; j < testDataLen; j += pieceSize {
end := j + pieceSize
if end > testDataLen {
end = testDataLen
}
_, writeErr := tcpConn.Write(testData[j:end])
if writeErr != nil {
log.Errorf("failed to write to tcp conn: %s", writeErr)
return
}
}
// grant time to flush out packages
time.Sleep(3 * time.Second)
}
func createSenderDevices(sender *TurnSender, clientAddresses *ClientPeerAddr) {
var i int
devices := make([]*tun.Device, 0, len(clientAddresses.Address))
for k, v := range clientAddresses.Address {
tc, ok := sender.turnConns[k]
if !ok {
log.Fatalf("failed to find turn conn: %s", k)
}
addr, err := net.ResolveUDPAddr("udp", v)
if err != nil {
log.Fatalf("failed to resolve udp address: %s", err)
}
device := &tun.Device{
Name: fmt.Sprintf("mtun-sender-%d", i),
IP: fmt.Sprintf("10.0.%d.1", i),
PConn: tc.relayConn,
DstAddr: addr,
}
err = device.Up()
if err != nil {
log.Fatalf("failed to bring up device: %s", err)
}
devices = append(devices, device)
i++
}
sender.devices = devices
}
func createTurnConns(p int, sender *TurnSender) {
turnConns := make(map[string]*TurnConn)
addresses := make([]string, 0, len(pairs))
for i := 0; i < p; i++ {
tc := AllocateTurnClient(turnSrvAddress)
log.Infof("allocated turn client: %s", tc.Address().String())
turnConns[tc.Address().String()] = tc
addresses = append(addresses, tc.Address().String())
}
sender.turnConns = turnConns
sender.addresses = addresses
}
func runTurnReading(d *tun.Device, durations chan time.Duration) {
tcpListener, err := net.Listen("tcp", d.IP+":9999")
if err != nil {
log.Fatalf("failed to listen on tcp: %s", err)
}
log := log.WithField("device", tcpListener.Addr())
tcpConn, err := tcpListener.Accept()
if err != nil {
_ = tcpListener.Close()
log.Fatalf("failed to accept connection: %s", err)
}
log.Infof("remote peer connected")
buf := make([]byte, 103)
n, err := tcpConn.Read(buf)
if err != nil {
_ = tcpListener.Close()
log.Fatalf(errMsgFailedReadTCP, err)
}
si := DecodeStartIndication(buf[:n])
log.Infof("received start indication: %v, %d", si, n)
buf = make([]byte, 8192)
i, err := tcpConn.Read(buf)
if err != nil {
_ = tcpListener.Close()
log.Fatalf(errMsgFailedReadTCP, err)
}
now := time.Now()
for i < si.TransferSize {
n, err := tcpConn.Read(buf)
if err != nil {
_ = tcpListener.Close()
log.Fatalf(errMsgFailedReadTCP, err)
}
i += n
}
durations <- time.Since(now)
}
func createDevices(addresses []string, receiver *TurnReceiver) error {
receiver.conns = make([]*net.UDPConn, 0, len(addresses))
receiver.clientAddresses = make(map[string]string, len(addresses))
receiver.devices = make([]*tun.Device, 0, len(addresses))
for i, addr := range addresses {
localAddr, err := net.ResolveUDPAddr("udp", udpListener)
if err != nil {
return fmt.Errorf("failed to resolve UDP address: %s", err)
}
conn, err := net.ListenUDP("udp", localAddr)
if err != nil {
return fmt.Errorf("failed to create UDP connection: %s", err)
}
receiver.conns = append(receiver.conns, conn)
receiver.clientAddresses[addr] = conn.LocalAddr().String()
dstAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return fmt.Errorf("failed to resolve address: %s", err)
}
device := &tun.Device{
Name: fmt.Sprintf("mtun-%d", i),
IP: fmt.Sprintf("10.0.%d.2", i),
PConn: conn,
DstAddr: dstAddr,
}
if err = device.Up(); err != nil {
return fmt.Errorf("failed to bring up device: %s, %s", device.Name, err)
}
receiver.devices = append(receiver.devices, device)
}
return nil
}

View File

@@ -0,0 +1,83 @@
//go:build linux || darwin
package main
import (
"fmt"
"net"
"github.com/pion/logging"
"github.com/pion/turn/v3"
log "github.com/sirupsen/logrus"
)
type TurnConn struct {
conn net.Conn
turnClient *turn.Client
relayConn net.PacketConn
}
func (tc *TurnConn) Address() net.Addr {
return tc.relayConn.LocalAddr()
}
func (tc *TurnConn) Close() {
_ = tc.relayConn.Close()
tc.turnClient.Close()
_ = tc.conn.Close()
}
func AllocateTurnClient(serverAddr string) *TurnConn {
conn, err := net.Dial("tcp", serverAddr)
if err != nil {
log.Fatal(err)
}
turnClient, err := getTurnClient(serverAddr, conn)
if err != nil {
log.Fatal(err)
}
relayConn, err := turnClient.Allocate()
if err != nil {
log.Fatal(err)
}
return &TurnConn{
conn: conn,
turnClient: turnClient,
relayConn: relayConn,
}
}
func getTurnClient(address string, conn net.Conn) (*turn.Client, error) {
// Dial TURN Server
addrStr := fmt.Sprintf("%s:%d", address, 443)
fac := logging.NewDefaultLoggerFactory()
//fac.DefaultLogLevel = logging.LogLevelTrace
// Start a new TURN Client and wrap our net.Conn in a STUNConn
// This allows us to simulate datagram based communication over a net.Conn
cfg := &turn.ClientConfig{
TURNServerAddr: address,
Conn: turn.NewSTUNConn(conn),
Username: "test",
Password: "test",
LoggerFactory: fac,
}
client, err := turn.NewClient(cfg)
if err != nil {
return nil, fmt.Errorf("failed to create TURN client for server %s: %s", addrStr, err)
}
// Start listening on the conn provided.
err = client.Listen()
if err != nil {
client.Close()
return nil, fmt.Errorf("failed to listen on TURN client for server %s: %s", addrStr, err)
}
return client, nil
}