mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
[relay] Feature/relay integration (#2244)
This update adds new relay integration for NetBird clients. The new relay is based on web sockets and listens on a single port. - Adds new relay implementation with websocket with single port relaying mechanism - refactor peer connection logic, allowing upgrade and downgrade from/to P2P connection - peer connections are faster since it connects first to relay and then upgrades to P2P - maintains compatibility with old clients by not using the new relay - updates infrastructure scripts with new relay service
This commit is contained in:
13
relay/client/addr.go
Normal file
13
relay/client/addr.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package client
|
||||
|
||||
type RelayAddr struct {
|
||||
addr string
|
||||
}
|
||||
|
||||
func (a RelayAddr) Network() string {
|
||||
return "relay"
|
||||
}
|
||||
|
||||
func (a RelayAddr) String() string {
|
||||
return a.addr
|
||||
}
|
||||
553
relay/client/client.go
Normal file
553
relay/client/client.go
Normal file
@@ -0,0 +1,553 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
auth "github.com/netbirdio/netbird/relay/auth/hmac"
|
||||
"github.com/netbirdio/netbird/relay/client/dialer/ws"
|
||||
"github.com/netbirdio/netbird/relay/healthcheck"
|
||||
"github.com/netbirdio/netbird/relay/messages"
|
||||
"github.com/netbirdio/netbird/relay/messages/address"
|
||||
auth2 "github.com/netbirdio/netbird/relay/messages/auth"
|
||||
)
|
||||
|
||||
const (
|
||||
bufferSize = 8820
|
||||
serverResponseTimeout = 8 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
ErrConnAlreadyExists = fmt.Errorf("connection already exists")
|
||||
)
|
||||
|
||||
type internalStopFlag struct {
|
||||
sync.Mutex
|
||||
stop bool
|
||||
}
|
||||
|
||||
func newInternalStopFlag() *internalStopFlag {
|
||||
return &internalStopFlag{}
|
||||
}
|
||||
|
||||
func (isf *internalStopFlag) set() {
|
||||
isf.Lock()
|
||||
defer isf.Unlock()
|
||||
isf.stop = true
|
||||
}
|
||||
|
||||
func (isf *internalStopFlag) isSet() bool {
|
||||
isf.Lock()
|
||||
defer isf.Unlock()
|
||||
return isf.stop
|
||||
}
|
||||
|
||||
// Msg carry the payload from the server to the client. With this struct, the net.Conn can free the buffer.
|
||||
type Msg struct {
|
||||
Payload []byte
|
||||
|
||||
bufPool *sync.Pool
|
||||
bufPtr *[]byte
|
||||
}
|
||||
|
||||
func (m *Msg) Free() {
|
||||
m.bufPool.Put(m.bufPtr)
|
||||
}
|
||||
|
||||
type connContainer struct {
|
||||
conn *Conn
|
||||
messages chan Msg
|
||||
msgChanLock sync.Mutex
|
||||
closed bool // flag to check if channel is closed
|
||||
}
|
||||
|
||||
func newConnContainer(conn *Conn, messages chan Msg) *connContainer {
|
||||
return &connContainer{
|
||||
conn: conn,
|
||||
messages: messages,
|
||||
}
|
||||
}
|
||||
|
||||
func (cc *connContainer) writeMsg(msg Msg) {
|
||||
cc.msgChanLock.Lock()
|
||||
defer cc.msgChanLock.Unlock()
|
||||
if cc.closed {
|
||||
return
|
||||
}
|
||||
cc.messages <- msg
|
||||
}
|
||||
|
||||
func (cc *connContainer) close() {
|
||||
cc.msgChanLock.Lock()
|
||||
defer cc.msgChanLock.Unlock()
|
||||
if cc.closed {
|
||||
return
|
||||
}
|
||||
close(cc.messages)
|
||||
cc.closed = true
|
||||
}
|
||||
|
||||
// Client is a client for the relay server. It is responsible for establishing a connection to the relay server and
|
||||
// managing connections to other peers. All exported functions are safe to call concurrently. After close the connection,
|
||||
// the client can be reused by calling Connect again. When the client is closed, all connections are closed too.
|
||||
// While the Connect is in progress, the OpenConn function will block until the connection is established with relay server.
|
||||
type Client struct {
|
||||
log *log.Entry
|
||||
parentCtx context.Context
|
||||
connectionURL string
|
||||
authTokenStore *auth.TokenStore
|
||||
hashedID []byte
|
||||
|
||||
bufPool *sync.Pool
|
||||
|
||||
relayConn net.Conn
|
||||
conns map[string]*connContainer
|
||||
serviceIsRunning bool
|
||||
mu sync.Mutex // protect serviceIsRunning and conns
|
||||
readLoopMutex sync.Mutex
|
||||
wgReadLoop sync.WaitGroup
|
||||
instanceURL *RelayAddr
|
||||
muInstanceURL sync.Mutex
|
||||
|
||||
onDisconnectListener func()
|
||||
listenerMutex sync.Mutex
|
||||
}
|
||||
|
||||
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
|
||||
func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client {
|
||||
hashedID, hashedStringId := messages.HashID(peerID)
|
||||
return &Client{
|
||||
log: log.WithField("client_id", hashedStringId),
|
||||
parentCtx: ctx,
|
||||
connectionURL: serverURL,
|
||||
authTokenStore: authTokenStore,
|
||||
hashedID: hashedID,
|
||||
bufPool: &sync.Pool{
|
||||
New: func() any {
|
||||
buf := make([]byte, bufferSize)
|
||||
return &buf
|
||||
},
|
||||
},
|
||||
conns: make(map[string]*connContainer),
|
||||
}
|
||||
}
|
||||
|
||||
// Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs.
|
||||
func (c *Client) Connect() error {
|
||||
c.log.Infof("connecting to relay server: %s", c.connectionURL)
|
||||
c.readLoopMutex.Lock()
|
||||
defer c.readLoopMutex.Unlock()
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.serviceIsRunning {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := c.connect()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.serviceIsRunning = true
|
||||
|
||||
c.wgReadLoop.Add(1)
|
||||
go c.readLoop(c.relayConn)
|
||||
|
||||
c.log.Infof("relay connection established with: %s", c.connectionURL)
|
||||
return nil
|
||||
}
|
||||
|
||||
// OpenConn create a new net.Conn for the destination peer ID. In case if the connection is in progress
|
||||
// to the relay server, the function will block until the connection is established or timed out. Otherwise,
|
||||
// it will return immediately.
|
||||
// todo: what should happen if call with the same peerID with multiple times?
|
||||
func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if !c.serviceIsRunning {
|
||||
return nil, fmt.Errorf("relay connection is not established")
|
||||
}
|
||||
|
||||
hashedID, hashedStringID := messages.HashID(dstPeerID)
|
||||
_, ok := c.conns[hashedStringID]
|
||||
if ok {
|
||||
return nil, ErrConnAlreadyExists
|
||||
}
|
||||
|
||||
log.Infof("open connection to peer: %s", hashedStringID)
|
||||
msgChannel := make(chan Msg, 2)
|
||||
conn := NewConn(c, hashedID, hashedStringID, msgChannel, c.instanceURL)
|
||||
|
||||
c.conns[hashedStringID] = newConnContainer(conn, msgChannel)
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// ServerInstanceURL returns the address of the relay server. It could change after the close and reopen the connection.
|
||||
func (c *Client) ServerInstanceURL() (string, error) {
|
||||
c.muInstanceURL.Lock()
|
||||
defer c.muInstanceURL.Unlock()
|
||||
if c.instanceURL == nil {
|
||||
return "", fmt.Errorf("relay connection is not established")
|
||||
}
|
||||
return c.instanceURL.String(), nil
|
||||
}
|
||||
|
||||
// SetOnDisconnectListener sets a function that will be called when the connection to the relay server is closed.
|
||||
func (c *Client) SetOnDisconnectListener(fn func()) {
|
||||
c.listenerMutex.Lock()
|
||||
defer c.listenerMutex.Unlock()
|
||||
c.onDisconnectListener = fn
|
||||
}
|
||||
|
||||
// HasConns returns true if there are connections.
|
||||
func (c *Client) HasConns() bool {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return len(c.conns) > 0
|
||||
}
|
||||
|
||||
// Close closes the connection to the relay server and all connections to other peers.
|
||||
func (c *Client) Close() error {
|
||||
return c.close(true)
|
||||
}
|
||||
|
||||
func (c *Client) connect() error {
|
||||
conn, err := ws.Dial(c.connectionURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.relayConn = conn
|
||||
|
||||
err = c.handShake()
|
||||
if err != nil {
|
||||
cErr := conn.Close()
|
||||
if cErr != nil {
|
||||
log.Errorf("failed to close connection: %s", cErr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) handShake() error {
|
||||
authMsg := &auth2.Msg{
|
||||
AuthAlgorithm: auth2.AlgoHMACSHA256,
|
||||
AdditionalData: c.authTokenStore.TokenBinary(),
|
||||
}
|
||||
|
||||
authData, err := authMsg.Marshal()
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal auth message: %w", err)
|
||||
}
|
||||
|
||||
msg, err := messages.MarshalHelloMsg(c.hashedID, authData)
|
||||
if err != nil {
|
||||
log.Errorf("failed to marshal hello message: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = c.relayConn.Write(msg)
|
||||
if err != nil {
|
||||
log.Errorf("failed to send hello message: %s", err)
|
||||
return err
|
||||
}
|
||||
buf := make([]byte, messages.MaxHandshakeSize)
|
||||
n, err := c.readWithTimeout(buf)
|
||||
if err != nil {
|
||||
log.Errorf("failed to read hello response: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = messages.ValidateVersion(buf[:n])
|
||||
if err != nil {
|
||||
return fmt.Errorf("validate version: %w", err)
|
||||
}
|
||||
|
||||
msgType, err := messages.DetermineServerMessageType(buf[messages.SizeOfVersionByte:n])
|
||||
if err != nil {
|
||||
log.Errorf("failed to determine message type: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if msgType != messages.MsgTypeHelloResponse {
|
||||
log.Errorf("unexpected message type: %s", msgType)
|
||||
return fmt.Errorf("unexpected message type")
|
||||
}
|
||||
|
||||
additionalData, err := messages.UnmarshalHelloResponse(buf[messages.SizeOfProtoHeader:n])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
addr, err := address.Unmarshal(additionalData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unmarshal address: %w", err)
|
||||
}
|
||||
|
||||
c.muInstanceURL.Lock()
|
||||
c.instanceURL = &RelayAddr{addr: addr.URL}
|
||||
c.muInstanceURL.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) readLoop(relayConn net.Conn) {
|
||||
internallyStoppedFlag := newInternalStopFlag()
|
||||
hc := healthcheck.NewReceiver()
|
||||
go c.listenForStopEvents(hc, relayConn, internallyStoppedFlag)
|
||||
|
||||
var (
|
||||
errExit error
|
||||
n int
|
||||
)
|
||||
for {
|
||||
bufPtr := c.bufPool.Get().(*[]byte)
|
||||
buf := *bufPtr
|
||||
n, errExit = relayConn.Read(buf)
|
||||
if errExit != nil {
|
||||
c.mu.Lock()
|
||||
if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
|
||||
c.log.Debugf("failed to read message from relay server: %s", errExit)
|
||||
}
|
||||
c.mu.Unlock()
|
||||
break
|
||||
}
|
||||
|
||||
_, err := messages.ValidateVersion(buf[:n])
|
||||
if err != nil {
|
||||
c.log.Errorf("failed to validate protocol version: %s", err)
|
||||
c.bufPool.Put(bufPtr)
|
||||
continue
|
||||
}
|
||||
|
||||
msgType, err := messages.DetermineServerMessageType(buf[messages.SizeOfVersionByte:n])
|
||||
if err != nil {
|
||||
c.log.Errorf("failed to determine message type: %s", err)
|
||||
c.bufPool.Put(bufPtr)
|
||||
continue
|
||||
}
|
||||
|
||||
if !c.handleMsg(msgType, buf[messages.SizeOfProtoHeader:n], bufPtr, hc, internallyStoppedFlag) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
hc.Stop()
|
||||
|
||||
c.muInstanceURL.Lock()
|
||||
c.instanceURL = nil
|
||||
c.muInstanceURL.Unlock()
|
||||
|
||||
c.notifyDisconnected()
|
||||
c.wgReadLoop.Done()
|
||||
_ = c.close(false)
|
||||
}
|
||||
|
||||
func (c *Client) handleMsg(msgType messages.MsgType, buf []byte, bufPtr *[]byte, hc *healthcheck.Receiver, internallyStoppedFlag *internalStopFlag) (continueLoop bool) {
|
||||
switch msgType {
|
||||
case messages.MsgTypeHealthCheck:
|
||||
c.handleHealthCheck(hc, internallyStoppedFlag)
|
||||
c.bufPool.Put(bufPtr)
|
||||
case messages.MsgTypeTransport:
|
||||
return c.handleTransportMsg(buf, bufPtr, internallyStoppedFlag)
|
||||
case messages.MsgTypeClose:
|
||||
log.Debugf("relay connection close by server")
|
||||
c.bufPool.Put(bufPtr)
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *Client) handleHealthCheck(hc *healthcheck.Receiver, internallyStoppedFlag *internalStopFlag) {
|
||||
msg := messages.MarshalHealthcheck()
|
||||
_, wErr := c.relayConn.Write(msg)
|
||||
if wErr != nil {
|
||||
if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
|
||||
c.log.Errorf("failed to send heartbeat: %s", wErr)
|
||||
}
|
||||
}
|
||||
hc.Heartbeat()
|
||||
}
|
||||
|
||||
func (c *Client) handleTransportMsg(buf []byte, bufPtr *[]byte, internallyStoppedFlag *internalStopFlag) bool {
|
||||
peerID, payload, err := messages.UnmarshalTransportMsg(buf)
|
||||
if err != nil {
|
||||
if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
|
||||
c.log.Errorf("failed to parse transport message: %v", err)
|
||||
}
|
||||
|
||||
c.bufPool.Put(bufPtr)
|
||||
return true
|
||||
}
|
||||
|
||||
stringID := messages.HashIDToString(peerID)
|
||||
|
||||
c.mu.Lock()
|
||||
if !c.serviceIsRunning {
|
||||
c.mu.Unlock()
|
||||
c.bufPool.Put(bufPtr)
|
||||
return false
|
||||
}
|
||||
container, ok := c.conns[stringID]
|
||||
c.mu.Unlock()
|
||||
if !ok {
|
||||
c.log.Errorf("peer not found: %s", stringID)
|
||||
c.bufPool.Put(bufPtr)
|
||||
return true
|
||||
}
|
||||
msg := Msg{
|
||||
bufPool: c.bufPool,
|
||||
bufPtr: bufPtr,
|
||||
Payload: payload,
|
||||
}
|
||||
container.writeMsg(msg)
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *Client) writeTo(connReference *Conn, id string, dstID []byte, payload []byte) (int, error) {
|
||||
c.mu.Lock()
|
||||
conn, ok := c.conns[id]
|
||||
c.mu.Unlock()
|
||||
if !ok {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
if conn.conn != connReference {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
// todo: use buffer pool instead of create new transport msg.
|
||||
msg, err := messages.MarshalTransportMsg(dstID, payload)
|
||||
if err != nil {
|
||||
log.Errorf("failed to marshal transport message: %s", err)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// the write always return with 0 length because the underling does not support the size feedback.
|
||||
_, err = c.relayConn.Write(msg)
|
||||
if err != nil {
|
||||
log.Errorf("failed to write transport message: %s", err)
|
||||
}
|
||||
return len(payload), err
|
||||
}
|
||||
|
||||
func (c *Client) listenForStopEvents(hc *healthcheck.Receiver, conn net.Conn, internalStopFlag *internalStopFlag) {
|
||||
for {
|
||||
select {
|
||||
case _, ok := <-hc.OnTimeout:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
c.log.Errorf("health check timeout")
|
||||
internalStopFlag.set()
|
||||
_ = conn.Close() // ignore the err because the readLoop will handle it
|
||||
return
|
||||
case <-c.parentCtx.Done():
|
||||
err := c.close(true)
|
||||
if err != nil {
|
||||
log.Errorf("failed to teardown connection: %s", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) closeAllConns() {
|
||||
for _, container := range c.conns {
|
||||
container.close()
|
||||
}
|
||||
c.conns = make(map[string]*connContainer)
|
||||
}
|
||||
|
||||
func (c *Client) closeConn(connReference *Conn, id string) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
container, ok := c.conns[id]
|
||||
if !ok {
|
||||
return fmt.Errorf("connection already closed")
|
||||
}
|
||||
|
||||
if container.conn != connReference {
|
||||
return fmt.Errorf("conn reference mismatch")
|
||||
}
|
||||
container.close()
|
||||
delete(c.conns, id)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) close(gracefullyExit bool) error {
|
||||
c.readLoopMutex.Lock()
|
||||
defer c.readLoopMutex.Unlock()
|
||||
|
||||
c.mu.Lock()
|
||||
var err error
|
||||
if !c.serviceIsRunning {
|
||||
c.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
c.serviceIsRunning = false
|
||||
c.closeAllConns()
|
||||
if gracefullyExit {
|
||||
c.writeCloseMsg()
|
||||
}
|
||||
err = c.relayConn.Close()
|
||||
c.mu.Unlock()
|
||||
|
||||
c.wgReadLoop.Wait()
|
||||
c.log.Infof("relay connection closed with: %s", c.connectionURL)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Client) notifyDisconnected() {
|
||||
c.listenerMutex.Lock()
|
||||
defer c.listenerMutex.Unlock()
|
||||
|
||||
if c.onDisconnectListener == nil {
|
||||
return
|
||||
}
|
||||
go c.onDisconnectListener()
|
||||
}
|
||||
|
||||
func (c *Client) writeCloseMsg() {
|
||||
msg := messages.MarshalCloseMsg()
|
||||
_, err := c.relayConn.Write(msg)
|
||||
if err != nil {
|
||||
c.log.Errorf("failed to send close message: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) readWithTimeout(buf []byte) (int, error) {
|
||||
ctx, cancel := context.WithTimeout(c.parentCtx, serverResponseTimeout)
|
||||
defer cancel()
|
||||
|
||||
readDone := make(chan struct{})
|
||||
var (
|
||||
n int
|
||||
err error
|
||||
)
|
||||
|
||||
go func() {
|
||||
n, err = c.relayConn.Read(buf)
|
||||
close(readDone)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return 0, fmt.Errorf("read operation timed out")
|
||||
case <-readDone:
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
631
relay/client/client_test.go
Normal file
631
relay/client/client_test.go
Normal file
@@ -0,0 +1,631 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"go.opentelemetry.io/otel"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/auth/allow"
|
||||
"github.com/netbirdio/netbird/relay/auth/hmac"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/server"
|
||||
)
|
||||
|
||||
var (
|
||||
av = &allow.Auth{}
|
||||
hmacTokenStore = &hmac.TokenStore{}
|
||||
serverListenAddr = "127.0.0.1:1234"
|
||||
serverURL = "rel://127.0.0.1:1234"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
_ = util.InitLog("error", "console")
|
||||
code := m.Run()
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func TestClient(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %s", err)
|
||||
}
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
listenCfg := server.ListenerConfig{Address: serverListenAddr}
|
||||
err := srv.Listen(listenCfg)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
err := srv.Shutdown(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// wait for server to start
|
||||
if err := waitForServerToStart(errChan); err != nil {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
t.Log("alice connecting to server")
|
||||
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
|
||||
err = clientAlice.Connect()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
defer clientAlice.Close()
|
||||
|
||||
t.Log("placeholder connecting to server")
|
||||
clientPlaceHolder := NewClient(ctx, serverURL, hmacTokenStore, "clientPlaceHolder")
|
||||
err = clientPlaceHolder.Connect()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
defer clientPlaceHolder.Close()
|
||||
|
||||
t.Log("Bob connecting to server")
|
||||
clientBob := NewClient(ctx, serverURL, hmacTokenStore, "bob")
|
||||
err = clientBob.Connect()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
defer clientBob.Close()
|
||||
|
||||
t.Log("Alice open connection to Bob")
|
||||
connAliceToBob, err := clientAlice.OpenConn("bob")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
t.Log("Bob open connection to Alice")
|
||||
connBobToAlice, err := clientBob.OpenConn("alice")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
payload := "hello bob, I am alice"
|
||||
_, err = connAliceToBob.Write([]byte(payload))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write to channel: %s", err)
|
||||
}
|
||||
log.Debugf("alice sent message to bob")
|
||||
|
||||
buf := make([]byte, 65535)
|
||||
n, err := connBobToAlice.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read from channel: %s", err)
|
||||
}
|
||||
log.Debugf("on new message from alice to bob")
|
||||
|
||||
if payload != string(buf[:n]) {
|
||||
t.Fatalf("expected %s, got %s", payload, string(buf[:n]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistration(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
||||
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %s", err)
|
||||
}
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := srv.Listen(srvCfg)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
// wait for server to start
|
||||
if err := waitForServerToStart(errChan); err != nil {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
|
||||
err = clientAlice.Connect()
|
||||
if err != nil {
|
||||
_ = srv.Shutdown(ctx)
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
err = clientAlice.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close conn: %s", err)
|
||||
}
|
||||
err = srv.Shutdown(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistrationTimeout(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
fakeUDPListener, err := net.ListenUDP("udp", &net.UDPAddr{
|
||||
Port: 1234,
|
||||
IP: net.ParseIP("0.0.0.0"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind UDP server: %s", err)
|
||||
}
|
||||
defer func(fakeUDPListener *net.UDPConn) {
|
||||
_ = fakeUDPListener.Close()
|
||||
}(fakeUDPListener)
|
||||
|
||||
fakeTCPListener, err := net.ListenTCP("tcp", &net.TCPAddr{
|
||||
Port: 1234,
|
||||
IP: net.ParseIP("0.0.0.0"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind TCP server: %s", err)
|
||||
}
|
||||
defer func(fakeTCPListener *net.TCPListener) {
|
||||
_ = fakeTCPListener.Close()
|
||||
}(fakeTCPListener)
|
||||
|
||||
clientAlice := NewClient(ctx, "127.0.0.1:1234", hmacTokenStore, "alice")
|
||||
err = clientAlice.Connect()
|
||||
if err == nil {
|
||||
t.Errorf("failed to connect to server: %s", err)
|
||||
}
|
||||
log.Debugf("%s", err)
|
||||
err = clientAlice.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close conn: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEcho(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
idAlice := "alice"
|
||||
idBob := "bob"
|
||||
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
||||
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %s", err)
|
||||
}
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := srv.Listen(srvCfg)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
err := srv.Shutdown(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// wait for servers to start
|
||||
if err := waitForServerToStart(errChan); err != nil {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, idAlice)
|
||||
err = clientAlice.Connect()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
defer func() {
|
||||
err := clientAlice.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close Alice client: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
clientBob := NewClient(ctx, serverURL, hmacTokenStore, idBob)
|
||||
err = clientBob.Connect()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
defer func() {
|
||||
err := clientBob.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close Bob client: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
connAliceToBob, err := clientAlice.OpenConn(idBob)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
connBobToAlice, err := clientBob.OpenConn(idAlice)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
payload := "hello bob, I am alice"
|
||||
_, err = connAliceToBob.Write([]byte(payload))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write to channel: %s", err)
|
||||
}
|
||||
|
||||
buf := make([]byte, 65535)
|
||||
n, err := connBobToAlice.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read from channel: %s", err)
|
||||
}
|
||||
|
||||
_, err = connBobToAlice.Write(buf[:n])
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write to channel: %s", err)
|
||||
}
|
||||
|
||||
n, err = connAliceToBob.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read from channel: %s", err)
|
||||
}
|
||||
|
||||
if payload != string(buf[:n]) {
|
||||
t.Fatalf("expected %s, got %s", payload, string(buf[:n]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBindToUnavailabePeer(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
||||
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %s", err)
|
||||
}
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := srv.Listen(srvCfg)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
log.Infof("closing server")
|
||||
err := srv.Shutdown(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// wait for servers to start
|
||||
if err := waitForServerToStart(errChan); err != nil {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
|
||||
err = clientAlice.Connect()
|
||||
if err != nil {
|
||||
t.Errorf("failed to connect to server: %s", err)
|
||||
}
|
||||
_, err = clientAlice.OpenConn("bob")
|
||||
if err != nil {
|
||||
t.Errorf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
log.Infof("closing client")
|
||||
err = clientAlice.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close client: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBindReconnect(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
||||
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %s", err)
|
||||
}
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := srv.Listen(srvCfg)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
log.Infof("closing server")
|
||||
err := srv.Shutdown(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// wait for servers to start
|
||||
if err := waitForServerToStart(errChan); err != nil {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
|
||||
err = clientAlice.Connect()
|
||||
if err != nil {
|
||||
t.Errorf("failed to connect to server: %s", err)
|
||||
}
|
||||
|
||||
_, err = clientAlice.OpenConn("bob")
|
||||
if err != nil {
|
||||
t.Errorf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
clientBob := NewClient(ctx, serverURL, hmacTokenStore, "bob")
|
||||
err = clientBob.Connect()
|
||||
if err != nil {
|
||||
t.Errorf("failed to connect to server: %s", err)
|
||||
}
|
||||
|
||||
chBob, err := clientBob.OpenConn("alice")
|
||||
if err != nil {
|
||||
t.Errorf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
log.Infof("closing client Alice")
|
||||
err = clientAlice.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close client: %s", err)
|
||||
}
|
||||
|
||||
clientAlice = NewClient(ctx, serverURL, hmacTokenStore, "alice")
|
||||
err = clientAlice.Connect()
|
||||
if err != nil {
|
||||
t.Errorf("failed to connect to server: %s", err)
|
||||
}
|
||||
|
||||
chAlice, err := clientAlice.OpenConn("bob")
|
||||
if err != nil {
|
||||
t.Errorf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
testString := "hello alice, I am bob"
|
||||
_, err = chBob.Write([]byte(testString))
|
||||
if err != nil {
|
||||
t.Errorf("failed to write to channel: %s", err)
|
||||
}
|
||||
|
||||
buf := make([]byte, 65535)
|
||||
n, err := chAlice.Read(buf)
|
||||
if err != nil {
|
||||
t.Errorf("failed to read from channel: %s", err)
|
||||
}
|
||||
|
||||
if testString != string(buf[:n]) {
|
||||
t.Errorf("expected %s, got %s", testString, string(buf[:n]))
|
||||
}
|
||||
|
||||
log.Infof("closing client")
|
||||
err = clientAlice.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close client: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloseConn(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
||||
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %s", err)
|
||||
}
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := srv.Listen(srvCfg)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
log.Infof("closing server")
|
||||
err := srv.Shutdown(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// wait for servers to start
|
||||
if err := waitForServerToStart(errChan); err != nil {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
|
||||
err = clientAlice.Connect()
|
||||
if err != nil {
|
||||
t.Errorf("failed to connect to server: %s", err)
|
||||
}
|
||||
|
||||
conn, err := clientAlice.OpenConn("bob")
|
||||
if err != nil {
|
||||
t.Errorf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
log.Infof("closing connection")
|
||||
err = conn.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close connection: %s", err)
|
||||
}
|
||||
|
||||
_, err = conn.Read(make([]byte, 1))
|
||||
if err == nil {
|
||||
t.Errorf("unexpected reading from closed connection")
|
||||
}
|
||||
|
||||
_, err = conn.Write([]byte("hello"))
|
||||
if err == nil {
|
||||
t.Errorf("unexpected writing from closed connection")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloseRelayConn(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
||||
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %s", err)
|
||||
}
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := srv.Listen(srvCfg)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
err := srv.Shutdown(ctx)
|
||||
if err != nil {
|
||||
log.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// wait for servers to start
|
||||
if err := waitForServerToStart(errChan); err != nil {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
|
||||
err = clientAlice.Connect()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
|
||||
conn, err := clientAlice.OpenConn("bob")
|
||||
if err != nil {
|
||||
t.Errorf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
_ = clientAlice.relayConn.Close()
|
||||
|
||||
_, err = conn.Read(make([]byte, 1))
|
||||
if err == nil {
|
||||
t.Errorf("unexpected reading from closed connection")
|
||||
}
|
||||
|
||||
_, err = clientAlice.OpenConn("bob")
|
||||
if err == nil {
|
||||
t.Errorf("unexpected opening connection to closed server")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloseByServer(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
||||
srv1, err := server.NewServer(otel.Meter(""), serverURL, false, av)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %s", err)
|
||||
}
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
err := srv1.Listen(srvCfg)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
// wait for servers to start
|
||||
if err := waitForServerToStart(errChan); err != nil {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
idAlice := "alice"
|
||||
log.Debugf("connect by alice")
|
||||
relayClient := NewClient(ctx, serverURL, hmacTokenStore, idAlice)
|
||||
err = relayClient.Connect()
|
||||
if err != nil {
|
||||
log.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
|
||||
disconnected := make(chan struct{})
|
||||
relayClient.SetOnDisconnectListener(func() {
|
||||
log.Infof("client disconnected")
|
||||
close(disconnected)
|
||||
})
|
||||
|
||||
err = srv1.Shutdown(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to close server: %s", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-disconnected:
|
||||
case <-time.After(3 * time.Second):
|
||||
log.Fatalf("timeout waiting for client to disconnect")
|
||||
}
|
||||
|
||||
_, err = relayClient.OpenConn("bob")
|
||||
if err == nil {
|
||||
t.Errorf("unexpected opening connection to closed server")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloseByClient(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
||||
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %s", err)
|
||||
}
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := srv.Listen(srvCfg)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
// wait for servers to start
|
||||
if err := waitForServerToStart(errChan); err != nil {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
idAlice := "alice"
|
||||
log.Debugf("connect by alice")
|
||||
relayClient := NewClient(ctx, serverURL, hmacTokenStore, idAlice)
|
||||
err = relayClient.Connect()
|
||||
if err != nil {
|
||||
log.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
|
||||
err = relayClient.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close client: %s", err)
|
||||
}
|
||||
|
||||
_, err = relayClient.OpenConn("bob")
|
||||
if err == nil {
|
||||
t.Errorf("unexpected opening connection to closed server")
|
||||
}
|
||||
|
||||
err = srv.Shutdown(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to close server: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func waitForServerToStart(errChan chan error) error {
|
||||
select {
|
||||
case err := <-errChan:
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case <-time.After(300 * time.Millisecond):
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
76
relay/client/conn.go
Normal file
76
relay/client/conn.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Conn represent a connection to a relayed remote peer.
|
||||
type Conn struct {
|
||||
client *Client
|
||||
dstID []byte
|
||||
dstStringID string
|
||||
messageChan chan Msg
|
||||
instanceURL *RelayAddr
|
||||
}
|
||||
|
||||
// NewConn creates a new connection to a relayed remote peer.
|
||||
// client: the client instance, it used to send messages to the destination peer
|
||||
// dstID: the destination peer ID
|
||||
// dstStringID: the destination peer ID in string format
|
||||
// messageChan: the channel where the messages will be received
|
||||
// instanceURL: the relay instance URL, it used to get the proper server instance address for the remote peer
|
||||
func NewConn(client *Client, dstID []byte, dstStringID string, messageChan chan Msg, instanceURL *RelayAddr) *Conn {
|
||||
c := &Conn{
|
||||
client: client,
|
||||
dstID: dstID,
|
||||
dstStringID: dstStringID,
|
||||
messageChan: messageChan,
|
||||
instanceURL: instanceURL,
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *Conn) Write(p []byte) (n int, err error) {
|
||||
return c.client.writeTo(c, c.dstStringID, c.dstID, p)
|
||||
}
|
||||
|
||||
func (c *Conn) Read(b []byte) (n int, err error) {
|
||||
msg, ok := <-c.messageChan
|
||||
if !ok {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
n = copy(b, msg.Payload)
|
||||
msg.Free()
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (c *Conn) Close() error {
|
||||
return c.client.closeConn(c, c.dstStringID)
|
||||
}
|
||||
|
||||
func (c *Conn) LocalAddr() net.Addr {
|
||||
return c.client.relayConn.LocalAddr()
|
||||
}
|
||||
|
||||
func (c *Conn) RemoteAddr() net.Addr {
|
||||
return c.instanceURL
|
||||
}
|
||||
|
||||
func (c *Conn) SetDeadline(t time.Time) error {
|
||||
//TODO implement me
|
||||
panic("SetDeadline is not implemented")
|
||||
}
|
||||
|
||||
func (c *Conn) SetReadDeadline(t time.Time) error {
|
||||
//TODO implement me
|
||||
panic("SetReadDeadline is not implemented")
|
||||
}
|
||||
|
||||
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
||||
//TODO implement me
|
||||
panic("SetReadDeadline is not implemented")
|
||||
}
|
||||
13
relay/client/dialer/ws/addr.go
Normal file
13
relay/client/dialer/ws/addr.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package ws
|
||||
|
||||
type WebsocketAddr struct {
|
||||
addr string
|
||||
}
|
||||
|
||||
func (a WebsocketAddr) Network() string {
|
||||
return "websocket"
|
||||
}
|
||||
|
||||
func (a WebsocketAddr) String() string {
|
||||
return a.addr
|
||||
}
|
||||
66
relay/client/dialer/ws/conn.go
Normal file
66
relay/client/dialer/ws/conn.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"nhooyr.io/websocket"
|
||||
)
|
||||
|
||||
type Conn struct {
|
||||
ctx context.Context
|
||||
*websocket.Conn
|
||||
remoteAddr WebsocketAddr
|
||||
}
|
||||
|
||||
func NewConn(wsConn *websocket.Conn, serverAddress string) net.Conn {
|
||||
return &Conn{
|
||||
ctx: context.Background(),
|
||||
Conn: wsConn,
|
||||
remoteAddr: WebsocketAddr{serverAddress},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) Read(b []byte) (n int, err error) {
|
||||
t, ioReader, err := c.Conn.Reader(c.ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if t != websocket.MessageBinary {
|
||||
return 0, fmt.Errorf("unexpected message type")
|
||||
}
|
||||
|
||||
return ioReader.Read(b)
|
||||
}
|
||||
|
||||
func (c *Conn) Write(b []byte) (n int, err error) {
|
||||
err = c.Conn.Write(c.ctx, websocket.MessageBinary, b)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
func (c *Conn) RemoteAddr() net.Addr {
|
||||
return c.remoteAddr
|
||||
}
|
||||
|
||||
func (c *Conn) LocalAddr() net.Addr {
|
||||
return WebsocketAddr{addr: "unknown"}
|
||||
}
|
||||
|
||||
func (c *Conn) SetReadDeadline(t time.Time) error {
|
||||
return fmt.Errorf("SetReadDeadline is not implemented")
|
||||
}
|
||||
|
||||
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
||||
return fmt.Errorf("SetWriteDeadline is not implemented")
|
||||
}
|
||||
|
||||
func (c *Conn) SetDeadline(t time.Time) error {
|
||||
return fmt.Errorf("SetDeadline is not implemented")
|
||||
}
|
||||
|
||||
func (c *Conn) Close() error {
|
||||
return c.Conn.CloseNow()
|
||||
}
|
||||
67
relay/client/dialer/ws/ws.go
Normal file
67
relay/client/dialer/ws/ws.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"nhooyr.io/websocket"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/server/listener/ws"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
func Dial(address string) (net.Conn, error) {
|
||||
wsURL, err := prepareURL(address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
opts := &websocket.DialOptions{
|
||||
HTTPClient: httpClientNbDialer(),
|
||||
}
|
||||
|
||||
parsedURL, err := url.Parse(wsURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
parsedURL.Path = ws.URLPath
|
||||
|
||||
wsConn, resp, err := websocket.Dial(context.Background(), parsedURL.String(), opts)
|
||||
if err != nil {
|
||||
log.Errorf("failed to dial to Relay server '%s': %s", wsURL, err)
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
|
||||
conn := NewConn(wsConn, address)
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func prepareURL(address string) (string, error) {
|
||||
if !strings.HasPrefix(address, "rel:") && !strings.HasPrefix(address, "rels:") {
|
||||
return "", fmt.Errorf("unsupported scheme: %s", address)
|
||||
}
|
||||
|
||||
return strings.Replace(address, "rel", "ws", 1), nil
|
||||
}
|
||||
|
||||
func httpClientNbDialer() *http.Client {
|
||||
customDialer := nbnet.NewDialer()
|
||||
|
||||
customTransport := &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return customDialer.DialContext(ctx, network, addr)
|
||||
},
|
||||
}
|
||||
|
||||
return &http.Client{
|
||||
Transport: customTransport,
|
||||
}
|
||||
}
|
||||
12
relay/client/doc.go
Normal file
12
relay/client/doc.go
Normal file
@@ -0,0 +1,12 @@
|
||||
/*
|
||||
Package client contains the implementation of the Relay client.
|
||||
|
||||
The Relay client is responsible for establishing a connection with the Relay server and sending and receiving messages,
|
||||
Keep persistent connection with the Relay server and handle the connection issues.
|
||||
It uses the WebSocket protocol for communication and optionally supports TLS (Transport Layer Security).
|
||||
|
||||
If a peer wants to communicate with a peer on a different relay server, the manager will establish a new connection to
|
||||
the relay server. The connection with these relay servers will be closed if there is no active connection. The peers
|
||||
negotiate the common relay instance via signaling service.
|
||||
*/
|
||||
package client
|
||||
48
relay/client/guard.go
Normal file
48
relay/client/guard.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
reconnectingTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
// Guard manage the reconnection tries to the Relay server in case of disconnection event.
|
||||
type Guard struct {
|
||||
ctx context.Context
|
||||
relayClient *Client
|
||||
}
|
||||
|
||||
// NewGuard creates a new guard for the relay client.
|
||||
func NewGuard(context context.Context, relayClient *Client) *Guard {
|
||||
g := &Guard{
|
||||
ctx: context,
|
||||
relayClient: relayClient,
|
||||
}
|
||||
return g
|
||||
}
|
||||
|
||||
// OnDisconnected is called when the relay client is disconnected from the relay server. It will trigger the reconnection
|
||||
// todo prevent multiple reconnection instances. In the current usage it should not happen, but it is better to prevent
|
||||
func (g *Guard) OnDisconnected() {
|
||||
ticker := time.NewTicker(reconnectingTimeout)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
err := g.relayClient.Connect()
|
||||
if err != nil {
|
||||
log.Errorf("failed to reconnect to relay server: %s", err)
|
||||
continue
|
||||
}
|
||||
return
|
||||
case <-g.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
365
relay/client/manager.go
Normal file
365
relay/client/manager.go
Normal file
@@ -0,0 +1,365 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"reflect"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
relayAuth "github.com/netbirdio/netbird/relay/auth/hmac"
|
||||
)
|
||||
|
||||
var (
|
||||
relayCleanupInterval = 60 * time.Second
|
||||
connectionTimeout = 30 * time.Second
|
||||
maxConcurrentServers = 7
|
||||
|
||||
ErrRelayClientNotConnected = fmt.Errorf("relay client not connected")
|
||||
)
|
||||
|
||||
// RelayTrack hold the relay clients for the foreign relay servers.
|
||||
// With the mutex can ensure we can open new connection in case the relay connection has been established with
|
||||
// the relay server.
|
||||
type RelayTrack struct {
|
||||
sync.RWMutex
|
||||
relayClient *Client
|
||||
}
|
||||
|
||||
func NewRelayTrack() *RelayTrack {
|
||||
return &RelayTrack{}
|
||||
}
|
||||
|
||||
type OnServerCloseListener func()
|
||||
|
||||
// ManagerService is the interface for the relay manager.
|
||||
type ManagerService interface {
|
||||
Serve() error
|
||||
OpenConn(serverAddress, peerKey string) (net.Conn, error)
|
||||
AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error
|
||||
RelayInstanceAddress() (string, error)
|
||||
ServerURLs() []string
|
||||
HasRelayAddress() bool
|
||||
UpdateToken(token *relayAuth.Token) error
|
||||
}
|
||||
|
||||
// Manager is a manager for the relay client instances. It establishes one persistent connection to the given relay URL
|
||||
// and automatically reconnect to them in case disconnection.
|
||||
// The manager also manage temporary relay connection. If a client wants to communicate with a client on a
|
||||
// different relay server, the manager will establish a new connection to the relay server. The connection with these
|
||||
// relay servers will be closed if there is no active connection. Periodically the manager will check if there is any
|
||||
// unused relay connection and close it.
|
||||
type Manager struct {
|
||||
ctx context.Context
|
||||
serverURLs []string
|
||||
peerID string
|
||||
tokenStore *relayAuth.TokenStore
|
||||
|
||||
relayClient *Client
|
||||
reconnectGuard *Guard
|
||||
|
||||
relayClients map[string]*RelayTrack
|
||||
relayClientsMutex sync.RWMutex
|
||||
|
||||
onDisconnectedListeners map[string]*list.List
|
||||
listenerLock sync.Mutex
|
||||
}
|
||||
|
||||
// NewManager creates a new manager instance.
|
||||
// The serverURL address can be empty. In this case, the manager will not serve.
|
||||
func NewManager(ctx context.Context, serverURLs []string, peerID string) *Manager {
|
||||
return &Manager{
|
||||
ctx: ctx,
|
||||
serverURLs: serverURLs,
|
||||
peerID: peerID,
|
||||
tokenStore: &relayAuth.TokenStore{},
|
||||
relayClients: make(map[string]*RelayTrack),
|
||||
onDisconnectedListeners: make(map[string]*list.List),
|
||||
}
|
||||
}
|
||||
|
||||
// Serve starts the manager. It will establish a connection to the relay server and start the relay cleanup loop for
|
||||
// the unused relay connections. The manager will automatically reconnect to the relay server in case of disconnection.
|
||||
func (m *Manager) Serve() error {
|
||||
if m.relayClient != nil {
|
||||
return fmt.Errorf("manager already serving")
|
||||
}
|
||||
log.Debugf("starting relay client manager with %v relay servers", m.serverURLs)
|
||||
|
||||
totalServers := len(m.serverURLs)
|
||||
|
||||
successChan := make(chan *Client, 1)
|
||||
errChan := make(chan error, len(m.serverURLs))
|
||||
|
||||
ctx, cancel := context.WithTimeout(m.ctx, connectionTimeout)
|
||||
defer cancel()
|
||||
|
||||
sem := make(chan struct{}, maxConcurrentServers)
|
||||
|
||||
for _, url := range m.serverURLs {
|
||||
sem <- struct{}{}
|
||||
go func(url string) {
|
||||
defer func() { <-sem }()
|
||||
m.connect(m.ctx, url, successChan, errChan)
|
||||
}(url)
|
||||
}
|
||||
|
||||
var errCount int
|
||||
|
||||
for {
|
||||
select {
|
||||
case client := <-successChan:
|
||||
log.Infof("Successfully connected to relay server: %s", client.connectionURL)
|
||||
|
||||
m.relayClient = client
|
||||
|
||||
m.reconnectGuard = NewGuard(m.ctx, m.relayClient)
|
||||
m.relayClient.SetOnDisconnectListener(func() {
|
||||
m.onServerDisconnected(client.connectionURL)
|
||||
})
|
||||
m.startCleanupLoop()
|
||||
return nil
|
||||
case err := <-errChan:
|
||||
errCount++
|
||||
log.Warnf("Connection attempt failed: %v", err)
|
||||
if errCount == totalServers {
|
||||
return errors.New("failed to connect to any relay server: all attempts failed")
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("failed to connect to any relay server: %w", ctx.Err())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) connect(ctx context.Context, serverURL string, successChan chan<- *Client, errChan chan<- error) {
|
||||
// TODO: abort the connection if another connection was successful
|
||||
relayClient := NewClient(ctx, serverURL, m.tokenStore, m.peerID)
|
||||
if err := relayClient.Connect(); err != nil {
|
||||
errChan <- fmt.Errorf("failed to connect to %s: %w", serverURL, err)
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case successChan <- relayClient:
|
||||
// This client was the first to connect successfully
|
||||
default:
|
||||
if err := relayClient.Close(); err != nil {
|
||||
log.Debugf("failed to close relay client: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be
|
||||
// established via the relay server. If the peer is on a different relay server, the manager will establish a new
|
||||
// connection to the relay server. It returns back with a net.Conn what represent the remote peer connection.
|
||||
func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) {
|
||||
if m.relayClient == nil {
|
||||
return nil, ErrRelayClientNotConnected
|
||||
}
|
||||
|
||||
foreign, err := m.isForeignServer(serverAddress)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var (
|
||||
netConn net.Conn
|
||||
)
|
||||
if !foreign {
|
||||
log.Debugf("open peer connection via permanent server: %s", peerKey)
|
||||
netConn, err = m.relayClient.OpenConn(peerKey)
|
||||
} else {
|
||||
log.Debugf("open peer connection via foreign server: %s", serverAddress)
|
||||
netConn, err = m.openConnVia(serverAddress, peerKey)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return netConn, err
|
||||
}
|
||||
|
||||
// AddCloseListener adds a listener to the given server instance address. The listener will be called if the connection
|
||||
// closed.
|
||||
func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error {
|
||||
foreign, err := m.isForeignServer(serverAddress)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var listenerAddr string
|
||||
if foreign {
|
||||
listenerAddr = serverAddress
|
||||
} else {
|
||||
listenerAddr = m.relayClient.connectionURL
|
||||
}
|
||||
m.addListener(listenerAddr, onClosedListener)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RelayInstanceAddress returns the address of the permanent relay server. It could change if the network connection is
|
||||
// lost. This address will be sent to the target peer to choose the common relay server for the communication.
|
||||
func (m *Manager) RelayInstanceAddress() (string, error) {
|
||||
if m.relayClient == nil {
|
||||
return "", ErrRelayClientNotConnected
|
||||
}
|
||||
return m.relayClient.ServerInstanceURL()
|
||||
}
|
||||
|
||||
// ServerURLs returns the addresses of the relay servers.
|
||||
func (m *Manager) ServerURLs() []string {
|
||||
return m.serverURLs
|
||||
}
|
||||
|
||||
// HasRelayAddress returns true if the manager is serving. With this method can check if the peer can communicate with
|
||||
// Relay service.
|
||||
func (m *Manager) HasRelayAddress() bool {
|
||||
return len(m.serverURLs) > 0
|
||||
}
|
||||
|
||||
// UpdateToken updates the token in the token store.
|
||||
func (m *Manager) UpdateToken(token *relayAuth.Token) error {
|
||||
return m.tokenStore.UpdateToken(token)
|
||||
}
|
||||
|
||||
func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) {
|
||||
// check if already has a connection to the desired relay server
|
||||
m.relayClientsMutex.RLock()
|
||||
rt, ok := m.relayClients[serverAddress]
|
||||
if ok {
|
||||
rt.RLock()
|
||||
m.relayClientsMutex.RUnlock()
|
||||
defer rt.RUnlock()
|
||||
return rt.relayClient.OpenConn(peerKey)
|
||||
}
|
||||
m.relayClientsMutex.RUnlock()
|
||||
|
||||
// if not, establish a new connection but check it again (because changed the lock type) before starting the
|
||||
// connection
|
||||
m.relayClientsMutex.Lock()
|
||||
rt, ok = m.relayClients[serverAddress]
|
||||
if ok {
|
||||
rt.RLock()
|
||||
m.relayClientsMutex.Unlock()
|
||||
defer rt.RUnlock()
|
||||
return rt.relayClient.OpenConn(peerKey)
|
||||
}
|
||||
|
||||
// create a new relay client and store it in the relayClients map
|
||||
rt = NewRelayTrack()
|
||||
rt.Lock()
|
||||
m.relayClients[serverAddress] = rt
|
||||
m.relayClientsMutex.Unlock()
|
||||
|
||||
relayClient := NewClient(m.ctx, serverAddress, m.tokenStore, m.peerID)
|
||||
err := relayClient.Connect()
|
||||
if err != nil {
|
||||
rt.Unlock()
|
||||
m.relayClientsMutex.Lock()
|
||||
delete(m.relayClients, serverAddress)
|
||||
m.relayClientsMutex.Unlock()
|
||||
return nil, err
|
||||
}
|
||||
// if connection closed then delete the relay client from the list
|
||||
relayClient.SetOnDisconnectListener(func() {
|
||||
m.onServerDisconnected(serverAddress)
|
||||
})
|
||||
rt.relayClient = relayClient
|
||||
rt.Unlock()
|
||||
|
||||
conn, err := relayClient.OpenConn(peerKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (m *Manager) onServerDisconnected(serverAddress string) {
|
||||
if serverAddress == m.relayClient.connectionURL {
|
||||
go m.reconnectGuard.OnDisconnected()
|
||||
}
|
||||
|
||||
m.notifyOnDisconnectListeners(serverAddress)
|
||||
}
|
||||
|
||||
func (m *Manager) isForeignServer(address string) (bool, error) {
|
||||
rAddr, err := m.relayClient.ServerInstanceURL()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("relay client not connected")
|
||||
}
|
||||
return rAddr != address, nil
|
||||
}
|
||||
|
||||
func (m *Manager) startCleanupLoop() {
|
||||
if m.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(relayCleanupInterval)
|
||||
go func() {
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
m.cleanUpUnusedRelays()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (m *Manager) cleanUpUnusedRelays() {
|
||||
m.relayClientsMutex.Lock()
|
||||
defer m.relayClientsMutex.Unlock()
|
||||
|
||||
for addr, rt := range m.relayClients {
|
||||
rt.Lock()
|
||||
if rt.relayClient.HasConns() {
|
||||
rt.Unlock()
|
||||
continue
|
||||
}
|
||||
rt.relayClient.SetOnDisconnectListener(nil)
|
||||
go func() {
|
||||
_ = rt.relayClient.Close()
|
||||
}()
|
||||
log.Debugf("clean up unused relay server connection: %s", addr)
|
||||
delete(m.relayClients, addr)
|
||||
rt.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) addListener(serverAddress string, onClosedListener OnServerCloseListener) {
|
||||
m.listenerLock.Lock()
|
||||
defer m.listenerLock.Unlock()
|
||||
l, ok := m.onDisconnectedListeners[serverAddress]
|
||||
if !ok {
|
||||
l = list.New()
|
||||
}
|
||||
for e := l.Front(); e != nil; e = e.Next() {
|
||||
if reflect.ValueOf(e.Value).Pointer() == reflect.ValueOf(onClosedListener).Pointer() {
|
||||
return
|
||||
}
|
||||
}
|
||||
l.PushBack(onClosedListener)
|
||||
m.onDisconnectedListeners[serverAddress] = l
|
||||
}
|
||||
|
||||
func (m *Manager) notifyOnDisconnectListeners(serverAddress string) {
|
||||
m.listenerLock.Lock()
|
||||
defer m.listenerLock.Unlock()
|
||||
|
||||
l, ok := m.onDisconnectedListeners[serverAddress]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
for e := l.Front(); e != nil; e = e.Next() {
|
||||
go e.Value.(OnServerCloseListener)()
|
||||
}
|
||||
delete(m.onDisconnectedListeners, serverAddress)
|
||||
}
|
||||
432
relay/client/manager_test.go
Normal file
432
relay/client/manager_test.go
Normal file
@@ -0,0 +1,432 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"go.opentelemetry.io/otel"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/server"
|
||||
)
|
||||
|
||||
func TestEmptyURL(t *testing.T) {
|
||||
mgr := NewManager(context.Background(), nil, "alice")
|
||||
err := mgr.Serve()
|
||||
if err == nil {
|
||||
t.Errorf("expected error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestForeignConn(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
srvCfg1 := server.ListenerConfig{
|
||||
Address: "localhost:1234",
|
||||
}
|
||||
srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %s", err)
|
||||
}
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := srv1.Listen(srvCfg1)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
err := srv1.Shutdown(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := waitForServerToStart(errChan); err != nil {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
srvCfg2 := server.ListenerConfig{
|
||||
Address: "localhost:2234",
|
||||
}
|
||||
srv2, err := server.NewServer(otel.Meter(""), srvCfg2.Address, false, av)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %s", err)
|
||||
}
|
||||
errChan2 := make(chan error, 1)
|
||||
go func() {
|
||||
err := srv2.Listen(srvCfg2)
|
||||
if err != nil {
|
||||
errChan2 <- err
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
err := srv2.Shutdown(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := waitForServerToStart(errChan2); err != nil {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
idAlice := "alice"
|
||||
log.Debugf("connect by alice")
|
||||
mCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
clientAlice := NewManager(mCtx, toURL(srvCfg1), idAlice)
|
||||
err = clientAlice.Serve()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to serve manager: %s", err)
|
||||
}
|
||||
|
||||
idBob := "bob"
|
||||
log.Debugf("connect by bob")
|
||||
clientBob := NewManager(mCtx, toURL(srvCfg2), idBob)
|
||||
err = clientBob.Serve()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to serve manager: %s", err)
|
||||
}
|
||||
bobsSrvAddr, err := clientBob.RelayInstanceAddress()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get relay address: %s", err)
|
||||
}
|
||||
connAliceToBob, err := clientAlice.OpenConn(bobsSrvAddr, idBob)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
connBobToAlice, err := clientBob.OpenConn(bobsSrvAddr, idAlice)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
payload := "hello bob, I am alice"
|
||||
_, err = connAliceToBob.Write([]byte(payload))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write to channel: %s", err)
|
||||
}
|
||||
|
||||
buf := make([]byte, 65535)
|
||||
n, err := connBobToAlice.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read from channel: %s", err)
|
||||
}
|
||||
|
||||
_, err = connBobToAlice.Write(buf[:n])
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write to channel: %s", err)
|
||||
}
|
||||
|
||||
n, err = connAliceToBob.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read from channel: %s", err)
|
||||
}
|
||||
|
||||
if payload != string(buf[:n]) {
|
||||
t.Fatalf("expected %s, got %s", payload, string(buf[:n]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestForeginConnClose(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
srvCfg1 := server.ListenerConfig{
|
||||
Address: "localhost:1234",
|
||||
}
|
||||
srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %s", err)
|
||||
}
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := srv1.Listen(srvCfg1)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
err := srv1.Shutdown(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := waitForServerToStart(errChan); err != nil {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
srvCfg2 := server.ListenerConfig{
|
||||
Address: "localhost:2234",
|
||||
}
|
||||
srv2, err := server.NewServer(otel.Meter(""), srvCfg2.Address, false, av)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %s", err)
|
||||
}
|
||||
errChan2 := make(chan error, 1)
|
||||
go func() {
|
||||
err := srv2.Listen(srvCfg2)
|
||||
if err != nil {
|
||||
errChan2 <- err
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
err := srv2.Shutdown(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := waitForServerToStart(errChan2); err != nil {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
idAlice := "alice"
|
||||
log.Debugf("connect by alice")
|
||||
mCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
mgr := NewManager(mCtx, toURL(srvCfg1), idAlice)
|
||||
err = mgr.Serve()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to serve manager: %s", err)
|
||||
}
|
||||
conn, err := mgr.OpenConn(toURL(srvCfg2)[0], "anotherpeer")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
err = conn.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to close connection: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestForeginAutoClose(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
relayCleanupInterval = 1 * time.Second
|
||||
srvCfg1 := server.ListenerConfig{
|
||||
Address: "localhost:1234",
|
||||
}
|
||||
srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %s", err)
|
||||
}
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
t.Log("binding server 1.")
|
||||
err := srv1.Listen(srvCfg1)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
t.Logf("closing server 1.")
|
||||
err := srv1.Shutdown(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
t.Logf("server 1. closed")
|
||||
}()
|
||||
|
||||
if err := waitForServerToStart(errChan); err != nil {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
srvCfg2 := server.ListenerConfig{
|
||||
Address: "localhost:2234",
|
||||
}
|
||||
srv2, err := server.NewServer(otel.Meter(""), srvCfg2.Address, false, av)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %s", err)
|
||||
}
|
||||
errChan2 := make(chan error, 1)
|
||||
go func() {
|
||||
t.Log("binding server 2.")
|
||||
err := srv2.Listen(srvCfg2)
|
||||
if err != nil {
|
||||
errChan2 <- err
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
t.Logf("closing server 2.")
|
||||
err := srv2.Shutdown(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
t.Logf("server 2 closed.")
|
||||
}()
|
||||
|
||||
if err := waitForServerToStart(errChan2); err != nil {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
idAlice := "alice"
|
||||
t.Log("connect to server 1.")
|
||||
mCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
mgr := NewManager(mCtx, toURL(srvCfg1), idAlice)
|
||||
err = mgr.Serve()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to serve manager: %s", err)
|
||||
}
|
||||
|
||||
t.Log("open connection to another peer")
|
||||
conn, err := mgr.OpenConn(toURL(srvCfg2)[0], "anotherpeer")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
t.Log("close conn")
|
||||
err = conn.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to close connection: %s", err)
|
||||
}
|
||||
|
||||
t.Logf("waiting for relay cleanup: %s", relayCleanupInterval+1*time.Second)
|
||||
time.Sleep(relayCleanupInterval + 1*time.Second)
|
||||
if len(mgr.relayClients) != 0 {
|
||||
t.Errorf("expected 0, got %d", len(mgr.relayClients))
|
||||
}
|
||||
|
||||
t.Logf("closing manager")
|
||||
}
|
||||
|
||||
func TestAutoReconnect(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
reconnectingTimeout = 2 * time.Second
|
||||
|
||||
srvCfg := server.ListenerConfig{
|
||||
Address: "localhost:1234",
|
||||
}
|
||||
srv, err := server.NewServer(otel.Meter(""), srvCfg.Address, false, av)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %s", err)
|
||||
}
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := srv.Listen(srvCfg)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
err := srv.Shutdown(ctx)
|
||||
if err != nil {
|
||||
log.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := waitForServerToStart(errChan); err != nil {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
mCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
clientAlice := NewManager(mCtx, toURL(srvCfg), "alice")
|
||||
err = clientAlice.Serve()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to serve manager: %s", err)
|
||||
}
|
||||
ra, err := clientAlice.RelayInstanceAddress()
|
||||
if err != nil {
|
||||
t.Errorf("failed to get relay address: %s", err)
|
||||
}
|
||||
conn, err := clientAlice.OpenConn(ra, "bob")
|
||||
if err != nil {
|
||||
t.Errorf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
t.Log("closing client relay connection")
|
||||
// todo figure out moc server
|
||||
_ = clientAlice.relayClient.relayConn.Close()
|
||||
t.Log("start test reading")
|
||||
_, err = conn.Read(make([]byte, 1))
|
||||
if err == nil {
|
||||
t.Errorf("unexpected reading from closed connection")
|
||||
}
|
||||
|
||||
log.Infof("waiting for reconnection")
|
||||
time.Sleep(reconnectingTimeout + 1*time.Second)
|
||||
|
||||
log.Infof("reopent the connection")
|
||||
_, err = clientAlice.OpenConn(ra, "bob")
|
||||
if err != nil {
|
||||
t.Errorf("failed to open channel: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNotifierDoubleAdd(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
srvCfg1 := server.ListenerConfig{
|
||||
Address: "localhost:1234",
|
||||
}
|
||||
srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %s", err)
|
||||
}
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := srv1.Listen(srvCfg1)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
err := srv1.Shutdown(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := waitForServerToStart(errChan); err != nil {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
idAlice := "alice"
|
||||
log.Debugf("connect by alice")
|
||||
mCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
clientAlice := NewManager(mCtx, toURL(srvCfg1), idAlice)
|
||||
err = clientAlice.Serve()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to serve manager: %s", err)
|
||||
}
|
||||
|
||||
conn1, err := clientAlice.OpenConn(clientAlice.ServerURLs()[0], "idBob")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
fnCloseListener := OnServerCloseListener(func() {
|
||||
log.Infof("close listener")
|
||||
})
|
||||
|
||||
err = clientAlice.AddCloseListener(clientAlice.ServerURLs()[0], fnCloseListener)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to add close listener: %s", err)
|
||||
}
|
||||
|
||||
err = clientAlice.AddCloseListener(clientAlice.ServerURLs()[0], fnCloseListener)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to add close listener: %s", err)
|
||||
}
|
||||
|
||||
err = conn1.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close connection: %s", err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func toURL(address server.ListenerConfig) []string {
|
||||
return []string{"rel://" + address.Address}
|
||||
}
|
||||
Reference in New Issue
Block a user