[relay] Feature/relay integration (#2244)

This update adds new relay integration for NetBird clients. The new relay is based on web sockets and listens on a single port.

- Adds new relay implementation with websocket with single port relaying mechanism
- refactor peer connection logic, allowing upgrade and downgrade from/to P2P connection
- peer connections are faster since it connects first to relay and then upgrades to P2P
- maintains compatibility with old clients by not using the new relay
- updates infrastructure scripts with new relay service
This commit is contained in:
Zoltan Papp
2024-09-08 12:06:14 +02:00
committed by GitHub
parent fcac02a92f
commit 0c039274a4
120 changed files with 9879 additions and 1940 deletions

13
relay/client/addr.go Normal file
View File

@@ -0,0 +1,13 @@
package client
type RelayAddr struct {
addr string
}
func (a RelayAddr) Network() string {
return "relay"
}
func (a RelayAddr) String() string {
return a.addr
}

553
relay/client/client.go Normal file
View File

@@ -0,0 +1,553 @@
package client
import (
"context"
"fmt"
"io"
"net"
"sync"
"time"
log "github.com/sirupsen/logrus"
auth "github.com/netbirdio/netbird/relay/auth/hmac"
"github.com/netbirdio/netbird/relay/client/dialer/ws"
"github.com/netbirdio/netbird/relay/healthcheck"
"github.com/netbirdio/netbird/relay/messages"
"github.com/netbirdio/netbird/relay/messages/address"
auth2 "github.com/netbirdio/netbird/relay/messages/auth"
)
const (
bufferSize = 8820
serverResponseTimeout = 8 * time.Second
)
var (
ErrConnAlreadyExists = fmt.Errorf("connection already exists")
)
type internalStopFlag struct {
sync.Mutex
stop bool
}
func newInternalStopFlag() *internalStopFlag {
return &internalStopFlag{}
}
func (isf *internalStopFlag) set() {
isf.Lock()
defer isf.Unlock()
isf.stop = true
}
func (isf *internalStopFlag) isSet() bool {
isf.Lock()
defer isf.Unlock()
return isf.stop
}
// Msg carry the payload from the server to the client. With this struct, the net.Conn can free the buffer.
type Msg struct {
Payload []byte
bufPool *sync.Pool
bufPtr *[]byte
}
func (m *Msg) Free() {
m.bufPool.Put(m.bufPtr)
}
type connContainer struct {
conn *Conn
messages chan Msg
msgChanLock sync.Mutex
closed bool // flag to check if channel is closed
}
func newConnContainer(conn *Conn, messages chan Msg) *connContainer {
return &connContainer{
conn: conn,
messages: messages,
}
}
func (cc *connContainer) writeMsg(msg Msg) {
cc.msgChanLock.Lock()
defer cc.msgChanLock.Unlock()
if cc.closed {
return
}
cc.messages <- msg
}
func (cc *connContainer) close() {
cc.msgChanLock.Lock()
defer cc.msgChanLock.Unlock()
if cc.closed {
return
}
close(cc.messages)
cc.closed = true
}
// Client is a client for the relay server. It is responsible for establishing a connection to the relay server and
// managing connections to other peers. All exported functions are safe to call concurrently. After close the connection,
// the client can be reused by calling Connect again. When the client is closed, all connections are closed too.
// While the Connect is in progress, the OpenConn function will block until the connection is established with relay server.
type Client struct {
log *log.Entry
parentCtx context.Context
connectionURL string
authTokenStore *auth.TokenStore
hashedID []byte
bufPool *sync.Pool
relayConn net.Conn
conns map[string]*connContainer
serviceIsRunning bool
mu sync.Mutex // protect serviceIsRunning and conns
readLoopMutex sync.Mutex
wgReadLoop sync.WaitGroup
instanceURL *RelayAddr
muInstanceURL sync.Mutex
onDisconnectListener func()
listenerMutex sync.Mutex
}
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client {
hashedID, hashedStringId := messages.HashID(peerID)
return &Client{
log: log.WithField("client_id", hashedStringId),
parentCtx: ctx,
connectionURL: serverURL,
authTokenStore: authTokenStore,
hashedID: hashedID,
bufPool: &sync.Pool{
New: func() any {
buf := make([]byte, bufferSize)
return &buf
},
},
conns: make(map[string]*connContainer),
}
}
// Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs.
func (c *Client) Connect() error {
c.log.Infof("connecting to relay server: %s", c.connectionURL)
c.readLoopMutex.Lock()
defer c.readLoopMutex.Unlock()
c.mu.Lock()
defer c.mu.Unlock()
if c.serviceIsRunning {
return nil
}
err := c.connect()
if err != nil {
return err
}
c.serviceIsRunning = true
c.wgReadLoop.Add(1)
go c.readLoop(c.relayConn)
c.log.Infof("relay connection established with: %s", c.connectionURL)
return nil
}
// OpenConn create a new net.Conn for the destination peer ID. In case if the connection is in progress
// to the relay server, the function will block until the connection is established or timed out. Otherwise,
// it will return immediately.
// todo: what should happen if call with the same peerID with multiple times?
func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) {
c.mu.Lock()
defer c.mu.Unlock()
if !c.serviceIsRunning {
return nil, fmt.Errorf("relay connection is not established")
}
hashedID, hashedStringID := messages.HashID(dstPeerID)
_, ok := c.conns[hashedStringID]
if ok {
return nil, ErrConnAlreadyExists
}
log.Infof("open connection to peer: %s", hashedStringID)
msgChannel := make(chan Msg, 2)
conn := NewConn(c, hashedID, hashedStringID, msgChannel, c.instanceURL)
c.conns[hashedStringID] = newConnContainer(conn, msgChannel)
return conn, nil
}
// ServerInstanceURL returns the address of the relay server. It could change after the close and reopen the connection.
func (c *Client) ServerInstanceURL() (string, error) {
c.muInstanceURL.Lock()
defer c.muInstanceURL.Unlock()
if c.instanceURL == nil {
return "", fmt.Errorf("relay connection is not established")
}
return c.instanceURL.String(), nil
}
// SetOnDisconnectListener sets a function that will be called when the connection to the relay server is closed.
func (c *Client) SetOnDisconnectListener(fn func()) {
c.listenerMutex.Lock()
defer c.listenerMutex.Unlock()
c.onDisconnectListener = fn
}
// HasConns returns true if there are connections.
func (c *Client) HasConns() bool {
c.mu.Lock()
defer c.mu.Unlock()
return len(c.conns) > 0
}
// Close closes the connection to the relay server and all connections to other peers.
func (c *Client) Close() error {
return c.close(true)
}
func (c *Client) connect() error {
conn, err := ws.Dial(c.connectionURL)
if err != nil {
return err
}
c.relayConn = conn
err = c.handShake()
if err != nil {
cErr := conn.Close()
if cErr != nil {
log.Errorf("failed to close connection: %s", cErr)
}
return err
}
return nil
}
func (c *Client) handShake() error {
authMsg := &auth2.Msg{
AuthAlgorithm: auth2.AlgoHMACSHA256,
AdditionalData: c.authTokenStore.TokenBinary(),
}
authData, err := authMsg.Marshal()
if err != nil {
return fmt.Errorf("marshal auth message: %w", err)
}
msg, err := messages.MarshalHelloMsg(c.hashedID, authData)
if err != nil {
log.Errorf("failed to marshal hello message: %s", err)
return err
}
_, err = c.relayConn.Write(msg)
if err != nil {
log.Errorf("failed to send hello message: %s", err)
return err
}
buf := make([]byte, messages.MaxHandshakeSize)
n, err := c.readWithTimeout(buf)
if err != nil {
log.Errorf("failed to read hello response: %s", err)
return err
}
_, err = messages.ValidateVersion(buf[:n])
if err != nil {
return fmt.Errorf("validate version: %w", err)
}
msgType, err := messages.DetermineServerMessageType(buf[messages.SizeOfVersionByte:n])
if err != nil {
log.Errorf("failed to determine message type: %s", err)
return err
}
if msgType != messages.MsgTypeHelloResponse {
log.Errorf("unexpected message type: %s", msgType)
return fmt.Errorf("unexpected message type")
}
additionalData, err := messages.UnmarshalHelloResponse(buf[messages.SizeOfProtoHeader:n])
if err != nil {
return err
}
addr, err := address.Unmarshal(additionalData)
if err != nil {
return fmt.Errorf("unmarshal address: %w", err)
}
c.muInstanceURL.Lock()
c.instanceURL = &RelayAddr{addr: addr.URL}
c.muInstanceURL.Unlock()
return nil
}
func (c *Client) readLoop(relayConn net.Conn) {
internallyStoppedFlag := newInternalStopFlag()
hc := healthcheck.NewReceiver()
go c.listenForStopEvents(hc, relayConn, internallyStoppedFlag)
var (
errExit error
n int
)
for {
bufPtr := c.bufPool.Get().(*[]byte)
buf := *bufPtr
n, errExit = relayConn.Read(buf)
if errExit != nil {
c.mu.Lock()
if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
c.log.Debugf("failed to read message from relay server: %s", errExit)
}
c.mu.Unlock()
break
}
_, err := messages.ValidateVersion(buf[:n])
if err != nil {
c.log.Errorf("failed to validate protocol version: %s", err)
c.bufPool.Put(bufPtr)
continue
}
msgType, err := messages.DetermineServerMessageType(buf[messages.SizeOfVersionByte:n])
if err != nil {
c.log.Errorf("failed to determine message type: %s", err)
c.bufPool.Put(bufPtr)
continue
}
if !c.handleMsg(msgType, buf[messages.SizeOfProtoHeader:n], bufPtr, hc, internallyStoppedFlag) {
break
}
}
hc.Stop()
c.muInstanceURL.Lock()
c.instanceURL = nil
c.muInstanceURL.Unlock()
c.notifyDisconnected()
c.wgReadLoop.Done()
_ = c.close(false)
}
func (c *Client) handleMsg(msgType messages.MsgType, buf []byte, bufPtr *[]byte, hc *healthcheck.Receiver, internallyStoppedFlag *internalStopFlag) (continueLoop bool) {
switch msgType {
case messages.MsgTypeHealthCheck:
c.handleHealthCheck(hc, internallyStoppedFlag)
c.bufPool.Put(bufPtr)
case messages.MsgTypeTransport:
return c.handleTransportMsg(buf, bufPtr, internallyStoppedFlag)
case messages.MsgTypeClose:
log.Debugf("relay connection close by server")
c.bufPool.Put(bufPtr)
return false
}
return true
}
func (c *Client) handleHealthCheck(hc *healthcheck.Receiver, internallyStoppedFlag *internalStopFlag) {
msg := messages.MarshalHealthcheck()
_, wErr := c.relayConn.Write(msg)
if wErr != nil {
if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
c.log.Errorf("failed to send heartbeat: %s", wErr)
}
}
hc.Heartbeat()
}
func (c *Client) handleTransportMsg(buf []byte, bufPtr *[]byte, internallyStoppedFlag *internalStopFlag) bool {
peerID, payload, err := messages.UnmarshalTransportMsg(buf)
if err != nil {
if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
c.log.Errorf("failed to parse transport message: %v", err)
}
c.bufPool.Put(bufPtr)
return true
}
stringID := messages.HashIDToString(peerID)
c.mu.Lock()
if !c.serviceIsRunning {
c.mu.Unlock()
c.bufPool.Put(bufPtr)
return false
}
container, ok := c.conns[stringID]
c.mu.Unlock()
if !ok {
c.log.Errorf("peer not found: %s", stringID)
c.bufPool.Put(bufPtr)
return true
}
msg := Msg{
bufPool: c.bufPool,
bufPtr: bufPtr,
Payload: payload,
}
container.writeMsg(msg)
return true
}
func (c *Client) writeTo(connReference *Conn, id string, dstID []byte, payload []byte) (int, error) {
c.mu.Lock()
conn, ok := c.conns[id]
c.mu.Unlock()
if !ok {
return 0, io.EOF
}
if conn.conn != connReference {
return 0, io.EOF
}
// todo: use buffer pool instead of create new transport msg.
msg, err := messages.MarshalTransportMsg(dstID, payload)
if err != nil {
log.Errorf("failed to marshal transport message: %s", err)
return 0, err
}
// the write always return with 0 length because the underling does not support the size feedback.
_, err = c.relayConn.Write(msg)
if err != nil {
log.Errorf("failed to write transport message: %s", err)
}
return len(payload), err
}
func (c *Client) listenForStopEvents(hc *healthcheck.Receiver, conn net.Conn, internalStopFlag *internalStopFlag) {
for {
select {
case _, ok := <-hc.OnTimeout:
if !ok {
return
}
c.log.Errorf("health check timeout")
internalStopFlag.set()
_ = conn.Close() // ignore the err because the readLoop will handle it
return
case <-c.parentCtx.Done():
err := c.close(true)
if err != nil {
log.Errorf("failed to teardown connection: %s", err)
}
return
}
}
}
func (c *Client) closeAllConns() {
for _, container := range c.conns {
container.close()
}
c.conns = make(map[string]*connContainer)
}
func (c *Client) closeConn(connReference *Conn, id string) error {
c.mu.Lock()
defer c.mu.Unlock()
container, ok := c.conns[id]
if !ok {
return fmt.Errorf("connection already closed")
}
if container.conn != connReference {
return fmt.Errorf("conn reference mismatch")
}
container.close()
delete(c.conns, id)
return nil
}
func (c *Client) close(gracefullyExit bool) error {
c.readLoopMutex.Lock()
defer c.readLoopMutex.Unlock()
c.mu.Lock()
var err error
if !c.serviceIsRunning {
c.mu.Unlock()
return nil
}
c.serviceIsRunning = false
c.closeAllConns()
if gracefullyExit {
c.writeCloseMsg()
}
err = c.relayConn.Close()
c.mu.Unlock()
c.wgReadLoop.Wait()
c.log.Infof("relay connection closed with: %s", c.connectionURL)
return err
}
func (c *Client) notifyDisconnected() {
c.listenerMutex.Lock()
defer c.listenerMutex.Unlock()
if c.onDisconnectListener == nil {
return
}
go c.onDisconnectListener()
}
func (c *Client) writeCloseMsg() {
msg := messages.MarshalCloseMsg()
_, err := c.relayConn.Write(msg)
if err != nil {
c.log.Errorf("failed to send close message: %s", err)
}
}
func (c *Client) readWithTimeout(buf []byte) (int, error) {
ctx, cancel := context.WithTimeout(c.parentCtx, serverResponseTimeout)
defer cancel()
readDone := make(chan struct{})
var (
n int
err error
)
go func() {
n, err = c.relayConn.Read(buf)
close(readDone)
}()
select {
case <-ctx.Done():
return 0, fmt.Errorf("read operation timed out")
case <-readDone:
return n, err
}
}

