Compare commits

...

10 Commits

Author SHA1 Message Date
Zoltan Papp
23cc698c4d Fix warn message 2025-07-16 21:25:35 +02:00
Zoltan Papp
929388c5f4 Fix sonar issue 2025-07-16 21:19:55 +02:00
Zoltan Papp
fcb06d1a23 Revert error handling behaviour in test to support auto net.Conn creation 2025-07-16 21:16:07 +02:00
Zoltan Papp
04619fe54f Revert error handling behaviour in test 2025-07-16 21:09:49 +02:00
Zoltan Papp
0bad81b99e Revert function signature usage 2025-07-16 20:58:52 +02:00
Zoltan Papp
7be389289c Fix import 2025-07-16 20:56:11 +02:00
Zoltan Papp
0568622d63 Add client part of the relay version changes 2025-07-16 20:45:53 +02:00
Maycon Santos
58185ced16 [misc] add forum post and update sign pipeline (#4155)
use old git-town version
2025-07-16 14:10:28 +02:00
Pedro Maia Costa
e67f44f47c [client] fix test (#4156) 2025-07-16 12:09:38 +02:00
Zoltan Papp
b524f486e2 [client] Fix/nil relayed address (#4153)
Fix nil pointer in Relay conn address

Meanwhile, we create a relayed net.Conn struct instance, it is possible to set the relayedURL to nil.

panic: value method github.com/netbirdio/netbird/relay/client.RelayAddr.String called using nil *RelayAddr pointer

Fix relayed URL variable protection
Protect the channel closing
2025-07-16 00:00:18 +02:00
14 changed files with 220 additions and 103 deletions

View File

@@ -16,6 +16,6 @@ jobs:
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: git-town/action@v1 - uses: git-town/action@v1.2.1
with: with:
skip-single-stacks: true skip-single-stacks: true

View File

@@ -9,7 +9,7 @@ on:
pull_request: pull_request:
env: env:
SIGN_PIPE_VER: "v0.0.20" SIGN_PIPE_VER: "v0.0.21"
GORELEASER_VER: "v2.3.2" GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird" PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH" COPYRIGHT: "NetBird GmbH"
@@ -231,3 +231,17 @@ jobs:
ref: ${{ env.SIGN_PIPE_VER }} ref: ${{ env.SIGN_PIPE_VER }}
token: ${{ secrets.SIGN_GITHUB_TOKEN }} token: ${{ secrets.SIGN_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref }}", "skipRelease": false }' inputs: '{ "tag": "${{ github.ref }}", "skipRelease": false }'
post_on_forum:
runs-on: ubuntu-latest
continue-on-error: true
needs: [trigger_signer]
steps:
- uses: Codixer/discourse-topic-github-release-action@v2.0.1
with:
discourse-api-key: ${{ secrets.DISCOURSE_RELEASES_API_KEY }}
discourse-base-url: https://forum.netbird.io
discourse-author-username: NetBird
discourse-category: 17
discourse-tags:
releases

View File

@@ -1481,6 +1481,10 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()). GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
Return(&types.Settings{}, nil). Return(&types.Settings{}, nil).
AnyTimes() AnyTimes()
settingsMockManager.EXPECT().
GetExtraSettings(gomock.Any(), gomock.Any()).
Return(&types.ExtraSettings{}, nil).
AnyTimes()
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)

View File

@@ -1,13 +0,0 @@
package client
type RelayAddr struct {
addr string
}
func (a RelayAddr) Network() string {
return "relay"
}
func (a RelayAddr) String() string {
return a.addr
}

View File

@@ -136,7 +136,7 @@ type Client struct {
mu sync.Mutex // protect serviceIsRunning and conns mu sync.Mutex // protect serviceIsRunning and conns
readLoopMutex sync.Mutex readLoopMutex sync.Mutex
wgReadLoop sync.WaitGroup wgReadLoop sync.WaitGroup
instanceURL *RelayAddr instanceURL *messages.RelayAddr
muInstanceURL sync.Mutex muInstanceURL sync.Mutex
onDisconnectListener func(string) onDisconnectListener func(string)
@@ -181,13 +181,21 @@ func (c *Client) Connect(ctx context.Context) error {
return nil return nil
} }
if err := c.connect(ctx); err != nil { instanceURL, err := c.connect(ctx)
if err != nil {
return err return err
} }
c.muInstanceURL.Lock()
c.instanceURL = instanceURL
c.muInstanceURL.Unlock()
if c.instanceURL.FeatureVersionCode < messages.VersionSubscription {
c.log.Warnf("server is deprecated, peer state subscription feature will not work")
} else {
c.stateSubscription = NewPeersStateSubscription(c.log, c.relayConn, c.closeConnsByPeerID) c.stateSubscription = NewPeersStateSubscription(c.log, c.relayConn, c.closeConnsByPeerID)
}
c.log = c.log.WithField("relay", c.instanceURL.String()) c.log = c.log.WithField("relay", instanceURL.String())
c.log.Infof("relay connection established") c.log.Infof("relay connection established")
c.serviceIsRunning = true c.serviceIsRunning = true
@@ -229,9 +237,18 @@ func (c *Client) OpenConn(ctx context.Context, dstPeerID string) (net.Conn, erro
c.log.Infof("remote peer is available, prepare the relayed connection: %s", peerID) c.log.Infof("remote peer is available, prepare the relayed connection: %s", peerID)
msgChannel := make(chan Msg, 100) msgChannel := make(chan Msg, 100)
conn := NewConn(c, peerID, msgChannel, c.instanceURL)
c.mu.Lock() c.mu.Lock()
if !c.serviceIsRunning {
c.mu.Unlock()
return nil, fmt.Errorf("relay connection is not established")
}
c.muInstanceURL.Lock()
instanceURL := c.instanceURL
c.muInstanceURL.Unlock()
conn := NewConn(c, peerID, msgChannel, instanceURL)
_, ok = c.conns[peerID] _, ok = c.conns[peerID]
if ok { if ok {
c.mu.Unlock() c.mu.Unlock()
@@ -278,69 +295,71 @@ func (c *Client) Close() error {
return c.close(true) return c.close(true)
} }
func (c *Client) connect(ctx context.Context) error { func (c *Client) connect(ctx context.Context) (*messages.RelayAddr, error) {
rd := dialer.NewRaceDial(c.log, c.connectionURL, quic.Dialer{}, ws.Dialer{}) rd := dialer.NewRaceDial(c.log, c.connectionURL, quic.Dialer{}, ws.Dialer{})
conn, err := rd.Dial() conn, err := rd.Dial()
if err != nil { if err != nil {
return err return nil, err
} }
c.relayConn = conn c.relayConn = conn
if err = c.handShake(ctx); err != nil { instanceURL, err := c.handShake(ctx)
if err != nil {
cErr := conn.Close() cErr := conn.Close()
if cErr != nil { if cErr != nil {
c.log.Errorf("failed to close connection: %s", cErr) c.log.Errorf("failed to close connection: %s", cErr)
} }
return err return nil, err
} }
return nil return instanceURL, nil
} }
func (c *Client) handShake(ctx context.Context) error { func (c *Client) handShake(ctx context.Context) (*messages.RelayAddr, error) {
msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary()) msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary())
if err != nil { if err != nil {
c.log.Errorf("failed to marshal auth message: %s", err) c.log.Errorf("failed to marshal auth message: %s", err)
return err return nil, err
} }
_, err = c.relayConn.Write(msg) _, err = c.relayConn.Write(msg)
if err != nil { if err != nil {
c.log.Errorf("failed to send auth message: %s", err) c.log.Errorf("failed to send auth message: %s", err)
return err return nil, err
} }
buf := make([]byte, messages.MaxHandshakeRespSize) buf := make([]byte, messages.MaxHandshakeRespSize)
n, err := c.readWithTimeout(ctx, buf) n, err := c.readWithTimeout(ctx, buf)
if err != nil { if err != nil {
c.log.Errorf("failed to read auth response: %s", err) c.log.Errorf("failed to read auth response: %s", err)
return err return nil, err
} }
_, err = messages.ValidateVersion(buf[:n]) _, err = messages.ValidateVersion(buf[:n])
if err != nil { if err != nil {
return fmt.Errorf("validate version: %w", err) return nil, fmt.Errorf("validate version: %w", err)
} }
msgType, err := messages.DetermineServerMessageType(buf[:n]) msgType, err := messages.DetermineServerMessageType(buf[:n])
if err != nil { if err != nil {
c.log.Errorf("failed to determine message type: %s", err) c.log.Errorf("failed to determine message type: %s", err)
return err return nil, err
} }
if msgType != messages.MsgTypeAuthResponse { if msgType != messages.MsgTypeAuthResponse {
c.log.Errorf("unexpected message type: %s", msgType) c.log.Errorf("unexpected message type: %s", msgType)
return fmt.Errorf("unexpected message type") return nil, fmt.Errorf("unexpected message type")
} }
addr, err := messages.UnmarshalAuthResponse(buf[:n]) payload, err := messages.UnmarshalAuthResponse(buf[:n])
if err != nil { if err != nil {
return err return nil, err
} }
c.muInstanceURL.Lock() relayAddr, err := messages.UnmarshalRelayAddr(payload)
c.instanceURL = &RelayAddr{addr: addr} if err != nil {
c.muInstanceURL.Unlock() return nil, err
return nil }
return relayAddr, nil
} }
func (c *Client) readLoop(hc *healthcheck.Receiver, relayConn net.Conn, internallyStoppedFlag *internalStopFlag) { func (c *Client) readLoop(hc *healthcheck.Receiver, relayConn net.Conn, internallyStoppedFlag *internalStopFlag) {
@@ -386,10 +405,6 @@ func (c *Client) readLoop(hc *healthcheck.Receiver, relayConn net.Conn, internal
hc.Stop() hc.Stop()
c.muInstanceURL.Lock()
c.instanceURL = nil
c.muInstanceURL.Unlock()
c.stateSubscription.Cleanup() c.stateSubscription.Cleanup()
c.wgReadLoop.Done() c.wgReadLoop.Done()
_ = c.close(false) _ = c.close(false)
@@ -578,8 +593,12 @@ func (c *Client) close(gracefullyExit bool) error {
c.log.Warn("relay connection was already marked as not running") c.log.Warn("relay connection was already marked as not running")
return nil return nil
} }
c.serviceIsRunning = false c.serviceIsRunning = false
c.muInstanceURL.Lock()
c.instanceURL = nil
c.muInstanceURL.Unlock()
c.log.Infof("closing all peer connections") c.log.Infof("closing all peer connections")
c.closeAllConns() c.closeAllConns()
if gracefullyExit { if gracefullyExit {
@@ -636,6 +655,11 @@ func (c *Client) readWithTimeout(ctx context.Context, buf []byte) (int, error) {
} }
func (c *Client) handlePeersOnlineMsg(buf []byte) { func (c *Client) handlePeersOnlineMsg(buf []byte) {
if c.stateSubscription == nil {
c.log.Warnf("message type %d is not supported by the server, peer state subscription feature is not available)", messages.MsgTypePeersOnline)
return
}
peersID, err := messages.UnmarshalPeersOnlineMsg(buf) peersID, err := messages.UnmarshalPeersOnlineMsg(buf)
if err != nil { if err != nil {
c.log.Errorf("failed to unmarshal peers online msg: %s", err) c.log.Errorf("failed to unmarshal peers online msg: %s", err)
@@ -645,6 +669,11 @@ func (c *Client) handlePeersOnlineMsg(buf []byte) {
} }
func (c *Client) handlePeersWentOfflineMsg(buf []byte) { func (c *Client) handlePeersWentOfflineMsg(buf []byte) {
if c.stateSubscription == nil {
c.log.Warnf("message type %d is not supported by the server, peer state subscription feature is not available)", messages.MsgTypePeersWentOffline)
return
}
peersID, err := messages.UnMarshalPeersWentOffline(buf) peersID, err := messages.UnMarshalPeersWentOffline(buf)
if err != nil { if err != nil {
c.log.Errorf("failed to unmarshal peers went offline msg: %s", err) c.log.Errorf("failed to unmarshal peers went offline msg: %s", err)

View File

@@ -314,8 +314,8 @@ func TestBindToUnavailabePeer(t *testing.T) {
t.Errorf("failed to connect to server: %s", err) t.Errorf("failed to connect to server: %s", err)
} }
_, err = clientAlice.OpenConn(ctx, "bob") _, err = clientAlice.OpenConn(ctx, "bob")
if err == nil { if err != nil {
t.Errorf("expected error when binding to unavailable peer, got nil") t.Errorf("failed to open bob: %s", err)
} }
log.Infof("closing client") log.Infof("closing client")
@@ -390,18 +390,13 @@ func TestBindReconnect(t *testing.T) {
chAlice, err := clientAlice.OpenConn(ctx, "bob") chAlice, err := clientAlice.OpenConn(ctx, "bob")
if err != nil { if err != nil {
t.Errorf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }
testString := "hello alice, I am bob" testString := "hello alice, I am bob"
_, err = chBob.Write([]byte(testString)) _, err = chBob.Write([]byte(testString))
if err == nil {
t.Errorf("expected error when writing to channel, got nil")
}
chBob, err = clientBob.OpenConn(ctx, "alice")
if err != nil { if err != nil {
t.Errorf("failed to bind channel: %s", err) t.Errorf("expected error when writing to channel, got nil")
} }
_, err = chBob.Write([]byte(testString)) _, err = chBob.Write([]byte(testString))

View File

@@ -12,7 +12,7 @@ type Conn struct {
client *Client client *Client
dstID messages.PeerID dstID messages.PeerID
messageChan chan Msg messageChan chan Msg
instanceURL *RelayAddr instanceURL *messages.RelayAddr
} }
// NewConn creates a new connection to a relayed remote peer. // NewConn creates a new connection to a relayed remote peer.
@@ -20,7 +20,7 @@ type Conn struct {
// dstID: the destination peer ID // dstID: the destination peer ID
// messageChan: the channel where the messages will be received // messageChan: the channel where the messages will be received
// instanceURL: the relay instance URL, it used to get the proper server instance address for the remote peer // instanceURL: the relay instance URL, it used to get the proper server instance address for the remote peer
func NewConn(client *Client, dstID messages.PeerID, messageChan chan Msg, instanceURL *RelayAddr) *Conn { func NewConn(client *Client, dstID messages.PeerID, messageChan chan Msg, instanceURL *messages.RelayAddr) *Conn {
c := &Conn{ c := &Conn{
client: client, client: client,
dstID: dstID, dstID: dstID,

View File

@@ -229,16 +229,14 @@ func TestForeginAutoClose(t *testing.T) {
errChan := make(chan error, 1) errChan := make(chan error, 1)
go func() { go func() {
t.Log("binding server 1.") t.Log("binding server 1.")
err := srv1.Listen(srvCfg1) if err := srv1.Listen(srvCfg1); err != nil {
if err != nil {
errChan <- err errChan <- err
} }
}() }()
defer func() { defer func() {
t.Logf("closing server 1.") t.Logf("closing server 1.")
err := srv1.Shutdown(ctx) if err := srv1.Shutdown(ctx); err != nil {
if err != nil {
t.Errorf("failed to close server: %s", err) t.Errorf("failed to close server: %s", err)
} }
t.Logf("server 1. closed") t.Logf("server 1. closed")
@@ -287,22 +285,15 @@ func TestForeginAutoClose(t *testing.T) {
} }
t.Log("open connection to another peer") t.Log("open connection to another peer")
conn, err := mgr.OpenConn(ctx, toURL(srvCfg2)[0], "anotherpeer") if _, err = mgr.OpenConn(ctx, toURL(srvCfg2)[0], "anotherpeer"); err != nil {
if err != nil { t.Fatalf("failed to open connection: %s", err)
t.Fatalf("failed to bind channel: %s", err)
}
t.Log("close conn")
err = conn.Close()
if err != nil {
t.Fatalf("failed to close connection: %s", err)
} }
timeout := relayCleanupInterval + keepUnusedServerTime + 1*time.Second timeout := relayCleanupInterval + keepUnusedServerTime + 1*time.Second
t.Logf("waiting for relay cleanup: %s", timeout) t.Logf("waiting for relay cleanup: %s", timeout)
time.Sleep(timeout) time.Sleep(timeout)
if len(mgr.relayClients) != 0 { if len(mgr.relayClients) != 1 {
t.Errorf("expected 0, got %d", len(mgr.relayClients)) t.Errorf("expected 1 relay client, got %d", len(mgr.relayClients))
} }
t.Logf("closing manager") t.Logf("closing manager")

View File

@@ -3,6 +3,8 @@ package client
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"sync"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -28,6 +30,7 @@ type PeersStateSubscription struct {
listenForOfflinePeers map[messages.PeerID]struct{} listenForOfflinePeers map[messages.PeerID]struct{}
waitingPeers map[messages.PeerID]chan struct{} waitingPeers map[messages.PeerID]chan struct{}
mu sync.Mutex // Mutex to protect access to waitingPeers and listenForOfflinePeers
} }
func NewPeersStateSubscription(log *log.Entry, relayConn relayedConnWriter, offlineCallback func(peerIDs []messages.PeerID)) *PeersStateSubscription { func NewPeersStateSubscription(log *log.Entry, relayConn relayedConnWriter, offlineCallback func(peerIDs []messages.PeerID)) *PeersStateSubscription {
@@ -43,24 +46,39 @@ func NewPeersStateSubscription(log *log.Entry, relayConn relayedConnWriter, offl
// OnPeersOnline should be called when a notification is received that certain peers have come online. // OnPeersOnline should be called when a notification is received that certain peers have come online.
// It checks if any of the peers are being waited on and signals their availability. // It checks if any of the peers are being waited on and signals their availability.
func (s *PeersStateSubscription) OnPeersOnline(peersID []messages.PeerID) { func (s *PeersStateSubscription) OnPeersOnline(peersID []messages.PeerID) {
if s == nil {
return
}
s.mu.Lock()
defer s.mu.Unlock()
for _, peerID := range peersID { for _, peerID := range peersID {
waitCh, ok := s.waitingPeers[peerID] waitCh, ok := s.waitingPeers[peerID]
if !ok { if !ok {
// If meanwhile the peer was unsubscribed, we don't need to signal it
continue continue
} }
close(waitCh) waitCh <- struct{}{}
delete(s.waitingPeers, peerID) delete(s.waitingPeers, peerID)
close(waitCh)
} }
} }
func (s *PeersStateSubscription) OnPeersWentOffline(peersID []messages.PeerID) { func (s *PeersStateSubscription) OnPeersWentOffline(peersID []messages.PeerID) {
if s == nil {
return
}
s.mu.Lock()
relevantPeers := make([]messages.PeerID, 0, len(peersID)) relevantPeers := make([]messages.PeerID, 0, len(peersID))
for _, peerID := range peersID { for _, peerID := range peersID {
if _, ok := s.listenForOfflinePeers[peerID]; ok { if _, ok := s.listenForOfflinePeers[peerID]; ok {
relevantPeers = append(relevantPeers, peerID) relevantPeers = append(relevantPeers, peerID)
} }
} }
s.mu.Unlock()
if len(relevantPeers) > 0 { if len(relevantPeers) > 0 {
s.offlineCallback(relevantPeers) s.offlineCallback(relevantPeers)
@@ -68,36 +86,44 @@ func (s *PeersStateSubscription) OnPeersWentOffline(peersID []messages.PeerID) {
} }
// WaitToBeOnlineAndSubscribe waits for a specific peer to come online and subscribes to its state changes. // WaitToBeOnlineAndSubscribe waits for a specific peer to come online and subscribes to its state changes.
// todo: when we unsubscribe while this is running, this will not return with error
func (s *PeersStateSubscription) WaitToBeOnlineAndSubscribe(ctx context.Context, peerID messages.PeerID) error { func (s *PeersStateSubscription) WaitToBeOnlineAndSubscribe(ctx context.Context, peerID messages.PeerID) error {
if s == nil {
return nil
}
// Check if already waiting for this peer // Check if already waiting for this peer
s.mu.Lock()
if _, exists := s.waitingPeers[peerID]; exists { if _, exists := s.waitingPeers[peerID]; exists {
s.mu.Unlock()
return errors.New("already waiting for peer to come online") return errors.New("already waiting for peer to come online")
} }
// Create a channel to wait for the peer to come online // Create a channel to wait for the peer to come online
waitCh := make(chan struct{}) waitCh := make(chan struct{}, 1)
s.waitingPeers[peerID] = waitCh s.waitingPeers[peerID] = waitCh
s.listenForOfflinePeers[peerID] = struct{}{}
s.mu.Unlock()
if err := s.subscribeStateChange([]messages.PeerID{peerID}); err != nil { if err := s.subscribeStateChange(peerID); err != nil {
s.log.Errorf("failed to subscribe to peer state: %s", err) s.log.Errorf("failed to subscribe to peer state: %s", err)
close(waitCh) s.mu.Lock()
delete(s.waitingPeers, peerID)
return err
}
defer func() {
if ch, exists := s.waitingPeers[peerID]; exists && ch == waitCh { if ch, exists := s.waitingPeers[peerID]; exists && ch == waitCh {
close(waitCh) close(waitCh)
delete(s.waitingPeers, peerID) delete(s.waitingPeers, peerID)
delete(s.listenForOfflinePeers, peerID)
}
s.mu.Unlock()
return err
} }
}()
// Wait for peer to come online or context to be cancelled // Wait for peer to come online or context to be cancelled
timeoutCtx, cancel := context.WithTimeout(ctx, OpenConnectionTimeout) timeoutCtx, cancel := context.WithTimeout(ctx, OpenConnectionTimeout)
defer cancel() defer cancel()
select { select {
case <-waitCh: case _, ok := <-waitCh:
if !ok {
return fmt.Errorf("wait for peer to come online has been cancelled")
}
s.log.Debugf("peer %s is now online", peerID) s.log.Debugf("peer %s is now online", peerID)
return nil return nil
case <-timeoutCtx.Done(): case <-timeoutCtx.Done():
@@ -105,13 +131,25 @@ func (s *PeersStateSubscription) WaitToBeOnlineAndSubscribe(ctx context.Context,
if err := s.unsubscribeStateChange([]messages.PeerID{peerID}); err != nil { if err := s.unsubscribeStateChange([]messages.PeerID{peerID}); err != nil {
s.log.Errorf("failed to unsubscribe from peer state: %s", err) s.log.Errorf("failed to unsubscribe from peer state: %s", err)
} }
s.mu.Lock()
if ch, exists := s.waitingPeers[peerID]; exists && ch == waitCh {
close(waitCh)
delete(s.waitingPeers, peerID)
delete(s.listenForOfflinePeers, peerID)
}
s.mu.Unlock()
return timeoutCtx.Err() return timeoutCtx.Err()
} }
} }
func (s *PeersStateSubscription) UnsubscribeStateChange(peerIDs []messages.PeerID) error { func (s *PeersStateSubscription) UnsubscribeStateChange(peerIDs []messages.PeerID) error {
if s == nil {
return nil
}
msgErr := s.unsubscribeStateChange(peerIDs) msgErr := s.unsubscribeStateChange(peerIDs)
s.mu.Lock()
for _, peerID := range peerIDs { for _, peerID := range peerIDs {
if wch, ok := s.waitingPeers[peerID]; ok { if wch, ok := s.waitingPeers[peerID]; ok {
close(wch) close(wch)
@@ -120,11 +158,19 @@ func (s *PeersStateSubscription) UnsubscribeStateChange(peerIDs []messages.PeerI
delete(s.listenForOfflinePeers, peerID) delete(s.listenForOfflinePeers, peerID)
} }
s.mu.Unlock()
return msgErr return msgErr
} }
func (s *PeersStateSubscription) Cleanup() { func (s *PeersStateSubscription) Cleanup() {
if s == nil {
return
}
s.mu.Lock()
defer s.mu.Unlock()
for _, waitCh := range s.waitingPeers { for _, waitCh := range s.waitingPeers {
close(waitCh) close(waitCh)
} }
@@ -133,16 +179,12 @@ func (s *PeersStateSubscription) Cleanup() {
s.listenForOfflinePeers = make(map[messages.PeerID]struct{}) s.listenForOfflinePeers = make(map[messages.PeerID]struct{})
} }
func (s *PeersStateSubscription) subscribeStateChange(peerIDs []messages.PeerID) error { func (s *PeersStateSubscription) subscribeStateChange(peerID messages.PeerID) error {
msgs, err := messages.MarshalSubPeerStateMsg(peerIDs) msgs, err := messages.MarshalSubPeerStateMsg([]messages.PeerID{peerID})
if err != nil { if err != nil {
return err return err
} }
for _, peer := range peerIDs {
s.listenForOfflinePeers[peer] = struct{}{}
}
for _, msg := range msgs { for _, msg := range msgs {
if _, err := s.relayConn.Write(msg); err != nil { if _, err := s.relayConn.Write(msg); err != nil {
return err return err

56
relay/messages/addr.go Normal file
View File

@@ -0,0 +1,56 @@
package messages
import (
"encoding/json"
"fmt"
"strings"
)
type FeatureVersionCode uint16
const (
VersionUnknown FeatureVersionCode = 0
VersionSubscription FeatureVersionCode = 1
)
type RelayAddr struct {
Addr string `json:"ExposedAddr,omitempty"`
FeatureVersionCode FeatureVersionCode `json:"Version,omitempty"`
}
func (a RelayAddr) Network() string {
return "relay"
}
func (a RelayAddr) String() string {
return a.Addr
}
// UnmarshalRelayAddr json encoded RelayAddr data.
func UnmarshalRelayAddr(data []byte) (*RelayAddr, error) {
if len(data) == 0 {
return nil, fmt.Errorf("unmarshalRelayAddr: empty data")
}
var addr RelayAddr
if err := json.Unmarshal(data, &addr); err != nil {
addrString, err := fallbackToOldFormat(data)
if err != nil {
return nil, fmt.Errorf("failed to fallback to old auth message: %v", err)
}
return &RelayAddr{Addr: addrString}, nil
}
if addr.Addr == "" {
return nil, fmt.Errorf("unmarshalRelayAddr: empty address in RelayAddr")
}
return &addr, nil
}
func fallbackToOldFormat(data []byte) (string, error) {
addr := string(data)
if !strings.HasPrefix(addr, "rel://") && !strings.HasPrefix(addr, "rels://") {
return "", fmt.Errorf("invalid address: must start with rel:// or rels://: %s", addr)
}
return addr, nil
}

View File

@@ -11,7 +11,7 @@ const (
MaxHandshakeRespSize = 8192 MaxHandshakeRespSize = 8192
MaxMessageSize = 8820 MaxMessageSize = 8820
CurrentProtocolVersion = 1 CurrentProtocolVersion = 2
MsgTypeUnknown MsgType = 0 MsgTypeUnknown MsgType = 0
// Deprecated: Use MsgTypeAuth instead. // Deprecated: Use MsgTypeAuth instead.
@@ -264,11 +264,11 @@ func MarshalAuthResponse(address string) ([]byte, error) {
} }
// UnmarshalAuthResponse it is a confirmation message to auth success // UnmarshalAuthResponse it is a confirmation message to auth success
func UnmarshalAuthResponse(msg []byte) (string, error) { func UnmarshalAuthResponse(msg []byte) ([]byte, error) {
if len(msg) < sizeOfProtoHeader+1 { if len(msg) < sizeOfProtoHeader+1 {
return "", ErrInvalidMessageLength return nil, ErrInvalidMessageLength
} }
return string(msg[sizeOfProtoHeader:]), nil return msg[sizeOfProtoHeader:], nil
} }
// MarshalCloseMsg creates a close message. // MarshalCloseMsg creates a close message.

View File

@@ -74,7 +74,7 @@ func TestMarshalAuthResponse(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("error: %v", err) t.Fatalf("error: %v", err)
} }
if respAddr != address { if string(respAddr) != address {
t.Errorf("expected %s, got %s", address, respAddr) t.Errorf("expected %s, got %s", address, respAddr)
} }
} }

View File

@@ -8,6 +8,7 @@ import (
) )
type Listener struct { type Listener struct {
ctx context.Context
store *Store store *Store
onlineChan chan messages.PeerID onlineChan chan messages.PeerID
@@ -15,12 +16,11 @@ type Listener struct {
interestedPeersForOffline map[messages.PeerID]struct{} interestedPeersForOffline map[messages.PeerID]struct{}
interestedPeersForOnline map[messages.PeerID]struct{} interestedPeersForOnline map[messages.PeerID]struct{}
mu sync.RWMutex mu sync.RWMutex
listenerCtx context.Context
} }
func newListener(store *Store) *Listener { func newListener(ctx context.Context, store *Store) *Listener {
l := &Listener{ l := &Listener{
ctx: ctx,
store: store, store: store,
onlineChan: make(chan messages.PeerID, 244), //244 is the message size limit in the relay protocol onlineChan: make(chan messages.PeerID, 244), //244 is the message size limit in the relay protocol
@@ -65,11 +65,10 @@ func (l *Listener) RemoveInterestedPeer(peerIDs []messages.PeerID) {
} }
} }
func (l *Listener) listenForEvents(ctx context.Context, onPeersComeOnline, onPeersWentOffline func([]messages.PeerID)) { func (l *Listener) listenForEvents(onPeersComeOnline, onPeersWentOffline func([]messages.PeerID)) {
l.listenerCtx = ctx
for { for {
select { select {
case <-ctx.Done(): case <-l.ctx.Done():
return return
case pID := <-l.onlineChan: case pID := <-l.onlineChan:
peers := make([]messages.PeerID, 0) peers := make([]messages.PeerID, 0)
@@ -102,7 +101,7 @@ func (l *Listener) peerWentOffline(peerID messages.PeerID) {
if _, ok := l.interestedPeersForOffline[peerID]; ok { if _, ok := l.interestedPeersForOffline[peerID]; ok {
select { select {
case l.offlineChan <- peerID: case l.offlineChan <- peerID:
case <-l.listenerCtx.Done(): case <-l.ctx.Done():
} }
} }
} }
@@ -114,7 +113,7 @@ func (l *Listener) peerComeOnline(peerID messages.PeerID) {
if _, ok := l.interestedPeersForOnline[peerID]; ok { if _, ok := l.interestedPeersForOnline[peerID]; ok {
select { select {
case l.onlineChan <- peerID: case l.onlineChan <- peerID:
case <-l.listenerCtx.Done(): case <-l.ctx.Done():
} }
delete(l.interestedPeersForOnline, peerID) delete(l.interestedPeersForOnline, peerID)
} }

View File

@@ -24,8 +24,8 @@ func NewPeerNotifier(store *Store) *PeerNotifier {
func (pn *PeerNotifier) NewListener(onPeersComeOnline, onPeersWentOffline func([]messages.PeerID)) *Listener { func (pn *PeerNotifier) NewListener(onPeersComeOnline, onPeersWentOffline func([]messages.PeerID)) *Listener {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
listener := newListener(pn.store) listener := newListener(ctx, pn.store)
go listener.listenForEvents(ctx, onPeersComeOnline, onPeersWentOffline) go listener.listenForEvents(onPeersComeOnline, onPeersWentOffline)
pn.listenersMutex.Lock() pn.listenersMutex.Lock()
pn.listeners[listener] = cancel pn.listeners[listener] = cancel