mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
[relay] Feature/relay integration (#2244)
This update adds new relay integration for NetBird clients. The new relay is based on web sockets and listens on a single port. - Adds new relay implementation with websocket with single port relaying mechanism - refactor peer connection logic, allowing upgrade and downgrade from/to P2P connection - peer connections are faster since it connects first to relay and then upgrades to P2P - maintains compatibility with old clients by not using the new relay - updates infrastructure scripts with new relay service
This commit is contained in:
4
relay/Dockerfile
Normal file
4
relay/Dockerfile
Normal file
@@ -0,0 +1,4 @@
|
||||
FROM gcr.io/distroless/base:debug
|
||||
ENTRYPOINT [ "/go/bin/netbird-relay" ]
|
||||
ENV NB_LOG_FILE=console
|
||||
COPY netbird-relay /go/bin/netbird-relay
|
||||
12
relay/auth/allow/allow_all.go
Normal file
12
relay/auth/allow/allow_all.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package allow
|
||||
|
||||
import "hash"
|
||||
|
||||
// Auth is a Validator that allows all connections.
|
||||
// Used this for testing purposes only.
|
||||
type Auth struct {
|
||||
}
|
||||
|
||||
func (a *Auth) Validate(func() hash.Hash, any) error {
|
||||
return nil
|
||||
}
|
||||
26
relay/auth/doc.go
Normal file
26
relay/auth/doc.go
Normal file
@@ -0,0 +1,26 @@
|
||||
/*
|
||||
Package auth manages the authentication process with the relay server.
|
||||
|
||||
Key Components:
|
||||
|
||||
Validator: The Validator interface defines the Validate method. Any type that provides this method can be used as a
|
||||
Validator.
|
||||
|
||||
Methods:
|
||||
|
||||
Validate(func() hash.Hash, any): This method is defined in the Validator interface and is used to validate the authentication.
|
||||
|
||||
Usage:
|
||||
|
||||
To create a new AllowAllAuth validator, simply instantiate it:
|
||||
|
||||
validator := &allow.Auth{}
|
||||
|
||||
To validate the authentication, use the Validate method:
|
||||
|
||||
err := validator.Validate(sha256.New, any)
|
||||
|
||||
This package provides a simple and effective way to manage authentication with the relay server, ensuring that the
|
||||
peers are authenticated properly.
|
||||
*/
|
||||
package auth
|
||||
8
relay/auth/hmac/doc.go
Normal file
8
relay/auth/hmac/doc.go
Normal file
@@ -0,0 +1,8 @@
|
||||
/*
|
||||
This package uses a similar HMAC method for authentication with the TURN server. The Management server provides the
|
||||
tokens for the peers. The peers manage these tokens in the token store. The token store is a simple thread safe store
|
||||
that keeps the tokens in memory. These tokens are used to authenticate the peers with the Relay server in the hello
|
||||
message.
|
||||
*/
|
||||
|
||||
package hmac
|
||||
36
relay/auth/hmac/store.go
Normal file
36
relay/auth/hmac/store.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package hmac
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// TokenStore is a simple in-memory store for token
|
||||
// With this can update the token in thread safe way
|
||||
type TokenStore struct {
|
||||
mu sync.Mutex
|
||||
token []byte
|
||||
}
|
||||
|
||||
func (a *TokenStore) UpdateToken(token *Token) error {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
if token == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
t, err := marshalToken(*token)
|
||||
if err != nil {
|
||||
log.Debugf("failed to marshal token: %s", err)
|
||||
return err
|
||||
}
|
||||
a.token = t
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *TokenStore) TokenBinary() []byte {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
return a.token
|
||||
}
|
||||
105
relay/auth/hmac/token.go
Normal file
105
relay/auth/hmac/token.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package hmac
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/hmac"
|
||||
"encoding/base64"
|
||||
"encoding/gob"
|
||||
"fmt"
|
||||
"hash"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type Token struct {
|
||||
Payload string
|
||||
Signature string
|
||||
}
|
||||
|
||||
func marshalToken(token Token) ([]byte, error) {
|
||||
var buffer bytes.Buffer
|
||||
encoder := gob.NewEncoder(&buffer)
|
||||
err := encoder.Encode(token)
|
||||
if err != nil {
|
||||
log.Debugf("failed to marshal token: %s", err)
|
||||
return nil, fmt.Errorf("failed to marshal token: %w", err)
|
||||
}
|
||||
return buffer.Bytes(), nil
|
||||
}
|
||||
|
||||
func unmarshalToken(payload []byte) (Token, error) {
|
||||
var creds Token
|
||||
buffer := bytes.NewBuffer(payload)
|
||||
decoder := gob.NewDecoder(buffer)
|
||||
err := decoder.Decode(&creds)
|
||||
return creds, err
|
||||
}
|
||||
|
||||
// TimedHMAC generates a token with TTL and uses a pre-shared secret known to the relay server
|
||||
type TimedHMAC struct {
|
||||
secret string
|
||||
timeToLive time.Duration
|
||||
}
|
||||
|
||||
// NewTimedHMAC creates a new TimedHMAC instance
|
||||
func NewTimedHMAC(secret string, timeToLive time.Duration) *TimedHMAC {
|
||||
return &TimedHMAC{
|
||||
secret: secret,
|
||||
timeToLive: timeToLive,
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateToken generates new time-based secret token - basically Payload is a unix timestamp and Signature is a HMAC
|
||||
// hash of a timestamp with a preshared TURN secret
|
||||
func (m *TimedHMAC) GenerateToken(algo func() hash.Hash) (*Token, error) {
|
||||
timeAuth := time.Now().Add(m.timeToLive).Unix()
|
||||
timeStamp := strconv.FormatInt(timeAuth, 10)
|
||||
|
||||
checksum, err := m.generate(algo, timeStamp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Token{
|
||||
Payload: timeStamp,
|
||||
Signature: base64.StdEncoding.EncodeToString(checksum),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Validate checks if the token is valid
|
||||
func (m *TimedHMAC) Validate(algo func() hash.Hash, token Token) error {
|
||||
expectedMAC, err := m.generate(algo, token.Payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
expectedSignature := base64.StdEncoding.EncodeToString(expectedMAC)
|
||||
|
||||
if !hmac.Equal([]byte(expectedSignature), []byte(token.Signature)) {
|
||||
return fmt.Errorf("signature mismatch")
|
||||
}
|
||||
|
||||
timeAuthInt, err := strconv.ParseInt(token.Payload, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid payload: %w", err)
|
||||
}
|
||||
|
||||
if time.Now().Unix() > timeAuthInt {
|
||||
return fmt.Errorf("expired token")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *TimedHMAC) generate(algo func() hash.Hash, payload string) ([]byte, error) {
|
||||
mac := hmac.New(algo, []byte(m.secret))
|
||||
_, err := mac.Write([]byte(payload))
|
||||
if err != nil {
|
||||
log.Debugf("failed to generate token: %s", err)
|
||||
return nil, fmt.Errorf("failed to generate token: %w", err)
|
||||
}
|
||||
|
||||
return mac.Sum(nil), nil
|
||||
}
|
||||
105
relay/auth/hmac/token_test.go
Normal file
105
relay/auth/hmac/token_test.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package hmac
|
||||
|
||||
import (
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestGenerateCredentials(t *testing.T) {
|
||||
secret := "secret"
|
||||
timeToLive := 1 * time.Hour
|
||||
v := NewTimedHMAC(secret, timeToLive)
|
||||
|
||||
creds, err := v.GenerateToken(sha1.New)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if creds.Payload == "" {
|
||||
t.Fatalf("expected non-empty payload")
|
||||
}
|
||||
|
||||
_, err = strconv.ParseInt(creds.Payload, 10, 64)
|
||||
if err != nil {
|
||||
t.Fatalf("expected payload to be a valid unix timestamp, got %v", err)
|
||||
}
|
||||
|
||||
_, err = base64.StdEncoding.DecodeString(creds.Signature)
|
||||
if err != nil {
|
||||
t.Fatalf("expected signature to be base64 encoded, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateCredentials(t *testing.T) {
|
||||
secret := "supersecret"
|
||||
timeToLive := 1 * time.Hour
|
||||
manager := NewTimedHMAC(secret, timeToLive)
|
||||
|
||||
// Test valid token
|
||||
creds, err := manager.GenerateToken(sha1.New)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if err := manager.Validate(sha1.New, *creds); err != nil {
|
||||
t.Fatalf("expected valid token: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvalidSignature(t *testing.T) {
|
||||
secret := "supersecret"
|
||||
timeToLive := 1 * time.Hour
|
||||
manager := NewTimedHMAC(secret, timeToLive)
|
||||
|
||||
creds, err := manager.GenerateToken(sha256.New)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
invalidCreds := &Token{
|
||||
Payload: creds.Payload,
|
||||
Signature: "invalidsignature",
|
||||
}
|
||||
|
||||
if err = manager.Validate(sha1.New, *invalidCreds); err == nil {
|
||||
t.Fatalf("expected invalid token due to signature mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpired(t *testing.T) {
|
||||
secret := "supersecret"
|
||||
v := NewTimedHMAC(secret, -1*time.Hour)
|
||||
expiredCreds, err := v.GenerateToken(sha256.New)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if err = v.Validate(sha1.New, *expiredCreds); err == nil {
|
||||
t.Fatalf("expected invalid token due to expiration")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvalidPayload(t *testing.T) {
|
||||
secret := "supersecret"
|
||||
timeToLive := 1 * time.Hour
|
||||
v := NewTimedHMAC(secret, timeToLive)
|
||||
|
||||
creds, err := v.GenerateToken(sha256.New)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
// Test invalid payload
|
||||
invalidPayloadCreds := &Token{
|
||||
Payload: "invalidtimestamp",
|
||||
Signature: creds.Signature,
|
||||
}
|
||||
|
||||
if err = v.Validate(sha1.New, *invalidPayloadCreds); err == nil {
|
||||
t.Fatalf("expected invalid token due to invalid payload")
|
||||
}
|
||||
}
|
||||
33
relay/auth/hmac/validator.go
Normal file
33
relay/auth/hmac/validator.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package hmac
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"hash"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type TimedHMACValidator struct {
|
||||
*TimedHMAC
|
||||
}
|
||||
|
||||
func NewTimedHMACValidator(secret string, duration time.Duration) *TimedHMACValidator {
|
||||
ta := NewTimedHMAC(secret, duration)
|
||||
return &TimedHMACValidator{
|
||||
ta,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *TimedHMACValidator) Validate(algo func() hash.Hash, credentials any) error {
|
||||
b, ok := credentials.([]byte)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid credentials type")
|
||||
}
|
||||
c, err := unmarshalToken(b)
|
||||
if err != nil {
|
||||
log.Debugf("failed to unmarshal token: %s", err)
|
||||
return err
|
||||
}
|
||||
return a.TimedHMAC.Validate(algo, c)
|
||||
}
|
||||
8
relay/auth/validator.go
Normal file
8
relay/auth/validator.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package auth
|
||||
|
||||
import "hash"
|
||||
|
||||
// Validator is an interface that defines the Validate method.
|
||||
type Validator interface {
|
||||
Validate(func() hash.Hash, any) error
|
||||
}
|
||||
13
relay/client/addr.go
Normal file
13
relay/client/addr.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package client
|
||||
|
||||
type RelayAddr struct {
|
||||
addr string
|
||||
}
|
||||
|
||||
func (a RelayAddr) Network() string {
|
||||
return "relay"
|
||||
}
|
||||
|
||||
func (a RelayAddr) String() string {
|
||||
return a.addr
|
||||
}
|
||||
553
relay/client/client.go
Normal file
553
relay/client/client.go
Normal file
@@ -0,0 +1,553 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
auth "github.com/netbirdio/netbird/relay/auth/hmac"
|
||||
"github.com/netbirdio/netbird/relay/client/dialer/ws"
|
||||
"github.com/netbirdio/netbird/relay/healthcheck"
|
||||
"github.com/netbirdio/netbird/relay/messages"
|
||||
"github.com/netbirdio/netbird/relay/messages/address"
|
||||
auth2 "github.com/netbirdio/netbird/relay/messages/auth"
|
||||
)
|
||||
|
||||
const (
|
||||
bufferSize = 8820
|
||||
serverResponseTimeout = 8 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
ErrConnAlreadyExists = fmt.Errorf("connection already exists")
|
||||
)
|
||||
|
||||
type internalStopFlag struct {
|
||||
sync.Mutex
|
||||
stop bool
|
||||
}
|
||||
|
||||
func newInternalStopFlag() *internalStopFlag {
|
||||
return &internalStopFlag{}
|
||||
}
|
||||
|
||||
func (isf *internalStopFlag) set() {
|
||||
isf.Lock()
|
||||
defer isf.Unlock()
|
||||
isf.stop = true
|
||||
}
|
||||
|
||||
func (isf *internalStopFlag) isSet() bool {
|
||||
isf.Lock()
|
||||
defer isf.Unlock()
|
||||
return isf.stop
|
||||
}
|
||||
|
||||
// Msg carry the payload from the server to the client. With this struct, the net.Conn can free the buffer.
|
||||
type Msg struct {
|
||||
Payload []byte
|
||||
|
||||
bufPool *sync.Pool
|
||||
bufPtr *[]byte
|
||||
}
|
||||
|
||||
func (m *Msg) Free() {
|
||||
m.bufPool.Put(m.bufPtr)
|
||||
}
|
||||
|
||||
type connContainer struct {
|
||||
conn *Conn
|
||||
messages chan Msg
|
||||
msgChanLock sync.Mutex
|
||||
closed bool // flag to check if channel is closed
|
||||
}
|
||||
|
||||
func newConnContainer(conn *Conn, messages chan Msg) *connContainer {
|
||||
return &connContainer{
|
||||
conn: conn,
|
||||
messages: messages,
|
||||
}
|
||||
}
|
||||
|
||||
func (cc *connContainer) writeMsg(msg Msg) {
|
||||
cc.msgChanLock.Lock()
|
||||
defer cc.msgChanLock.Unlock()
|
||||
if cc.closed {
|
||||
return
|
||||
}
|
||||
cc.messages <- msg
|
||||
}
|
||||
|
||||
func (cc *connContainer) close() {
|
||||
cc.msgChanLock.Lock()
|
||||
defer cc.msgChanLock.Unlock()
|
||||
if cc.closed {
|
||||
return
|
||||
}
|
||||
close(cc.messages)
|
||||
cc.closed = true
|
||||
}
|
||||
|
||||
// Client is a client for the relay server. It is responsible for establishing a connection to the relay server and
|
||||
// managing connections to other peers. All exported functions are safe to call concurrently. After close the connection,
|
||||
// the client can be reused by calling Connect again. When the client is closed, all connections are closed too.
|
||||
// While the Connect is in progress, the OpenConn function will block until the connection is established with relay server.
|
||||
type Client struct {
|
||||
log *log.Entry
|
||||
parentCtx context.Context
|
||||
connectionURL string
|
||||
authTokenStore *auth.TokenStore
|
||||
hashedID []byte
|
||||
|
||||
bufPool *sync.Pool
|
||||
|
||||
relayConn net.Conn
|
||||
conns map[string]*connContainer
|
||||
serviceIsRunning bool
|
||||
mu sync.Mutex // protect serviceIsRunning and conns
|
||||
readLoopMutex sync.Mutex
|
||||
wgReadLoop sync.WaitGroup
|
||||
instanceURL *RelayAddr
|
||||
muInstanceURL sync.Mutex
|
||||
|
||||
onDisconnectListener func()
|
||||
listenerMutex sync.Mutex
|
||||
}
|
||||
|
||||
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
|
||||
func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client {
|
||||
hashedID, hashedStringId := messages.HashID(peerID)
|
||||
return &Client{
|
||||
log: log.WithField("client_id", hashedStringId),
|
||||
parentCtx: ctx,
|
||||
connectionURL: serverURL,
|
||||
authTokenStore: authTokenStore,
|
||||
hashedID: hashedID,
|
||||
bufPool: &sync.Pool{
|
||||
New: func() any {
|
||||
buf := make([]byte, bufferSize)
|
||||
return &buf
|
||||
},
|
||||
},
|
||||
conns: make(map[string]*connContainer),
|
||||
}
|
||||
}
|
||||
|
||||
// Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs.
|
||||
func (c *Client) Connect() error {
|
||||
c.log.Infof("connecting to relay server: %s", c.connectionURL)
|
||||
c.readLoopMutex.Lock()
|
||||
defer c.readLoopMutex.Unlock()
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.serviceIsRunning {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := c.connect()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.serviceIsRunning = true
|
||||
|
||||
c.wgReadLoop.Add(1)
|
||||
go c.readLoop(c.relayConn)
|
||||
|
||||
c.log.Infof("relay connection established with: %s", c.connectionURL)
|
||||
return nil
|
||||
}
|
||||
|
||||
// OpenConn create a new net.Conn for the destination peer ID. In case if the connection is in progress
|
||||
// to the relay server, the function will block until the connection is established or timed out. Otherwise,
|
||||
// it will return immediately.
|
||||
// todo: what should happen if call with the same peerID with multiple times?
|
||||
func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if !c.serviceIsRunning {
|
||||
return nil, fmt.Errorf("relay connection is not established")
|
||||
}
|
||||
|
||||
hashedID, hashedStringID := messages.HashID(dstPeerID)
|
||||
_, ok := c.conns[hashedStringID]
|
||||
if ok {
|
||||
return nil, ErrConnAlreadyExists
|
||||
}
|
||||
|
||||
log.Infof("open connection to peer: %s", hashedStringID)
|
||||
msgChannel := make(chan Msg, 2)
|
||||
conn := NewConn(c, hashedID, hashedStringID, msgChannel, c.instanceURL)
|
||||
|
||||
c.conns[hashedStringID] = newConnContainer(conn, msgChannel)
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// ServerInstanceURL returns the address of the relay server. It could change after the close and reopen the connection.
|
||||
func (c *Client) ServerInstanceURL() (string, error) {
|
||||
c.muInstanceURL.Lock()
|
||||
defer c.muInstanceURL.Unlock()
|
||||
if c.instanceURL == nil {
|
||||
return "", fmt.Errorf("relay connection is not established")
|
||||
}
|
||||
return c.instanceURL.String(), nil
|
||||
}
|
||||
|
||||
// SetOnDisconnectListener sets a function that will be called when the connection to the relay server is closed.
|
||||
func (c *Client) SetOnDisconnectListener(fn func()) {
|
||||
c.listenerMutex.Lock()
|
||||
defer c.listenerMutex.Unlock()
|
||||
c.onDisconnectListener = fn
|
||||
}
|
||||
|
||||
// HasConns returns true if there are connections.
|
||||
func (c *Client) HasConns() bool {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return len(c.conns) > 0
|
||||
}
|
||||
|
||||
// Close closes the connection to the relay server and all connections to other peers.
|
||||
func (c *Client) Close() error {
|
||||
return c.close(true)
|
||||
}
|
||||
|
||||
func (c *Client) connect() error {
|
||||
conn, err := ws.Dial(c.connectionURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.relayConn = conn
|
||||
|
||||
err = c.handShake()
|
||||
if err != nil {
|
||||
cErr := conn.Close()
|
||||
if cErr != nil {
|
||||
log.Errorf("failed to close connection: %s", cErr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) handShake() error {
|
||||
authMsg := &auth2.Msg{
|
||||
AuthAlgorithm: auth2.AlgoHMACSHA256,
|
||||
AdditionalData: c.authTokenStore.TokenBinary(),
|
||||
}
|
||||
|
||||
authData, err := authMsg.Marshal()
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal auth message: %w", err)
|
||||
}
|
||||
|
||||
msg, err := messages.MarshalHelloMsg(c.hashedID, authData)
|
||||
if err != nil {
|
||||
log.Errorf("failed to marshal hello message: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = c.relayConn.Write(msg)
|
||||
if err != nil {
|
||||
log.Errorf("failed to send hello message: %s", err)
|
||||
return err
|
||||
}
|
||||
buf := make([]byte, messages.MaxHandshakeSize)
|
||||
n, err := c.readWithTimeout(buf)
|
||||
if err != nil {
|
||||
log.Errorf("failed to read hello response: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = messages.ValidateVersion(buf[:n])
|
||||
if err != nil {
|
||||
return fmt.Errorf("validate version: %w", err)
|
||||
}
|
||||
|
||||
msgType, err := messages.DetermineServerMessageType(buf[messages.SizeOfVersionByte:n])
|
||||
if err != nil {
|
||||
log.Errorf("failed to determine message type: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if msgType != messages.MsgTypeHelloResponse {
|
||||
log.Errorf("unexpected message type: %s", msgType)
|
||||
return fmt.Errorf("unexpected message type")
|
||||
}
|
||||
|
||||
additionalData, err := messages.UnmarshalHelloResponse(buf[messages.SizeOfProtoHeader:n])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
addr, err := address.Unmarshal(additionalData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unmarshal address: %w", err)
|
||||
}
|
||||
|
||||
c.muInstanceURL.Lock()
|
||||
c.instanceURL = &RelayAddr{addr: addr.URL}
|
||||
c.muInstanceURL.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) readLoop(relayConn net.Conn) {
|
||||
internallyStoppedFlag := newInternalStopFlag()
|
||||
hc := healthcheck.NewReceiver()
|
||||
go c.listenForStopEvents(hc, relayConn, internallyStoppedFlag)
|
||||
|
||||
var (
|
||||
errExit error
|
||||
n int
|
||||
)
|
||||
for {
|
||||
bufPtr := c.bufPool.Get().(*[]byte)
|
||||
buf := *bufPtr
|
||||
n, errExit = relayConn.Read(buf)
|
||||
if errExit != nil {
|
||||
c.mu.Lock()
|
||||
if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
|
||||
c.log.Debugf("failed to read message from relay server: %s", errExit)
|
||||
}
|
||||
c.mu.Unlock()
|
||||
break
|
||||
}
|
||||
|
||||
_, err := messages.ValidateVersion(buf[:n])
|
||||
if err != nil {
|
||||
c.log.Errorf("failed to validate protocol version: %s", err)
|
||||
c.bufPool.Put(bufPtr)
|
||||
continue
|
||||
}
|
||||
|
||||
msgType, err := messages.DetermineServerMessageType(buf[messages.SizeOfVersionByte:n])
|
||||
if err != nil {
|
||||
c.log.Errorf("failed to determine message type: %s", err)
|
||||
c.bufPool.Put(bufPtr)
|
||||
continue
|
||||
}
|
||||
|
||||
if !c.handleMsg(msgType, buf[messages.SizeOfProtoHeader:n], bufPtr, hc, internallyStoppedFlag) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
hc.Stop()
|
||||
|
||||
c.muInstanceURL.Lock()
|
||||
c.instanceURL = nil
|
||||
c.muInstanceURL.Unlock()
|
||||
|
||||
c.notifyDisconnected()
|
||||
c.wgReadLoop.Done()
|
||||
_ = c.close(false)
|
||||
}
|
||||
|
||||
func (c *Client) handleMsg(msgType messages.MsgType, buf []byte, bufPtr *[]byte, hc *healthcheck.Receiver, internallyStoppedFlag *internalStopFlag) (continueLoop bool) {
|
||||
switch msgType {
|
||||
case messages.MsgTypeHealthCheck:
|
||||
c.handleHealthCheck(hc, internallyStoppedFlag)
|
||||
c.bufPool.Put(bufPtr)
|
||||
case messages.MsgTypeTransport:
|
||||
return c.handleTransportMsg(buf, bufPtr, internallyStoppedFlag)
|
||||
case messages.MsgTypeClose:
|
||||
log.Debugf("relay connection close by server")
|
||||
c.bufPool.Put(bufPtr)
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *Client) handleHealthCheck(hc *healthcheck.Receiver, internallyStoppedFlag *internalStopFlag) {
|
||||
msg := messages.MarshalHealthcheck()
|
||||
_, wErr := c.relayConn.Write(msg)
|
||||
if wErr != nil {
|
||||
if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
|
||||
c.log.Errorf("failed to send heartbeat: %s", wErr)
|
||||
}
|
||||
}
|
||||
hc.Heartbeat()
|
||||
}
|
||||
|
||||
func (c *Client) handleTransportMsg(buf []byte, bufPtr *[]byte, internallyStoppedFlag *internalStopFlag) bool {
|
||||
peerID, payload, err := messages.UnmarshalTransportMsg(buf)
|
||||
if err != nil {
|
||||
if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
|
||||
c.log.Errorf("failed to parse transport message: %v", err)
|
||||
}
|
||||
|
||||
c.bufPool.Put(bufPtr)
|
||||
return true
|
||||
}
|
||||
|
||||
stringID := messages.HashIDToString(peerID)
|
||||
|
||||
c.mu.Lock()
|
||||
if !c.serviceIsRunning {
|
||||
c.mu.Unlock()
|
||||
c.bufPool.Put(bufPtr)
|
||||
return false
|
||||
}
|
||||
container, ok := c.conns[stringID]
|
||||
c.mu.Unlock()
|
||||
if !ok {
|
||||
c.log.Errorf("peer not found: %s", stringID)
|
||||
c.bufPool.Put(bufPtr)
|
||||
return true
|
||||
}
|
||||
msg := Msg{
|
||||
bufPool: c.bufPool,
|
||||
bufPtr: bufPtr,
|
||||
Payload: payload,
|
||||
}
|
||||
container.writeMsg(msg)
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *Client) writeTo(connReference *Conn, id string, dstID []byte, payload []byte) (int, error) {
|
||||
c.mu.Lock()
|
||||
conn, ok := c.conns[id]
|
||||
c.mu.Unlock()
|
||||
if !ok {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
if conn.conn != connReference {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
// todo: use buffer pool instead of create new transport msg.
|
||||
msg, err := messages.MarshalTransportMsg(dstID, payload)
|
||||
if err != nil {
|
||||
log.Errorf("failed to marshal transport message: %s", err)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// the write always return with 0 length because the underling does not support the size feedback.
|
||||
_, err = c.relayConn.Write(msg)
|
||||
if err != nil {
|
||||
log.Errorf("failed to write transport message: %s", err)
|
||||
}
|
||||
return len(payload), err
|
||||
}
|
||||
|
||||
func (c *Client) listenForStopEvents(hc *healthcheck.Receiver, conn net.Conn, internalStopFlag *internalStopFlag) {
|
||||
for {
|
||||
select {
|
||||
case _, ok := <-hc.OnTimeout:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
c.log.Errorf("health check timeout")
|
||||
internalStopFlag.set()
|
||||
_ = conn.Close() // ignore the err because the readLoop will handle it
|
||||
return
|
||||
case <-c.parentCtx.Done():
|
||||
err := c.close(true)
|
||||
if err != nil {
|
||||
log.Errorf("failed to teardown connection: %s", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) closeAllConns() {
|
||||
for _, container := range c.conns {
|
||||
container.close()
|
||||
}
|
||||
c.conns = make(map[string]*connContainer)
|
||||
}
|
||||
|
||||
func (c *Client) closeConn(connReference *Conn, id string) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
container, ok := c.conns[id]
|
||||
if !ok {
|
||||
return fmt.Errorf("connection already closed")
|
||||
}
|
||||
|
||||
if container.conn != connReference {
|
||||
return fmt.Errorf("conn reference mismatch")
|
||||
}
|
||||
container.close()
|
||||
delete(c.conns, id)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) close(gracefullyExit bool) error {
|
||||
c.readLoopMutex.Lock()
|
||||
defer c.readLoopMutex.Unlock()
|
||||
|
||||
c.mu.Lock()
|
||||
var err error
|
||||
if !c.serviceIsRunning {
|
||||
c.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
c.serviceIsRunning = false
|
||||
c.closeAllConns()
|
||||
if gracefullyExit {
|
||||
c.writeCloseMsg()
|
||||
}
|
||||
err = c.relayConn.Close()
|
||||
c.mu.Unlock()
|
||||
|
||||
c.wgReadLoop.Wait()
|
||||
c.log.Infof("relay connection closed with: %s", c.connectionURL)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Client) notifyDisconnected() {
|
||||
c.listenerMutex.Lock()
|
||||
defer c.listenerMutex.Unlock()
|
||||
|
||||
if c.onDisconnectListener == nil {
|
||||
return
|
||||
}
|
||||
go c.onDisconnectListener()
|
||||
}
|
||||
|
||||
func (c *Client) writeCloseMsg() {
|
||||
msg := messages.MarshalCloseMsg()
|
||||
_, err := c.relayConn.Write(msg)
|
||||
if err != nil {
|
||||
c.log.Errorf("failed to send close message: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) readWithTimeout(buf []byte) (int, error) {
|
||||
ctx, cancel := context.WithTimeout(c.parentCtx, serverResponseTimeout)
|
||||
defer cancel()
|
||||
|
||||
readDone := make(chan struct{})
|
||||
var (
|
||||
n int
|
||||
err error
|
||||
)
|
||||
|
||||
go func() {
|
||||
n, err = c.relayConn.Read(buf)
|
||||
close(readDone)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return 0, fmt.Errorf("read operation timed out")
|
||||
case <-readDone:
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
631
relay/client/client_test.go
Normal file
631
relay/client/client_test.go
Normal file
@@ -0,0 +1,631 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"go.opentelemetry.io/otel"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/auth/allow"
|
||||
"github.com/netbirdio/netbird/relay/auth/hmac"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/server"
|
||||
)
|
||||
|
||||
var (
|
||||
av = &allow.Auth{}
|
||||
hmacTokenStore = &hmac.TokenStore{}
|
||||
serverListenAddr = "127.0.0.1:1234"
|
||||
serverURL = "rel://127.0.0.1:1234"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
_ = util.InitLog("error", "console")
|
||||
code := m.Run()
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func TestClient(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %s", err)
|
||||
}
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
listenCfg := server.ListenerConfig{Address: serverListenAddr}
|
||||
err := srv.Listen(listenCfg)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
err := srv.Shutdown(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// wait for server to start
|
||||
if err := waitForServerToStart(errChan); err != nil {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
t.Log("alice connecting to server")
|
||||
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
|
||||
err = clientAlice.Connect()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
defer clientAlice.Close()
|
||||
|
||||
t.Log("placeholder connecting to server")
|
||||
clientPlaceHolder := NewClient(ctx, serverURL, hmacTokenStore, "clientPlaceHolder")
|
||||
err = clientPlaceHolder.Connect()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
defer clientPlaceHolder.Close()
|
||||
|
||||
t.Log("Bob connecting to server")
|
||||
clientBob := NewClient(ctx, serverURL, hmacTokenStore, "bob")
|
||||
err = clientBob.Connect()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
defer clientBob.Close()
|
||||
|
||||
t.Log("Alice open connection to Bob")
|
||||
connAliceToBob, err := clientAlice.OpenConn("bob")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
t.Log("Bob open connection to Alice")
|
||||
connBobToAlice, err := clientBob.OpenConn("alice")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
payload := "hello bob, I am alice"
|
||||
_, err = connAliceToBob.Write([]byte(payload))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write to channel: %s", err)
|
||||
}
|
||||
log.Debugf("alice sent message to bob")
|
||||
|
||||
buf := make([]byte, 65535)
|
||||
n, err := connBobToAlice.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read from channel: %s", err)
|
||||
}
|
||||
log.Debugf("on new message from alice to bob")
|
||||
|
||||
if payload != string(buf[:n]) {
|
||||
t.Fatalf("expected %s, got %s", payload, string(buf[:n]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistration(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
||||
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %s", err)
|
||||
}
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := srv.Listen(srvCfg)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
// wait for server to start
|
||||
if err := waitForServerToStart(errChan); err != nil {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
|
||||
err = clientAlice.Connect()
|
||||
if err != nil {
|
||||
_ = srv.Shutdown(ctx)
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
err = clientAlice.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close conn: %s", err)
|
||||
}
|
||||
err = srv.Shutdown(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistrationTimeout(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
fakeUDPListener, err := net.ListenUDP("udp", &net.UDPAddr{
|
||||
Port: 1234,
|
||||
IP: net.ParseIP("0.0.0.0"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind UDP server: %s", err)
|
||||
}
|
||||
defer func(fakeUDPListener *net.UDPConn) {
|
||||
_ = fakeUDPListener.Close()
|
||||
}(fakeUDPListener)
|
||||
|
||||
fakeTCPListener, err := net.ListenTCP("tcp", &net.TCPAddr{
|
||||
Port: 1234,
|
||||
IP: net.ParseIP("0.0.0.0"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind TCP server: %s", err)
|
||||
}
|
||||
defer func(fakeTCPListener *net.TCPListener) {
|
||||
_ = fakeTCPListener.Close()
|
||||
}(fakeTCPListener)
|
||||
|
||||
clientAlice := NewClient(ctx, "127.0.0.1:1234", hmacTokenStore, "alice")
|
||||
err = clientAlice.Connect()
|
||||
if err == nil {
|
||||
t.Errorf("failed to connect to server: %s", err)
|
||||
}
|
||||
log.Debugf("%s", err)
|
||||
err = clientAlice.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close conn: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEcho(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
idAlice := "alice"
|
||||
idBob := "bob"
|
||||
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
||||
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %s", err)
|
||||
}
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := srv.Listen(srvCfg)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
err := srv.Shutdown(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// wait for servers to start
|
||||
if err := waitForServerToStart(errChan); err != nil {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, idAlice)
|
||||
err = clientAlice.Connect()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
defer func() {
|
||||
err := clientAlice.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close Alice client: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
clientBob := NewClient(ctx, serverURL, hmacTokenStore, idBob)
|
||||
err = clientBob.Connect()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
defer func() {
|
||||
err := clientBob.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close Bob client: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
connAliceToBob, err := clientAlice.OpenConn(idBob)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
connBobToAlice, err := clientBob.OpenConn(idAlice)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
payload := "hello bob, I am alice"
|
||||
_, err = connAliceToBob.Write([]byte(payload))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write to channel: %s", err)
|
||||
}
|
||||
|
||||
buf := make([]byte, 65535)
|
||||
n, err := connBobToAlice.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read from channel: %s", err)
|
||||
}
|
||||
|
||||
_, err = connBobToAlice.Write(buf[:n])
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write to channel: %s", err)
|
||||
}
|
||||
|
||||
n, err = connAliceToBob.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read from channel: %s", err)
|
||||
}
|
||||
|
||||
if payload != string(buf[:n]) {
|
||||
t.Fatalf("expected %s, got %s", payload, string(buf[:n]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBindToUnavailabePeer(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
||||
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %s", err)
|
||||
}
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := srv.Listen(srvCfg)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
log.Infof("closing server")
|
||||
err := srv.Shutdown(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// wait for servers to start
|
||||
if err := waitForServerToStart(errChan); err != nil {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
|
||||
err = clientAlice.Connect()
|
||||
if err != nil {
|
||||
t.Errorf("failed to connect to server: %s", err)
|
||||
}
|
||||
_, err = clientAlice.OpenConn("bob")
|
||||
if err != nil {
|
||||
t.Errorf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
log.Infof("closing client")
|
||||
err = clientAlice.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close client: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBindReconnect(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
||||
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %s", err)
|
||||
}
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := srv.Listen(srvCfg)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
log.Infof("closing server")
|
||||
err := srv.Shutdown(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// wait for servers to start
|
||||
if err := waitForServerToStart(errChan); err != nil {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
|
||||
err = clientAlice.Connect()
|
||||
if err != nil {
|
||||
t.Errorf("failed to connect to server: %s", err)
|
||||
}
|
||||
|
||||
_, err = clientAlice.OpenConn("bob")
|
||||
if err != nil {
|
||||
t.Errorf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
clientBob := NewClient(ctx, serverURL, hmacTokenStore, "bob")
|
||||
err = clientBob.Connect()
|
||||
if err != nil {
|
||||
t.Errorf("failed to connect to server: %s", err)
|
||||
}
|
||||
|
||||
chBob, err := clientBob.OpenConn("alice")
|
||||
if err != nil {
|
||||
t.Errorf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
log.Infof("closing client Alice")
|
||||
err = clientAlice.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close client: %s", err)
|
||||
}
|
||||
|
||||
clientAlice = NewClient(ctx, serverURL, hmacTokenStore, "alice")
|
||||
err = clientAlice.Connect()
|
||||
if err != nil {
|
||||
t.Errorf("failed to connect to server: %s", err)
|
||||
}
|
||||
|
||||
chAlice, err := clientAlice.OpenConn("bob")
|
||||
if err != nil {
|
||||
t.Errorf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
testString := "hello alice, I am bob"
|
||||
_, err = chBob.Write([]byte(testString))
|
||||
if err != nil {
|
||||
t.Errorf("failed to write to channel: %s", err)
|
||||
}
|
||||
|
||||
buf := make([]byte, 65535)
|
||||
n, err := chAlice.Read(buf)
|
||||
if err != nil {
|
||||
t.Errorf("failed to read from channel: %s", err)
|
||||
}
|
||||
|
||||
if testString != string(buf[:n]) {
|
||||
t.Errorf("expected %s, got %s", testString, string(buf[:n]))
|
||||
}
|
||||
|
||||
log.Infof("closing client")
|
||||
err = clientAlice.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close client: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloseConn(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
||||
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %s", err)
|
||||
}
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := srv.Listen(srvCfg)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
log.Infof("closing server")
|
||||
err := srv.Shutdown(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// wait for servers to start
|
||||
if err := waitForServerToStart(errChan); err != nil {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
|
||||
err = clientAlice.Connect()
|
||||
if err != nil {
|
||||
t.Errorf("failed to connect to server: %s", err)
|
||||
}
|
||||
|
||||
conn, err := clientAlice.OpenConn("bob")
|
||||
if err != nil {
|
||||
t.Errorf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
log.Infof("closing connection")
|
||||
err = conn.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close connection: %s", err)
|
||||
}
|
||||
|
||||
_, err = conn.Read(make([]byte, 1))
|
||||
if err == nil {
|
||||
t.Errorf("unexpected reading from closed connection")
|
||||
}
|
||||
|
||||
_, err = conn.Write([]byte("hello"))
|
||||
if err == nil {
|
||||
t.Errorf("unexpected writing from closed connection")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloseRelayConn(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
||||
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %s", err)
|
||||
}
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := srv.Listen(srvCfg)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
err := srv.Shutdown(ctx)
|
||||
if err != nil {
|
||||
log.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// wait for servers to start
|
||||
if err := waitForServerToStart(errChan); err != nil {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
|
||||
err = clientAlice.Connect()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
|
||||
conn, err := clientAlice.OpenConn("bob")
|
||||
if err != nil {
|
||||
t.Errorf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
_ = clientAlice.relayConn.Close()
|
||||
|
||||
_, err = conn.Read(make([]byte, 1))
|
||||
if err == nil {
|
||||
t.Errorf("unexpected reading from closed connection")
|
||||
}
|
||||
|
||||
_, err = clientAlice.OpenConn("bob")
|
||||
if err == nil {
|
||||
t.Errorf("unexpected opening connection to closed server")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloseByServer(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
||||
srv1, err := server.NewServer(otel.Meter(""), serverURL, false, av)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %s", err)
|
||||
}
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
err := srv1.Listen(srvCfg)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
// wait for servers to start
|
||||
if err := waitForServerToStart(errChan); err != nil {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
idAlice := "alice"
|
||||
log.Debugf("connect by alice")
|
||||
relayClient := NewClient(ctx, serverURL, hmacTokenStore, idAlice)
|
||||
err = relayClient.Connect()
|
||||
if err != nil {
|
||||
log.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
|
||||
disconnected := make(chan struct{})
|
||||
relayClient.SetOnDisconnectListener(func() {
|
||||
log.Infof("client disconnected")
|
||||
close(disconnected)
|
||||
})
|
||||
|
||||
err = srv1.Shutdown(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to close server: %s", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-disconnected:
|
||||
case <-time.After(3 * time.Second):
|
||||
log.Fatalf("timeout waiting for client to disconnect")
|
||||
}
|
||||
|
||||
_, err = relayClient.OpenConn("bob")
|
||||
if err == nil {
|
||||
t.Errorf("unexpected opening connection to closed server")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloseByClient(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
||||
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %s", err)
|
||||
}
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := srv.Listen(srvCfg)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
// wait for servers to start
|
||||
if err := waitForServerToStart(errChan); err != nil {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
idAlice := "alice"
|
||||
log.Debugf("connect by alice")
|
||||
relayClient := NewClient(ctx, serverURL, hmacTokenStore, idAlice)
|
||||
err = relayClient.Connect()
|
||||
if err != nil {
|
||||
log.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
|
||||
err = relayClient.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close client: %s", err)
|
||||
}
|
||||
|
||||
_, err = relayClient.OpenConn("bob")
|
||||
if err == nil {
|
||||
t.Errorf("unexpected opening connection to closed server")
|
||||
}
|
||||
|
||||
err = srv.Shutdown(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to close server: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func waitForServerToStart(errChan chan error) error {
|
||||
select {
|
||||
case err := <-errChan:
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case <-time.After(300 * time.Millisecond):
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
76
relay/client/conn.go
Normal file
76
relay/client/conn.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Conn represent a connection to a relayed remote peer.
|
||||
type Conn struct {
|
||||
client *Client
|
||||
dstID []byte
|
||||
dstStringID string
|
||||
messageChan chan Msg
|
||||
instanceURL *RelayAddr
|
||||
}
|
||||
|
||||
// NewConn creates a new connection to a relayed remote peer.
|
||||
// client: the client instance, it used to send messages to the destination peer
|
||||
// dstID: the destination peer ID
|
||||
// dstStringID: the destination peer ID in string format
|
||||
// messageChan: the channel where the messages will be received
|
||||
// instanceURL: the relay instance URL, it used to get the proper server instance address for the remote peer
|
||||
func NewConn(client *Client, dstID []byte, dstStringID string, messageChan chan Msg, instanceURL *RelayAddr) *Conn {
|
||||
c := &Conn{
|
||||
client: client,
|
||||
dstID: dstID,
|
||||
dstStringID: dstStringID,
|
||||
messageChan: messageChan,
|
||||
instanceURL: instanceURL,
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *Conn) Write(p []byte) (n int, err error) {
|
||||
return c.client.writeTo(c, c.dstStringID, c.dstID, p)
|
||||
}
|
||||
|
||||
func (c *Conn) Read(b []byte) (n int, err error) {
|
||||
msg, ok := <-c.messageChan
|
||||
if !ok {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
n = copy(b, msg.Payload)
|
||||
msg.Free()
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (c *Conn) Close() error {
|
||||
return c.client.closeConn(c, c.dstStringID)
|
||||
}
|
||||
|
||||
func (c *Conn) LocalAddr() net.Addr {
|
||||
return c.client.relayConn.LocalAddr()
|
||||
}
|
||||
|
||||
func (c *Conn) RemoteAddr() net.Addr {
|
||||
return c.instanceURL
|
||||
}
|
||||
|
||||
func (c *Conn) SetDeadline(t time.Time) error {
|
||||
//TODO implement me
|
||||
panic("SetDeadline is not implemented")
|
||||
}
|
||||
|
||||
func (c *Conn) SetReadDeadline(t time.Time) error {
|
||||
//TODO implement me
|
||||
panic("SetReadDeadline is not implemented")
|
||||
}
|
||||
|
||||
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
||||
//TODO implement me
|
||||
panic("SetReadDeadline is not implemented")
|
||||
}
|
||||
13
relay/client/dialer/ws/addr.go
Normal file
13
relay/client/dialer/ws/addr.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package ws
|
||||
|
||||
type WebsocketAddr struct {
|
||||
addr string
|
||||
}
|
||||
|
||||
func (a WebsocketAddr) Network() string {
|
||||
return "websocket"
|
||||
}
|
||||
|
||||
func (a WebsocketAddr) String() string {
|
||||
return a.addr
|
||||
}
|
||||
66
relay/client/dialer/ws/conn.go
Normal file
66
relay/client/dialer/ws/conn.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"nhooyr.io/websocket"
|
||||
)
|
||||
|
||||
type Conn struct {
|
||||
ctx context.Context
|
||||
*websocket.Conn
|
||||
remoteAddr WebsocketAddr
|
||||
}
|
||||
|
||||
func NewConn(wsConn *websocket.Conn, serverAddress string) net.Conn {
|
||||
return &Conn{
|
||||
ctx: context.Background(),
|
||||
Conn: wsConn,
|
||||
remoteAddr: WebsocketAddr{serverAddress},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) Read(b []byte) (n int, err error) {
|
||||
t, ioReader, err := c.Conn.Reader(c.ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if t != websocket.MessageBinary {
|
||||
return 0, fmt.Errorf("unexpected message type")
|
||||
}
|
||||
|
||||
return ioReader.Read(b)
|
||||
}
|
||||
|
||||
func (c *Conn) Write(b []byte) (n int, err error) {
|
||||
err = c.Conn.Write(c.ctx, websocket.MessageBinary, b)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
func (c *Conn) RemoteAddr() net.Addr {
|
||||
return c.remoteAddr
|
||||
}
|
||||
|
||||
func (c *Conn) LocalAddr() net.Addr {
|
||||
return WebsocketAddr{addr: "unknown"}
|
||||
}
|
||||
|
||||
func (c *Conn) SetReadDeadline(t time.Time) error {
|
||||
return fmt.Errorf("SetReadDeadline is not implemented")
|
||||
}
|
||||
|
||||
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
||||
return fmt.Errorf("SetWriteDeadline is not implemented")
|
||||
}
|
||||
|
||||
func (c *Conn) SetDeadline(t time.Time) error {
|
||||
return fmt.Errorf("SetDeadline is not implemented")
|
||||
}
|
||||
|
||||
func (c *Conn) Close() error {
|
||||
return c.Conn.CloseNow()
|
||||
}
|
||||
67
relay/client/dialer/ws/ws.go
Normal file
67
relay/client/dialer/ws/ws.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"nhooyr.io/websocket"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/server/listener/ws"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
func Dial(address string) (net.Conn, error) {
|
||||
wsURL, err := prepareURL(address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
opts := &websocket.DialOptions{
|
||||
HTTPClient: httpClientNbDialer(),
|
||||
}
|
||||
|
||||
parsedURL, err := url.Parse(wsURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
parsedURL.Path = ws.URLPath
|
||||
|
||||
wsConn, resp, err := websocket.Dial(context.Background(), parsedURL.String(), opts)
|
||||
if err != nil {
|
||||
log.Errorf("failed to dial to Relay server '%s': %s", wsURL, err)
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
|
||||
conn := NewConn(wsConn, address)
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func prepareURL(address string) (string, error) {
|
||||
if !strings.HasPrefix(address, "rel:") && !strings.HasPrefix(address, "rels:") {
|
||||
return "", fmt.Errorf("unsupported scheme: %s", address)
|
||||
}
|
||||
|
||||
return strings.Replace(address, "rel", "ws", 1), nil
|
||||
}
|
||||
|
||||
func httpClientNbDialer() *http.Client {
|
||||
customDialer := nbnet.NewDialer()
|
||||
|
||||
customTransport := &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return customDialer.DialContext(ctx, network, addr)
|
||||
},
|
||||
}
|
||||
|
||||
return &http.Client{
|
||||
Transport: customTransport,
|
||||
}
|
||||
}
|
||||
12
relay/client/doc.go
Normal file
12
relay/client/doc.go
Normal file
@@ -0,0 +1,12 @@
|
||||
/*
|
||||
Package client contains the implementation of the Relay client.
|
||||
|
||||
The Relay client is responsible for establishing a connection with the Relay server and sending and receiving messages,
|
||||
Keep persistent connection with the Relay server and handle the connection issues.
|
||||
It uses the WebSocket protocol for communication and optionally supports TLS (Transport Layer Security).
|
||||
|
||||
If a peer wants to communicate with a peer on a different relay server, the manager will establish a new connection to
|
||||
the relay server. The connection with these relay servers will be closed if there is no active connection. The peers
|
||||
negotiate the common relay instance via signaling service.
|
||||
*/
|
||||
package client
|
||||
48
relay/client/guard.go
Normal file
48
relay/client/guard.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
reconnectingTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
// Guard manage the reconnection tries to the Relay server in case of disconnection event.
|
||||
type Guard struct {
|
||||
ctx context.Context
|
||||
relayClient *Client
|
||||
}
|
||||
|
||||
// NewGuard creates a new guard for the relay client.
|
||||
func NewGuard(context context.Context, relayClient *Client) *Guard {
|
||||
g := &Guard{
|
||||
ctx: context,
|
||||
relayClient: relayClient,
|
||||
}
|
||||
return g
|
||||
}
|
||||
|
||||
// OnDisconnected is called when the relay client is disconnected from the relay server. It will trigger the reconnection
|
||||
// todo prevent multiple reconnection instances. In the current usage it should not happen, but it is better to prevent
|
||||
func (g *Guard) OnDisconnected() {
|
||||
ticker := time.NewTicker(reconnectingTimeout)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
err := g.relayClient.Connect()
|
||||
if err != nil {
|
||||
log.Errorf("failed to reconnect to relay server: %s", err)
|
||||
continue
|
||||
}
|
||||
return
|
||||
case <-g.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
365
relay/client/manager.go
Normal file
365
relay/client/manager.go
Normal file
@@ -0,0 +1,365 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"reflect"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
relayAuth "github.com/netbirdio/netbird/relay/auth/hmac"
|
||||
)
|
||||
|
||||
var (
|
||||
relayCleanupInterval = 60 * time.Second
|
||||
connectionTimeout = 30 * time.Second
|
||||
maxConcurrentServers = 7
|
||||
|
||||
ErrRelayClientNotConnected = fmt.Errorf("relay client not connected")
|
||||
)
|
||||
|
||||
// RelayTrack hold the relay clients for the foreign relay servers.
|
||||
// With the mutex can ensure we can open new connection in case the relay connection has been established with
|
||||
// the relay server.
|
||||
type RelayTrack struct {
|
||||
sync.RWMutex
|
||||
relayClient *Client
|
||||
}
|
||||
|
||||
func NewRelayTrack() *RelayTrack {
|
||||
return &RelayTrack{}
|
||||
}
|
||||
|
||||
type OnServerCloseListener func()
|
||||
|
||||
// ManagerService is the interface for the relay manager.
|
||||
type ManagerService interface {
|
||||
Serve() error
|
||||
OpenConn(serverAddress, peerKey string) (net.Conn, error)
|
||||
AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error
|
||||
RelayInstanceAddress() (string, error)
|
||||
ServerURLs() []string
|
||||
HasRelayAddress() bool
|
||||
UpdateToken(token *relayAuth.Token) error
|
||||
}
|
||||
|
||||
// Manager is a manager for the relay client instances. It establishes one persistent connection to the given relay URL
|
||||
// and automatically reconnect to them in case disconnection.
|
||||
// The manager also manage temporary relay connection. If a client wants to communicate with a client on a
|
||||
// different relay server, the manager will establish a new connection to the relay server. The connection with these
|
||||
// relay servers will be closed if there is no active connection. Periodically the manager will check if there is any
|
||||
// unused relay connection and close it.
|
||||
type Manager struct {
|
||||
ctx context.Context
|
||||
serverURLs []string
|
||||
peerID string
|
||||
tokenStore *relayAuth.TokenStore
|
||||
|
||||
relayClient *Client
|
||||
reconnectGuard *Guard
|
||||
|
||||
relayClients map[string]*RelayTrack
|
||||
relayClientsMutex sync.RWMutex
|
||||
|
||||
onDisconnectedListeners map[string]*list.List
|
||||
listenerLock sync.Mutex
|
||||
}
|
||||
|
||||
// NewManager creates a new manager instance.
|
||||
// The serverURL address can be empty. In this case, the manager will not serve.
|
||||
func NewManager(ctx context.Context, serverURLs []string, peerID string) *Manager {
|
||||
return &Manager{
|
||||
ctx: ctx,
|
||||
serverURLs: serverURLs,
|
||||
peerID: peerID,
|
||||
tokenStore: &relayAuth.TokenStore{},
|
||||
relayClients: make(map[string]*RelayTrack),
|
||||
onDisconnectedListeners: make(map[string]*list.List),
|
||||
}
|
||||
}
|
||||
|
||||
// Serve starts the manager. It will establish a connection to the relay server and start the relay cleanup loop for
|
||||
// the unused relay connections. The manager will automatically reconnect to the relay server in case of disconnection.
|
||||
func (m *Manager) Serve() error {
|
||||
if m.relayClient != nil {
|
||||
return fmt.Errorf("manager already serving")
|
||||
}
|
||||
log.Debugf("starting relay client manager with %v relay servers", m.serverURLs)
|
||||
|
||||
totalServers := len(m.serverURLs)
|
||||
|
||||
successChan := make(chan *Client, 1)
|
||||
errChan := make(chan error, len(m.serverURLs))
|
||||
|
||||
ctx, cancel := context.WithTimeout(m.ctx, connectionTimeout)
|
||||
defer cancel()
|
||||
|
||||
sem := make(chan struct{}, maxConcurrentServers)
|
||||
|
||||
for _, url := range m.serverURLs {
|
||||
sem <- struct{}{}
|
||||
go func(url string) {
|
||||
defer func() { <-sem }()
|
||||
m.connect(m.ctx, url, successChan, errChan)
|
||||
}(url)
|
||||
}
|
||||
|
||||
var errCount int
|
||||
|
||||
for {
|
||||
select {
|
||||
case client := <-successChan:
|
||||
log.Infof("Successfully connected to relay server: %s", client.connectionURL)
|
||||
|
||||
m.relayClient = client
|
||||
|
||||
m.reconnectGuard = NewGuard(m.ctx, m.relayClient)
|
||||
m.relayClient.SetOnDisconnectListener(func() {
|
||||
m.onServerDisconnected(client.connectionURL)
|
||||
})
|
||||
m.startCleanupLoop()
|
||||
return nil
|
||||
case err := <-errChan:
|
||||
errCount++
|
||||
log.Warnf("Connection attempt failed: %v", err)
|
||||
if errCount == totalServers {
|
||||
return errors.New("failed to connect to any relay server: all attempts failed")
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("failed to connect to any relay server: %w", ctx.Err())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) connect(ctx context.Context, serverURL string, successChan chan<- *Client, errChan chan<- error) {
|
||||
// TODO: abort the connection if another connection was successful
|
||||
relayClient := NewClient(ctx, serverURL, m.tokenStore, m.peerID)
|
||||
if err := relayClient.Connect(); err != nil {
|
||||
errChan <- fmt.Errorf("failed to connect to %s: %w", serverURL, err)
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case successChan <- relayClient:
|
||||
// This client was the first to connect successfully
|
||||
default:
|
||||
if err := relayClient.Close(); err != nil {
|
||||
log.Debugf("failed to close relay client: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be
|
||||
// established via the relay server. If the peer is on a different relay server, the manager will establish a new
|
||||
// connection to the relay server. It returns back with a net.Conn what represent the remote peer connection.
|
||||
func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) {
|
||||
if m.relayClient == nil {
|
||||
return nil, ErrRelayClientNotConnected
|
||||
}
|
||||
|
||||
foreign, err := m.isForeignServer(serverAddress)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var (
|
||||
netConn net.Conn
|
||||
)
|
||||
if !foreign {
|
||||
log.Debugf("open peer connection via permanent server: %s", peerKey)
|
||||
netConn, err = m.relayClient.OpenConn(peerKey)
|
||||
} else {
|
||||
log.Debugf("open peer connection via foreign server: %s", serverAddress)
|
||||
netConn, err = m.openConnVia(serverAddress, peerKey)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return netConn, err
|
||||
}
|
||||
|
||||
// AddCloseListener adds a listener to the given server instance address. The listener will be called if the connection
|
||||
// closed.
|
||||
func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error {
|
||||
foreign, err := m.isForeignServer(serverAddress)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var listenerAddr string
|
||||
if foreign {
|
||||
listenerAddr = serverAddress
|
||||
} else {
|
||||
listenerAddr = m.relayClient.connectionURL
|
||||
}
|
||||
m.addListener(listenerAddr, onClosedListener)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RelayInstanceAddress returns the address of the permanent relay server. It could change if the network connection is
|
||||
// lost. This address will be sent to the target peer to choose the common relay server for the communication.
|
||||
func (m *Manager) RelayInstanceAddress() (string, error) {
|
||||
if m.relayClient == nil {
|
||||
return "", ErrRelayClientNotConnected
|
||||
}
|
||||
return m.relayClient.ServerInstanceURL()
|
||||
}
|
||||
|
||||
// ServerURLs returns the addresses of the relay servers.
|
||||
func (m *Manager) ServerURLs() []string {
|
||||
return m.serverURLs
|
||||
}
|
||||
|
||||
// HasRelayAddress returns true if the manager is serving. With this method can check if the peer can communicate with
|
||||
// Relay service.
|
||||
func (m *Manager) HasRelayAddress() bool {
|
||||
return len(m.serverURLs) > 0
|
||||
}
|
||||
|
||||
// UpdateToken updates the token in the token store.
|
||||
func (m *Manager) UpdateToken(token *relayAuth.Token) error {
|
||||
return m.tokenStore.UpdateToken(token)
|
||||
}
|
||||
|
||||
func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) {
|
||||
// check if already has a connection to the desired relay server
|
||||
m.relayClientsMutex.RLock()
|
||||
rt, ok := m.relayClients[serverAddress]
|
||||
if ok {
|
||||
rt.RLock()
|
||||
m.relayClientsMutex.RUnlock()
|
||||
defer rt.RUnlock()
|
||||
return rt.relayClient.OpenConn(peerKey)
|
||||
}
|
||||
m.relayClientsMutex.RUnlock()
|
||||
|
||||
// if not, establish a new connection but check it again (because changed the lock type) before starting the
|
||||
// connection
|
||||
m.relayClientsMutex.Lock()
|
||||
rt, ok = m.relayClients[serverAddress]
|
||||
if ok {
|
||||
rt.RLock()
|
||||
m.relayClientsMutex.Unlock()
|
||||
defer rt.RUnlock()
|
||||
return rt.relayClient.OpenConn(peerKey)
|
||||
}
|
||||
|
||||
// create a new relay client and store it in the relayClients map
|
||||
rt = NewRelayTrack()
|
||||
rt.Lock()
|
||||
m.relayClients[serverAddress] = rt
|
||||
m.relayClientsMutex.Unlock()
|
||||
|
||||
relayClient := NewClient(m.ctx, serverAddress, m.tokenStore, m.peerID)
|
||||
err := relayClient.Connect()
|
||||
if err != nil {
|
||||
rt.Unlock()
|
||||
m.relayClientsMutex.Lock()
|
||||
delete(m.relayClients, serverAddress)
|
||||
m.relayClientsMutex.Unlock()
|
||||
return nil, err
|
||||
}
|
||||
// if connection closed then delete the relay client from the list
|
||||
relayClient.SetOnDisconnectListener(func() {
|
||||
m.onServerDisconnected(serverAddress)
|
||||
})
|
||||
rt.relayClient = relayClient
|
||||
rt.Unlock()
|
||||
|
||||
conn, err := relayClient.OpenConn(peerKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (m *Manager) onServerDisconnected(serverAddress string) {
|
||||
if serverAddress == m.relayClient.connectionURL {
|
||||
go m.reconnectGuard.OnDisconnected()
|
||||
}
|
||||
|
||||
m.notifyOnDisconnectListeners(serverAddress)
|
||||
}
|
||||
|
||||
func (m *Manager) isForeignServer(address string) (bool, error) {
|
||||
rAddr, err := m.relayClient.ServerInstanceURL()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("relay client not connected")
|
||||
}
|
||||
return rAddr != address, nil
|
||||
}
|
||||
|
||||
func (m *Manager) startCleanupLoop() {
|
||||
if m.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(relayCleanupInterval)
|
||||
go func() {
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
m.cleanUpUnusedRelays()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (m *Manager) cleanUpUnusedRelays() {
|
||||
m.relayClientsMutex.Lock()
|
||||
defer m.relayClientsMutex.Unlock()
|
||||
|
||||
for addr, rt := range m.relayClients {
|
||||
rt.Lock()
|
||||
if rt.relayClient.HasConns() {
|
||||
rt.Unlock()
|
||||
continue
|
||||
}
|
||||
rt.relayClient.SetOnDisconnectListener(nil)
|
||||
go func() {
|
||||
_ = rt.relayClient.Close()
|
||||
}()
|
||||
log.Debugf("clean up unused relay server connection: %s", addr)
|
||||
delete(m.relayClients, addr)
|
||||
rt.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) addListener(serverAddress string, onClosedListener OnServerCloseListener) {
|
||||
m.listenerLock.Lock()
|
||||
defer m.listenerLock.Unlock()
|
||||
l, ok := m.onDisconnectedListeners[serverAddress]
|
||||
if !ok {
|
||||
l = list.New()
|
||||
}
|
||||
for e := l.Front(); e != nil; e = e.Next() {
|
||||
if reflect.ValueOf(e.Value).Pointer() == reflect.ValueOf(onClosedListener).Pointer() {
|
||||
return
|
||||
}
|
||||
}
|
||||
l.PushBack(onClosedListener)
|
||||
m.onDisconnectedListeners[serverAddress] = l
|
||||
}
|
||||
|
||||
func (m *Manager) notifyOnDisconnectListeners(serverAddress string) {
|
||||
m.listenerLock.Lock()
|
||||
defer m.listenerLock.Unlock()
|
||||
|
||||
l, ok := m.onDisconnectedListeners[serverAddress]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
for e := l.Front(); e != nil; e = e.Next() {
|
||||
go e.Value.(OnServerCloseListener)()
|
||||
}
|
||||
delete(m.onDisconnectedListeners, serverAddress)
|
||||
}
|
||||
432
relay/client/manager_test.go
Normal file
432
relay/client/manager_test.go
Normal file
@@ -0,0 +1,432 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"go.opentelemetry.io/otel"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/server"
|
||||
)
|
||||
|
||||
func TestEmptyURL(t *testing.T) {
|
||||
mgr := NewManager(context.Background(), nil, "alice")
|
||||
err := mgr.Serve()
|
||||
if err == nil {
|
||||
t.Errorf("expected error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestForeignConn(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
srvCfg1 := server.ListenerConfig{
|
||||
Address: "localhost:1234",
|
||||
}
|
||||
srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %s", err)
|
||||
}
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := srv1.Listen(srvCfg1)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
err := srv1.Shutdown(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := waitForServerToStart(errChan); err != nil {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
srvCfg2 := server.ListenerConfig{
|
||||
Address: "localhost:2234",
|
||||
}
|
||||
srv2, err := server.NewServer(otel.Meter(""), srvCfg2.Address, false, av)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %s", err)
|
||||
}
|
||||
errChan2 := make(chan error, 1)
|
||||
go func() {
|
||||
err := srv2.Listen(srvCfg2)
|
||||
if err != nil {
|
||||
errChan2 <- err
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
err := srv2.Shutdown(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := waitForServerToStart(errChan2); err != nil {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
idAlice := "alice"
|
||||
log.Debugf("connect by alice")
|
||||
mCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
clientAlice := NewManager(mCtx, toURL(srvCfg1), idAlice)
|
||||
err = clientAlice.Serve()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to serve manager: %s", err)
|
||||
}
|
||||
|
||||
idBob := "bob"
|
||||
log.Debugf("connect by bob")
|
||||
clientBob := NewManager(mCtx, toURL(srvCfg2), idBob)
|
||||
err = clientBob.Serve()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to serve manager: %s", err)
|
||||
}
|
||||
bobsSrvAddr, err := clientBob.RelayInstanceAddress()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get relay address: %s", err)
|
||||
}
|
||||
connAliceToBob, err := clientAlice.OpenConn(bobsSrvAddr, idBob)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
connBobToAlice, err := clientBob.OpenConn(bobsSrvAddr, idAlice)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
payload := "hello bob, I am alice"
|
||||
_, err = connAliceToBob.Write([]byte(payload))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write to channel: %s", err)
|
||||
}
|
||||
|
||||
buf := make([]byte, 65535)
|
||||
n, err := connBobToAlice.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read from channel: %s", err)
|
||||
}
|
||||
|
||||
_, err = connBobToAlice.Write(buf[:n])
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write to channel: %s", err)
|
||||
}
|
||||
|
||||
n, err = connAliceToBob.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read from channel: %s", err)
|
||||
}
|
||||
|
||||
if payload != string(buf[:n]) {
|
||||
t.Fatalf("expected %s, got %s", payload, string(buf[:n]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestForeginConnClose(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
srvCfg1 := server.ListenerConfig{
|
||||
Address: "localhost:1234",
|
||||
}
|
||||
srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %s", err)
|
||||
}
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := srv1.Listen(srvCfg1)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
err := srv1.Shutdown(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := waitForServerToStart(errChan); err != nil {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
srvCfg2 := server.ListenerConfig{
|
||||
Address: "localhost:2234",
|
||||
}
|
||||
srv2, err := server.NewServer(otel.Meter(""), srvCfg2.Address, false, av)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %s", err)
|
||||
}
|
||||
errChan2 := make(chan error, 1)
|
||||
go func() {
|
||||
err := srv2.Listen(srvCfg2)
|
||||
if err != nil {
|
||||
errChan2 <- err
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
err := srv2.Shutdown(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := waitForServerToStart(errChan2); err != nil {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
idAlice := "alice"
|
||||
log.Debugf("connect by alice")
|
||||
mCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
mgr := NewManager(mCtx, toURL(srvCfg1), idAlice)
|
||||
err = mgr.Serve()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to serve manager: %s", err)
|
||||
}
|
||||
conn, err := mgr.OpenConn(toURL(srvCfg2)[0], "anotherpeer")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
err = conn.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to close connection: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestForeginAutoClose(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
relayCleanupInterval = 1 * time.Second
|
||||
srvCfg1 := server.ListenerConfig{
|
||||
Address: "localhost:1234",
|
||||
}
|
||||
srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %s", err)
|
||||
}
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
t.Log("binding server 1.")
|
||||
err := srv1.Listen(srvCfg1)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
t.Logf("closing server 1.")
|
||||
err := srv1.Shutdown(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
t.Logf("server 1. closed")
|
||||
}()
|
||||
|
||||
if err := waitForServerToStart(errChan); err != nil {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
srvCfg2 := server.ListenerConfig{
|
||||
Address: "localhost:2234",
|
||||
}
|
||||
srv2, err := server.NewServer(otel.Meter(""), srvCfg2.Address, false, av)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %s", err)
|
||||
}
|
||||
errChan2 := make(chan error, 1)
|
||||
go func() {
|
||||
t.Log("binding server 2.")
|
||||
err := srv2.Listen(srvCfg2)
|
||||
if err != nil {
|
||||
errChan2 <- err
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
t.Logf("closing server 2.")
|
||||
err := srv2.Shutdown(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
t.Logf("server 2 closed.")
|
||||
}()
|
||||
|
||||
if err := waitForServerToStart(errChan2); err != nil {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
idAlice := "alice"
|
||||
t.Log("connect to server 1.")
|
||||
mCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
mgr := NewManager(mCtx, toURL(srvCfg1), idAlice)
|
||||
err = mgr.Serve()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to serve manager: %s", err)
|
||||
}
|
||||
|
||||
t.Log("open connection to another peer")
|
||||
conn, err := mgr.OpenConn(toURL(srvCfg2)[0], "anotherpeer")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
t.Log("close conn")
|
||||
err = conn.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to close connection: %s", err)
|
||||
}
|
||||
|
||||
t.Logf("waiting for relay cleanup: %s", relayCleanupInterval+1*time.Second)
|
||||
time.Sleep(relayCleanupInterval + 1*time.Second)
|
||||
if len(mgr.relayClients) != 0 {
|
||||
t.Errorf("expected 0, got %d", len(mgr.relayClients))
|
||||
}
|
||||
|
||||
t.Logf("closing manager")
|
||||
}
|
||||
|
||||
func TestAutoReconnect(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
reconnectingTimeout = 2 * time.Second
|
||||
|
||||
srvCfg := server.ListenerConfig{
|
||||
Address: "localhost:1234",
|
||||
}
|
||||
srv, err := server.NewServer(otel.Meter(""), srvCfg.Address, false, av)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %s", err)
|
||||
}
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := srv.Listen(srvCfg)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
err := srv.Shutdown(ctx)
|
||||
if err != nil {
|
||||
log.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := waitForServerToStart(errChan); err != nil {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
mCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
clientAlice := NewManager(mCtx, toURL(srvCfg), "alice")
|
||||
err = clientAlice.Serve()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to serve manager: %s", err)
|
||||
}
|
||||
ra, err := clientAlice.RelayInstanceAddress()
|
||||
if err != nil {
|
||||
t.Errorf("failed to get relay address: %s", err)
|
||||
}
|
||||
conn, err := clientAlice.OpenConn(ra, "bob")
|
||||
if err != nil {
|
||||
t.Errorf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
t.Log("closing client relay connection")
|
||||
// todo figure out moc server
|
||||
_ = clientAlice.relayClient.relayConn.Close()
|
||||
t.Log("start test reading")
|
||||
_, err = conn.Read(make([]byte, 1))
|
||||
if err == nil {
|
||||
t.Errorf("unexpected reading from closed connection")
|
||||
}
|
||||
|
||||
log.Infof("waiting for reconnection")
|
||||
time.Sleep(reconnectingTimeout + 1*time.Second)
|
||||
|
||||
log.Infof("reopent the connection")
|
||||
_, err = clientAlice.OpenConn(ra, "bob")
|
||||
if err != nil {
|
||||
t.Errorf("failed to open channel: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNotifierDoubleAdd(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
srvCfg1 := server.ListenerConfig{
|
||||
Address: "localhost:1234",
|
||||
}
|
||||
srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %s", err)
|
||||
}
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := srv1.Listen(srvCfg1)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
err := srv1.Shutdown(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := waitForServerToStart(errChan); err != nil {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
idAlice := "alice"
|
||||
log.Debugf("connect by alice")
|
||||
mCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
clientAlice := NewManager(mCtx, toURL(srvCfg1), idAlice)
|
||||
err = clientAlice.Serve()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to serve manager: %s", err)
|
||||
}
|
||||
|
||||
conn1, err := clientAlice.OpenConn(clientAlice.ServerURLs()[0], "idBob")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
fnCloseListener := OnServerCloseListener(func() {
|
||||
log.Infof("close listener")
|
||||
})
|
||||
|
||||
err = clientAlice.AddCloseListener(clientAlice.ServerURLs()[0], fnCloseListener)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to add close listener: %s", err)
|
||||
}
|
||||
|
||||
err = clientAlice.AddCloseListener(clientAlice.ServerURLs()[0], fnCloseListener)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to add close listener: %s", err)
|
||||
}
|
||||
|
||||
err = conn1.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close connection: %s", err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func toURL(address server.ListenerConfig) []string {
|
||||
return []string{"rel://" + address.Address}
|
||||
}
|
||||
35
relay/cmd/env.go
Normal file
35
relay/cmd/env.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/pflag"
|
||||
)
|
||||
|
||||
// setFlagsFromEnvVars reads and updates flag values from environment variables with prefix NB_
|
||||
func setFlagsFromEnvVars(cmd *cobra.Command) {
|
||||
flags := cmd.PersistentFlags()
|
||||
flags.VisitAll(func(f *pflag.Flag) {
|
||||
newEnvVar := flagNameToEnvVar(f.Name, "NB_")
|
||||
value, present := os.LookupEnv(newEnvVar)
|
||||
if !present {
|
||||
return
|
||||
}
|
||||
|
||||
err := flags.Set(f.Name, value)
|
||||
if err != nil {
|
||||
log.Infof("unable to configure flag %s using variable %s, err: %v", f.Name, newEnvVar, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// flagNameToEnvVar converts flag name to environment var name adding a prefix,
|
||||
// replacing dashes and making all uppercase (e.g. setup-keys is converted to NB_SETUP_KEYS according to the input prefix)
|
||||
func flagNameToEnvVar(cmdFlag string, prefix string) string {
|
||||
parsed := strings.ReplaceAll(cmdFlag, "-", "_")
|
||||
upper := strings.ToUpper(parsed)
|
||||
return prefix + upper
|
||||
}
|
||||
214
relay/cmd/root.go
Normal file
214
relay/cmd/root.go
Normal file
@@ -0,0 +1,214 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
auth "github.com/netbirdio/netbird/relay/auth/hmac"
|
||||
"github.com/netbirdio/netbird/relay/server"
|
||||
"github.com/netbirdio/netbird/signal/metrics"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
const (
|
||||
metricsPort = 9090
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
ListenAddress string
|
||||
// in HA every peer connect to a common domain, the instance domain has been distributed during the p2p connection
|
||||
// it is a domain:port or ip:port
|
||||
ExposedAddress string
|
||||
LetsencryptEmail string
|
||||
LetsencryptDataDir string
|
||||
LetsencryptDomains []string
|
||||
// in case of using Route 53 for DNS challenge the credentials should be provided in the environment variables or
|
||||
// in the AWS credentials file
|
||||
LetsencryptAWSRoute53 bool
|
||||
TlsCertFile string
|
||||
TlsKeyFile string
|
||||
AuthSecret string
|
||||
LogLevel string
|
||||
LogFile string
|
||||
}
|
||||
|
||||
func (c Config) Validate() error {
|
||||
if c.ExposedAddress == "" {
|
||||
return fmt.Errorf("exposed address is required")
|
||||
}
|
||||
if c.AuthSecret == "" {
|
||||
return fmt.Errorf("auth secret is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c Config) HasCertConfig() bool {
|
||||
return c.TlsCertFile != "" && c.TlsKeyFile != ""
|
||||
}
|
||||
|
||||
func (c Config) HasLetsEncrypt() bool {
|
||||
return c.LetsencryptDataDir != "" && c.LetsencryptDomains != nil && len(c.LetsencryptDomains) > 0
|
||||
}
|
||||
|
||||
var (
|
||||
cobraConfig *Config
|
||||
rootCmd = &cobra.Command{
|
||||
Use: "relay",
|
||||
Short: "Relay service",
|
||||
Long: "Relay service for Netbird agents",
|
||||
SilenceUsage: true,
|
||||
SilenceErrors: true,
|
||||
RunE: execute,
|
||||
}
|
||||
)
|
||||
|
||||
func init() {
|
||||
_ = util.InitLog("trace", "console")
|
||||
cobraConfig = &Config{}
|
||||
rootCmd.PersistentFlags().StringVarP(&cobraConfig.ListenAddress, "listen-address", "l", ":443", "listen address")
|
||||
rootCmd.PersistentFlags().StringVarP(&cobraConfig.ExposedAddress, "exposed-address", "e", "", "instance domain address (or ip) and port, it will be distributes between peers")
|
||||
rootCmd.PersistentFlags().StringVarP(&cobraConfig.LetsencryptDataDir, "letsencrypt-data-dir", "d", "", "a directory to store Let's Encrypt data. Required if Let's Encrypt is enabled.")
|
||||
rootCmd.PersistentFlags().StringSliceVarP(&cobraConfig.LetsencryptDomains, "letsencrypt-domains", "a", nil, "list of domains to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
|
||||
rootCmd.PersistentFlags().StringVar(&cobraConfig.LetsencryptEmail, "letsencrypt-email", "", "email address to use for Let's Encrypt certificate registration")
|
||||
rootCmd.PersistentFlags().BoolVar(&cobraConfig.LetsencryptAWSRoute53, "letsencrypt-aws-route53", false, "use AWS Route 53 for Let's Encrypt DNS challenge")
|
||||
rootCmd.PersistentFlags().StringVarP(&cobraConfig.TlsCertFile, "tls-cert-file", "c", "", "")
|
||||
rootCmd.PersistentFlags().StringVarP(&cobraConfig.TlsKeyFile, "tls-key-file", "k", "", "")
|
||||
rootCmd.PersistentFlags().StringVarP(&cobraConfig.AuthSecret, "auth-secret", "s", "", "auth secret")
|
||||
rootCmd.PersistentFlags().StringVar(&cobraConfig.LogLevel, "log-level", "info", "log level")
|
||||
rootCmd.PersistentFlags().StringVar(&cobraConfig.LogFile, "log-file", "console", "log file")
|
||||
|
||||
setFlagsFromEnvVars(rootCmd)
|
||||
}
|
||||
|
||||
func Execute() error {
|
||||
return rootCmd.Execute()
|
||||
}
|
||||
|
||||
func waitForExitSignal() {
|
||||
osSigs := make(chan os.Signal, 1)
|
||||
signal.Notify(osSigs, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-osSigs
|
||||
}
|
||||
|
||||
func execute(cmd *cobra.Command, args []string) error {
|
||||
err := cobraConfig.Validate()
|
||||
if err != nil {
|
||||
log.Debugf("invalid config: %s", err)
|
||||
return fmt.Errorf("invalid config: %s", err)
|
||||
}
|
||||
|
||||
err = util.InitLog(cobraConfig.LogLevel, cobraConfig.LogFile)
|
||||
if err != nil {
|
||||
log.Debugf("failed to initialize log: %s", err)
|
||||
return fmt.Errorf("failed to initialize log: %s", err)
|
||||
}
|
||||
|
||||
metricsServer, err := metrics.NewServer(metricsPort, "")
|
||||
if err != nil {
|
||||
log.Debugf("setup metrics: %v", err)
|
||||
return fmt.Errorf("setup metrics: %v", err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
log.Infof("running metrics server: %s%s", metricsServer.Addr, metricsServer.Endpoint)
|
||||
if err := metricsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
|
||||
log.Fatalf("Failed to start metrics server: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
srvListenerCfg := server.ListenerConfig{
|
||||
Address: cobraConfig.ListenAddress,
|
||||
}
|
||||
|
||||
tlsConfig, tlsSupport, err := handleTLSConfig(cobraConfig)
|
||||
if err != nil {
|
||||
log.Debugf("failed to setup TLS config: %s", err)
|
||||
return fmt.Errorf("failed to setup TLS config: %s", err)
|
||||
}
|
||||
srvListenerCfg.TLSConfig = tlsConfig
|
||||
|
||||
authenticator := auth.NewTimedHMACValidator(cobraConfig.AuthSecret, 24*time.Hour)
|
||||
srv, err := server.NewServer(metricsServer.Meter, cobraConfig.ExposedAddress, tlsSupport, authenticator)
|
||||
if err != nil {
|
||||
log.Debugf("failed to create relay server: %v", err)
|
||||
return fmt.Errorf("failed to create relay server: %v", err)
|
||||
}
|
||||
log.Infof("server will be available on: %s", srv.InstanceURL())
|
||||
go func() {
|
||||
if err := srv.Listen(srvListenerCfg); err != nil {
|
||||
log.Fatalf("failed to bind server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// it will block until exit signal
|
||||
waitForExitSignal()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
var shutDownErrors error
|
||||
if err := srv.Shutdown(ctx); err != nil {
|
||||
shutDownErrors = multierror.Append(shutDownErrors, fmt.Errorf("failed to close server: %s", err))
|
||||
}
|
||||
|
||||
log.Infof("shutting down metrics server")
|
||||
if err := metricsServer.Shutdown(ctx); err != nil {
|
||||
shutDownErrors = multierror.Append(shutDownErrors, fmt.Errorf("failed to close metrics server: %v", err))
|
||||
}
|
||||
return shutDownErrors
|
||||
}
|
||||
|
||||
func handleTLSConfig(cfg *Config) (*tls.Config, bool, error) {
|
||||
if cfg.LetsencryptAWSRoute53 {
|
||||
log.Debugf("using Let's Encrypt DNS resolver with Route 53 support")
|
||||
r53 := encryption.Route53TLS{
|
||||
DataDir: cfg.LetsencryptDataDir,
|
||||
Email: cfg.LetsencryptEmail,
|
||||
Domains: cfg.LetsencryptDomains,
|
||||
}
|
||||
tlsCfg, err := r53.GetCertificate()
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("%s", err)
|
||||
}
|
||||
return tlsCfg, true, nil
|
||||
}
|
||||
|
||||
if cfg.HasLetsEncrypt() {
|
||||
log.Infof("setting up TLS with Let's Encrypt.")
|
||||
tlsCfg, err := setupTLSCertManager(cfg.LetsencryptDataDir, cfg.LetsencryptDomains...)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("%s", err)
|
||||
}
|
||||
return tlsCfg, true, nil
|
||||
}
|
||||
|
||||
if cfg.HasCertConfig() {
|
||||
log.Debugf("using file based TLS config")
|
||||
tlsCfg, err := encryption.LoadTLSConfig(cfg.TlsCertFile, cfg.TlsKeyFile)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("%s", err)
|
||||
}
|
||||
return tlsCfg, true, nil
|
||||
}
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
func setupTLSCertManager(letsencryptDataDir string, letsencryptDomains ...string) (*tls.Config, error) {
|
||||
certManager, err := encryption.CreateCertManager(letsencryptDataDir, letsencryptDomains...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed creating LetsEncrypt cert manager: %v", err)
|
||||
}
|
||||
return certManager.TLSConfig(), nil
|
||||
}
|
||||
14
relay/doc.go
Normal file
14
relay/doc.go
Normal file
@@ -0,0 +1,14 @@
|
||||
//Package main
|
||||
/*
|
||||
The `relay` package contains the implementation of the Relay server and client. The Relay server can be used to relay
|
||||
messages between peers on a single network channel. In this implementation the transport layer is the WebSocket
|
||||
protocol.
|
||||
|
||||
Between the server and client communication has been design a custom protocol and message format. These messages are
|
||||
transported over the WebSocket connection. Optionally the server can use TLS to secure the communication.
|
||||
|
||||
The service can support multiple Relay server instances. For this purpose the peers must know the server instance URL.
|
||||
This URL will be sent to the target peer to choose the common Relay server for the communication via Signal service.
|
||||
|
||||
*/
|
||||
package main
|
||||
17
relay/healthcheck/doc.go
Normal file
17
relay/healthcheck/doc.go
Normal file
@@ -0,0 +1,17 @@
|
||||
/*
|
||||
The `healthcheck` package is responsible for managing the health checks between the client and the relay server. It
|
||||
ensures that the connection between the client and the server are alive and functioning properly.
|
||||
|
||||
The `Sender` struct is responsible for sending health check signals to the receiver. The receiver listens for these
|
||||
signals and sends a new signal back to the sender to acknowledge that the signal has been received. If the sender does
|
||||
not receive an acknowledgment signal within a certain time frame, it will send a timeout signal via timeout channel
|
||||
and stop working.
|
||||
|
||||
The `Receiver` struct is responsible for receiving the health check signals from the sender. If the receiver does not
|
||||
receive a signal within a certain time frame, it will send a timeout signal via the OnTimeout channel and stop working.
|
||||
|
||||
In the Relay usage the signal is sent to the peer in message type Healthcheck. In case of timeout the connection is
|
||||
closed and the peer is removed from the relay.
|
||||
*/
|
||||
|
||||
package healthcheck
|
||||
82
relay/healthcheck/receiver.go
Normal file
82
relay/healthcheck/receiver.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package healthcheck
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
heartbeatTimeout = healthCheckInterval + 3*time.Second
|
||||
)
|
||||
|
||||
// Receiver is a healthcheck receiver
|
||||
// It will listen for heartbeat and check if the heartbeat is not received in a certain time
|
||||
// If the heartbeat is not received in a certain time, it will send a timeout signal and stop to work
|
||||
// The heartbeat timeout is a bit longer than the sender's healthcheck interval
|
||||
type Receiver struct {
|
||||
OnTimeout chan struct{}
|
||||
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
heartbeat chan struct{}
|
||||
alive bool
|
||||
}
|
||||
|
||||
// NewReceiver creates a new healthcheck receiver and start the timer in the background
|
||||
func NewReceiver() *Receiver {
|
||||
ctx, ctxCancel := context.WithCancel(context.Background())
|
||||
|
||||
r := &Receiver{
|
||||
OnTimeout: make(chan struct{}, 1),
|
||||
ctx: ctx,
|
||||
ctxCancel: ctxCancel,
|
||||
heartbeat: make(chan struct{}, 1),
|
||||
}
|
||||
|
||||
go r.waitForHealthcheck()
|
||||
return r
|
||||
}
|
||||
|
||||
// Heartbeat acknowledge the heartbeat has been received
|
||||
func (r *Receiver) Heartbeat() {
|
||||
select {
|
||||
case r.heartbeat <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// Stop check the timeout and do not send new notifications
|
||||
func (r *Receiver) Stop() {
|
||||
r.ctxCancel()
|
||||
}
|
||||
|
||||
func (r *Receiver) waitForHealthcheck() {
|
||||
ticker := time.NewTicker(heartbeatTimeout)
|
||||
defer ticker.Stop()
|
||||
defer r.ctxCancel()
|
||||
defer close(r.OnTimeout)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-r.heartbeat:
|
||||
r.alive = true
|
||||
case <-ticker.C:
|
||||
if r.alive {
|
||||
r.alive = false
|
||||
continue
|
||||
}
|
||||
|
||||
r.notifyTimeout()
|
||||
return
|
||||
case <-r.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Receiver) notifyTimeout() {
|
||||
select {
|
||||
case r.OnTimeout <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
42
relay/healthcheck/receiver_test.go
Normal file
42
relay/healthcheck/receiver_test.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package healthcheck
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewReceiver(t *testing.T) {
|
||||
heartbeatTimeout = 5 * time.Second
|
||||
r := NewReceiver()
|
||||
|
||||
select {
|
||||
case <-r.OnTimeout:
|
||||
t.Error("unexpected timeout")
|
||||
case <-time.After(1 * time.Second):
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewReceiverNotReceive(t *testing.T) {
|
||||
heartbeatTimeout = 1 * time.Second
|
||||
r := NewReceiver()
|
||||
|
||||
select {
|
||||
case <-r.OnTimeout:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Error("timeout not received")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewReceiverAck(t *testing.T) {
|
||||
heartbeatTimeout = 2 * time.Second
|
||||
r := NewReceiver()
|
||||
|
||||
r.Heartbeat()
|
||||
|
||||
select {
|
||||
case <-r.OnTimeout:
|
||||
t.Error("unexpected timeout")
|
||||
case <-time.After(3 * time.Second):
|
||||
}
|
||||
}
|
||||
68
relay/healthcheck/sender.go
Normal file
68
relay/healthcheck/sender.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package healthcheck
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
healthCheckInterval = 25 * time.Second
|
||||
healthCheckTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
// Sender is a healthcheck sender
|
||||
// It will send healthcheck signal to the receiver
|
||||
// If the receiver does not receive the signal in a certain time, it will send a timeout signal and stop to work
|
||||
// It will also stop if the context is canceled
|
||||
type Sender struct {
|
||||
// HealthCheck is a channel to send health check signal to the peer
|
||||
HealthCheck chan struct{}
|
||||
// Timeout is a channel to the health check signal is not received in a certain time
|
||||
Timeout chan struct{}
|
||||
|
||||
ack chan struct{}
|
||||
}
|
||||
|
||||
// NewSender creates a new healthcheck sender
|
||||
func NewSender() *Sender {
|
||||
hc := &Sender{
|
||||
HealthCheck: make(chan struct{}, 1),
|
||||
Timeout: make(chan struct{}, 1),
|
||||
ack: make(chan struct{}, 1),
|
||||
}
|
||||
|
||||
return hc
|
||||
}
|
||||
|
||||
// OnHCResponse sends an acknowledgment signal to the sender
|
||||
func (hc *Sender) OnHCResponse() {
|
||||
select {
|
||||
case hc.ack <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (hc *Sender) StartHealthCheck(ctx context.Context) {
|
||||
ticker := time.NewTicker(healthCheckInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
timeoutTimer := time.NewTimer(healthCheckInterval + healthCheckTimeout)
|
||||
defer timeoutTimer.Stop()
|
||||
|
||||
defer close(hc.HealthCheck)
|
||||
defer close(hc.Timeout)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
hc.HealthCheck <- struct{}{}
|
||||
case <-timeoutTimer.C:
|
||||
hc.Timeout <- struct{}{}
|
||||
return
|
||||
case <-hc.ack:
|
||||
timeoutTimer.Reset(healthCheckInterval + healthCheckTimeout)
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
103
relay/healthcheck/sender_test.go
Normal file
103
relay/healthcheck/sender_test.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package healthcheck
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
// override the health check interval to speed up the test
|
||||
healthCheckInterval = 2 * time.Second
|
||||
healthCheckTimeout = 100 * time.Millisecond
|
||||
code := m.Run()
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func TestNewHealthPeriod(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
hc := NewSender()
|
||||
go hc.StartHealthCheck(ctx)
|
||||
|
||||
iterations := 0
|
||||
for i := 0; i < 3; i++ {
|
||||
select {
|
||||
case <-hc.HealthCheck:
|
||||
iterations++
|
||||
hc.OnHCResponse()
|
||||
case <-hc.Timeout:
|
||||
t.Fatalf("health check is timed out")
|
||||
case <-time.After(healthCheckInterval + 100*time.Millisecond):
|
||||
t.Fatalf("health check not received")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewHealthFailed(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
hc := NewSender()
|
||||
go hc.StartHealthCheck(ctx)
|
||||
|
||||
select {
|
||||
case <-hc.Timeout:
|
||||
case <-time.After(healthCheckInterval + healthCheckTimeout + 100*time.Millisecond):
|
||||
t.Fatalf("health check is not timed out")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewHealthcheckStop(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
hc := NewSender()
|
||||
go hc.StartHealthCheck(ctx)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case _, ok := <-hc.HealthCheck:
|
||||
if ok {
|
||||
t.Fatalf("health check on received")
|
||||
}
|
||||
case _, ok := <-hc.Timeout:
|
||||
if ok {
|
||||
t.Fatalf("health check on received")
|
||||
}
|
||||
case <-ctx.Done():
|
||||
// expected
|
||||
case <-time.After(10 * time.Second):
|
||||
t.Fatalf("is not exited")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimeoutReset(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
hc := NewSender()
|
||||
go hc.StartHealthCheck(ctx)
|
||||
|
||||
iterations := 0
|
||||
for i := 0; i < 3; i++ {
|
||||
select {
|
||||
case <-hc.HealthCheck:
|
||||
iterations++
|
||||
hc.OnHCResponse()
|
||||
case <-hc.Timeout:
|
||||
t.Fatalf("health check is timed out")
|
||||
case <-time.After(healthCheckInterval + 100*time.Millisecond):
|
||||
t.Fatalf("health check not received")
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-hc.HealthCheck:
|
||||
case <-hc.Timeout:
|
||||
// expected
|
||||
case <-ctx.Done():
|
||||
t.Fatalf("context is done")
|
||||
case <-time.After(10 * time.Second):
|
||||
t.Fatalf("is not exited")
|
||||
}
|
||||
}
|
||||
13
relay/main.go
Normal file
13
relay/main.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/cmd"
|
||||
)
|
||||
|
||||
func main() {
|
||||
if err := cmd.Execute(); err != nil {
|
||||
log.Fatalf("failed to execute command: %v", err)
|
||||
}
|
||||
}
|
||||
30
relay/messages/address/address.go
Normal file
30
relay/messages/address/address.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package address
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/gob"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type Address struct {
|
||||
URL string
|
||||
}
|
||||
|
||||
func (addr *Address) Marshal() ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
enc := gob.NewEncoder(&buf)
|
||||
if err := enc.Encode(addr); err != nil {
|
||||
return nil, fmt.Errorf("encode Address: %w", err)
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func Unmarshal(data []byte) (*Address, error) {
|
||||
var addr Address
|
||||
buf := bytes.NewBuffer(data)
|
||||
dec := gob.NewDecoder(buf)
|
||||
if err := dec.Decode(&addr); err != nil {
|
||||
return nil, fmt.Errorf("decode Address: %w", err)
|
||||
}
|
||||
return &addr, nil
|
||||
}
|
||||
51
relay/messages/auth/auth.go
Normal file
51
relay/messages/auth/auth.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/gob"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type Algorithm int
|
||||
|
||||
const (
|
||||
AlgoUnknown Algorithm = iota
|
||||
AlgoHMACSHA256
|
||||
AlgoHMACSHA512
|
||||
)
|
||||
|
||||
func (a Algorithm) String() string {
|
||||
switch a {
|
||||
case AlgoHMACSHA256:
|
||||
return "HMAC-SHA256"
|
||||
case AlgoHMACSHA512:
|
||||
return "HMAC-SHA512"
|
||||
default:
|
||||
return "Unknown"
|
||||
}
|
||||
}
|
||||
|
||||
type Msg struct {
|
||||
AuthAlgorithm Algorithm
|
||||
AdditionalData []byte
|
||||
}
|
||||
|
||||
func (msg *Msg) Marshal() ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
enc := gob.NewEncoder(&buf)
|
||||
if err := enc.Encode(msg); err != nil {
|
||||
return nil, fmt.Errorf("encode Msg: %w", err)
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func UnmarshalMsg(data []byte) (*Msg, error) {
|
||||
var msg *Msg
|
||||
|
||||
buf := bytes.NewBuffer(data)
|
||||
dec := gob.NewDecoder(buf)
|
||||
if err := dec.Decode(&msg); err != nil {
|
||||
return nil, fmt.Errorf("decode Msg: %w", err)
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
5
relay/messages/doc.go
Normal file
5
relay/messages/doc.go
Normal file
@@ -0,0 +1,5 @@
|
||||
/*
|
||||
Package messages provides the message types that are used to communicate between the relay and the client.
|
||||
This package is used to determine the type of message that is being sent and received between the relay and the client.
|
||||
*/
|
||||
package messages
|
||||
31
relay/messages/id.go
Normal file
31
relay/messages/id.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package messages
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
const (
|
||||
prefixLength = 4
|
||||
IDSize = prefixLength + sha256.Size
|
||||
)
|
||||
|
||||
var (
|
||||
prefix = []byte("sha-") // 4 bytes
|
||||
)
|
||||
|
||||
// HashID generates a sha256 hash from the peerID and returns the hash and the human-readable string
|
||||
func HashID(peerID string) ([]byte, string) {
|
||||
idHash := sha256.Sum256([]byte(peerID))
|
||||
idHashString := string(prefix) + base64.StdEncoding.EncodeToString(idHash[:])
|
||||
var prefixedHash []byte
|
||||
prefixedHash = append(prefixedHash, prefix...)
|
||||
prefixedHash = append(prefixedHash, idHash[:]...)
|
||||
return prefixedHash, idHashString
|
||||
}
|
||||
|
||||
// HashIDToString converts a hash to a human-readable string
|
||||
func HashIDToString(idHash []byte) string {
|
||||
return fmt.Sprintf("%s%s", idHash[:prefixLength], base64.StdEncoding.EncodeToString(idHash[prefixLength:]))
|
||||
}
|
||||
13
relay/messages/id_test.go
Normal file
13
relay/messages/id_test.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package messages
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHashID(t *testing.T) {
|
||||
hashedID, hashedStringId := HashID("alice")
|
||||
enc := HashIDToString(hashedID)
|
||||
if enc != hashedStringId {
|
||||
t.Errorf("expected %s, got %s", hashedStringId, enc)
|
||||
}
|
||||
}
|
||||
239
relay/messages/message.go
Normal file
239
relay/messages/message.go
Normal file
@@ -0,0 +1,239 @@
|
||||
package messages
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
const (
|
||||
MsgTypeUnknown MsgType = 0
|
||||
MsgTypeHello MsgType = 1
|
||||
MsgTypeHelloResponse MsgType = 2
|
||||
MsgTypeTransport MsgType = 3
|
||||
MsgTypeClose MsgType = 4
|
||||
MsgTypeHealthCheck MsgType = 5
|
||||
|
||||
SizeOfVersionByte = 1
|
||||
SizeOfMsgType = 1
|
||||
|
||||
SizeOfProtoHeader = SizeOfVersionByte + SizeOfMsgType
|
||||
|
||||
sizeOfMagicByte = 4
|
||||
|
||||
headerSizeTransport = IDSize
|
||||
headerSizeHello = sizeOfMagicByte + IDSize
|
||||
headerSizeHelloResp = 0
|
||||
|
||||
MaxHandshakeSize = 8192
|
||||
|
||||
CurrentProtocolVersion = 1
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidMessageLength = errors.New("invalid message length")
|
||||
ErrUnsupportedVersion = errors.New("unsupported version")
|
||||
|
||||
magicHeader = []byte{0x21, 0x12, 0xA4, 0x42}
|
||||
|
||||
healthCheckMsg = []byte{byte(CurrentProtocolVersion), byte(MsgTypeHealthCheck)}
|
||||
)
|
||||
|
||||
type MsgType byte
|
||||
|
||||
func (m MsgType) String() string {
|
||||
switch m {
|
||||
case MsgTypeHello:
|
||||
return "hello"
|
||||
case MsgTypeHelloResponse:
|
||||
return "hello response"
|
||||
case MsgTypeTransport:
|
||||
return "transport"
|
||||
case MsgTypeClose:
|
||||
return "close"
|
||||
case MsgTypeHealthCheck:
|
||||
return "health check"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
type HelloResponse struct {
|
||||
InstanceAddress string
|
||||
}
|
||||
|
||||
// ValidateVersion checks if the given version is supported by the protocol
|
||||
func ValidateVersion(msg []byte) (int, error) {
|
||||
if len(msg) < SizeOfVersionByte {
|
||||
return 0, ErrInvalidMessageLength
|
||||
}
|
||||
version := int(msg[0])
|
||||
if version != CurrentProtocolVersion {
|
||||
return 0, fmt.Errorf("%d: %w", version, ErrUnsupportedVersion)
|
||||
}
|
||||
return version, nil
|
||||
}
|
||||
|
||||
// DetermineClientMessageType determines the message type from the first the message
|
||||
func DetermineClientMessageType(msg []byte) (MsgType, error) {
|
||||
if len(msg) < SizeOfMsgType {
|
||||
return 0, ErrInvalidMessageLength
|
||||
}
|
||||
|
||||
msgType := MsgType(msg[0])
|
||||
switch msgType {
|
||||
case
|
||||
MsgTypeHello,
|
||||
MsgTypeTransport,
|
||||
MsgTypeClose,
|
||||
MsgTypeHealthCheck:
|
||||
return msgType, nil
|
||||
default:
|
||||
return MsgTypeUnknown, fmt.Errorf("invalid msg type %d", msgType)
|
||||
}
|
||||
}
|
||||
|
||||
// DetermineServerMessageType determines the message type from the first the message
|
||||
func DetermineServerMessageType(msg []byte) (MsgType, error) {
|
||||
if len(msg) < SizeOfMsgType {
|
||||
return 0, ErrInvalidMessageLength
|
||||
}
|
||||
|
||||
msgType := MsgType(msg[0])
|
||||
switch msgType {
|
||||
case
|
||||
MsgTypeHelloResponse,
|
||||
MsgTypeTransport,
|
||||
MsgTypeClose,
|
||||
MsgTypeHealthCheck:
|
||||
return msgType, nil
|
||||
default:
|
||||
return MsgTypeUnknown, fmt.Errorf("invalid msg type %d", msgType)
|
||||
}
|
||||
}
|
||||
|
||||
// MarshalHelloMsg initial hello message
|
||||
// The Hello message is the first message sent by a client after establishing a connection with the Relay server. This
|
||||
// message is used to authenticate the client with the server. The authentication is done using an HMAC method.
|
||||
// The protocol does not limit to use HMAC, it can be any other method. If the authentication failed the server will
|
||||
// close the network connection without any response.
|
||||
func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) {
|
||||
if len(peerID) != IDSize {
|
||||
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
|
||||
}
|
||||
|
||||
msg := make([]byte, SizeOfProtoHeader+sizeOfMagicByte, SizeOfProtoHeader+headerSizeHello+len(additions))
|
||||
|
||||
msg[0] = byte(CurrentProtocolVersion)
|
||||
msg[1] = byte(MsgTypeHello)
|
||||
|
||||
copy(msg[SizeOfProtoHeader:SizeOfProtoHeader+sizeOfMagicByte], magicHeader)
|
||||
|
||||
msg = append(msg, peerID...)
|
||||
msg = append(msg, additions...)
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// UnmarshalHelloMsg extracts peerID and the additional data from the hello message. The Additional data is used to
|
||||
// authenticate the client with the server.
|
||||
func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) {
|
||||
if len(msg) < headerSizeHello {
|
||||
return nil, nil, ErrInvalidMessageLength
|
||||
}
|
||||
if !bytes.Equal(msg[:sizeOfMagicByte], magicHeader) {
|
||||
return nil, nil, errors.New("invalid magic header")
|
||||
}
|
||||
|
||||
return msg[sizeOfMagicByte:headerSizeHello], msg[headerSizeHello:], nil
|
||||
}
|
||||
|
||||
// MarshalHelloResponse creates a response message to the hello message.
|
||||
// In case of success connection the server response with a Hello Response message. This message contains the server's
|
||||
// instance URL. This URL will be used by choose the common Relay server in case if the peers are in different Relay
|
||||
// servers.
|
||||
func MarshalHelloResponse(additionalData []byte) ([]byte, error) {
|
||||
msg := make([]byte, SizeOfProtoHeader, SizeOfProtoHeader+headerSizeHelloResp+len(additionalData))
|
||||
|
||||
msg[0] = byte(CurrentProtocolVersion)
|
||||
msg[1] = byte(MsgTypeHelloResponse)
|
||||
|
||||
msg = append(msg, additionalData...)
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// UnmarshalHelloResponse extracts the additional data from the hello response message.
|
||||
func UnmarshalHelloResponse(msg []byte) ([]byte, error) {
|
||||
if len(msg) < headerSizeHelloResp {
|
||||
return nil, ErrInvalidMessageLength
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// MarshalCloseMsg creates a close message.
|
||||
// The close message is used to close the connection gracefully between the client and the server. The server and the
|
||||
// client can send this message. After receiving this message, the server or client will close the connection.
|
||||
func MarshalCloseMsg() []byte {
|
||||
msg := make([]byte, SizeOfProtoHeader)
|
||||
|
||||
msg[0] = byte(CurrentProtocolVersion)
|
||||
msg[1] = byte(MsgTypeClose)
|
||||
|
||||
return msg
|
||||
}
|
||||
|
||||
// MarshalTransportMsg creates a transport message.
|
||||
// The transport message is used to exchange data between peers. The message contains the data to be exchanged and the
|
||||
// destination peer hashed ID.
|
||||
func MarshalTransportMsg(peerID []byte, payload []byte) ([]byte, error) {
|
||||
if len(peerID) != IDSize {
|
||||
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
|
||||
}
|
||||
|
||||
msg := make([]byte, SizeOfProtoHeader+headerSizeTransport, SizeOfProtoHeader+headerSizeTransport+len(payload))
|
||||
|
||||
msg[0] = byte(CurrentProtocolVersion)
|
||||
msg[1] = byte(MsgTypeTransport)
|
||||
|
||||
copy(msg[SizeOfProtoHeader:], peerID)
|
||||
|
||||
msg = append(msg, payload...)
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// UnmarshalTransportMsg extracts the peerID and the payload from the transport message.
|
||||
func UnmarshalTransportMsg(buf []byte) ([]byte, []byte, error) {
|
||||
if len(buf) < headerSizeTransport {
|
||||
return nil, nil, ErrInvalidMessageLength
|
||||
}
|
||||
|
||||
return buf[:headerSizeTransport], buf[headerSizeTransport:], nil
|
||||
}
|
||||
|
||||
// UnmarshalTransportID extracts the peerID from the transport message.
|
||||
func UnmarshalTransportID(buf []byte) ([]byte, error) {
|
||||
if len(buf) < headerSizeTransport {
|
||||
return nil, ErrInvalidMessageLength
|
||||
}
|
||||
return buf[:headerSizeTransport], nil
|
||||
}
|
||||
|
||||
// UpdateTransportMsg updates the peerID in the transport message.
|
||||
// With this function the server can reuse the given byte slice to update the peerID in the transport message. So do
|
||||
// need to allocate a new byte slice.
|
||||
func UpdateTransportMsg(msg []byte, peerID []byte) error {
|
||||
if len(msg) < len(peerID) {
|
||||
return ErrInvalidMessageLength
|
||||
}
|
||||
copy(msg, peerID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarshalHealthcheck creates a health check message.
|
||||
// Health check message is sent by the server periodically. The client will respond with a health check response
|
||||
// message. If the client does not respond to the health check message, the server will close the connection.
|
||||
func MarshalHealthcheck() []byte {
|
||||
return healthCheckMsg
|
||||
}
|
||||
43
relay/messages/message_test.go
Normal file
43
relay/messages/message_test.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package messages
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMarshalHelloMsg(t *testing.T) {
|
||||
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
|
||||
bHello, err := MarshalHelloMsg(peerID, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
receivedPeerID, _, err := UnmarshalHelloMsg(bHello[SizeOfProtoHeader:])
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
if string(receivedPeerID) != string(peerID) {
|
||||
t.Errorf("expected %s, got %s", peerID, receivedPeerID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalTransportMsg(t *testing.T) {
|
||||
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
|
||||
payload := []byte("payload")
|
||||
msg, err := MarshalTransportMsg(peerID, payload)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
id, respPayload, err := UnmarshalTransportMsg(msg[SizeOfProtoHeader:])
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
if string(id) != string(peerID) {
|
||||
t.Errorf("expected %s, got %s", peerID, id)
|
||||
}
|
||||
|
||||
if string(respPayload) != string(payload) {
|
||||
t.Errorf("expected %s, got %s", payload, respPayload)
|
||||
}
|
||||
}
|
||||
136
relay/metrics/realy.go
Normal file
136
relay/metrics/realy.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package metrics
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
)
|
||||
|
||||
const (
|
||||
idleTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
type Metrics struct {
|
||||
metric.Meter
|
||||
|
||||
TransferBytesSent metric.Int64Counter
|
||||
TransferBytesRecv metric.Int64Counter
|
||||
|
||||
peers metric.Int64UpDownCounter
|
||||
peerActivityChan chan string
|
||||
peerLastActive map[string]time.Time
|
||||
mutexActivity sync.Mutex
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func NewMetrics(ctx context.Context, meter metric.Meter) (*Metrics, error) {
|
||||
bytesSent, err := meter.Int64Counter("relay_transfer_sent_bytes_total")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bytesRecv, err := meter.Int64Counter("relay_transfer_received_bytes_total")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peers, err := meter.Int64UpDownCounter("relay_peers")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peersActive, err := meter.Int64ObservableGauge("relay_peers_active")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peersIdle, err := meter.Int64ObservableGauge("relay_peers_idle")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m := &Metrics{
|
||||
Meter: meter,
|
||||
TransferBytesSent: bytesSent,
|
||||
TransferBytesRecv: bytesRecv,
|
||||
peers: peers,
|
||||
|
||||
ctx: ctx,
|
||||
peerActivityChan: make(chan string, 10),
|
||||
peerLastActive: make(map[string]time.Time),
|
||||
}
|
||||
|
||||
_, err = meter.RegisterCallback(
|
||||
func(ctx context.Context, o metric.Observer) error {
|
||||
active, idle := m.calculateActiveIdleConnections()
|
||||
o.ObserveInt64(peersActive, active)
|
||||
o.ObserveInt64(peersIdle, idle)
|
||||
return nil
|
||||
},
|
||||
peersActive, peersIdle,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go m.readPeerActivity()
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// PeerConnected increments the number of connected peers and increments number of idle connections
|
||||
func (m *Metrics) PeerConnected(id string) {
|
||||
m.peers.Add(m.ctx, 1)
|
||||
m.mutexActivity.Lock()
|
||||
defer m.mutexActivity.Unlock()
|
||||
|
||||
m.peerLastActive[id] = time.Time{}
|
||||
}
|
||||
|
||||
// PeerDisconnected decrements the number of connected peers and decrements number of idle or active connections
|
||||
func (m *Metrics) PeerDisconnected(id string) {
|
||||
m.peers.Add(m.ctx, -1)
|
||||
m.mutexActivity.Lock()
|
||||
defer m.mutexActivity.Unlock()
|
||||
|
||||
delete(m.peerLastActive, id)
|
||||
}
|
||||
|
||||
// PeerActivity increases the active connections
|
||||
func (m *Metrics) PeerActivity(peerID string) {
|
||||
select {
|
||||
case m.peerActivityChan <- peerID:
|
||||
default:
|
||||
log.Errorf("peer activity channel is full, dropping activity metrics for peer %s", peerID)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Metrics) calculateActiveIdleConnections() (int64, int64) {
|
||||
active, idle := int64(0), int64(0)
|
||||
m.mutexActivity.Lock()
|
||||
defer m.mutexActivity.Unlock()
|
||||
|
||||
for _, lastActive := range m.peerLastActive {
|
||||
if time.Since(lastActive) > idleTimeout {
|
||||
idle++
|
||||
} else {
|
||||
active++
|
||||
}
|
||||
}
|
||||
return active, idle
|
||||
}
|
||||
|
||||
func (m *Metrics) readPeerActivity() {
|
||||
for {
|
||||
select {
|
||||
case peerID := <-m.peerActivityChan:
|
||||
m.mutexActivity.Lock()
|
||||
m.peerLastActive[peerID] = time.Now()
|
||||
m.mutexActivity.Unlock()
|
||||
case <-m.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
11
relay/server/listener/listener.go
Normal file
11
relay/server/listener/listener.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package listener
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
)
|
||||
|
||||
type Listener interface {
|
||||
Listen(func(conn net.Conn)) error
|
||||
Shutdown(ctx context.Context) error
|
||||
}
|
||||
114
relay/server/listener/ws/conn.go
Normal file
114
relay/server/listener/ws/conn.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"nhooyr.io/websocket"
|
||||
)
|
||||
|
||||
const (
|
||||
writeTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
type Conn struct {
|
||||
*websocket.Conn
|
||||
lAddr *net.TCPAddr
|
||||
rAddr *net.TCPAddr
|
||||
|
||||
closed bool
|
||||
closedMu sync.Mutex
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func NewConn(wsConn *websocket.Conn, lAddr, rAddr *net.TCPAddr) *Conn {
|
||||
return &Conn{
|
||||
Conn: wsConn,
|
||||
lAddr: lAddr,
|
||||
rAddr: rAddr,
|
||||
ctx: context.Background(),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) Read(b []byte) (n int, err error) {
|
||||
t, r, err := c.Reader(c.ctx)
|
||||
if err != nil {
|
||||
return 0, c.ioErrHandling(err)
|
||||
}
|
||||
|
||||
if t != websocket.MessageBinary {
|
||||
log.Errorf("unexpected message type: %d", t)
|
||||
return 0, fmt.Errorf("unexpected message type")
|
||||
}
|
||||
|
||||
n, err = r.Read(b)
|
||||
if err != nil {
|
||||
return 0, c.ioErrHandling(err)
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Write writes a binary message with the given payload.
|
||||
// It does not block until fill the internal buffer.
|
||||
// If the buffer filled up, wait until the buffer is drained or timeout.
|
||||
func (c *Conn) Write(b []byte) (int, error) {
|
||||
ctx, ctxCancel := context.WithTimeout(c.ctx, writeTimeout)
|
||||
defer ctxCancel()
|
||||
|
||||
err := c.Conn.Write(ctx, websocket.MessageBinary, b)
|
||||
return len(b), err
|
||||
}
|
||||
|
||||
func (c *Conn) LocalAddr() net.Addr {
|
||||
return c.lAddr
|
||||
}
|
||||
|
||||
func (c *Conn) RemoteAddr() net.Addr {
|
||||
return c.rAddr
|
||||
}
|
||||
|
||||
func (c *Conn) SetReadDeadline(t time.Time) error {
|
||||
return fmt.Errorf("SetReadDeadline is not implemented")
|
||||
}
|
||||
|
||||
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
||||
return fmt.Errorf("SetWriteDeadline is not implemented")
|
||||
}
|
||||
|
||||
func (c *Conn) SetDeadline(t time.Time) error {
|
||||
return fmt.Errorf("SetDeadline is not implemented")
|
||||
}
|
||||
|
||||
func (c *Conn) Close() error {
|
||||
c.closedMu.Lock()
|
||||
c.closed = true
|
||||
c.closedMu.Unlock()
|
||||
return c.Conn.CloseNow()
|
||||
}
|
||||
|
||||
func (c *Conn) isClosed() bool {
|
||||
c.closedMu.Lock()
|
||||
defer c.closedMu.Unlock()
|
||||
return c.closed
|
||||
}
|
||||
|
||||
func (c *Conn) ioErrHandling(err error) error {
|
||||
if c.isClosed() {
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
var wErr *websocket.CloseError
|
||||
if !errors.As(err, &wErr) {
|
||||
return err
|
||||
}
|
||||
if wErr.Code == websocket.StatusNormalClosure {
|
||||
return io.EOF
|
||||
}
|
||||
return err
|
||||
}
|
||||
92
relay/server/listener/ws/listener.go
Normal file
92
relay/server/listener/ws/listener.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"nhooyr.io/websocket"
|
||||
)
|
||||
|
||||
// URLPath is the path for the websocket connection.
|
||||
const URLPath = "/relay"
|
||||
|
||||
type Listener struct {
|
||||
// Address is the address to listen on.
|
||||
Address string
|
||||
// TLSConfig is the TLS configuration for the server.
|
||||
TLSConfig *tls.Config
|
||||
|
||||
server *http.Server
|
||||
acceptFn func(conn net.Conn)
|
||||
}
|
||||
|
||||
func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
|
||||
l.acceptFn = acceptFn
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc(URLPath, l.onAccept)
|
||||
|
||||
l.server = &http.Server{
|
||||
Addr: l.Address,
|
||||
Handler: mux,
|
||||
TLSConfig: l.TLSConfig,
|
||||
}
|
||||
|
||||
log.Infof("WS server listening address: %s", l.Address)
|
||||
var err error
|
||||
if l.TLSConfig != nil {
|
||||
err = l.server.ListenAndServeTLS("", "")
|
||||
} else {
|
||||
err = l.server.ListenAndServe()
|
||||
}
|
||||
if errors.Is(err, http.ErrServerClosed) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (l *Listener) Shutdown(ctx context.Context) error {
|
||||
if l.server == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Infof("stop WS listener")
|
||||
if err := l.server.Shutdown(ctx); err != nil {
|
||||
return fmt.Errorf("server shutdown failed: %v", err)
|
||||
}
|
||||
log.Infof("WS listener stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) {
|
||||
wsConn, err := websocket.Accept(w, r, nil)
|
||||
if err != nil {
|
||||
log.Errorf("failed to accept ws connection from %s: %s", r.RemoteAddr, err)
|
||||
return
|
||||
}
|
||||
|
||||
rAddr, err := net.ResolveTCPAddr("tcp", r.RemoteAddr)
|
||||
if err != nil {
|
||||
err = wsConn.Close(websocket.StatusInternalError, "internal error")
|
||||
if err != nil {
|
||||
log.Errorf("failed to close ws connection: %s", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
lAddr, err := net.ResolveTCPAddr("tcp", l.server.Addr)
|
||||
if err != nil {
|
||||
err = wsConn.Close(websocket.StatusInternalError, "internal error")
|
||||
if err != nil {
|
||||
log.Errorf("failed to close ws connection: %s", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
conn := NewConn(wsConn, lAddr, rAddr)
|
||||
l.acceptFn(conn)
|
||||
}
|
||||
203
relay/server/peer.go
Normal file
203
relay/server/peer.go
Normal file
@@ -0,0 +1,203 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/healthcheck"
|
||||
"github.com/netbirdio/netbird/relay/messages"
|
||||
"github.com/netbirdio/netbird/relay/metrics"
|
||||
)
|
||||
|
||||
const (
|
||||
bufferSize = 8820
|
||||
)
|
||||
|
||||
// Peer represents a peer connection
|
||||
type Peer struct {
|
||||
metrics *metrics.Metrics
|
||||
log *log.Entry
|
||||
idS string
|
||||
idB []byte
|
||||
conn net.Conn
|
||||
connMu sync.RWMutex
|
||||
store *Store
|
||||
}
|
||||
|
||||
// NewPeer creates a new Peer instance and prepare custom logging
|
||||
func NewPeer(metrics *metrics.Metrics, id []byte, conn net.Conn, store *Store) *Peer {
|
||||
stringID := messages.HashIDToString(id)
|
||||
return &Peer{
|
||||
metrics: metrics,
|
||||
log: log.WithField("peer_id", stringID),
|
||||
idS: stringID,
|
||||
idB: id,
|
||||
conn: conn,
|
||||
store: store,
|
||||
}
|
||||
}
|
||||
|
||||
// Work reads data from the connection
|
||||
// It manages the protocol (healthcheck, transport, close). Read the message and determine the message type and handle
|
||||
// the message accordingly.
|
||||
func (p *Peer) Work() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
hc := healthcheck.NewSender()
|
||||
go hc.StartHealthCheck(ctx)
|
||||
go p.handleHealthcheckEvents(ctx, hc)
|
||||
|
||||
buf := make([]byte, bufferSize)
|
||||
for {
|
||||
n, err := p.conn.Read(buf)
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
p.log.Errorf("failed to read message: %s", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if n == 0 {
|
||||
p.log.Errorf("received empty message")
|
||||
return
|
||||
}
|
||||
|
||||
msg := buf[:n]
|
||||
|
||||
_, err = messages.ValidateVersion(msg)
|
||||
if err != nil {
|
||||
p.log.Warnf("failed to validate protocol version: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
msgType, err := messages.DetermineClientMessageType(msg[messages.SizeOfVersionByte:])
|
||||
if err != nil {
|
||||
p.log.Errorf("failed to determine message type: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
p.handleMsgType(ctx, msgType, hc, n, msg)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Peer) handleMsgType(ctx context.Context, msgType messages.MsgType, hc *healthcheck.Sender, n int, msg []byte) {
|
||||
switch msgType {
|
||||
case messages.MsgTypeHealthCheck:
|
||||
hc.OnHCResponse()
|
||||
case messages.MsgTypeTransport:
|
||||
p.metrics.TransferBytesRecv.Add(ctx, int64(n))
|
||||
p.metrics.PeerActivity(p.String())
|
||||
p.handleTransportMsg(msg)
|
||||
case messages.MsgTypeClose:
|
||||
p.log.Infof("peer exited gracefully")
|
||||
if err := p.conn.Close(); err != nil {
|
||||
log.Errorf("failed to close connection to peer: %s", err)
|
||||
}
|
||||
default:
|
||||
p.log.Warnf("received unexpected message type: %s", msgType)
|
||||
}
|
||||
}
|
||||
|
||||
// Write writes data to the connection
|
||||
func (p *Peer) Write(b []byte) (int, error) {
|
||||
p.connMu.RLock()
|
||||
defer p.connMu.RUnlock()
|
||||
return p.conn.Write(b)
|
||||
}
|
||||
|
||||
// CloseGracefully closes the connection with the peer gracefully. Send a close message to the client and close the
|
||||
// connection.
|
||||
func (p *Peer) CloseGracefully(ctx context.Context) {
|
||||
p.connMu.Lock()
|
||||
err := p.writeWithTimeout(ctx, messages.MarshalCloseMsg())
|
||||
if err != nil {
|
||||
p.log.Errorf("failed to send close message to peer: %s", p.String())
|
||||
}
|
||||
|
||||
err = p.conn.Close()
|
||||
if err != nil {
|
||||
p.log.Errorf("failed to close connection to peer: %s", err)
|
||||
}
|
||||
|
||||
defer p.connMu.Unlock()
|
||||
}
|
||||
|
||||
// String returns the peer ID
|
||||
func (p *Peer) String() string {
|
||||
return p.idS
|
||||
}
|
||||
|
||||
func (p *Peer) writeWithTimeout(ctx context.Context, buf []byte) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
writeDone := make(chan struct{})
|
||||
var err error
|
||||
go func() {
|
||||
_, err = p.conn.Write(buf)
|
||||
close(writeDone)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-writeDone:
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Peer) handleHealthcheckEvents(ctx context.Context, hc *healthcheck.Sender) {
|
||||
for {
|
||||
select {
|
||||
case <-hc.HealthCheck:
|
||||
_, err := p.Write(messages.MarshalHealthcheck())
|
||||
if err != nil {
|
||||
p.log.Errorf("failed to send healthcheck message: %s", err)
|
||||
return
|
||||
}
|
||||
case <-hc.Timeout:
|
||||
p.log.Errorf("peer healthcheck timeout")
|
||||
err := p.conn.Close()
|
||||
if err != nil {
|
||||
p.log.Errorf("failed to close connection to peer: %s", err)
|
||||
}
|
||||
return
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Peer) handleTransportMsg(msg []byte) {
|
||||
peerID, err := messages.UnmarshalTransportID(msg[messages.SizeOfProtoHeader:])
|
||||
if err != nil {
|
||||
p.log.Errorf("failed to unmarshal transport message: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
stringPeerID := messages.HashIDToString(peerID)
|
||||
dp, ok := p.store.Peer(stringPeerID)
|
||||
if !ok {
|
||||
p.log.Errorf("peer not found: %s", stringPeerID)
|
||||
return
|
||||
}
|
||||
|
||||
err = messages.UpdateTransportMsg(msg[messages.SizeOfProtoHeader:], p.idB)
|
||||
if err != nil {
|
||||
p.log.Errorf("failed to update transport message: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
n, err := dp.Write(msg)
|
||||
if err != nil {
|
||||
p.log.Errorf("failed to write transport message to: %s", dp.String())
|
||||
return
|
||||
}
|
||||
p.metrics.TransferBytesSent.Add(context.Background(), int64(n))
|
||||
}
|
||||
206
relay/server/relay.go
Normal file
206
relay/server/relay.go
Normal file
@@ -0,0 +1,206 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/auth"
|
||||
"github.com/netbirdio/netbird/relay/messages"
|
||||
"github.com/netbirdio/netbird/relay/messages/address"
|
||||
authmsg "github.com/netbirdio/netbird/relay/messages/auth"
|
||||
"github.com/netbirdio/netbird/relay/metrics"
|
||||
)
|
||||
|
||||
// Relay represents the relay server
|
||||
type Relay struct {
|
||||
metrics *metrics.Metrics
|
||||
metricsCancel context.CancelFunc
|
||||
validator auth.Validator
|
||||
|
||||
store *Store
|
||||
instanceURL string
|
||||
|
||||
closed bool
|
||||
closeMu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewRelay creates a new Relay instance
|
||||
//
|
||||
// Parameters:
|
||||
// meter: An instance of metric.Meter from the go.opentelemetry.io/otel/metric package. It is used to create and manage
|
||||
// metrics for the relay server.
|
||||
// exposedAddress: A string representing the address that the relay server is exposed on. The client will use this
|
||||
// address as the relay server's instance URL.
|
||||
// tlsSupport: A boolean indicating whether the relay server supports TLS (Transport Layer Security) or not. The
|
||||
// instance URL depends on this value.
|
||||
// validator: An instance of auth.Validator from the auth package. It is used to validate the authentication of the
|
||||
// peers.
|
||||
//
|
||||
// Returns:
|
||||
// A pointer to a Relay instance and an error. If the Relay instance is successfully created, the error is nil.
|
||||
// Otherwise, the error contains the details of what went wrong.
|
||||
func NewRelay(meter metric.Meter, exposedAddress string, tlsSupport bool, validator auth.Validator) (*Relay, error) {
|
||||
ctx, metricsCancel := context.WithCancel(context.Background())
|
||||
m, err := metrics.NewMetrics(ctx, meter)
|
||||
if err != nil {
|
||||
metricsCancel()
|
||||
return nil, fmt.Errorf("creating app metrics: %v", err)
|
||||
}
|
||||
|
||||
r := &Relay{
|
||||
metrics: m,
|
||||
metricsCancel: metricsCancel,
|
||||
validator: validator,
|
||||
store: NewStore(),
|
||||
}
|
||||
|
||||
r.instanceURL, err = getInstanceURL(exposedAddress, tlsSupport)
|
||||
if err != nil {
|
||||
metricsCancel()
|
||||
return nil, fmt.Errorf("get instance URL: %v", err)
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// getInstanceURL checks if user supplied a URL scheme otherwise adds to the
|
||||
// provided address according to TLS definition and parses the address before returning it
|
||||
func getInstanceURL(exposedAddress string, tlsSupported bool) (string, error) {
|
||||
addr := exposedAddress
|
||||
split := strings.Split(exposedAddress, "://")
|
||||
switch {
|
||||
case len(split) == 1 && tlsSupported:
|
||||
addr = "rels://" + exposedAddress
|
||||
case len(split) == 1 && !tlsSupported:
|
||||
addr = "rel://" + exposedAddress
|
||||
case len(split) > 2:
|
||||
return "", fmt.Errorf("invalid exposed address: %s", exposedAddress)
|
||||
}
|
||||
|
||||
parsedURL, err := url.ParseRequestURI(addr)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid exposed address: %v", err)
|
||||
}
|
||||
|
||||
if parsedURL.Scheme != "rel" && parsedURL.Scheme != "rels" {
|
||||
return "", fmt.Errorf("invalid scheme: %s", parsedURL.Scheme)
|
||||
}
|
||||
|
||||
return parsedURL.String(), nil
|
||||
}
|
||||
|
||||
// Accept start to handle a new peer connection
|
||||
func (r *Relay) Accept(conn net.Conn) {
|
||||
r.closeMu.RLock()
|
||||
defer r.closeMu.RUnlock()
|
||||
if r.closed {
|
||||
return
|
||||
}
|
||||
|
||||
peerID, err := r.handshake(conn)
|
||||
if err != nil {
|
||||
log.Errorf("failed to handshake: %s", err)
|
||||
cErr := conn.Close()
|
||||
if cErr != nil {
|
||||
log.Errorf("failed to close connection, %s: %s", conn.RemoteAddr(), cErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
peer := NewPeer(r.metrics, peerID, conn, r.store)
|
||||
peer.log.Infof("peer connected from: %s", conn.RemoteAddr())
|
||||
r.store.AddPeer(peer)
|
||||
r.metrics.PeerConnected(peer.String())
|
||||
go func() {
|
||||
peer.Work()
|
||||
r.store.DeletePeer(peer)
|
||||
peer.log.Debugf("relay connection closed")
|
||||
r.metrics.PeerDisconnected(peer.String())
|
||||
}()
|
||||
}
|
||||
|
||||
// Shutdown closes the relay server
|
||||
// It closes the connection with all peers in gracefully and stops accepting new connections.
|
||||
func (r *Relay) Shutdown(ctx context.Context) {
|
||||
log.Infof("close connection with all peers")
|
||||
r.closeMu.Lock()
|
||||
wg := sync.WaitGroup{}
|
||||
peers := r.store.Peers()
|
||||
for _, peer := range peers {
|
||||
wg.Add(1)
|
||||
go func(p *Peer) {
|
||||
p.CloseGracefully(ctx)
|
||||
wg.Done()
|
||||
}(peer)
|
||||
}
|
||||
wg.Wait()
|
||||
r.metricsCancel()
|
||||
r.closeMu.Unlock()
|
||||
}
|
||||
|
||||
// InstanceURL returns the instance URL of the relay server
|
||||
func (r *Relay) InstanceURL() string {
|
||||
return r.instanceURL
|
||||
}
|
||||
|
||||
func (r *Relay) handshake(conn net.Conn) ([]byte, error) {
|
||||
buf := make([]byte, messages.MaxHandshakeSize)
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read from %s: %w", conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
_, err = messages.ValidateVersion(buf[:n])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("validate version from %s: %w", conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
msgType, err := messages.DetermineClientMessageType(buf[messages.SizeOfVersionByte:n])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("determine message type from %s: %w", conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
if msgType != messages.MsgTypeHello {
|
||||
return nil, fmt.Errorf("invalid message type from %s", conn.RemoteAddr())
|
||||
}
|
||||
|
||||
peerID, authData, err := messages.UnmarshalHelloMsg(buf[messages.SizeOfProtoHeader:n])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unmarshal hello message: %w", err)
|
||||
}
|
||||
|
||||
authMsg, err := authmsg.UnmarshalMsg(authData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unmarshal auth message: %w", err)
|
||||
}
|
||||
|
||||
if err := r.validator.Validate(sha256.New, authMsg.AdditionalData); err != nil {
|
||||
return nil, fmt.Errorf("validate %s (%s): %w", peerID, conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
addr := &address.Address{URL: r.instanceURL}
|
||||
addrData, err := addr.Marshal()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal addressc to %s (%s): %w", peerID, conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
msg, err := messages.MarshalHelloResponse(addrData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
_, err = conn.Write(msg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("write to %s (%s): %w", peerID, conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
return peerID, nil
|
||||
}
|
||||
36
relay/server/relay_test.go
Normal file
36
relay/server/relay_test.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package server
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestGetInstanceURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
exposedAddress string
|
||||
tlsSupported bool
|
||||
expectedURL string
|
||||
expectError bool
|
||||
}{
|
||||
{"Valid address with TLS", "example.com", true, "rels://example.com", false},
|
||||
{"Valid address without TLS", "example.com", false, "rel://example.com", false},
|
||||
{"Valid address with scheme", "rel://example.com", false, "rel://example.com", false},
|
||||
{"Valid address with non TLS scheme and TLS true", "rel://example.com", true, "rel://example.com", false},
|
||||
{"Valid address with TLS scheme", "rels://example.com", true, "rels://example.com", false},
|
||||
{"Valid address with TLS scheme and TLS false", "rels://example.com", false, "rels://example.com", false},
|
||||
{"Valid address with TLS scheme and custom port", "rels://example.com:9300", true, "rels://example.com:9300", false},
|
||||
{"Invalid address with multiple schemes", "rel://rels://example.com", false, "", true},
|
||||
{"Invalid address with unsupported scheme", "http://example.com", false, "", true},
|
||||
{"Invalid address format", "://example.com", false, "", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
url, err := getInstanceURL(tt.exposedAddress, tt.tlsSupported)
|
||||
if (err != nil) != tt.expectError {
|
||||
t.Errorf("expected error: %v, got: %v", tt.expectError, err)
|
||||
}
|
||||
if url != tt.expectedURL {
|
||||
t.Errorf("expected URL: %s, got: %s", tt.expectedURL, url)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
76
relay/server/server.go
Normal file
76
relay/server/server.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/auth"
|
||||
"github.com/netbirdio/netbird/relay/server/listener"
|
||||
"github.com/netbirdio/netbird/relay/server/listener/ws"
|
||||
)
|
||||
|
||||
// ListenerConfig is the configuration for the listener.
|
||||
// Address: the address to bind the listener to. It could be an address behind a reverse proxy.
|
||||
// TLSConfig: the TLS configuration for the listener.
|
||||
type ListenerConfig struct {
|
||||
Address string
|
||||
TLSConfig *tls.Config
|
||||
}
|
||||
|
||||
// Server is the main entry point for the relay server.
|
||||
// It is the gate between the WebSocket listener and the Relay server logic.
|
||||
// In a new HTTP connection, the server will accept the connection and pass it to the Relay server via the Accept method.
|
||||
type Server struct {
|
||||
relay *Relay
|
||||
wSListener listener.Listener
|
||||
}
|
||||
|
||||
// NewServer creates a new relay server instance.
|
||||
// meter: the OpenTelemetry meter
|
||||
// exposedAddress: this address will be used as the instance URL. It should be a domain:port format.
|
||||
// tlsSupport: if true, the server will support TLS
|
||||
// authValidator: the auth validator to use for the server
|
||||
func NewServer(meter metric.Meter, exposedAddress string, tlsSupport bool, authValidator auth.Validator) (*Server, error) {
|
||||
relay, err := NewRelay(meter, exposedAddress, tlsSupport, authValidator)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Server{
|
||||
relay: relay,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Listen starts the relay server.
|
||||
func (r *Server) Listen(cfg ListenerConfig) error {
|
||||
r.wSListener = &ws.Listener{
|
||||
Address: cfg.Address,
|
||||
TLSConfig: cfg.TLSConfig,
|
||||
}
|
||||
|
||||
wslErr := r.wSListener.Listen(r.relay.Accept)
|
||||
if wslErr != nil {
|
||||
log.Errorf("failed to bind ws server: %s", wslErr)
|
||||
}
|
||||
|
||||
return wslErr
|
||||
}
|
||||
|
||||
// Shutdown stops the relay server. If there are active connections, they will be closed gracefully. In case of a context,
|
||||
// the connections will be forcefully closed.
|
||||
func (r *Server) Shutdown(ctx context.Context) (err error) {
|
||||
// stop service new connections
|
||||
if r.wSListener != nil {
|
||||
err = r.wSListener.Shutdown(ctx)
|
||||
}
|
||||
|
||||
r.relay.Shutdown(ctx)
|
||||
return
|
||||
}
|
||||
|
||||
// InstanceURL returns the instance URL of the relay server.
|
||||
func (r *Server) InstanceURL() string {
|
||||
return r.relay.instanceURL
|
||||
}
|
||||
64
relay/server/store.go
Normal file
64
relay/server/store.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Store is a thread-safe store of peers
|
||||
// It is used to store the peers that are connected to the relay server
|
||||
type Store struct {
|
||||
peers map[string]*Peer // consider to use [32]byte as key. The Peer(id string) would be faster
|
||||
peersLock sync.RWMutex
|
||||
}
|
||||
|
||||
// NewStore creates a new Store instance
|
||||
func NewStore() *Store {
|
||||
return &Store{
|
||||
peers: make(map[string]*Peer),
|
||||
}
|
||||
}
|
||||
|
||||
// AddPeer adds a peer to the store
|
||||
// todo: consider to close peer conn if the peer already exists
|
||||
func (s *Store) AddPeer(peer *Peer) {
|
||||
s.peersLock.Lock()
|
||||
defer s.peersLock.Unlock()
|
||||
s.peers[peer.String()] = peer
|
||||
}
|
||||
|
||||
// DeletePeer deletes a peer from the store
|
||||
func (s *Store) DeletePeer(peer *Peer) {
|
||||
s.peersLock.Lock()
|
||||
defer s.peersLock.Unlock()
|
||||
|
||||
dp, ok := s.peers[peer.String()]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if dp != peer {
|
||||
return
|
||||
}
|
||||
|
||||
delete(s.peers, peer.String())
|
||||
}
|
||||
|
||||
// Peer returns a peer by its ID
|
||||
func (s *Store) Peer(id string) (*Peer, bool) {
|
||||
s.peersLock.RLock()
|
||||
defer s.peersLock.RUnlock()
|
||||
|
||||
p, ok := s.peers[id]
|
||||
return p, ok
|
||||
}
|
||||
|
||||
// Peers returns all the peers in the store
|
||||
func (s *Store) Peers() []*Peer {
|
||||
s.peersLock.RLock()
|
||||
defer s.peersLock.RUnlock()
|
||||
|
||||
peers := make([]*Peer, 0, len(s.peers))
|
||||
for _, p := range s.peers {
|
||||
peers = append(peers, p)
|
||||
}
|
||||
return peers
|
||||
}
|
||||
40
relay/server/store_test.go
Normal file
40
relay/server/store_test.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"go.opentelemetry.io/otel"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/metrics"
|
||||
)
|
||||
|
||||
func TestStore_DeletePeer(t *testing.T) {
|
||||
s := NewStore()
|
||||
|
||||
m, _ := metrics.NewMetrics(context.Background(), otel.Meter(""))
|
||||
|
||||
p := NewPeer(m, []byte("peer_one"), nil, nil)
|
||||
s.AddPeer(p)
|
||||
s.DeletePeer(p)
|
||||
if _, ok := s.Peer(p.String()); ok {
|
||||
t.Errorf("peer was not deleted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_DeleteDeprecatedPeer(t *testing.T) {
|
||||
s := NewStore()
|
||||
|
||||
m, _ := metrics.NewMetrics(context.Background(), otel.Meter(""))
|
||||
|
||||
p1 := NewPeer(m, []byte("peer_id"), nil, nil)
|
||||
p2 := NewPeer(m, []byte("peer_id"), nil, nil)
|
||||
|
||||
s.AddPeer(p1)
|
||||
s.AddPeer(p2)
|
||||
s.DeletePeer(p1)
|
||||
|
||||
if _, ok := s.Peer(p2.String()); !ok {
|
||||
t.Errorf("second peer was deleted")
|
||||
}
|
||||
}
|
||||
386
relay/test/benchmark_test.go
Normal file
386
relay/test/benchmark_test.go
Normal file
@@ -0,0 +1,386 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pion/logging"
|
||||
"github.com/pion/turn/v3"
|
||||
"go.opentelemetry.io/otel"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/auth/allow"
|
||||
"github.com/netbirdio/netbird/relay/auth/hmac"
|
||||
"github.com/netbirdio/netbird/relay/client"
|
||||
"github.com/netbirdio/netbird/relay/server"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var (
|
||||
av = &allow.Auth{}
|
||||
hmacTokenStore = &hmac.TokenStore{}
|
||||
pairs = []int{1, 5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100}
|
||||
dataSize = 1024 * 1024 * 10
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
_ = util.InitLog("error", "console")
|
||||
code := m.Run()
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func TestRelayDataTransfer(t *testing.T) {
|
||||
t.SkipNow() // skip this test on CI because it is a benchmark test
|
||||
testData, err := seedRandomData(dataSize)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to seed random data: %s", err)
|
||||
}
|
||||
|
||||
for _, peerPairs := range pairs {
|
||||
t.Run(fmt.Sprintf("peerPairs-%d", peerPairs), func(t *testing.T) {
|
||||
transfer(t, testData, peerPairs)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTurnDataTransfer run turn server:
|
||||
// docker run --rm --name coturn -d --network=host coturn/coturn --user test:test
|
||||
func TestTurnDataTransfer(t *testing.T) {
|
||||
t.SkipNow() // skip this test on CI because it is a benchmark test
|
||||
testData, err := seedRandomData(dataSize)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to seed random data: %s", err)
|
||||
}
|
||||
|
||||
for _, peerPairs := range pairs {
|
||||
t.Run(fmt.Sprintf("peerPairs-%d", peerPairs), func(t *testing.T) {
|
||||
runTurnTest(t, testData, peerPairs)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func transfer(t *testing.T, testData []byte, peerPairs int) {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
port := 35000 + peerPairs
|
||||
serverAddress := fmt.Sprintf("127.0.0.1:%d", port)
|
||||
serverConnURL := fmt.Sprintf("rel://%s", serverAddress)
|
||||
|
||||
srv, err := server.NewServer(otel.Meter(""), serverConnURL, false, av)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %s", err)
|
||||
}
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
listenCfg := server.ListenerConfig{Address: serverAddress}
|
||||
err := srv.Listen(listenCfg)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
err := srv.Shutdown(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// wait for server to start
|
||||
if err := waitForServerToStart(errChan); err != nil {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
clientsSender := make([]*client.Client, peerPairs)
|
||||
for i := 0; i < cap(clientsSender); i++ {
|
||||
c := client.NewClient(ctx, serverConnURL, hmacTokenStore, "sender-"+fmt.Sprint(i))
|
||||
err := c.Connect()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
clientsSender[i] = c
|
||||
}
|
||||
|
||||
clientsReceiver := make([]*client.Client, peerPairs)
|
||||
for i := 0; i < cap(clientsReceiver); i++ {
|
||||
c := client.NewClient(ctx, serverConnURL, hmacTokenStore, "receiver-"+fmt.Sprint(i))
|
||||
err := c.Connect()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
clientsReceiver[i] = c
|
||||
}
|
||||
|
||||
connsSender := make([]net.Conn, 0, peerPairs)
|
||||
connsReceiver := make([]net.Conn, 0, peerPairs)
|
||||
for i := 0; i < len(clientsSender); i++ {
|
||||
conn, err := clientsSender[i].OpenConn("receiver-" + fmt.Sprint(i))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
connsSender = append(connsSender, conn)
|
||||
|
||||
conn, err = clientsReceiver[i].OpenConn("sender-" + fmt.Sprint(i))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
connsReceiver = append(connsReceiver, conn)
|
||||
}
|
||||
|
||||
var transferDuration []time.Duration
|
||||
wg := sync.WaitGroup{}
|
||||
var writeErr error
|
||||
var readErr error
|
||||
for i := 0; i < len(connsSender); i++ {
|
||||
wg.Add(2)
|
||||
start := time.Now()
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
pieceSize := 1024
|
||||
testDataLen := len(testData)
|
||||
|
||||
for j := 0; j < testDataLen; j += pieceSize {
|
||||
end := j + pieceSize
|
||||
if end > testDataLen {
|
||||
end = testDataLen
|
||||
}
|
||||
_, writeErr = connsSender[i].Write(testData[j:end])
|
||||
if writeErr != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
}(i)
|
||||
|
||||
go func(i int, start time.Time) {
|
||||
defer wg.Done()
|
||||
buf := make([]byte, 8192)
|
||||
rcv := 0
|
||||
var n int
|
||||
for receivedSize := 0; receivedSize < len(testData); {
|
||||
|
||||
n, readErr = connsReceiver[i].Read(buf)
|
||||
if readErr != nil {
|
||||
return
|
||||
}
|
||||
|
||||
receivedSize += n
|
||||
rcv += n
|
||||
}
|
||||
transferDuration = append(transferDuration, time.Since(start))
|
||||
}(i, start)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if writeErr != nil {
|
||||
t.Fatalf("failed to write to channel: %s", err)
|
||||
}
|
||||
|
||||
if readErr != nil {
|
||||
t.Fatalf("failed to read from channel: %s", err)
|
||||
}
|
||||
|
||||
// calculate the megabytes per second from the average transferDuration against the dataSize
|
||||
var totalDuration time.Duration
|
||||
for _, d := range transferDuration {
|
||||
totalDuration += d
|
||||
}
|
||||
avgDuration := totalDuration / time.Duration(len(transferDuration))
|
||||
mbps := float64(len(testData)) / avgDuration.Seconds() / 1024 / 1024
|
||||
t.Logf("average transfer duration: %s", avgDuration)
|
||||
t.Logf("average transfer speed: %.2f MB/s", mbps)
|
||||
|
||||
for i := 0; i < len(connsSender); i++ {
|
||||
err := connsSender[i].Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close connection: %s", err)
|
||||
}
|
||||
|
||||
err = connsReceiver[i].Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close connection: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func runTurnTest(t *testing.T, testData []byte, maxPairs int) {
|
||||
t.Helper()
|
||||
var transferDuration []time.Duration
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i := 0; i < maxPairs; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
d := runTurnDataTransfer(t, testData)
|
||||
transferDuration = append(transferDuration, d)
|
||||
}()
|
||||
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
var totalDuration time.Duration
|
||||
for _, d := range transferDuration {
|
||||
totalDuration += d
|
||||
}
|
||||
avgDuration := totalDuration / time.Duration(len(transferDuration))
|
||||
mbps := float64(len(testData)) / avgDuration.Seconds() / 1024 / 1024
|
||||
t.Logf("average transfer duration: %s", avgDuration)
|
||||
t.Logf("average transfer speed: %.2f MB/s", mbps)
|
||||
}
|
||||
|
||||
func runTurnDataTransfer(t *testing.T, testData []byte) time.Duration {
|
||||
t.Helper()
|
||||
testDataLen := len(testData)
|
||||
relayAddress := "192.168.0.10:3478"
|
||||
conn, err := net.Dial("tcp", relayAddress)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func(conn net.Conn) {
|
||||
_ = conn.Close()
|
||||
}(conn)
|
||||
|
||||
turnClient, err := getTurnClient(t, relayAddress, conn)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer turnClient.Close()
|
||||
|
||||
relayConn, err := turnClient.Allocate()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func(relayConn net.PacketConn) {
|
||||
_ = relayConn.Close()
|
||||
}(relayConn)
|
||||
|
||||
receiverConn, err := net.Dial("udp", relayConn.LocalAddr().String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func(receiverConn net.Conn) {
|
||||
_ = receiverConn.Close()
|
||||
}(receiverConn)
|
||||
|
||||
var (
|
||||
tb int
|
||||
start time.Time
|
||||
timerInit bool
|
||||
readDone = make(chan struct{})
|
||||
ack = make([]byte, 1)
|
||||
)
|
||||
go func() {
|
||||
defer func() {
|
||||
readDone <- struct{}{}
|
||||
}()
|
||||
buff := make([]byte, 8192)
|
||||
for {
|
||||
n, e := receiverConn.Read(buff)
|
||||
if e != nil {
|
||||
return
|
||||
}
|
||||
if !timerInit {
|
||||
start = time.Now()
|
||||
timerInit = true
|
||||
}
|
||||
tb += n
|
||||
_, _ = receiverConn.Write(ack)
|
||||
|
||||
if tb >= testDataLen {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
pieceSize := 1024
|
||||
ackBuff := make([]byte, 1)
|
||||
pipelineSize := 10
|
||||
for j := 0; j < testDataLen; j += pieceSize {
|
||||
end := j + pieceSize
|
||||
if end > testDataLen {
|
||||
end = testDataLen
|
||||
}
|
||||
_, err := relayConn.WriteTo(testData[j:end], receiverConn.LocalAddr())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write to channel: %s", err)
|
||||
}
|
||||
if pipelineSize == 0 {
|
||||
_, _, _ = relayConn.ReadFrom(ackBuff)
|
||||
} else {
|
||||
pipelineSize--
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
|
||||
defer cancel()
|
||||
select {
|
||||
case <-readDone:
|
||||
if tb != testDataLen {
|
||||
t.Fatalf("failed to read all data: %d/%d", tb, testDataLen)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
return time.Since(start)
|
||||
}
|
||||
|
||||
func getTurnClient(t *testing.T, address string, conn net.Conn) (*turn.Client, error) {
|
||||
t.Helper()
|
||||
// Dial TURN Server
|
||||
addrStr := fmt.Sprintf("%s:%d", address, 443)
|
||||
|
||||
fac := logging.NewDefaultLoggerFactory()
|
||||
//fac.DefaultLogLevel = logging.LogLevelTrace
|
||||
|
||||
// Start a new TURN Client and wrap our net.Conn in a STUNConn
|
||||
// This allows us to simulate datagram based communication over a net.Conn
|
||||
cfg := &turn.ClientConfig{
|
||||
TURNServerAddr: address,
|
||||
Conn: turn.NewSTUNConn(conn),
|
||||
Username: "test",
|
||||
Password: "test",
|
||||
LoggerFactory: fac,
|
||||
}
|
||||
|
||||
client, err := turn.NewClient(cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create TURN client for server %s: %s", addrStr, err)
|
||||
}
|
||||
|
||||
// Start listening on the conn provided.
|
||||
err = client.Listen()
|
||||
if err != nil {
|
||||
client.Close()
|
||||
return nil, fmt.Errorf("failed to listen on TURN client for server %s: %s", addrStr, err)
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func seedRandomData(size int) ([]byte, error) {
|
||||
token := make([]byte, size)
|
||||
_, err := rand.Read(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func waitForServerToStart(errChan chan error) error {
|
||||
select {
|
||||
case err := <-errChan:
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case <-time.After(300 * time.Millisecond):
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
258
relay/testec2/main.go
Normal file
258
relay/testec2/main.go
Normal file
@@ -0,0 +1,258 @@
|
||||
//go:build linux || darwin
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"flag"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
const (
|
||||
errMsgFailedReadTCP = "failed to read from tcp: %s"
|
||||
)
|
||||
|
||||
var (
|
||||
dataSize = 1024 * 1024 * 50 // 50MB
|
||||
pairs = []int{1, 5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100}
|
||||
signalListenAddress = ":8081"
|
||||
|
||||
relaySrvAddress string
|
||||
turnSrvAddress string
|
||||
signalURL string
|
||||
udpListener string // used for TURN test
|
||||
)
|
||||
|
||||
type testResult struct {
|
||||
numOfPairs int
|
||||
duration time.Duration
|
||||
speed float64
|
||||
}
|
||||
|
||||
func (tr testResult) Speed() string {
|
||||
speed := tr.speed
|
||||
var unit string
|
||||
|
||||
switch {
|
||||
case speed < 1024:
|
||||
unit = "B/s"
|
||||
case speed < 1048576:
|
||||
speed /= 1024
|
||||
unit = "KB/s"
|
||||
case speed < 1073741824:
|
||||
speed /= 1048576
|
||||
unit = "MB/s"
|
||||
default:
|
||||
speed /= 1073741824
|
||||
unit = "GB/s"
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%.2f %s", speed, unit)
|
||||
}
|
||||
|
||||
func seedRandomData(size int) ([]byte, error) {
|
||||
token := make([]byte, size)
|
||||
_, err := rand.Read(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func avg(transferDuration []time.Duration) (time.Duration, float64) {
|
||||
var totalDuration time.Duration
|
||||
for _, d := range transferDuration {
|
||||
totalDuration += d
|
||||
}
|
||||
avgDuration := totalDuration / time.Duration(len(transferDuration))
|
||||
bps := float64(dataSize) / avgDuration.Seconds()
|
||||
return avgDuration, bps
|
||||
}
|
||||
|
||||
func RelayReceiverMain() []testResult {
|
||||
testResults := make([]testResult, 0, len(pairs))
|
||||
for _, p := range pairs {
|
||||
tr := testResult{numOfPairs: p}
|
||||
td := relayReceive(relaySrvAddress, p)
|
||||
tr.duration, tr.speed = avg(td)
|
||||
|
||||
testResults = append(testResults, tr)
|
||||
}
|
||||
|
||||
return testResults
|
||||
}
|
||||
|
||||
func RelaySenderMain() {
|
||||
log.Infof("starting sender")
|
||||
log.Infof("starting seed phase")
|
||||
|
||||
testData, err := seedRandomData(dataSize)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to seed random data: %s", err)
|
||||
}
|
||||
|
||||
log.Infof("data size: %d", len(testData))
|
||||
|
||||
for n, p := range pairs {
|
||||
log.Infof("running test with %d pairs", p)
|
||||
relayTransfer(relaySrvAddress, testData, p)
|
||||
|
||||
// grant time to prepare new receivers
|
||||
if n < len(pairs)-1 {
|
||||
time.Sleep(3 * time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// TRUNSenderMain is the sender
|
||||
// - allocate turn clients
|
||||
// - send relayed addresses to signal server in batch
|
||||
// - wait for signal server to send back addresses in a map
|
||||
// - send test data to each address in parallel
|
||||
func TRUNSenderMain() {
|
||||
log.Infof("starting TURN sender test")
|
||||
|
||||
log.Infof("starting seed random data: %d", dataSize)
|
||||
testData, err := seedRandomData(dataSize)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to seed random data: %s", err)
|
||||
}
|
||||
|
||||
ss := SignalClient{signalURL}
|
||||
|
||||
for _, p := range pairs {
|
||||
log.Infof("running test with %d pairs", p)
|
||||
turnSender := &TurnSender{}
|
||||
|
||||
createTurnConns(p, turnSender)
|
||||
|
||||
log.Infof("send addresses via signal server: %d", len(turnSender.addresses))
|
||||
clientAddresses, err := ss.SendAddress(turnSender.addresses)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to send address: %s", err)
|
||||
}
|
||||
log.Infof("received addresses: %v", clientAddresses.Address)
|
||||
|
||||
createSenderDevices(turnSender, clientAddresses)
|
||||
|
||||
log.Infof("waiting for tcpListeners to be ready")
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
tcpConns := make([]net.Conn, 0, len(turnSender.devices))
|
||||
for i := range turnSender.devices {
|
||||
addr := fmt.Sprintf("10.0.%d.2:9999", i)
|
||||
log.Infof("dialing: %s", addr)
|
||||
tcpConn, err := net.Dial("tcp", addr)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to dial tcp: %s", err)
|
||||
}
|
||||
tcpConns = append(tcpConns, tcpConn)
|
||||
}
|
||||
|
||||
log.Infof("start test data transfer for %d pairs", p)
|
||||
testDataLen := len(testData)
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(len(tcpConns))
|
||||
for i, tcpConn := range tcpConns {
|
||||
log.Infof("sending test data to device: %d", i)
|
||||
go runTurnWriting(tcpConn, testData, testDataLen, &wg)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
for _, d := range turnSender.devices {
|
||||
_ = d.Close()
|
||||
}
|
||||
|
||||
log.Infof("test finished with %d pairs", p)
|
||||
}
|
||||
}
|
||||
|
||||
func TURNReaderMain() []testResult {
|
||||
log.Infof("starting TURN receiver test")
|
||||
si := NewSignalService()
|
||||
go func() {
|
||||
log.Infof("starting signal server")
|
||||
err := si.Listen(signalListenAddress)
|
||||
if err != nil {
|
||||
log.Errorf("failed to listen: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
testResults := make([]testResult, 0, len(pairs))
|
||||
for range pairs {
|
||||
addresses := <-si.AddressesChan
|
||||
instanceNumber := len(addresses)
|
||||
log.Infof("received addresses: %d", instanceNumber)
|
||||
|
||||
turnReceiver := &TurnReceiver{}
|
||||
err := createDevices(addresses, turnReceiver)
|
||||
if err != nil {
|
||||
log.Fatalf("%s", err)
|
||||
}
|
||||
|
||||
// send client addresses back via signal server
|
||||
si.ClientAddressChan <- turnReceiver.clientAddresses
|
||||
|
||||
durations := make(chan time.Duration, instanceNumber)
|
||||
for _, device := range turnReceiver.devices {
|
||||
go runTurnReading(device, durations)
|
||||
}
|
||||
|
||||
durationsList := make([]time.Duration, 0, instanceNumber)
|
||||
for d := range durations {
|
||||
durationsList = append(durationsList, d)
|
||||
if len(durationsList) == instanceNumber {
|
||||
close(durations)
|
||||
}
|
||||
}
|
||||
|
||||
avgDuration, avgSpeed := avg(durationsList)
|
||||
ts := testResult{
|
||||
numOfPairs: len(durationsList),
|
||||
duration: avgDuration,
|
||||
speed: avgSpeed,
|
||||
}
|
||||
testResults = append(testResults, ts)
|
||||
|
||||
for _, d := range turnReceiver.devices {
|
||||
_ = d.Close()
|
||||
}
|
||||
}
|
||||
return testResults
|
||||
}
|
||||
|
||||
func main() {
|
||||
var mode string
|
||||
|
||||
_ = util.InitLog("debug", "console")
|
||||
flag.StringVar(&mode, "mode", "sender", "sender or receiver mode")
|
||||
flag.Parse()
|
||||
|
||||
relaySrvAddress = os.Getenv("TEST_RELAY_SERVER") // rel://ip:port
|
||||
turnSrvAddress = os.Getenv("TEST_TURN_SERVER") // ip:3478
|
||||
signalURL = os.Getenv("TEST_SIGNAL_URL") // http://receiver_ip:8081
|
||||
udpListener = os.Getenv("TEST_UDP_LISTENER") // IP:0
|
||||
|
||||
if mode == "receiver" {
|
||||
relayResult := RelayReceiverMain()
|
||||
turnResults := TURNReaderMain()
|
||||
for i := 0; i < len(turnResults); i++ {
|
||||
log.Infof("pairs: %d,\tRelay speed:\t%s,\trelay duration:\t%s", relayResult[i].numOfPairs, relayResult[i].Speed(), relayResult[i].duration)
|
||||
log.Infof("pairs: %d,\tTURN speed:\t%s,\tturn duration:\t%s", turnResults[i].numOfPairs, turnResults[i].Speed(), turnResults[i].duration)
|
||||
}
|
||||
} else {
|
||||
RelaySenderMain()
|
||||
// grant time for receiver to start
|
||||
time.Sleep(3 * time.Second)
|
||||
TRUNSenderMain()
|
||||
}
|
||||
}
|
||||
176
relay/testec2/relay.go
Normal file
176
relay/testec2/relay.go
Normal file
@@ -0,0 +1,176 @@
|
||||
//go:build linux || darwin
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/auth/hmac"
|
||||
"github.com/netbirdio/netbird/relay/client"
|
||||
)
|
||||
|
||||
var (
|
||||
hmacTokenStore = &hmac.TokenStore{}
|
||||
)
|
||||
|
||||
func relayTransfer(serverConnURL string, testData []byte, peerPairs int) {
|
||||
connsSender := prepareConnsSender(serverConnURL, peerPairs)
|
||||
defer func() {
|
||||
for i := 0; i < len(connsSender); i++ {
|
||||
err := connsSender[i].Close()
|
||||
if err != nil {
|
||||
log.Errorf("failed to close connection: %s", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(len(connsSender))
|
||||
for _, conn := range connsSender {
|
||||
go func(conn net.Conn) {
|
||||
defer wg.Done()
|
||||
runWriter(conn, testData)
|
||||
}(conn)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func runWriter(conn net.Conn, testData []byte) {
|
||||
si := NewStartInidication(time.Now(), len(testData))
|
||||
_, err := conn.Write(si)
|
||||
if err != nil {
|
||||
log.Errorf("failed to write to channel: %s", err)
|
||||
return
|
||||
}
|
||||
log.Infof("sent start indication")
|
||||
|
||||
pieceSize := 1024
|
||||
testDataLen := len(testData)
|
||||
|
||||
for j := 0; j < testDataLen; j += pieceSize {
|
||||
end := j + pieceSize
|
||||
if end > testDataLen {
|
||||
end = testDataLen
|
||||
}
|
||||
_, writeErr := conn.Write(testData[j:end])
|
||||
if writeErr != nil {
|
||||
log.Errorf("failed to write to channel: %s", writeErr)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func prepareConnsSender(serverConnURL string, peerPairs int) []net.Conn {
|
||||
ctx := context.Background()
|
||||
clientsSender := make([]*client.Client, peerPairs)
|
||||
for i := 0; i < cap(clientsSender); i++ {
|
||||
c := client.NewClient(ctx, serverConnURL, hmacTokenStore, "sender-"+fmt.Sprint(i))
|
||||
if err := c.Connect(); err != nil {
|
||||
log.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
clientsSender[i] = c
|
||||
}
|
||||
|
||||
connsSender := make([]net.Conn, 0, peerPairs)
|
||||
for i := 0; i < len(clientsSender); i++ {
|
||||
conn, err := clientsSender[i].OpenConn("receiver-" + fmt.Sprint(i))
|
||||
if err != nil {
|
||||
log.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
connsSender = append(connsSender, conn)
|
||||
}
|
||||
return connsSender
|
||||
}
|
||||
|
||||
func relayReceive(serverConnURL string, peerPairs int) []time.Duration {
|
||||
connsReceiver := prepareConnsReceiver(serverConnURL, peerPairs)
|
||||
defer func() {
|
||||
for i := 0; i < len(connsReceiver); i++ {
|
||||
if err := connsReceiver[i].Close(); err != nil {
|
||||
log.Errorf("failed to close connection: %s", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
durations := make(chan time.Duration, len(connsReceiver))
|
||||
wg := sync.WaitGroup{}
|
||||
for _, conn := range connsReceiver {
|
||||
wg.Add(1)
|
||||
go func(conn net.Conn) {
|
||||
defer wg.Done()
|
||||
duration := runReader(conn)
|
||||
durations <- duration
|
||||
}(conn)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
durationsList := make([]time.Duration, 0, len(connsReceiver))
|
||||
for d := range durations {
|
||||
durationsList = append(durationsList, d)
|
||||
if len(durationsList) == len(connsReceiver) {
|
||||
close(durations)
|
||||
}
|
||||
}
|
||||
|
||||
return durationsList
|
||||
}
|
||||
|
||||
func runReader(conn net.Conn) time.Duration {
|
||||
buf := make([]byte, 8192)
|
||||
|
||||
n, readErr := conn.Read(buf)
|
||||
if readErr != nil {
|
||||
log.Errorf("failed to read from channel: %s", readErr)
|
||||
return 0
|
||||
}
|
||||
|
||||
si := DecodeStartIndication(buf[:n])
|
||||
log.Infof("received start indication: %v", si)
|
||||
|
||||
receivedSize, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to read from relay: %s", err)
|
||||
}
|
||||
now := time.Now()
|
||||
|
||||
rcv := 0
|
||||
for receivedSize < si.TransferSize {
|
||||
n, readErr = conn.Read(buf)
|
||||
if readErr != nil {
|
||||
log.Errorf("failed to read from channel: %s", readErr)
|
||||
return 0
|
||||
}
|
||||
|
||||
receivedSize += n
|
||||
rcv += n
|
||||
}
|
||||
return time.Since(now)
|
||||
}
|
||||
|
||||
func prepareConnsReceiver(serverConnURL string, peerPairs int) []net.Conn {
|
||||
clientsReceiver := make([]*client.Client, peerPairs)
|
||||
for i := 0; i < cap(clientsReceiver); i++ {
|
||||
c := client.NewClient(context.Background(), serverConnURL, hmacTokenStore, "receiver-"+fmt.Sprint(i))
|
||||
err := c.Connect()
|
||||
if err != nil {
|
||||
log.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
clientsReceiver[i] = c
|
||||
}
|
||||
|
||||
connsReceiver := make([]net.Conn, 0, peerPairs)
|
||||
for i := 0; i < len(clientsReceiver); i++ {
|
||||
conn, err := clientsReceiver[i].OpenConn("sender-" + fmt.Sprint(i))
|
||||
if err != nil {
|
||||
log.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
connsReceiver = append(connsReceiver, conn)
|
||||
}
|
||||
return connsReceiver
|
||||
}
|
||||
91
relay/testec2/signal.go
Normal file
91
relay/testec2/signal.go
Normal file
@@ -0,0 +1,91 @@
|
||||
//go:build linux || darwin
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type PeerAddr struct {
|
||||
Address []string
|
||||
}
|
||||
|
||||
type ClientPeerAddr struct {
|
||||
Address map[string]string
|
||||
}
|
||||
|
||||
type Signal struct {
|
||||
AddressesChan chan []string
|
||||
ClientAddressChan chan map[string]string
|
||||
}
|
||||
|
||||
func NewSignalService() *Signal {
|
||||
return &Signal{
|
||||
AddressesChan: make(chan []string),
|
||||
ClientAddressChan: make(chan map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
func (rs *Signal) Listen(listenAddr string) error {
|
||||
http.HandleFunc("/", rs.onNewAddresses)
|
||||
return http.ListenAndServe(listenAddr, nil)
|
||||
}
|
||||
|
||||
func (rs *Signal) onNewAddresses(w http.ResponseWriter, r *http.Request) {
|
||||
var msg PeerAddr
|
||||
err := json.NewDecoder(r.Body).Decode(&msg)
|
||||
if err != nil {
|
||||
log.Errorf("Error decoding message: %v", err)
|
||||
}
|
||||
|
||||
log.Infof("received addresses: %d", len(msg.Address))
|
||||
rs.AddressesChan <- msg.Address
|
||||
clientAddresses := <-rs.ClientAddressChan
|
||||
|
||||
respMsg := ClientPeerAddr{
|
||||
Address: clientAddresses,
|
||||
}
|
||||
data, err := json.Marshal(respMsg)
|
||||
if err != nil {
|
||||
log.Errorf("Error marshalling message: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
_, err = w.Write(data)
|
||||
if err != nil {
|
||||
log.Errorf("Error writing response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
type SignalClient struct {
|
||||
SignalURL string
|
||||
}
|
||||
|
||||
func (ss SignalClient) SendAddress(addresses []string) (*ClientPeerAddr, error) {
|
||||
msg := PeerAddr{
|
||||
Address: addresses,
|
||||
}
|
||||
data, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
response, err := http.Post(ss.SignalURL, "application/json", bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer response.Body.Close()
|
||||
|
||||
log.Debugf("wait for signal response")
|
||||
var respPeerAddress ClientPeerAddr
|
||||
err = json.NewDecoder(response.Body).Decode(&respPeerAddress)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &respPeerAddress, nil
|
||||
}
|
||||
39
relay/testec2/start_msg.go
Normal file
39
relay/testec2/start_msg.go
Normal file
@@ -0,0 +1,39 @@
|
||||
//go:build linux || darwin
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/gob"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type StartIndication struct {
|
||||
Started time.Time
|
||||
TransferSize int
|
||||
}
|
||||
|
||||
func NewStartInidication(started time.Time, transferSize int) []byte {
|
||||
si := StartIndication{
|
||||
Started: started,
|
||||
TransferSize: transferSize,
|
||||
}
|
||||
|
||||
var data bytes.Buffer
|
||||
err := gob.NewEncoder(&data).Encode(si)
|
||||
if err != nil {
|
||||
log.Fatal("encode error:", err)
|
||||
}
|
||||
return data.Bytes()
|
||||
}
|
||||
|
||||
func DecodeStartIndication(data []byte) StartIndication {
|
||||
var si StartIndication
|
||||
err := gob.NewDecoder(bytes.NewReader(data)).Decode(&si)
|
||||
if err != nil {
|
||||
log.Fatal("decode error:", err)
|
||||
}
|
||||
return si
|
||||
}
|
||||
72
relay/testec2/tun/proxy.go
Normal file
72
relay/testec2/tun/proxy.go
Normal file
@@ -0,0 +1,72 @@
|
||||
//go:build linux || darwin
|
||||
|
||||
package tun
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync/atomic"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type Proxy struct {
|
||||
Device *Device
|
||||
PConn net.PacketConn
|
||||
DstAddr net.Addr
|
||||
shutdownFlag atomic.Bool
|
||||
}
|
||||
|
||||
func (p *Proxy) Start() {
|
||||
go p.readFromDevice()
|
||||
go p.readFromConn()
|
||||
}
|
||||
|
||||
func (p *Proxy) Close() {
|
||||
p.shutdownFlag.Store(true)
|
||||
}
|
||||
|
||||
func (p *Proxy) readFromDevice() {
|
||||
buf := make([]byte, 1500)
|
||||
for {
|
||||
n, err := p.Device.Read(buf)
|
||||
if err != nil {
|
||||
if p.shutdownFlag.Load() {
|
||||
return
|
||||
}
|
||||
log.Errorf("failed to read from device: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
_, err = p.PConn.WriteTo(buf[:n], p.DstAddr)
|
||||
if err != nil {
|
||||
if p.shutdownFlag.Load() {
|
||||
return
|
||||
}
|
||||
log.Errorf("failed to write to conn: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Proxy) readFromConn() {
|
||||
buf := make([]byte, 1500)
|
||||
for {
|
||||
n, _, err := p.PConn.ReadFrom(buf)
|
||||
if err != nil {
|
||||
if p.shutdownFlag.Load() {
|
||||
return
|
||||
}
|
||||
log.Errorf("failed to read from conn: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
_, err = p.Device.Write(buf[:n])
|
||||
if err != nil {
|
||||
if p.shutdownFlag.Load() {
|
||||
return
|
||||
}
|
||||
log.Errorf("failed to write to device: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
110
relay/testec2/tun/tun.go
Normal file
110
relay/testec2/tun/tun.go
Normal file
@@ -0,0 +1,110 @@
|
||||
//go:build linux || darwin
|
||||
|
||||
package tun
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/songgao/water"
|
||||
"github.com/vishvananda/netlink"
|
||||
)
|
||||
|
||||
type Device struct {
|
||||
Name string
|
||||
IP string
|
||||
PConn net.PacketConn
|
||||
DstAddr net.Addr
|
||||
|
||||
iFace *water.Interface
|
||||
proxy *Proxy
|
||||
}
|
||||
|
||||
func (d *Device) Up() error {
|
||||
cfg := water.Config{
|
||||
DeviceType: water.TUN,
|
||||
PlatformSpecificParams: water.PlatformSpecificParams{
|
||||
Name: d.Name,
|
||||
},
|
||||
}
|
||||
iFace, err := water.New(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.iFace = iFace
|
||||
|
||||
err = d.assignIP()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = d.bringUp()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
d.proxy = &Proxy{
|
||||
Device: d,
|
||||
PConn: d.PConn,
|
||||
DstAddr: d.DstAddr,
|
||||
}
|
||||
d.proxy.Start()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Device) Close() error {
|
||||
if d.proxy != nil {
|
||||
d.proxy.Close()
|
||||
}
|
||||
if d.iFace != nil {
|
||||
return d.iFace.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Device) Read(b []byte) (int, error) {
|
||||
return d.iFace.Read(b)
|
||||
}
|
||||
|
||||
func (d *Device) Write(b []byte) (int, error) {
|
||||
return d.iFace.Write(b)
|
||||
}
|
||||
|
||||
func (d *Device) assignIP() error {
|
||||
iface, err := netlink.LinkByName(d.Name)
|
||||
if err != nil {
|
||||
log.Errorf("failed to get TUN device: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
ip := net.IPNet{
|
||||
IP: net.ParseIP(d.IP),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
}
|
||||
|
||||
addr := &netlink.Addr{
|
||||
IPNet: &ip,
|
||||
}
|
||||
err = netlink.AddrAdd(iface, addr)
|
||||
if err != nil {
|
||||
log.Errorf("failed to add IP address: %v", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Device) bringUp() error {
|
||||
iface, err := netlink.LinkByName(d.Name)
|
||||
if err != nil {
|
||||
log.Errorf("failed to get device: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Bring the interface up
|
||||
err = netlink.LinkSetUp(iface)
|
||||
if err != nil {
|
||||
log.Errorf("failed to set device up: %v", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
181
relay/testec2/turn.go
Normal file
181
relay/testec2/turn.go
Normal file
@@ -0,0 +1,181 @@
|
||||
//go:build linux || darwin
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/testec2/tun"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type TurnReceiver struct {
|
||||
conns []*net.UDPConn
|
||||
clientAddresses map[string]string
|
||||
devices []*tun.Device
|
||||
}
|
||||
|
||||
type TurnSender struct {
|
||||
turnConns map[string]*TurnConn
|
||||
addresses []string
|
||||
devices []*tun.Device
|
||||
}
|
||||
|
||||
func runTurnWriting(tcpConn net.Conn, testData []byte, testDataLen int, wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
defer tcpConn.Close()
|
||||
|
||||
log.Infof("start to sending test data: %s", tcpConn.RemoteAddr())
|
||||
|
||||
si := NewStartInidication(time.Now(), testDataLen)
|
||||
_, err := tcpConn.Write(si)
|
||||
if err != nil {
|
||||
log.Errorf("failed to write to tcp: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
pieceSize := 1024
|
||||
for j := 0; j < testDataLen; j += pieceSize {
|
||||
end := j + pieceSize
|
||||
if end > testDataLen {
|
||||
end = testDataLen
|
||||
}
|
||||
_, writeErr := tcpConn.Write(testData[j:end])
|
||||
if writeErr != nil {
|
||||
log.Errorf("failed to write to tcp conn: %s", writeErr)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// grant time to flush out packages
|
||||
time.Sleep(3 * time.Second)
|
||||
}
|
||||
|
||||
func createSenderDevices(sender *TurnSender, clientAddresses *ClientPeerAddr) {
|
||||
var i int
|
||||
devices := make([]*tun.Device, 0, len(clientAddresses.Address))
|
||||
for k, v := range clientAddresses.Address {
|
||||
tc, ok := sender.turnConns[k]
|
||||
if !ok {
|
||||
log.Fatalf("failed to find turn conn: %s", k)
|
||||
}
|
||||
|
||||
addr, err := net.ResolveUDPAddr("udp", v)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to resolve udp address: %s", err)
|
||||
}
|
||||
device := &tun.Device{
|
||||
Name: fmt.Sprintf("mtun-sender-%d", i),
|
||||
IP: fmt.Sprintf("10.0.%d.1", i),
|
||||
PConn: tc.relayConn,
|
||||
DstAddr: addr,
|
||||
}
|
||||
|
||||
err = device.Up()
|
||||
if err != nil {
|
||||
log.Fatalf("failed to bring up device: %s", err)
|
||||
}
|
||||
|
||||
devices = append(devices, device)
|
||||
i++
|
||||
}
|
||||
sender.devices = devices
|
||||
}
|
||||
|
||||
func createTurnConns(p int, sender *TurnSender) {
|
||||
turnConns := make(map[string]*TurnConn)
|
||||
addresses := make([]string, 0, len(pairs))
|
||||
for i := 0; i < p; i++ {
|
||||
tc := AllocateTurnClient(turnSrvAddress)
|
||||
log.Infof("allocated turn client: %s", tc.Address().String())
|
||||
turnConns[tc.Address().String()] = tc
|
||||
addresses = append(addresses, tc.Address().String())
|
||||
}
|
||||
|
||||
sender.turnConns = turnConns
|
||||
sender.addresses = addresses
|
||||
}
|
||||
|
||||
func runTurnReading(d *tun.Device, durations chan time.Duration) {
|
||||
tcpListener, err := net.Listen("tcp", d.IP+":9999")
|
||||
if err != nil {
|
||||
log.Fatalf("failed to listen on tcp: %s", err)
|
||||
}
|
||||
log := log.WithField("device", tcpListener.Addr())
|
||||
|
||||
tcpConn, err := tcpListener.Accept()
|
||||
if err != nil {
|
||||
_ = tcpListener.Close()
|
||||
log.Fatalf("failed to accept connection: %s", err)
|
||||
}
|
||||
log.Infof("remote peer connected")
|
||||
|
||||
buf := make([]byte, 103)
|
||||
n, err := tcpConn.Read(buf)
|
||||
if err != nil {
|
||||
_ = tcpListener.Close()
|
||||
log.Fatalf(errMsgFailedReadTCP, err)
|
||||
}
|
||||
|
||||
si := DecodeStartIndication(buf[:n])
|
||||
log.Infof("received start indication: %v, %d", si, n)
|
||||
|
||||
buf = make([]byte, 8192)
|
||||
i, err := tcpConn.Read(buf)
|
||||
if err != nil {
|
||||
_ = tcpListener.Close()
|
||||
log.Fatalf(errMsgFailedReadTCP, err)
|
||||
}
|
||||
now := time.Now()
|
||||
for i < si.TransferSize {
|
||||
n, err := tcpConn.Read(buf)
|
||||
if err != nil {
|
||||
_ = tcpListener.Close()
|
||||
log.Fatalf(errMsgFailedReadTCP, err)
|
||||
}
|
||||
i += n
|
||||
}
|
||||
durations <- time.Since(now)
|
||||
}
|
||||
|
||||
func createDevices(addresses []string, receiver *TurnReceiver) error {
|
||||
receiver.conns = make([]*net.UDPConn, 0, len(addresses))
|
||||
receiver.clientAddresses = make(map[string]string, len(addresses))
|
||||
receiver.devices = make([]*tun.Device, 0, len(addresses))
|
||||
for i, addr := range addresses {
|
||||
localAddr, err := net.ResolveUDPAddr("udp", udpListener)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to resolve UDP address: %s", err)
|
||||
}
|
||||
|
||||
conn, err := net.ListenUDP("udp", localAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create UDP connection: %s", err)
|
||||
}
|
||||
|
||||
receiver.conns = append(receiver.conns, conn)
|
||||
receiver.clientAddresses[addr] = conn.LocalAddr().String()
|
||||
|
||||
dstAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to resolve address: %s", err)
|
||||
}
|
||||
|
||||
device := &tun.Device{
|
||||
Name: fmt.Sprintf("mtun-%d", i),
|
||||
IP: fmt.Sprintf("10.0.%d.2", i),
|
||||
PConn: conn,
|
||||
DstAddr: dstAddr,
|
||||
}
|
||||
|
||||
if err = device.Up(); err != nil {
|
||||
return fmt.Errorf("failed to bring up device: %s, %s", device.Name, err)
|
||||
}
|
||||
receiver.devices = append(receiver.devices, device)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
83
relay/testec2/turn_allocator.go
Normal file
83
relay/testec2/turn_allocator.go
Normal file
@@ -0,0 +1,83 @@
|
||||
//go:build linux || darwin
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/pion/logging"
|
||||
"github.com/pion/turn/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type TurnConn struct {
|
||||
conn net.Conn
|
||||
turnClient *turn.Client
|
||||
relayConn net.PacketConn
|
||||
}
|
||||
|
||||
func (tc *TurnConn) Address() net.Addr {
|
||||
return tc.relayConn.LocalAddr()
|
||||
}
|
||||
|
||||
func (tc *TurnConn) Close() {
|
||||
_ = tc.relayConn.Close()
|
||||
tc.turnClient.Close()
|
||||
_ = tc.conn.Close()
|
||||
}
|
||||
|
||||
func AllocateTurnClient(serverAddr string) *TurnConn {
|
||||
conn, err := net.Dial("tcp", serverAddr)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
turnClient, err := getTurnClient(serverAddr, conn)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
relayConn, err := turnClient.Allocate()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return &TurnConn{
|
||||
conn: conn,
|
||||
turnClient: turnClient,
|
||||
relayConn: relayConn,
|
||||
}
|
||||
}
|
||||
|
||||
func getTurnClient(address string, conn net.Conn) (*turn.Client, error) {
|
||||
// Dial TURN Server
|
||||
addrStr := fmt.Sprintf("%s:%d", address, 443)
|
||||
|
||||
fac := logging.NewDefaultLoggerFactory()
|
||||
//fac.DefaultLogLevel = logging.LogLevelTrace
|
||||
|
||||
// Start a new TURN Client and wrap our net.Conn in a STUNConn
|
||||
// This allows us to simulate datagram based communication over a net.Conn
|
||||
cfg := &turn.ClientConfig{
|
||||
TURNServerAddr: address,
|
||||
Conn: turn.NewSTUNConn(conn),
|
||||
Username: "test",
|
||||
Password: "test",
|
||||
LoggerFactory: fac,
|
||||
}
|
||||
|
||||
client, err := turn.NewClient(cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create TURN client for server %s: %s", addrStr, err)
|
||||
}
|
||||
|
||||
// Start listening on the conn provided.
|
||||
err = client.Listen()
|
||||
if err != nil {
|
||||
client.Close()
|
||||
return nil, fmt.Errorf("failed to listen on TURN client for server %s: %s", addrStr, err)
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
Reference in New Issue
Block a user