631
relay/client/client_test.go Normal file
View File

@@ -0,0 +1,631 @@
package client
import (
"context"
"net"
"os"
"testing"
"time"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/relay/auth/allow"
"github.com/netbirdio/netbird/relay/auth/hmac"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/relay/server"
)
var (
av = &allow.Auth{}
hmacTokenStore = &hmac.TokenStore{}
serverListenAddr = "127.0.0.1:1234"
serverURL = "rel://127.0.0.1:1234"
)
func TestMain(m *testing.M) {
_ = util.InitLog("error", "console")
code := m.Run()
os.Exit(code)
}
func TestClient(t *testing.T) {
ctx := context.Background()
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
listenCfg := server.ListenerConfig{Address: serverListenAddr}
err := srv.Listen(listenCfg)
if err != nil {
errChan <- err
}
}()
defer func() {
err := srv.Shutdown(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
// wait for server to start
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
t.Log("alice connecting to server")
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect()
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
defer clientAlice.Close()
t.Log("placeholder connecting to server")
clientPlaceHolder := NewClient(ctx, serverURL, hmacTokenStore, "clientPlaceHolder")
err = clientPlaceHolder.Connect()
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
defer clientPlaceHolder.Close()
t.Log("Bob connecting to server")
clientBob := NewClient(ctx, serverURL, hmacTokenStore, "bob")
err = clientBob.Connect()
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
defer clientBob.Close()
t.Log("Alice open connection to Bob")
connAliceToBob, err := clientAlice.OpenConn("bob")
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
t.Log("Bob open connection to Alice")
connBobToAlice, err := clientBob.OpenConn("alice")
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
payload := "hello bob, I am alice"
_, err = connAliceToBob.Write([]byte(payload))
if err != nil {
t.Fatalf("failed to write to channel: %s", err)
}
log.Debugf("alice sent message to bob")
buf := make([]byte, 65535)
n, err := connBobToAlice.Read(buf)
if err != nil {
t.Fatalf("failed to read from channel: %s", err)
}
log.Debugf("on new message from alice to bob")
if payload != string(buf[:n]) {
t.Fatalf("expected %s, got %s", payload, string(buf[:n]))
}
}
func TestRegistration(t *testing.T) {
ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
err := srv.Listen(srvCfg)
if err != nil {
errChan <- err
}
}()
// wait for server to start
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect()
if err != nil {
_ = srv.Shutdown(ctx)
t.Fatalf("failed to connect to server: %s", err)
}
err = clientAlice.Close()
if err != nil {
t.Errorf("failed to close conn: %s", err)
}
err = srv.Shutdown(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}
func TestRegistrationTimeout(t *testing.T) {
ctx := context.Background()
fakeUDPListener, err := net.ListenUDP("udp", &net.UDPAddr{
Port: 1234,
IP: net.ParseIP("0.0.0.0"),
})
if err != nil {
t.Fatalf("failed to bind UDP server: %s", err)
}
defer func(fakeUDPListener *net.UDPConn) {
_ = fakeUDPListener.Close()
}(fakeUDPListener)
fakeTCPListener, err := net.ListenTCP("tcp", &net.TCPAddr{
Port: 1234,
IP: net.ParseIP("0.0.0.0"),
})
if err != nil {
t.Fatalf("failed to bind TCP server: %s", err)
}
defer func(fakeTCPListener *net.TCPListener) {
_ = fakeTCPListener.Close()
}(fakeTCPListener)
clientAlice := NewClient(ctx, "127.0.0.1:1234", hmacTokenStore, "alice")
err = clientAlice.Connect()
if err == nil {
t.Errorf("failed to connect to server: %s", err)
}
log.Debugf("%s", err)
err = clientAlice.Close()
if err != nil {
t.Errorf("failed to close conn: %s", err)
}
}
func TestEcho(t *testing.T) {
ctx := context.Background()
idAlice := "alice"
idBob := "bob"
srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
err := srv.Listen(srvCfg)
if err != nil {
errChan <- err
}
}()
defer func() {
err := srv.Shutdown(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
// wait for servers to start
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, idAlice)
err = clientAlice.Connect()
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
defer func() {
err := clientAlice.Close()
if err != nil {
t.Errorf("failed to close Alice client: %s", err)
}
}()
clientBob := NewClient(ctx, serverURL, hmacTokenStore, idBob)
err = clientBob.Connect()
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
defer func() {
err := clientBob.Close()
if err != nil {
t.Errorf("failed to close Bob client: %s", err)
}
}()
connAliceToBob, err := clientAlice.OpenConn(idBob)
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
connBobToAlice, err := clientBob.OpenConn(idAlice)
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
payload := "hello bob, I am alice"
_, err = connAliceToBob.Write([]byte(payload))
if err != nil {
t.Fatalf("failed to write to channel: %s", err)
}
buf := make([]byte, 65535)
n, err := connBobToAlice.Read(buf)
if err != nil {
t.Fatalf("failed to read from channel: %s", err)
}
_, err = connBobToAlice.Write(buf[:n])
if err != nil {
t.Fatalf("failed to write to channel: %s", err)
}
n, err = connAliceToBob.Read(buf)
if err != nil {
t.Fatalf("failed to read from channel: %s", err)
}
if payload != string(buf[:n]) {
t.Fatalf("expected %s, got %s", payload, string(buf[:n]))
}
}
func TestBindToUnavailabePeer(t *testing.T) {
ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
err := srv.Listen(srvCfg)
if err != nil {
errChan <- err
}
}()
defer func() {
log.Infof("closing server")
err := srv.Shutdown(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
// wait for servers to start
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect()
if err != nil {
t.Errorf("failed to connect to server: %s", err)
}
_, err = clientAlice.OpenConn("bob")
if err != nil {
t.Errorf("failed to bind channel: %s", err)
}
log.Infof("closing client")
err = clientAlice.Close()
if err != nil {
t.Errorf("failed to close client: %s", err)
}
}
func TestBindReconnect(t *testing.T) {
ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
err := srv.Listen(srvCfg)
if err != nil {
errChan <- err
}
}()
defer func() {
log.Infof("closing server")
err := srv.Shutdown(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
// wait for servers to start
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect()
if err != nil {
t.Errorf("failed to connect to server: %s", err)
}
_, err = clientAlice.OpenConn("bob")
if err != nil {
t.Errorf("failed to bind channel: %s", err)
}
clientBob := NewClient(ctx, serverURL, hmacTokenStore, "bob")
err = clientBob.Connect()
if err != nil {
t.Errorf("failed to connect to server: %s", err)
}
chBob, err := clientBob.OpenConn("alice")
if err != nil {
t.Errorf("failed to bind channel: %s", err)
}
log.Infof("closing client Alice")
err = clientAlice.Close()
if err != nil {
t.Errorf("failed to close client: %s", err)
}
clientAlice = NewClient(ctx, serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect()
if err != nil {
t.Errorf("failed to connect to server: %s", err)
}
chAlice, err := clientAlice.OpenConn("bob")
if err != nil {
t.Errorf("failed to bind channel: %s", err)
}
testString := "hello alice, I am bob"
_, err = chBob.Write([]byte(testString))
if err != nil {
t.Errorf("failed to write to channel: %s", err)
}
buf := make([]byte, 65535)
n, err := chAlice.Read(buf)
if err != nil {
t.Errorf("failed to read from channel: %s", err)
}
if testString != string(buf[:n]) {
t.Errorf("expected %s, got %s", testString, string(buf[:n]))
}
log.Infof("closing client")
err = clientAlice.Close()
if err != nil {
t.Errorf("failed to close client: %s", err)
}
}
func TestCloseConn(t *testing.T) {
ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
err := srv.Listen(srvCfg)
if err != nil {
errChan <- err
}
}()
defer func() {
log.Infof("closing server")
err := srv.Shutdown(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
// wait for servers to start
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect()
if err != nil {
t.Errorf("failed to connect to server: %s", err)
}
conn, err := clientAlice.OpenConn("bob")
if err != nil {
t.Errorf("failed to bind channel: %s", err)
}
log.Infof("closing connection")
err = conn.Close()
if err != nil {
t.Errorf("failed to close connection: %s", err)
}
_, err = conn.Read(make([]byte, 1))
if err == nil {
t.Errorf("unexpected reading from closed connection")
}
_, err = conn.Write([]byte("hello"))
if err == nil {
t.Errorf("unexpected writing from closed connection")
}
}
func TestCloseRelayConn(t *testing.T) {
ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
err := srv.Listen(srvCfg)
if err != nil {
errChan <- err
}
}()
defer func() {
err := srv.Shutdown(ctx)
if err != nil {
log.Errorf("failed to close server: %s", err)
}
}()
// wait for servers to start
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect()
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
conn, err := clientAlice.OpenConn("bob")
if err != nil {
t.Errorf("failed to bind channel: %s", err)
}
_ = clientAlice.relayConn.Close()
_, err = conn.Read(make([]byte, 1))
if err == nil {
t.Errorf("unexpected reading from closed connection")
}
_, err = clientAlice.OpenConn("bob")
if err == nil {
t.Errorf("unexpected opening connection to closed server")
}
}
func TestCloseByServer(t *testing.T) {
ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv1, err := server.NewServer(otel.Meter(""), serverURL, false, av)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
err := srv1.Listen(srvCfg)
if err != nil {
errChan <- err
}
}()
// wait for servers to start
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
idAlice := "alice"
log.Debugf("connect by alice")
relayClient := NewClient(ctx, serverURL, hmacTokenStore, idAlice)
err = relayClient.Connect()
if err != nil {
log.Fatalf("failed to connect to server: %s", err)
}
disconnected := make(chan struct{})
relayClient.SetOnDisconnectListener(func() {
log.Infof("client disconnected")
close(disconnected)
})
err = srv1.Shutdown(ctx)
if err != nil {
t.Fatalf("failed to close server: %s", err)
}
select {
case <-disconnected:
case <-time.After(3 * time.Second):
log.Fatalf("timeout waiting for client to disconnect")
}
_, err = relayClient.OpenConn("bob")
if err == nil {
t.Errorf("unexpected opening connection to closed server")
}
}
func TestCloseByClient(t *testing.T) {
ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
err := srv.Listen(srvCfg)
if err != nil {
errChan <- err
}
}()
// wait for servers to start
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
idAlice := "alice"
log.Debugf("connect by alice")
relayClient := NewClient(ctx, serverURL, hmacTokenStore, idAlice)
err = relayClient.Connect()
if err != nil {
log.Fatalf("failed to connect to server: %s", err)
}
err = relayClient.Close()
if err != nil {
t.Errorf("failed to close client: %s", err)
}
_, err = relayClient.OpenConn("bob")
if err == nil {
t.Errorf("unexpected opening connection to closed server")
}
err = srv.Shutdown(ctx)
if err != nil {
t.Fatalf("failed to close server: %s", err)
}
}
func waitForServerToStart(errChan chan error) error {
select {
case err := <-errChan:
if err != nil {
return err
}
case <-time.After(300 * time.Millisecond):
return nil
}
return nil
}

76
relay/client/conn.go Normal file
View File

@@ -0,0 +1,76 @@
package client
import (
"io"
"net"
"time"
)
// Conn represent a connection to a relayed remote peer.
type Conn struct {
client *Client
dstID []byte
dstStringID string
messageChan chan Msg
instanceURL *RelayAddr
}
// NewConn creates a new connection to a relayed remote peer.
// client: the client instance, it used to send messages to the destination peer
// dstID: the destination peer ID
// dstStringID: the destination peer ID in string format
// messageChan: the channel where the messages will be received
// instanceURL: the relay instance URL, it used to get the proper server instance address for the remote peer
func NewConn(client *Client, dstID []byte, dstStringID string, messageChan chan Msg, instanceURL *RelayAddr) *Conn {
c := &Conn{
client: client,
dstID: dstID,
dstStringID: dstStringID,
messageChan: messageChan,
instanceURL: instanceURL,
}
return c
}
func (c *Conn) Write(p []byte) (n int, err error) {
return c.client.writeTo(c, c.dstStringID, c.dstID, p)
}
func (c *Conn) Read(b []byte) (n int, err error) {
msg, ok := <-c.messageChan
if !ok {
return 0, io.EOF
}
n = copy(b, msg.Payload)
msg.Free()
return n, nil
}
func (c *Conn) Close() error {
return c.client.closeConn(c, c.dstStringID)
}
func (c *Conn) LocalAddr() net.Addr {
return c.client.relayConn.LocalAddr()
}
func (c *Conn) RemoteAddr() net.Addr {
return c.instanceURL
}
func (c *Conn) SetDeadline(t time.Time) error {
//TODO implement me
panic("SetDeadline is not implemented")
}
func (c *Conn) SetReadDeadline(t time.Time) error {
//TODO implement me
panic("SetReadDeadline is not implemented")
}
func (c *Conn) SetWriteDeadline(t time.Time) error {
//TODO implement me
panic("SetReadDeadline is not implemented")
}

View File

@@ -0,0 +1,13 @@
package ws
type WebsocketAddr struct {
addr string
}
func (a WebsocketAddr) Network() string {
return "websocket"
}
func (a WebsocketAddr) String() string {
return a.addr
}

View File

@@ -0,0 +1,66 @@
package ws
import (
"context"
"fmt"
"net"
"time"
"nhooyr.io/websocket"
)
type Conn struct {
ctx context.Context
*websocket.Conn
remoteAddr WebsocketAddr
}
func NewConn(wsConn *websocket.Conn, serverAddress string) net.Conn {
return &Conn{
ctx: context.Background(),
Conn: wsConn,
remoteAddr: WebsocketAddr{serverAddress},
}
}
func (c *Conn) Read(b []byte) (n int, err error) {
t, ioReader, err := c.Conn.Reader(c.ctx)
if err != nil {
return 0, err
}
if t != websocket.MessageBinary {
return 0, fmt.Errorf("unexpected message type")
}
return ioReader.Read(b)
}
func (c *Conn) Write(b []byte) (n int, err error) {
err = c.Conn.Write(c.ctx, websocket.MessageBinary, b)
return 0, err
}
func (c *Conn) RemoteAddr() net.Addr {
return c.remoteAddr
}
func (c *Conn) LocalAddr() net.Addr {
return WebsocketAddr{addr: "unknown"}
}
func (c *Conn) SetReadDeadline(t time.Time) error {
return fmt.Errorf("SetReadDeadline is not implemented")
}
func (c *Conn) SetWriteDeadline(t time.Time) error {
return fmt.Errorf("SetWriteDeadline is not implemented")
}
func (c *Conn) SetDeadline(t time.Time) error {
return fmt.Errorf("SetDeadline is not implemented")
}
func (c *Conn) Close() error {
return c.Conn.CloseNow()
}

View File

@@ -0,0 +1,67 @@
package ws
import (
"context"
"fmt"
"net"
"net/http"
"net/url"
"strings"
log "github.com/sirupsen/logrus"
"nhooyr.io/websocket"
"github.com/netbirdio/netbird/relay/server/listener/ws"
nbnet "github.com/netbirdio/netbird/util/net"
)
func Dial(address string) (net.Conn, error) {
wsURL, err := prepareURL(address)
if err != nil {
return nil, err
}
opts := &websocket.DialOptions{
HTTPClient: httpClientNbDialer(),
}
parsedURL, err := url.Parse(wsURL)
if err != nil {
return nil, err
}
parsedURL.Path = ws.URLPath
wsConn, resp, err := websocket.Dial(context.Background(), parsedURL.String(), opts)
if err != nil {
log.Errorf("failed to dial to Relay server '%s': %s", wsURL, err)
return nil, err
}
if resp.Body != nil {
_ = resp.Body.Close()
}
conn := NewConn(wsConn, address)
return conn, nil
}
func prepareURL(address string) (string, error) {
if !strings.HasPrefix(address, "rel:") && !strings.HasPrefix(address, "rels:") {
return "", fmt.Errorf("unsupported scheme: %s", address)
}
return strings.Replace(address, "rel", "ws", 1), nil
}
func httpClientNbDialer() *http.Client {
customDialer := nbnet.NewDialer()
customTransport := &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return customDialer.DialContext(ctx, network, addr)
},
}
return &http.Client{
Transport: customTransport,
}
}

12
relay/client/doc.go Normal file
View File

@@ -0,0 +1,12 @@
/*
Package client contains the implementation of the Relay client.
The Relay client is responsible for establishing a connection with the Relay server and sending and receiving messages,
Keep persistent connection with the Relay server and handle the connection issues.
It uses the WebSocket protocol for communication and optionally supports TLS (Transport Layer Security).
If a peer wants to communicate with a peer on a different relay server, the manager will establish a new connection to
the relay server. The connection with these relay servers will be closed if there is no active connection. The peers
negotiate the common relay instance via signaling service.
*/
package client

48
relay/client/guard.go Normal file
View File

@@ -0,0 +1,48 @@
package client
import (
"context"
"time"
log "github.com/sirupsen/logrus"
)
var (
reconnectingTimeout = 5 * time.Second
)
// Guard manage the reconnection tries to the Relay server in case of disconnection event.
type Guard struct {
ctx context.Context
relayClient *Client
}
// NewGuard creates a new guard for the relay client.
func NewGuard(context context.Context, relayClient *Client) *Guard {
g := &Guard{
ctx: context,
relayClient: relayClient,
}
return g
}
// OnDisconnected is called when the relay client is disconnected from the relay server. It will trigger the reconnection
// todo prevent multiple reconnection instances. In the current usage it should not happen, but it is better to prevent
func (g *Guard) OnDisconnected() {
ticker := time.NewTicker(reconnectingTimeout)
defer ticker.Stop()
for {
select {
case <-ticker.C:
err := g.relayClient.Connect()
if err != nil {
log.Errorf("failed to reconnect to relay server: %s", err)
continue
}
return
case <-g.ctx.Done():
return
}
}
}

365
relay/client/manager.go Normal file
View File

@@ -0,0 +1,365 @@
package client
import (
"container/list"
"context"
"errors"
"fmt"
"net"
"reflect"
"sync"
"time"
log "github.com/sirupsen/logrus"
relayAuth "github.com/netbirdio/netbird/relay/auth/hmac"
)
var (
relayCleanupInterval = 60 * time.Second
connectionTimeout = 30 * time.Second
maxConcurrentServers = 7
ErrRelayClientNotConnected = fmt.Errorf("relay client not connected")
)
// RelayTrack hold the relay clients for the foreign relay servers.
// With the mutex can ensure we can open new connection in case the relay connection has been established with
// the relay server.
type RelayTrack struct {
sync.RWMutex
relayClient *Client
}
func NewRelayTrack() *RelayTrack {
return &RelayTrack{}
}
type OnServerCloseListener func()
// ManagerService is the interface for the relay manager.
type ManagerService interface {
Serve() error
OpenConn(serverAddress, peerKey string) (net.Conn, error)
AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error
RelayInstanceAddress() (string, error)
ServerURLs() []string
HasRelayAddress() bool
UpdateToken(token *relayAuth.Token) error
}
// Manager is a manager for the relay client instances. It establishes one persistent connection to the given relay URL
// and automatically reconnect to them in case disconnection.
// The manager also manage temporary relay connection. If a client wants to communicate with a client on a
// different relay server, the manager will establish a new connection to the relay server. The connection with these
// relay servers will be closed if there is no active connection. Periodically the manager will check if there is any
// unused relay connection and close it.
type Manager struct {
ctx context.Context
serverURLs []string
peerID string
tokenStore *relayAuth.TokenStore
relayClient *Client
reconnectGuard *Guard
relayClients map[string]*RelayTrack
relayClientsMutex sync.RWMutex
onDisconnectedListeners map[string]*list.List
listenerLock sync.Mutex
}
// NewManager creates a new manager instance.
// The serverURL address can be empty. In this case, the manager will not serve.
func NewManager(ctx context.Context, serverURLs []string, peerID string) *Manager {
return &Manager{
ctx: ctx,
serverURLs: serverURLs,
peerID: peerID,
tokenStore: &relayAuth.TokenStore{},
relayClients: make(map[string]*RelayTrack),
onDisconnectedListeners: make(map[string]*list.List),
}
}
// Serve starts the manager. It will establish a connection to the relay server and start the relay cleanup loop for
// the unused relay connections. The manager will automatically reconnect to the relay server in case of disconnection.
func (m *Manager) Serve() error {
if m.relayClient != nil {
return fmt.Errorf("manager already serving")
}
log.Debugf("starting relay client manager with %v relay servers", m.serverURLs)
totalServers := len(m.serverURLs)
successChan := make(chan *Client, 1)
errChan := make(chan error, len(m.serverURLs))
ctx, cancel := context.WithTimeout(m.ctx, connectionTimeout)
defer cancel()
sem := make(chan struct{}, maxConcurrentServers)
for _, url := range m.serverURLs {
sem <- struct{}{}
go func(url string) {
defer func() { <-sem }()
m.connect(m.ctx, url, successChan, errChan)
}(url)
}
var errCount int
for {
select {
case client := <-successChan:
log.Infof("Successfully connected to relay server: %s", client.connectionURL)
m.relayClient = client
m.reconnectGuard = NewGuard(m.ctx, m.relayClient)
m.relayClient.SetOnDisconnectListener(func() {
m.onServerDisconnected(client.connectionURL)
})
m.startCleanupLoop()
return nil
case err := <-errChan:
errCount++
log.Warnf("Connection attempt failed: %v", err)
if errCount == totalServers {
return errors.New("failed to connect to any relay server: all attempts failed")
}
case <-ctx.Done():
return fmt.Errorf("failed to connect to any relay server: %w", ctx.Err())
}
}
}
func (m *Manager) connect(ctx context.Context, serverURL string, successChan chan<- *Client, errChan chan<- error) {
// TODO: abort the connection if another connection was successful
relayClient := NewClient(ctx, serverURL, m.tokenStore, m.peerID)
if err := relayClient.Connect(); err != nil {
errChan <- fmt.Errorf("failed to connect to %s: %w", serverURL, err)
return
}
select {
case successChan <- relayClient:
// This client was the first to connect successfully
default:
if err := relayClient.Close(); err != nil {
log.Debugf("failed to close relay client: %s", err)
}
}
}
// OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be
// established via the relay server. If the peer is on a different relay server, the manager will establish a new
// connection to the relay server. It returns back with a net.Conn what represent the remote peer connection.
func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) {
if m.relayClient == nil {
return nil, ErrRelayClientNotConnected
}
foreign, err := m.isForeignServer(serverAddress)
if err != nil {
return nil, err
}
var (
netConn net.Conn
)
if !foreign {
log.Debugf("open peer connection via permanent server: %s", peerKey)
netConn, err = m.relayClient.OpenConn(peerKey)
} else {
log.Debugf("open peer connection via foreign server: %s", serverAddress)
netConn, err = m.openConnVia(serverAddress, peerKey)
}
if err != nil {
return nil, err
}
return netConn, err
}
// AddCloseListener adds a listener to the given server instance address. The listener will be called if the connection
// closed.
func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error {
foreign, err := m.isForeignServer(serverAddress)
if err != nil {
return err
}
var listenerAddr string
if foreign {
listenerAddr = serverAddress
} else {
listenerAddr = m.relayClient.connectionURL
}
m.addListener(listenerAddr, onClosedListener)
return nil
}
// RelayInstanceAddress returns the address of the permanent relay server. It could change if the network connection is
// lost. This address will be sent to the target peer to choose the common relay server for the communication.
func (m *Manager) RelayInstanceAddress() (string, error) {
if m.relayClient == nil {
return "", ErrRelayClientNotConnected
}
return m.relayClient.ServerInstanceURL()
}
// ServerURLs returns the addresses of the relay servers.
func (m *Manager) ServerURLs() []string {
return m.serverURLs
}
// HasRelayAddress returns true if the manager is serving. With this method can check if the peer can communicate with
// Relay service.
func (m *Manager) HasRelayAddress() bool {
return len(m.serverURLs) > 0
}
// UpdateToken updates the token in the token store.
func (m *Manager) UpdateToken(token *relayAuth.Token) error {
return m.tokenStore.UpdateToken(token)
}
func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) {
// check if already has a connection to the desired relay server
m.relayClientsMutex.RLock()
rt, ok := m.relayClients[serverAddress]
if ok {
rt.RLock()
m.relayClientsMutex.RUnlock()
defer rt.RUnlock()
return rt.relayClient.OpenConn(peerKey)
}
m.relayClientsMutex.RUnlock()
// if not, establish a new connection but check it again (because changed the lock type) before starting the
// connection
m.relayClientsMutex.Lock()
rt, ok = m.relayClients[serverAddress]
if ok {
rt.RLock()
m.relayClientsMutex.Unlock()
defer rt.RUnlock()
return rt.relayClient.OpenConn(peerKey)
}
// create a new relay client and store it in the relayClients map
rt = NewRelayTrack()
rt.Lock()
m.relayClients[serverAddress] = rt
m.relayClientsMutex.Unlock()
relayClient := NewClient(m.ctx, serverAddress, m.tokenStore, m.peerID)
err := relayClient.Connect()
if err != nil {
rt.Unlock()
m.relayClientsMutex.Lock()
delete(m.relayClients, serverAddress)
m.relayClientsMutex.Unlock()
return nil, err
}
// if connection closed then delete the relay client from the list
relayClient.SetOnDisconnectListener(func() {
m.onServerDisconnected(serverAddress)
})
rt.relayClient = relayClient
rt.Unlock()
conn, err := relayClient.OpenConn(peerKey)
if err != nil {
return nil, err
}
return conn, nil
}
func (m *Manager) onServerDisconnected(serverAddress string) {
if serverAddress == m.relayClient.connectionURL {
go m.reconnectGuard.OnDisconnected()
}
m.notifyOnDisconnectListeners(serverAddress)
}
func (m *Manager) isForeignServer(address string) (bool, error) {
rAddr, err := m.relayClient.ServerInstanceURL()
if err != nil {
return false, fmt.Errorf("relay client not connected")
}
return rAddr != address, nil
}
func (m *Manager) startCleanupLoop() {
if m.ctx.Err() != nil {
return
}
ticker := time.NewTicker(relayCleanupInterval)
go func() {
defer ticker.Stop()
for {
select {
case <-m.ctx.Done():
return
case <-ticker.C:
m.cleanUpUnusedRelays()
}
}
}()
}
func (m *Manager) cleanUpUnusedRelays() {
m.relayClientsMutex.Lock()
defer m.relayClientsMutex.Unlock()
for addr, rt := range m.relayClients {
rt.Lock()
if rt.relayClient.HasConns() {
rt.Unlock()
continue
}
rt.relayClient.SetOnDisconnectListener(nil)
go func() {
_ = rt.relayClient.Close()
}()
log.Debugf("clean up unused relay server connection: %s", addr)
delete(m.relayClients, addr)
rt.Unlock()
}
}
func (m *Manager) addListener(serverAddress string, onClosedListener OnServerCloseListener) {
m.listenerLock.Lock()
defer m.listenerLock.Unlock()
l, ok := m.onDisconnectedListeners[serverAddress]
if !ok {
l = list.New()
}
for e := l.Front(); e != nil; e = e.Next() {
if reflect.ValueOf(e.Value).Pointer() == reflect.ValueOf(onClosedListener).Pointer() {
return
}
}
l.PushBack(onClosedListener)
m.onDisconnectedListeners[serverAddress] = l
}
func (m *Manager) notifyOnDisconnectListeners(serverAddress string) {
m.listenerLock.Lock()
defer m.listenerLock.Unlock()
l, ok := m.onDisconnectedListeners[serverAddress]
if !ok {
return
}
for e := l.Front(); e != nil; e = e.Next() {
go e.Value.(OnServerCloseListener)()
}
delete(m.onDisconnectedListeners, serverAddress)
}

View File

@@ -0,0 +1,432 @@
package client
import (
"context"
"testing"
"time"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/relay/server"
)
func TestEmptyURL(t *testing.T) {
mgr := NewManager(context.Background(), nil, "alice")
err := mgr.Serve()
if err == nil {
t.Errorf("expected error, got nil")
}
}
func TestForeignConn(t *testing.T) {
ctx := context.Background()
srvCfg1 := server.ListenerConfig{
Address: "localhost:1234",
}
srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
err := srv1.Listen(srvCfg1)
if err != nil {
errChan <- err
}
}()
defer func() {
err := srv1.Shutdown(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
srvCfg2 := server.ListenerConfig{
Address: "localhost:2234",
}
srv2, err := server.NewServer(otel.Meter(""), srvCfg2.Address, false, av)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan2 := make(chan error, 1)
go func() {
err := srv2.Listen(srvCfg2)
if err != nil {
errChan2 <- err
}
}()
defer func() {
err := srv2.Shutdown(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
if err := waitForServerToStart(errChan2); err != nil {
t.Fatalf("failed to start server: %s", err)
}
idAlice := "alice"
log.Debugf("connect by alice")
mCtx, cancel := context.WithCancel(ctx)
defer cancel()
clientAlice := NewManager(mCtx, toURL(srvCfg1), idAlice)
err = clientAlice.Serve()
if err != nil {
t.Fatalf("failed to serve manager: %s", err)
}
idBob := "bob"
log.Debugf("connect by bob")
clientBob := NewManager(mCtx, toURL(srvCfg2), idBob)
err = clientBob.Serve()
if err != nil {
t.Fatalf("failed to serve manager: %s", err)
}
bobsSrvAddr, err := clientBob.RelayInstanceAddress()
if err != nil {
t.Fatalf("failed to get relay address: %s", err)
}
connAliceToBob, err := clientAlice.OpenConn(bobsSrvAddr, idBob)
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
connBobToAlice, err := clientBob.OpenConn(bobsSrvAddr, idAlice)
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
payload := "hello bob, I am alice"
_, err = connAliceToBob.Write([]byte(payload))
if err != nil {
t.Fatalf("failed to write to channel: %s", err)
}
buf := make([]byte, 65535)
n, err := connBobToAlice.Read(buf)
if err != nil {
t.Fatalf("failed to read from channel: %s", err)
}
_, err = connBobToAlice.Write(buf[:n])
if err != nil {
t.Fatalf("failed to write to channel: %s", err)
}
n, err = connAliceToBob.Read(buf)
if err != nil {
t.Fatalf("failed to read from channel: %s", err)
}
if payload != string(buf[:n]) {
t.Fatalf("expected %s, got %s", payload, string(buf[:n]))
}
}
func TestForeginConnClose(t *testing.T) {
ctx := context.Background()
srvCfg1 := server.ListenerConfig{
Address: "localhost:1234",
}
srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
err := srv1.Listen(srvCfg1)
if err != nil {
errChan <- err
}
}()
defer func() {
err := srv1.Shutdown(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
srvCfg2 := server.ListenerConfig{
Address: "localhost:2234",
}
srv2, err := server.NewServer(otel.Meter(""), srvCfg2.Address, false, av)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan2 := make(chan error, 1)
go func() {
err := srv2.Listen(srvCfg2)
if err != nil {
errChan2 <- err
}
}()
defer func() {
err := srv2.Shutdown(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
if err := waitForServerToStart(errChan2); err != nil {
t.Fatalf("failed to start server: %s", err)
}
idAlice := "alice"
log.Debugf("connect by alice")
mCtx, cancel := context.WithCancel(ctx)
defer cancel()
mgr := NewManager(mCtx, toURL(srvCfg1), idAlice)
err = mgr.Serve()
if err != nil {
t.Fatalf("failed to serve manager: %s", err)
}
conn, err := mgr.OpenConn(toURL(srvCfg2)[0], "anotherpeer")
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
err = conn.Close()
if err != nil {
t.Fatalf("failed to close connection: %s", err)
}
}
func TestForeginAutoClose(t *testing.T) {
ctx := context.Background()
relayCleanupInterval = 1 * time.Second
srvCfg1 := server.ListenerConfig{
Address: "localhost:1234",
}
srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
t.Log("binding server 1.")
err := srv1.Listen(srvCfg1)
if err != nil {
errChan <- err
}
}()
defer func() {
t.Logf("closing server 1.")
err := srv1.Shutdown(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
t.Logf("server 1. closed")
}()
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
srvCfg2 := server.ListenerConfig{
Address: "localhost:2234",
}
srv2, err := server.NewServer(otel.Meter(""), srvCfg2.Address, false, av)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan2 := make(chan error, 1)
go func() {
t.Log("binding server 2.")
err := srv2.Listen(srvCfg2)
if err != nil {
errChan2 <- err
}
}()
defer func() {
t.Logf("closing server 2.")
err := srv2.Shutdown(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
t.Logf("server 2 closed.")
}()
if err := waitForServerToStart(errChan2); err != nil {
t.Fatalf("failed to start server: %s", err)
}
idAlice := "alice"
t.Log("connect to server 1.")
mCtx, cancel := context.WithCancel(ctx)
defer cancel()
mgr := NewManager(mCtx, toURL(srvCfg1), idAlice)
err = mgr.Serve()
if err != nil {
t.Fatalf("failed to serve manager: %s", err)
}
t.Log("open connection to another peer")
conn, err := mgr.OpenConn(toURL(srvCfg2)[0], "anotherpeer")
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
t.Log("close conn")
err = conn.Close()
if err != nil {
t.Fatalf("failed to close connection: %s", err)
}
t.Logf("waiting for relay cleanup: %s", relayCleanupInterval+1*time.Second)
time.Sleep(relayCleanupInterval + 1*time.Second)
if len(mgr.relayClients) != 0 {
t.Errorf("expected 0, got %d", len(mgr.relayClients))
}
t.Logf("closing manager")
}
func TestAutoReconnect(t *testing.T) {
ctx := context.Background()
reconnectingTimeout = 2 * time.Second
srvCfg := server.ListenerConfig{
Address: "localhost:1234",
}
srv, err := server.NewServer(otel.Meter(""), srvCfg.Address, false, av)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
err := srv.Listen(srvCfg)
if err != nil {
errChan <- err
}
}()
defer func() {
err := srv.Shutdown(ctx)
if err != nil {
log.Errorf("failed to close server: %s", err)
}
}()
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
mCtx, cancel := context.WithCancel(ctx)
defer cancel()
clientAlice := NewManager(mCtx, toURL(srvCfg), "alice")
err = clientAlice.Serve()
if err != nil {
t.Fatalf("failed to serve manager: %s", err)
}
ra, err := clientAlice.RelayInstanceAddress()
if err != nil {
t.Errorf("failed to get relay address: %s", err)
}
conn, err := clientAlice.OpenConn(ra, "bob")
if err != nil {
t.Errorf("failed to bind channel: %s", err)
}
t.Log("closing client relay connection")
// todo figure out moc server
_ = clientAlice.relayClient.relayConn.Close()
t.Log("start test reading")
_, err = conn.Read(make([]byte, 1))
if err == nil {
t.Errorf("unexpected reading from closed connection")
}
log.Infof("waiting for reconnection")
time.Sleep(reconnectingTimeout + 1*time.Second)
log.Infof("reopent the connection")
_, err = clientAlice.OpenConn(ra, "bob")
if err != nil {
t.Errorf("failed to open channel: %s", err)
}
}
func TestNotifierDoubleAdd(t *testing.T) {
ctx := context.Background()
srvCfg1 := server.ListenerConfig{
Address: "localhost:1234",
}
srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
err := srv1.Listen(srvCfg1)
if err != nil {
errChan <- err
}
}()
defer func() {
err := srv1.Shutdown(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
idAlice := "alice"
log.Debugf("connect by alice")
mCtx, cancel := context.WithCancel(ctx)
defer cancel()
clientAlice := NewManager(mCtx, toURL(srvCfg1), idAlice)
err = clientAlice.Serve()
if err != nil {
t.Fatalf("failed to serve manager: %s", err)
}
conn1, err := clientAlice.OpenConn(clientAlice.ServerURLs()[0], "idBob")
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
fnCloseListener := OnServerCloseListener(func() {
log.Infof("close listener")
})
err = clientAlice.AddCloseListener(clientAlice.ServerURLs()[0], fnCloseListener)
if err != nil {
t.Fatalf("failed to add close listener: %s", err)
}
err = clientAlice.AddCloseListener(clientAlice.ServerURLs()[0], fnCloseListener)
if err != nil {
t.Fatalf("failed to add close listener: %s", err)
}
err = conn1.Close()
if err != nil {
t.Errorf("failed to close connection: %s", err)
}
}
func toURL(address server.ListenerConfig) []string {
return []string{"rel://" + address.Address}
}