mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-20 01:06:45 +00:00
Compare commits
2 Commits
nb-interfa
...
batch-wg-o
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
036a3020fe | ||
|
|
86c16cf651 |
10
.github/workflows/golang-test-linux.yml
vendored
10
.github/workflows/golang-test-linux.yml
vendored
@@ -211,7 +211,11 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
arch: [ '386','amd64' ]
|
include:
|
||||||
|
- arch: "386"
|
||||||
|
raceFlag: ""
|
||||||
|
- arch: "amd64"
|
||||||
|
raceFlag: "-race"
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
@@ -251,9 +255,9 @@ jobs:
|
|||||||
- name: Test
|
- name: Test
|
||||||
run: |
|
run: |
|
||||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||||
go test \
|
go test ${{ matrix.raceFlag }} \
|
||||||
-exec 'sudo' \
|
-exec 'sudo' \
|
||||||
-timeout 10m ./signal/...
|
-timeout 10m ./relay/...
|
||||||
|
|
||||||
test_signal:
|
test_signal:
|
||||||
name: "Signal / Unit"
|
name: "Signal / Unit"
|
||||||
|
|||||||
338
client/iface/batcher.go
Normal file
338
client/iface/batcher.go
Normal file
@@ -0,0 +1,338 @@
|
|||||||
|
package iface
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// DefaultBatchFlushInterval is the default maximum time to wait before flushing batched operations
|
||||||
|
DefaultBatchFlushInterval = 300 * time.Millisecond
|
||||||
|
// DefaultBatchSizeThreshold is the default number of operations to trigger an immediate flush
|
||||||
|
DefaultBatchSizeThreshold = 100
|
||||||
|
|
||||||
|
// AllowedIPOpAdd represents an add operation
|
||||||
|
AllowedIPOpAdd = "add"
|
||||||
|
// AllowedIPOpRemove represents a remove operation
|
||||||
|
AllowedIPOpRemove = "remove"
|
||||||
|
|
||||||
|
EnvDisableWGBatching = "NB_DISABLE_WG_BATCHING"
|
||||||
|
EnvWGBatchFlushIntervalMS = "NB_WG_BATCH_FLUSH_INTERVAL_MS"
|
||||||
|
EnvWGBatchSizeThreshold = "NB_WG_BATCH_SIZE_THRESHOLD"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AllowedIPOperation represents a pending allowed IP operation
|
||||||
|
type AllowedIPOperation struct {
|
||||||
|
PeerKey string
|
||||||
|
Prefix netip.Prefix
|
||||||
|
Operation string
|
||||||
|
}
|
||||||
|
|
||||||
|
// PeerUpdateOperation represents a pending peer update operation
|
||||||
|
type PeerUpdateOperation struct {
|
||||||
|
PeerKey string
|
||||||
|
AllowedIPs []netip.Prefix
|
||||||
|
KeepAlive time.Duration
|
||||||
|
Endpoint *net.UDPAddr
|
||||||
|
PreSharedKey *wgtypes.Key
|
||||||
|
}
|
||||||
|
|
||||||
|
// WGBatcher batches WireGuard configuration updates to reduce syscall overhead
|
||||||
|
type WGBatcher struct {
|
||||||
|
configurer device.WGConfigurer
|
||||||
|
mu sync.Mutex
|
||||||
|
|
||||||
|
allowedIPOps []AllowedIPOperation
|
||||||
|
peerUpdates map[string]*PeerUpdateOperation
|
||||||
|
|
||||||
|
flushTimer *time.Timer
|
||||||
|
flushChan chan struct{}
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
wg sync.WaitGroup
|
||||||
|
|
||||||
|
batchFlushInterval time.Duration
|
||||||
|
batchSizeThreshold int
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewWGBatcher creates a new WireGuard operation batcher
|
||||||
|
func NewWGBatcher(configurer device.WGConfigurer) *WGBatcher {
|
||||||
|
if os.Getenv(EnvDisableWGBatching) != "" {
|
||||||
|
log.Infof("WireGuard allowed IP batching disabled via %s", EnvDisableWGBatching)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
flushInterval := DefaultBatchFlushInterval
|
||||||
|
sizeThreshold := DefaultBatchSizeThreshold
|
||||||
|
|
||||||
|
if intervalMs := os.Getenv(EnvWGBatchFlushIntervalMS); intervalMs != "" {
|
||||||
|
if ms, err := strconv.Atoi(intervalMs); err == nil && ms > 0 {
|
||||||
|
flushInterval = time.Duration(ms) * time.Millisecond
|
||||||
|
log.Infof("WireGuard batch flush interval set to %v", flushInterval)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if threshold := os.Getenv(EnvWGBatchSizeThreshold); threshold != "" {
|
||||||
|
if size, err := strconv.Atoi(threshold); err == nil && size > 0 {
|
||||||
|
sizeThreshold = size
|
||||||
|
log.Infof("WireGuard batch size threshold set to %d", sizeThreshold)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("WireGuard allowed IP batching enabled")
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
b := &WGBatcher{
|
||||||
|
configurer: configurer,
|
||||||
|
peerUpdates: make(map[string]*PeerUpdateOperation),
|
||||||
|
flushChan: make(chan struct{}, 1),
|
||||||
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
|
batchFlushInterval: flushInterval,
|
||||||
|
batchSizeThreshold: sizeThreshold,
|
||||||
|
}
|
||||||
|
|
||||||
|
b.wg.Add(1)
|
||||||
|
go b.flushLoop()
|
||||||
|
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close stops the batcher and flushes any pending operations
|
||||||
|
func (b *WGBatcher) Close() error {
|
||||||
|
b.mu.Lock()
|
||||||
|
if b.flushTimer != nil {
|
||||||
|
b.flushTimer.Stop()
|
||||||
|
}
|
||||||
|
b.mu.Unlock()
|
||||||
|
|
||||||
|
b.cancel()
|
||||||
|
|
||||||
|
if err := b.Flush(); err != nil {
|
||||||
|
log.Errorf("failed to flush pending operations on close: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.wg.Wait()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatePeer batches a peer update operation
|
||||||
|
func (b *WGBatcher) UpdatePeer(peerKey string, allowedIPs []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
||||||
|
b.mu.Lock()
|
||||||
|
defer b.mu.Unlock()
|
||||||
|
|
||||||
|
b.peerUpdates[peerKey] = &PeerUpdateOperation{
|
||||||
|
PeerKey: peerKey,
|
||||||
|
AllowedIPs: allowedIPs,
|
||||||
|
KeepAlive: keepAlive,
|
||||||
|
Endpoint: endpoint,
|
||||||
|
PreSharedKey: preSharedKey,
|
||||||
|
}
|
||||||
|
|
||||||
|
b.scheduleFlush()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddAllowedIP batches an allowed IP addition
|
||||||
|
func (b *WGBatcher) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
|
||||||
|
b.mu.Lock()
|
||||||
|
defer b.mu.Unlock()
|
||||||
|
|
||||||
|
b.allowedIPOps = append(b.allowedIPOps, AllowedIPOperation{
|
||||||
|
PeerKey: peerKey,
|
||||||
|
Prefix: allowedIP,
|
||||||
|
Operation: AllowedIPOpAdd,
|
||||||
|
})
|
||||||
|
|
||||||
|
b.scheduleFlush()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveAllowedIP batches an allowed IP removal
|
||||||
|
func (b *WGBatcher) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error {
|
||||||
|
b.mu.Lock()
|
||||||
|
defer b.mu.Unlock()
|
||||||
|
|
||||||
|
b.allowedIPOps = append(b.allowedIPOps, AllowedIPOperation{
|
||||||
|
PeerKey: peerKey,
|
||||||
|
Prefix: allowedIP,
|
||||||
|
Operation: AllowedIPOpRemove,
|
||||||
|
})
|
||||||
|
|
||||||
|
b.scheduleFlush()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush immediately processes all batched operations
|
||||||
|
func (b *WGBatcher) Flush() error {
|
||||||
|
b.mu.Lock()
|
||||||
|
|
||||||
|
if b.flushTimer != nil {
|
||||||
|
b.flushTimer.Stop()
|
||||||
|
b.flushTimer = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
peerUpdates := b.peerUpdates
|
||||||
|
allowedIPOps := b.allowedIPOps
|
||||||
|
|
||||||
|
b.peerUpdates = make(map[string]*PeerUpdateOperation)
|
||||||
|
b.allowedIPOps = nil
|
||||||
|
|
||||||
|
b.mu.Unlock()
|
||||||
|
|
||||||
|
return b.processBatch(peerUpdates, allowedIPOps)
|
||||||
|
}
|
||||||
|
|
||||||
|
// scheduleFlush schedules a batch flush if not already scheduled
|
||||||
|
func (b *WGBatcher) scheduleFlush() {
|
||||||
|
shouldFlushNow := len(b.allowedIPOps)+len(b.peerUpdates) >= b.batchSizeThreshold
|
||||||
|
|
||||||
|
if shouldFlushNow {
|
||||||
|
select {
|
||||||
|
case b.flushChan <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if b.flushTimer == nil {
|
||||||
|
b.flushTimer = time.AfterFunc(b.batchFlushInterval, func() {
|
||||||
|
select {
|
||||||
|
case b.flushChan <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// flushLoop handles periodic flushing of batched operations
|
||||||
|
func (b *WGBatcher) flushLoop() {
|
||||||
|
defer b.wg.Done()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-b.flushChan:
|
||||||
|
if err := b.Flush(); err != nil {
|
||||||
|
log.Errorf("Error flushing WireGuard operations: %v", err)
|
||||||
|
}
|
||||||
|
case <-b.ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// processBatch processes a batch of operations
|
||||||
|
func (b *WGBatcher) processBatch(peerUpdates map[string]*PeerUpdateOperation, allowedIPOps []AllowedIPOperation) error {
|
||||||
|
if len(peerUpdates) == 0 && len(allowedIPOps) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
defer func() {
|
||||||
|
duration := time.Since(start)
|
||||||
|
log.Debugf("Processed batch of %d peer updates and %d allowed IP operations in %v",
|
||||||
|
len(peerUpdates), len(allowedIPOps), duration)
|
||||||
|
}()
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
|
||||||
|
if err := b.processPeerUpdates(peerUpdates); err != nil {
|
||||||
|
merr = multierror.Append(merr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := b.processAllowedIPOps(allowedIPOps); err != nil {
|
||||||
|
merr = multierror.Append(merr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// processPeerUpdates processes peer update operations
|
||||||
|
func (b *WGBatcher) processPeerUpdates(peerUpdates map[string]*PeerUpdateOperation) error {
|
||||||
|
var merr *multierror.Error
|
||||||
|
for _, update := range peerUpdates {
|
||||||
|
if err := b.configurer.UpdatePeer(
|
||||||
|
update.PeerKey,
|
||||||
|
update.AllowedIPs,
|
||||||
|
update.KeepAlive,
|
||||||
|
update.Endpoint,
|
||||||
|
update.PreSharedKey,
|
||||||
|
); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("update peer %s: %w", update.PeerKey, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// processAllowedIPOps processes allowed IP add/remove operations
|
||||||
|
func (b *WGBatcher) processAllowedIPOps(allowedIPOps []AllowedIPOperation) error {
|
||||||
|
peerChanges := b.groupAllowedIPChanges(allowedIPOps)
|
||||||
|
return b.applyAllowedIPChanges(peerChanges)
|
||||||
|
}
|
||||||
|
|
||||||
|
// groupAllowedIPChanges groups allowed IP operations by peer
|
||||||
|
func (b *WGBatcher) groupAllowedIPChanges(allowedIPOps []AllowedIPOperation) map[string]struct {
|
||||||
|
toAdd []netip.Prefix
|
||||||
|
toRemove []netip.Prefix
|
||||||
|
} {
|
||||||
|
peerChanges := make(map[string]struct {
|
||||||
|
toAdd []netip.Prefix
|
||||||
|
toRemove []netip.Prefix
|
||||||
|
})
|
||||||
|
|
||||||
|
for _, op := range allowedIPOps {
|
||||||
|
changes := peerChanges[op.PeerKey]
|
||||||
|
if op.Operation == AllowedIPOpAdd {
|
||||||
|
changes.toAdd = append(changes.toAdd, op.Prefix)
|
||||||
|
} else {
|
||||||
|
changes.toRemove = append(changes.toRemove, op.Prefix)
|
||||||
|
}
|
||||||
|
peerChanges[op.PeerKey] = changes
|
||||||
|
}
|
||||||
|
|
||||||
|
return peerChanges
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyAllowedIPChanges applies allowed IP changes for each peer
|
||||||
|
func (b *WGBatcher) applyAllowedIPChanges(peerChanges map[string]struct {
|
||||||
|
toAdd []netip.Prefix
|
||||||
|
toRemove []netip.Prefix
|
||||||
|
}) error {
|
||||||
|
var merr *multierror.Error
|
||||||
|
|
||||||
|
for peerKey, changes := range peerChanges {
|
||||||
|
for _, prefix := range changes.toRemove {
|
||||||
|
if err := b.configurer.RemoveAllowedIP(peerKey, prefix); err != nil {
|
||||||
|
if errors.Is(err, configurer.ErrPeerNotFound) || errors.Is(err, configurer.ErrAllowedIPNotFound) {
|
||||||
|
log.Debugf("remove allowed IP %s for peer %s: %v", prefix, peerKey, err)
|
||||||
|
} else {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s for peer %s: %w", prefix, peerKey, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, prefix := range changes.toAdd {
|
||||||
|
if err := b.configurer.AddAllowedIP(peerKey, prefix); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s for peer %s: %w", prefix, peerKey, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
@@ -59,6 +59,7 @@ type WGIface struct {
|
|||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
|
|
||||||
configurer device.WGConfigurer
|
configurer device.WGConfigurer
|
||||||
|
batcher *WGBatcher
|
||||||
filter device.PacketFilter
|
filter device.PacketFilter
|
||||||
wgProxyFactory wgProxyFactory
|
wgProxyFactory wgProxyFactory
|
||||||
}
|
}
|
||||||
@@ -128,6 +129,12 @@ func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAliv
|
|||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("updating interface %s peer %s, endpoint %s, allowedIPs %v", w.tun.DeviceName(), peerKey, endpoint, allowedIps)
|
log.Debugf("updating interface %s peer %s, endpoint %s, allowedIPs %v", w.tun.DeviceName(), peerKey, endpoint, allowedIps)
|
||||||
|
|
||||||
|
if endpoint != nil && w.batcher != nil {
|
||||||
|
if err := w.batcher.Flush(); err != nil {
|
||||||
|
log.Warnf("failed to flush batched operations: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
|
return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -152,6 +159,10 @@ func (w *WGIface) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("Adding allowed IP to interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP)
|
log.Debugf("Adding allowed IP to interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP)
|
||||||
|
|
||||||
|
if w.batcher != nil {
|
||||||
|
return w.batcher.AddAllowedIP(peerKey, allowedIP)
|
||||||
|
}
|
||||||
return w.configurer.AddAllowedIP(peerKey, allowedIP)
|
return w.configurer.AddAllowedIP(peerKey, allowedIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -164,6 +175,10 @@ func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error
|
|||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("Removing allowed IP from interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP)
|
log.Debugf("Removing allowed IP from interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP)
|
||||||
|
|
||||||
|
if w.batcher != nil {
|
||||||
|
return w.batcher.RemoveAllowedIP(peerKey, allowedIP)
|
||||||
|
}
|
||||||
return w.configurer.RemoveAllowedIP(peerKey, allowedIP)
|
return w.configurer.RemoveAllowedIP(peerKey, allowedIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -174,6 +189,12 @@ func (w *WGIface) Close() error {
|
|||||||
|
|
||||||
var result *multierror.Error
|
var result *multierror.Error
|
||||||
|
|
||||||
|
if w.batcher != nil {
|
||||||
|
if err := w.batcher.Close(); err != nil {
|
||||||
|
result = multierror.Append(result, fmt.Errorf("failed to close WireGuard batcher: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if err := w.wgProxyFactory.Free(); err != nil {
|
if err := w.wgProxyFactory.Free(); err != nil {
|
||||||
result = multierror.Append(result, fmt.Errorf("failed to free WireGuard proxy: %w", err))
|
result = multierror.Append(result, fmt.Errorf("failed to free WireGuard proxy: %w", err))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ func (w *WGIface) Create() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
w.configurer = cfgr
|
w.configurer = cfgr
|
||||||
|
w.batcher = NewWGBatcher(cfgr)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
package iface
|
package iface
|
||||||
|
|
||||||
import (
|
import "fmt"
|
||||||
"fmt"
|
|
||||||
)
|
|
||||||
|
|
||||||
// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up.
|
// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up.
|
||||||
// Will reuse an existing one.
|
// Will reuse an existing one.
|
||||||
@@ -15,6 +13,7 @@ func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []s
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
w.configurer = cfgr
|
w.configurer = cfgr
|
||||||
|
w.batcher = NewWGBatcher(cfgr)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ func (w *WGIface) Create() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
w.configurer = cfgr
|
w.configurer = cfgr
|
||||||
|
w.batcher = NewWGBatcher(cfgr)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ type WorkerRelay struct {
|
|||||||
isController bool
|
isController bool
|
||||||
config ConnConfig
|
config ConnConfig
|
||||||
conn *Conn
|
conn *Conn
|
||||||
relayManager relayClient.ManagerService
|
relayManager *relayClient.Manager
|
||||||
|
|
||||||
relayedConn net.Conn
|
relayedConn net.Conn
|
||||||
relayLock sync.Mutex
|
relayLock sync.Mutex
|
||||||
@@ -34,7 +34,7 @@ type WorkerRelay struct {
|
|||||||
wgWatcher *WGWatcher
|
wgWatcher *WGWatcher
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager relayClient.ManagerService, stateDump *stateDump) *WorkerRelay {
|
func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager *relayClient.Manager, stateDump *stateDump) *WorkerRelay {
|
||||||
r := &WorkerRelay{
|
r := &WorkerRelay{
|
||||||
peerCtx: ctx,
|
peerCtx: ctx,
|
||||||
log: log,
|
log: log,
|
||||||
|
|||||||
@@ -292,7 +292,7 @@ func (c *Client) Close() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
|
func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
|
||||||
rd := dialer.NewRaceDial(c.log, c.connectionURL, quic.Dialer{}, ws.Dialer{})
|
rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, quic.Dialer{}, ws.Dialer{})
|
||||||
conn, err := rd.Dial()
|
conn, err := rd.Dial()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -572,10 +572,14 @@ func TestCloseByServer(t *testing.T) {
|
|||||||
idAlice := "alice"
|
idAlice := "alice"
|
||||||
log.Debugf("connect by alice")
|
log.Debugf("connect by alice")
|
||||||
relayClient := NewClient(serverURL, hmacTokenStore, idAlice)
|
relayClient := NewClient(serverURL, hmacTokenStore, idAlice)
|
||||||
err = relayClient.Connect(ctx)
|
if err = relayClient.Connect(ctx); err != nil {
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("failed to connect to server: %s", err)
|
log.Fatalf("failed to connect to server: %s", err)
|
||||||
}
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := relayClient.Close(); err != nil {
|
||||||
|
log.Errorf("failed to close client: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
disconnected := make(chan struct{})
|
disconnected := make(chan struct{})
|
||||||
relayClient.SetOnDisconnectListener(func(_ string) {
|
relayClient.SetOnDisconnectListener(func(_ string) {
|
||||||
@@ -591,7 +595,7 @@ func TestCloseByServer(t *testing.T) {
|
|||||||
select {
|
select {
|
||||||
case <-disconnected:
|
case <-disconnected:
|
||||||
case <-time.After(3 * time.Second):
|
case <-time.After(3 * time.Second):
|
||||||
log.Fatalf("timeout waiting for client to disconnect")
|
log.Errorf("timeout waiting for client to disconnect")
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = relayClient.OpenConn(ctx, "bob")
|
_, err = relayClient.OpenConn(ctx, "bob")
|
||||||
|
|||||||
@@ -9,8 +9,8 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
const (
|
||||||
connectionTimeout = 30 * time.Second
|
DefaultConnectionTimeout = 30 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
type DialeFn interface {
|
type DialeFn interface {
|
||||||
@@ -25,16 +25,18 @@ type dialResult struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type RaceDial struct {
|
type RaceDial struct {
|
||||||
log *log.Entry
|
log *log.Entry
|
||||||
serverURL string
|
serverURL string
|
||||||
dialerFns []DialeFn
|
dialerFns []DialeFn
|
||||||
|
connectionTimeout time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewRaceDial(log *log.Entry, serverURL string, dialerFns ...DialeFn) *RaceDial {
|
func NewRaceDial(log *log.Entry, connectionTimeout time.Duration, serverURL string, dialerFns ...DialeFn) *RaceDial {
|
||||||
return &RaceDial{
|
return &RaceDial{
|
||||||
log: log,
|
log: log,
|
||||||
serverURL: serverURL,
|
serverURL: serverURL,
|
||||||
dialerFns: dialerFns,
|
dialerFns: dialerFns,
|
||||||
|
connectionTimeout: connectionTimeout,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -58,7 +60,7 @@ func (r *RaceDial) Dial() (net.Conn, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *RaceDial) dial(dfn DialeFn, abortCtx context.Context, connChan chan dialResult) {
|
func (r *RaceDial) dial(dfn DialeFn, abortCtx context.Context, connChan chan dialResult) {
|
||||||
ctx, cancel := context.WithTimeout(abortCtx, connectionTimeout)
|
ctx, cancel := context.WithTimeout(abortCtx, r.connectionTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
r.log.Infof("dialing Relay server via %s", dfn.Protocol())
|
r.log.Infof("dialing Relay server via %s", dfn.Protocol())
|
||||||
|
|||||||
@@ -77,7 +77,7 @@ func TestRaceDialEmptyDialers(t *testing.T) {
|
|||||||
logger := logrus.NewEntry(logrus.New())
|
logger := logrus.NewEntry(logrus.New())
|
||||||
serverURL := "test.server.com"
|
serverURL := "test.server.com"
|
||||||
|
|
||||||
rd := NewRaceDial(logger, serverURL)
|
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL)
|
||||||
conn, err := rd.Dial()
|
conn, err := rd.Dial()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("Expected an error with empty dialers, got nil")
|
t.Errorf("Expected an error with empty dialers, got nil")
|
||||||
@@ -103,7 +103,7 @@ func TestRaceDialSingleSuccessfulDialer(t *testing.T) {
|
|||||||
protocolStr: proto,
|
protocolStr: proto,
|
||||||
}
|
}
|
||||||
|
|
||||||
rd := NewRaceDial(logger, serverURL, mockDialer)
|
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer)
|
||||||
conn, err := rd.Dial()
|
conn, err := rd.Dial()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Expected no error, got %v", err)
|
t.Errorf("Expected no error, got %v", err)
|
||||||
@@ -136,7 +136,7 @@ func TestRaceDialMultipleDialersWithOneSuccess(t *testing.T) {
|
|||||||
protocolStr: "proto2",
|
protocolStr: "proto2",
|
||||||
}
|
}
|
||||||
|
|
||||||
rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2)
|
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2)
|
||||||
conn, err := rd.Dial()
|
conn, err := rd.Dial()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Expected no error, got %v", err)
|
t.Errorf("Expected no error, got %v", err)
|
||||||
@@ -144,13 +144,13 @@ func TestRaceDialMultipleDialersWithOneSuccess(t *testing.T) {
|
|||||||
if conn.RemoteAddr().Network() != proto2 {
|
if conn.RemoteAddr().Network() != proto2 {
|
||||||
t.Errorf("Expected connection with protocol %s, got %s", proto2, conn.RemoteAddr().Network())
|
t.Errorf("Expected connection with protocol %s, got %s", proto2, conn.RemoteAddr().Network())
|
||||||
}
|
}
|
||||||
|
_ = conn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRaceDialTimeout(t *testing.T) {
|
func TestRaceDialTimeout(t *testing.T) {
|
||||||
logger := logrus.NewEntry(logrus.New())
|
logger := logrus.NewEntry(logrus.New())
|
||||||
serverURL := "test.server.com"
|
serverURL := "test.server.com"
|
||||||
|
|
||||||
connectionTimeout = 3 * time.Second
|
|
||||||
mockDialer := &MockDialer{
|
mockDialer := &MockDialer{
|
||||||
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
||||||
<-ctx.Done()
|
<-ctx.Done()
|
||||||
@@ -159,7 +159,7 @@ func TestRaceDialTimeout(t *testing.T) {
|
|||||||
protocolStr: "proto1",
|
protocolStr: "proto1",
|
||||||
}
|
}
|
||||||
|
|
||||||
rd := NewRaceDial(logger, serverURL, mockDialer)
|
rd := NewRaceDial(logger, 3*time.Second, serverURL, mockDialer)
|
||||||
conn, err := rd.Dial()
|
conn, err := rd.Dial()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("Expected an error, got nil")
|
t.Errorf("Expected an error, got nil")
|
||||||
@@ -187,7 +187,7 @@ func TestRaceDialAllDialersFail(t *testing.T) {
|
|||||||
protocolStr: "protocol2",
|
protocolStr: "protocol2",
|
||||||
}
|
}
|
||||||
|
|
||||||
rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2)
|
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2)
|
||||||
conn, err := rd.Dial()
|
conn, err := rd.Dial()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("Expected an error, got nil")
|
t.Errorf("Expected an error, got nil")
|
||||||
@@ -229,7 +229,7 @@ func TestRaceDialFirstSuccessfulDialerWins(t *testing.T) {
|
|||||||
protocolStr: proto2,
|
protocolStr: proto2,
|
||||||
}
|
}
|
||||||
|
|
||||||
rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2)
|
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2)
|
||||||
conn, err := rd.Dial()
|
conn, err := rd.Dial()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Expected no error, got %v", err)
|
t.Errorf("Expected no error, got %v", err)
|
||||||
|
|||||||
@@ -8,7 +8,8 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
const (
|
||||||
|
// TODO: make it configurable, the manager should validate all configurable parameters
|
||||||
reconnectingTimeout = 60 * time.Second
|
reconnectingTimeout = 60 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -39,17 +39,6 @@ func NewRelayTrack() *RelayTrack {
|
|||||||
|
|
||||||
type OnServerCloseListener func()
|
type OnServerCloseListener func()
|
||||||
|
|
||||||
// ManagerService is the interface for the relay manager.
|
|
||||||
type ManagerService interface {
|
|
||||||
Serve() error
|
|
||||||
OpenConn(ctx context.Context, serverAddress, peerKey string) (net.Conn, error)
|
|
||||||
AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error
|
|
||||||
RelayInstanceAddress() (string, error)
|
|
||||||
ServerURLs() []string
|
|
||||||
HasRelayAddress() bool
|
|
||||||
UpdateToken(token *relayAuth.Token) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// Manager is a manager for the relay client instances. It establishes one persistent connection to the given relay URL
|
// Manager is a manager for the relay client instances. It establishes one persistent connection to the given relay URL
|
||||||
// and automatically reconnect to them in case disconnection.
|
// and automatically reconnect to them in case disconnection.
|
||||||
// The manager also manage temporary relay connection. If a client wants to communicate with a client on a
|
// The manager also manage temporary relay connection. If a client wants to communicate with a client on a
|
||||||
|
|||||||
@@ -13,7 +13,9 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestEmptyURL(t *testing.T) {
|
func TestEmptyURL(t *testing.T) {
|
||||||
mgr := NewManager(context.Background(), nil, "alice")
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
mgr := NewManager(ctx, nil, "alice")
|
||||||
err := mgr.Serve()
|
err := mgr.Serve()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("expected error, got nil")
|
t.Errorf("expected error, got nil")
|
||||||
@@ -216,9 +218,11 @@ func TestForeginConnClose(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestForeginAutoClose(t *testing.T) {
|
func TestForeignAutoClose(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
relayCleanupInterval = 1 * time.Second
|
relayCleanupInterval = 1 * time.Second
|
||||||
|
keepUnusedServerTime = 2 * time.Second
|
||||||
|
|
||||||
srvCfg1 := server.ListenerConfig{
|
srvCfg1 := server.ListenerConfig{
|
||||||
Address: "localhost:1234",
|
Address: "localhost:1234",
|
||||||
}
|
}
|
||||||
@@ -284,16 +288,35 @@ func TestForeginAutoClose(t *testing.T) {
|
|||||||
t.Fatalf("failed to serve manager: %s", err)
|
t.Fatalf("failed to serve manager: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set up a disconnect listener to track when foreign server disconnects
|
||||||
|
foreignServerURL := toURL(srvCfg2)[0]
|
||||||
|
disconnected := make(chan struct{})
|
||||||
|
onDisconnect := func() {
|
||||||
|
select {
|
||||||
|
case disconnected <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
t.Log("open connection to another peer")
|
t.Log("open connection to another peer")
|
||||||
if _, err = mgr.OpenConn(ctx, toURL(srvCfg2)[0], "anotherpeer"); err == nil {
|
if _, err = mgr.OpenConn(ctx, foreignServerURL, "anotherpeer"); err == nil {
|
||||||
t.Fatalf("should have failed to open connection to another peer")
|
t.Fatalf("should have failed to open connection to another peer")
|
||||||
}
|
}
|
||||||
|
|
||||||
timeout := relayCleanupInterval + keepUnusedServerTime + 1*time.Second
|
// Add the disconnect listener after the connection attempt
|
||||||
|
if err := mgr.AddCloseListener(foreignServerURL, onDisconnect); err != nil {
|
||||||
|
t.Logf("failed to add close listener (expected if connection failed): %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for cleanup to happen
|
||||||
|
timeout := relayCleanupInterval + keepUnusedServerTime + 2*time.Second
|
||||||
t.Logf("waiting for relay cleanup: %s", timeout)
|
t.Logf("waiting for relay cleanup: %s", timeout)
|
||||||
time.Sleep(timeout)
|
|
||||||
if len(mgr.relayClients) != 0 {
|
select {
|
||||||
t.Errorf("expected 0, got %d", len(mgr.relayClients))
|
case <-disconnected:
|
||||||
|
t.Log("foreign relay connection cleaned up successfully")
|
||||||
|
case <-time.After(timeout):
|
||||||
|
t.Log("timeout waiting for cleanup - this might be expected if connection never established")
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Logf("closing manager")
|
t.Logf("closing manager")
|
||||||
@@ -301,7 +324,6 @@ func TestForeginAutoClose(t *testing.T) {
|
|||||||
|
|
||||||
func TestAutoReconnect(t *testing.T) {
|
func TestAutoReconnect(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
reconnectingTimeout = 2 * time.Second
|
|
||||||
|
|
||||||
srvCfg := server.ListenerConfig{
|
srvCfg := server.ListenerConfig{
|
||||||
Address: "localhost:1234",
|
Address: "localhost:1234",
|
||||||
@@ -312,8 +334,7 @@ func TestAutoReconnect(t *testing.T) {
|
|||||||
}
|
}
|
||||||
errChan := make(chan error, 1)
|
errChan := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
err := srv.Listen(srvCfg)
|
if err := srv.Listen(srvCfg); err != nil {
|
||||||
if err != nil {
|
|
||||||
errChan <- err
|
errChan <- err
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|||||||
@@ -4,38 +4,76 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Mutex to protect global variable access in tests
|
||||||
|
var testMutex sync.Mutex
|
||||||
|
|
||||||
func TestNewReceiver(t *testing.T) {
|
func TestNewReceiver(t *testing.T) {
|
||||||
|
testMutex.Lock()
|
||||||
|
originalTimeout := heartbeatTimeout
|
||||||
heartbeatTimeout = 5 * time.Second
|
heartbeatTimeout = 5 * time.Second
|
||||||
|
testMutex.Unlock()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
testMutex.Lock()
|
||||||
|
heartbeatTimeout = originalTimeout
|
||||||
|
testMutex.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
r := NewReceiver(log.WithContext(context.Background()))
|
r := NewReceiver(log.WithContext(context.Background()))
|
||||||
|
defer r.Stop()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-r.OnTimeout:
|
case <-r.OnTimeout:
|
||||||
t.Error("unexpected timeout")
|
t.Error("unexpected timeout")
|
||||||
case <-time.After(1 * time.Second):
|
case <-time.After(1 * time.Second):
|
||||||
|
// Test passes if no timeout received
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewReceiverNotReceive(t *testing.T) {
|
func TestNewReceiverNotReceive(t *testing.T) {
|
||||||
|
testMutex.Lock()
|
||||||
|
originalTimeout := heartbeatTimeout
|
||||||
heartbeatTimeout = 1 * time.Second
|
heartbeatTimeout = 1 * time.Second
|
||||||
|
testMutex.Unlock()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
testMutex.Lock()
|
||||||
|
heartbeatTimeout = originalTimeout
|
||||||
|
testMutex.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
r := NewReceiver(log.WithContext(context.Background()))
|
r := NewReceiver(log.WithContext(context.Background()))
|
||||||
|
defer r.Stop()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-r.OnTimeout:
|
case <-r.OnTimeout:
|
||||||
|
// Test passes if timeout is received
|
||||||
case <-time.After(2 * time.Second):
|
case <-time.After(2 * time.Second):
|
||||||
t.Error("timeout not received")
|
t.Error("timeout not received")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewReceiverAck(t *testing.T) {
|
func TestNewReceiverAck(t *testing.T) {
|
||||||
|
testMutex.Lock()
|
||||||
|
originalTimeout := heartbeatTimeout
|
||||||
heartbeatTimeout = 2 * time.Second
|
heartbeatTimeout = 2 * time.Second
|
||||||
|
testMutex.Unlock()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
testMutex.Lock()
|
||||||
|
heartbeatTimeout = originalTimeout
|
||||||
|
testMutex.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
r := NewReceiver(log.WithContext(context.Background()))
|
r := NewReceiver(log.WithContext(context.Background()))
|
||||||
|
defer r.Stop()
|
||||||
|
|
||||||
r.Heartbeat()
|
r.Heartbeat()
|
||||||
|
|
||||||
@@ -59,13 +97,18 @@ func TestReceiverHealthCheckAttemptThreshold(t *testing.T) {
|
|||||||
|
|
||||||
for _, tc := range testsCases {
|
for _, tc := range testsCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
testMutex.Lock()
|
||||||
originalInterval := healthCheckInterval
|
originalInterval := healthCheckInterval
|
||||||
originalTimeout := heartbeatTimeout
|
originalTimeout := heartbeatTimeout
|
||||||
healthCheckInterval = 1 * time.Second
|
healthCheckInterval = 1 * time.Second
|
||||||
heartbeatTimeout = healthCheckInterval + 500*time.Millisecond
|
heartbeatTimeout = healthCheckInterval + 500*time.Millisecond
|
||||||
|
testMutex.Unlock()
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
|
testMutex.Lock()
|
||||||
healthCheckInterval = originalInterval
|
healthCheckInterval = originalInterval
|
||||||
heartbeatTimeout = originalTimeout
|
heartbeatTimeout = originalTimeout
|
||||||
|
testMutex.Unlock()
|
||||||
}()
|
}()
|
||||||
//nolint:tenv
|
//nolint:tenv
|
||||||
os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold))
|
os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold))
|
||||||
|
|||||||
@@ -135,7 +135,11 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
sender := NewSender(log.WithField("test_name", tc.name))
|
sender := NewSender(log.WithField("test_name", tc.name))
|
||||||
go sender.StartHealthCheck(ctx)
|
senderExit := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
sender.StartHealthCheck(ctx)
|
||||||
|
close(senderExit)
|
||||||
|
}()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
responded := false
|
responded := false
|
||||||
@@ -169,6 +173,11 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) {
|
|||||||
t.Fatalf("should have timed out before %s", testTimeout)
|
t.Fatalf("should have timed out before %s", testTimeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-senderExit:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatalf("sender did not exit in time")
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -20,12 +20,12 @@ type Metrics struct {
|
|||||||
TransferBytesRecv metric.Int64Counter
|
TransferBytesRecv metric.Int64Counter
|
||||||
AuthenticationTime metric.Float64Histogram
|
AuthenticationTime metric.Float64Histogram
|
||||||
PeerStoreTime metric.Float64Histogram
|
PeerStoreTime metric.Float64Histogram
|
||||||
|
peerReconnections metric.Int64Counter
|
||||||
peers metric.Int64UpDownCounter
|
peers metric.Int64UpDownCounter
|
||||||
peerActivityChan chan string
|
peerActivityChan chan string
|
||||||
peerLastActive map[string]time.Time
|
peerLastActive map[string]time.Time
|
||||||
mutexActivity sync.Mutex
|
mutexActivity sync.Mutex
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewMetrics(ctx context.Context, meter metric.Meter) (*Metrics, error) {
|
func NewMetrics(ctx context.Context, meter metric.Meter) (*Metrics, error) {
|
||||||
@@ -80,6 +80,13 @@ func NewMetrics(ctx context.Context, meter metric.Meter) (*Metrics, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
peerReconnections, err := meter.Int64Counter("relay_peer_reconnections_total",
|
||||||
|
metric.WithDescription("Total number of times peers have reconnected and closed old connections"),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
m := &Metrics{
|
m := &Metrics{
|
||||||
Meter: meter,
|
Meter: meter,
|
||||||
TransferBytesSent: bytesSent,
|
TransferBytesSent: bytesSent,
|
||||||
@@ -87,6 +94,7 @@ func NewMetrics(ctx context.Context, meter metric.Meter) (*Metrics, error) {
|
|||||||
AuthenticationTime: authTime,
|
AuthenticationTime: authTime,
|
||||||
PeerStoreTime: peerStoreTime,
|
PeerStoreTime: peerStoreTime,
|
||||||
peers: peers,
|
peers: peers,
|
||||||
|
peerReconnections: peerReconnections,
|
||||||
|
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
peerActivityChan: make(chan string, 10),
|
peerActivityChan: make(chan string, 10),
|
||||||
@@ -138,6 +146,10 @@ func (m *Metrics) PeerDisconnected(id string) {
|
|||||||
delete(m.peerLastActive, id)
|
delete(m.peerLastActive, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Metrics) RecordPeerReconnection() {
|
||||||
|
m.peerReconnections.Add(m.ctx, 1)
|
||||||
|
}
|
||||||
|
|
||||||
// PeerActivity increases the active connections
|
// PeerActivity increases the active connections
|
||||||
func (m *Metrics) PeerActivity(peerID string) {
|
func (m *Metrics) PeerActivity(peerID string) {
|
||||||
select {
|
select {
|
||||||
|
|||||||
@@ -18,12 +18,9 @@ type Listener struct {
|
|||||||
TLSConfig *tls.Config
|
TLSConfig *tls.Config
|
||||||
|
|
||||||
listener *quic.Listener
|
listener *quic.Listener
|
||||||
acceptFn func(conn net.Conn)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
|
func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
|
||||||
l.acceptFn = acceptFn
|
|
||||||
|
|
||||||
quicCfg := &quic.Config{
|
quicCfg := &quic.Config{
|
||||||
EnableDatagrams: true,
|
EnableDatagrams: true,
|
||||||
InitialPacketSize: 1452,
|
InitialPacketSize: 1452,
|
||||||
@@ -49,7 +46,7 @@ func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
|
|||||||
|
|
||||||
log.Infof("QUIC client connected from: %s", session.RemoteAddr())
|
log.Infof("QUIC client connected from: %s", session.RemoteAddr())
|
||||||
conn := NewConn(session)
|
conn := NewConn(session)
|
||||||
l.acceptFn(conn)
|
acceptFn(conn)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -32,6 +32,9 @@ type Peer struct {
|
|||||||
notifier *store.PeerNotifier
|
notifier *store.PeerNotifier
|
||||||
|
|
||||||
peersListener *store.Listener
|
peersListener *store.Listener
|
||||||
|
|
||||||
|
// between the online peer collection step and the notification sending should not be sent offline notifications from another thread
|
||||||
|
notificationMutex sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewPeer creates a new Peer instance and prepare custom logging
|
// NewPeer creates a new Peer instance and prepare custom logging
|
||||||
@@ -241,10 +244,16 @@ func (p *Peer) handleSubscribePeerState(msg []byte) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
p.log.Debugf("received subscription message for %d peers", len(peerIDs))
|
p.log.Debugf("received subscription message for %d peers", len(peerIDs))
|
||||||
onlinePeers := p.peersListener.AddInterestedPeers(peerIDs)
|
|
||||||
|
// collect online peers to response back to the caller
|
||||||
|
p.notificationMutex.Lock()
|
||||||
|
defer p.notificationMutex.Unlock()
|
||||||
|
|
||||||
|
onlinePeers := p.store.GetOnlinePeersAndRegisterInterest(peerIDs, p.peersListener)
|
||||||
if len(onlinePeers) == 0 {
|
if len(onlinePeers) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
p.log.Debugf("response with %d online peers", len(onlinePeers))
|
p.log.Debugf("response with %d online peers", len(onlinePeers))
|
||||||
p.sendPeersOnline(onlinePeers)
|
p.sendPeersOnline(onlinePeers)
|
||||||
}
|
}
|
||||||
@@ -274,6 +283,9 @@ func (p *Peer) sendPeersOnline(peers []messages.PeerID) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *Peer) sendPeersWentOffline(peers []messages.PeerID) {
|
func (p *Peer) sendPeersWentOffline(peers []messages.PeerID) {
|
||||||
|
p.notificationMutex.Lock()
|
||||||
|
defer p.notificationMutex.Unlock()
|
||||||
|
|
||||||
msgs, err := messages.MarshalPeersWentOffline(peers)
|
msgs, err := messages.MarshalPeersWentOffline(peers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
p.log.Errorf("failed to marshal peer location message: %s", err)
|
p.log.Errorf("failed to marshal peer location message: %s", err)
|
||||||
|
|||||||
@@ -86,14 +86,13 @@ func NewRelay(config Config) (*Relay, error) {
|
|||||||
return nil, fmt.Errorf("creating app metrics: %v", err)
|
return nil, fmt.Errorf("creating app metrics: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
peerStore := store.NewStore()
|
|
||||||
r := &Relay{
|
r := &Relay{
|
||||||
metrics: m,
|
metrics: m,
|
||||||
metricsCancel: metricsCancel,
|
metricsCancel: metricsCancel,
|
||||||
validator: config.AuthValidator,
|
validator: config.AuthValidator,
|
||||||
instanceURL: config.instanceURL,
|
instanceURL: config.instanceURL,
|
||||||
store: peerStore,
|
store: store.NewStore(),
|
||||||
notifier: store.NewPeerNotifier(peerStore),
|
notifier: store.NewPeerNotifier(),
|
||||||
}
|
}
|
||||||
|
|
||||||
r.preparedMsg, err = newPreparedMsg(r.instanceURL)
|
r.preparedMsg, err = newPreparedMsg(r.instanceURL)
|
||||||
@@ -131,15 +130,18 @@ func (r *Relay) Accept(conn net.Conn) {
|
|||||||
peer := NewPeer(r.metrics, *peerID, conn, r.store, r.notifier)
|
peer := NewPeer(r.metrics, *peerID, conn, r.store, r.notifier)
|
||||||
peer.log.Infof("peer connected from: %s", conn.RemoteAddr())
|
peer.log.Infof("peer connected from: %s", conn.RemoteAddr())
|
||||||
storeTime := time.Now()
|
storeTime := time.Now()
|
||||||
r.store.AddPeer(peer)
|
if isReconnection := r.store.AddPeer(peer); isReconnection {
|
||||||
|
r.metrics.RecordPeerReconnection()
|
||||||
|
}
|
||||||
r.notifier.PeerCameOnline(peer.ID())
|
r.notifier.PeerCameOnline(peer.ID())
|
||||||
|
|
||||||
r.metrics.RecordPeerStoreTime(time.Since(storeTime))
|
r.metrics.RecordPeerStoreTime(time.Since(storeTime))
|
||||||
r.metrics.PeerConnected(peer.String())
|
r.metrics.PeerConnected(peer.String())
|
||||||
go func() {
|
go func() {
|
||||||
peer.Work()
|
peer.Work()
|
||||||
r.notifier.PeerWentOffline(peer.ID())
|
if deleted := r.store.DeletePeer(peer); deleted {
|
||||||
r.store.DeletePeer(peer)
|
r.notifier.PeerWentOffline(peer.ID())
|
||||||
|
}
|
||||||
peer.log.Debugf("relay connection closed")
|
peer.log.Debugf("relay connection closed")
|
||||||
r.metrics.PeerDisconnected(peer.String())
|
r.metrics.PeerDisconnected(peer.String())
|
||||||
}()
|
}()
|
||||||
|
|||||||
@@ -7,24 +7,27 @@ import (
|
|||||||
"github.com/netbirdio/netbird/relay/messages"
|
"github.com/netbirdio/netbird/relay/messages"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Listener struct {
|
type event struct {
|
||||||
ctx context.Context
|
peerID messages.PeerID
|
||||||
store *Store
|
online bool
|
||||||
|
}
|
||||||
|
|
||||||
onlineChan chan messages.PeerID
|
type Listener struct {
|
||||||
offlineChan chan messages.PeerID
|
ctx context.Context
|
||||||
|
|
||||||
|
eventChan chan *event
|
||||||
interestedPeersForOffline map[messages.PeerID]struct{}
|
interestedPeersForOffline map[messages.PeerID]struct{}
|
||||||
interestedPeersForOnline map[messages.PeerID]struct{}
|
interestedPeersForOnline map[messages.PeerID]struct{}
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func newListener(ctx context.Context, store *Store) *Listener {
|
func newListener(ctx context.Context) *Listener {
|
||||||
l := &Listener{
|
l := &Listener{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
store: store,
|
|
||||||
|
|
||||||
onlineChan: make(chan messages.PeerID, 244), //244 is the message size limit in the relay protocol
|
// important to use a single channel for offline and online events because with it we can ensure all events
|
||||||
offlineChan: make(chan messages.PeerID, 244), //244 is the message size limit in the relay protocol
|
// will be processed in the order they were sent
|
||||||
|
eventChan: make(chan *event, 244), //244 is the message size limit in the relay protocol
|
||||||
interestedPeersForOffline: make(map[messages.PeerID]struct{}),
|
interestedPeersForOffline: make(map[messages.PeerID]struct{}),
|
||||||
interestedPeersForOnline: make(map[messages.PeerID]struct{}),
|
interestedPeersForOnline: make(map[messages.PeerID]struct{}),
|
||||||
}
|
}
|
||||||
@@ -32,8 +35,7 @@ func newListener(ctx context.Context, store *Store) *Listener {
|
|||||||
return l
|
return l
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Listener) AddInterestedPeers(peerIDs []messages.PeerID) []messages.PeerID {
|
func (l *Listener) AddInterestedPeers(peerIDs []messages.PeerID) {
|
||||||
availablePeers := make([]messages.PeerID, 0)
|
|
||||||
l.mu.Lock()
|
l.mu.Lock()
|
||||||
defer l.mu.Unlock()
|
defer l.mu.Unlock()
|
||||||
|
|
||||||
@@ -41,17 +43,6 @@ func (l *Listener) AddInterestedPeers(peerIDs []messages.PeerID) []messages.Peer
|
|||||||
l.interestedPeersForOnline[id] = struct{}{}
|
l.interestedPeersForOnline[id] = struct{}{}
|
||||||
l.interestedPeersForOffline[id] = struct{}{}
|
l.interestedPeersForOffline[id] = struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// collect online peers to response back to the caller
|
|
||||||
for _, id := range peerIDs {
|
|
||||||
_, ok := l.store.Peer(id)
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
availablePeers = append(availablePeers, id)
|
|
||||||
}
|
|
||||||
return availablePeers
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Listener) RemoveInterestedPeer(peerIDs []messages.PeerID) {
|
func (l *Listener) RemoveInterestedPeer(peerIDs []messages.PeerID) {
|
||||||
@@ -61,7 +52,6 @@ func (l *Listener) RemoveInterestedPeer(peerIDs []messages.PeerID) {
|
|||||||
for _, id := range peerIDs {
|
for _, id := range peerIDs {
|
||||||
delete(l.interestedPeersForOffline, id)
|
delete(l.interestedPeersForOffline, id)
|
||||||
delete(l.interestedPeersForOnline, id)
|
delete(l.interestedPeersForOnline, id)
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -70,26 +60,31 @@ func (l *Listener) listenForEvents(onPeersComeOnline, onPeersWentOffline func([]
|
|||||||
select {
|
select {
|
||||||
case <-l.ctx.Done():
|
case <-l.ctx.Done():
|
||||||
return
|
return
|
||||||
case pID := <-l.onlineChan:
|
case e := <-l.eventChan:
|
||||||
peers := make([]messages.PeerID, 0)
|
peersOffline := make([]messages.PeerID, 0)
|
||||||
peers = append(peers, pID)
|
peersOnline := make([]messages.PeerID, 0)
|
||||||
|
if e.online {
|
||||||
for len(l.onlineChan) > 0 {
|
peersOnline = append(peersOnline, e.peerID)
|
||||||
pID = <-l.onlineChan
|
} else {
|
||||||
peers = append(peers, pID)
|
peersOffline = append(peersOffline, e.peerID)
|
||||||
}
|
}
|
||||||
|
|
||||||
onPeersComeOnline(peers)
|
// Drain the channel to collect all events
|
||||||
case pID := <-l.offlineChan:
|
for len(l.eventChan) > 0 {
|
||||||
peers := make([]messages.PeerID, 0)
|
e = <-l.eventChan
|
||||||
peers = append(peers, pID)
|
if e.online {
|
||||||
|
peersOnline = append(peersOnline, e.peerID)
|
||||||
for len(l.offlineChan) > 0 {
|
} else {
|
||||||
pID = <-l.offlineChan
|
peersOffline = append(peersOffline, e.peerID)
|
||||||
peers = append(peers, pID)
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
onPeersWentOffline(peers)
|
if len(peersOnline) > 0 {
|
||||||
|
onPeersComeOnline(peersOnline)
|
||||||
|
}
|
||||||
|
if len(peersOffline) > 0 {
|
||||||
|
onPeersWentOffline(peersOffline)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -100,7 +95,10 @@ func (l *Listener) peerWentOffline(peerID messages.PeerID) {
|
|||||||
|
|
||||||
if _, ok := l.interestedPeersForOffline[peerID]; ok {
|
if _, ok := l.interestedPeersForOffline[peerID]; ok {
|
||||||
select {
|
select {
|
||||||
case l.offlineChan <- peerID:
|
case l.eventChan <- &event{
|
||||||
|
peerID: peerID,
|
||||||
|
online: false,
|
||||||
|
}:
|
||||||
case <-l.ctx.Done():
|
case <-l.ctx.Done():
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -112,9 +110,13 @@ func (l *Listener) peerComeOnline(peerID messages.PeerID) {
|
|||||||
|
|
||||||
if _, ok := l.interestedPeersForOnline[peerID]; ok {
|
if _, ok := l.interestedPeersForOnline[peerID]; ok {
|
||||||
select {
|
select {
|
||||||
case l.onlineChan <- peerID:
|
case l.eventChan <- &event{
|
||||||
|
peerID: peerID,
|
||||||
|
online: true,
|
||||||
|
}:
|
||||||
case <-l.ctx.Done():
|
case <-l.ctx.Done():
|
||||||
}
|
}
|
||||||
|
|
||||||
delete(l.interestedPeersForOnline, peerID)
|
delete(l.interestedPeersForOnline, peerID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,15 +8,12 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type PeerNotifier struct {
|
type PeerNotifier struct {
|
||||||
store *Store
|
|
||||||
|
|
||||||
listeners map[*Listener]context.CancelFunc
|
listeners map[*Listener]context.CancelFunc
|
||||||
listenersMutex sync.RWMutex
|
listenersMutex sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPeerNotifier(store *Store) *PeerNotifier {
|
func NewPeerNotifier() *PeerNotifier {
|
||||||
pn := &PeerNotifier{
|
pn := &PeerNotifier{
|
||||||
store: store,
|
|
||||||
listeners: make(map[*Listener]context.CancelFunc),
|
listeners: make(map[*Listener]context.CancelFunc),
|
||||||
}
|
}
|
||||||
return pn
|
return pn
|
||||||
@@ -24,7 +21,7 @@ func NewPeerNotifier(store *Store) *PeerNotifier {
|
|||||||
|
|
||||||
func (pn *PeerNotifier) NewListener(onPeersComeOnline, onPeersWentOffline func([]messages.PeerID)) *Listener {
|
func (pn *PeerNotifier) NewListener(onPeersComeOnline, onPeersWentOffline func([]messages.PeerID)) *Listener {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
listener := newListener(ctx, pn.store)
|
listener := newListener(ctx)
|
||||||
go listener.listenForEvents(onPeersComeOnline, onPeersWentOffline)
|
go listener.listenForEvents(onPeersComeOnline, onPeersWentOffline)
|
||||||
|
|
||||||
pn.listenersMutex.Lock()
|
pn.listenersMutex.Lock()
|
||||||
|
|||||||
@@ -26,7 +26,9 @@ func NewStore() *Store {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddPeer adds a peer to the store
|
// AddPeer adds a peer to the store
|
||||||
func (s *Store) AddPeer(peer IPeer) {
|
// If the peer already exists, it will be replaced and the old peer will be closed
|
||||||
|
// Returns true if the peer was replaced, false if it was added for the first time.
|
||||||
|
func (s *Store) AddPeer(peer IPeer) bool {
|
||||||
s.peersLock.Lock()
|
s.peersLock.Lock()
|
||||||
defer s.peersLock.Unlock()
|
defer s.peersLock.Unlock()
|
||||||
odlPeer, ok := s.peers[peer.ID()]
|
odlPeer, ok := s.peers[peer.ID()]
|
||||||
@@ -35,22 +37,24 @@ func (s *Store) AddPeer(peer IPeer) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s.peers[peer.ID()] = peer
|
s.peers[peer.ID()] = peer
|
||||||
|
return ok
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeletePeer deletes a peer from the store
|
// DeletePeer deletes a peer from the store
|
||||||
func (s *Store) DeletePeer(peer IPeer) {
|
func (s *Store) DeletePeer(peer IPeer) bool {
|
||||||
s.peersLock.Lock()
|
s.peersLock.Lock()
|
||||||
defer s.peersLock.Unlock()
|
defer s.peersLock.Unlock()
|
||||||
|
|
||||||
dp, ok := s.peers[peer.ID()]
|
dp, ok := s.peers[peer.ID()]
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return false
|
||||||
}
|
}
|
||||||
if dp != peer {
|
if dp != peer {
|
||||||
return
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
delete(s.peers, peer.ID())
|
delete(s.peers, peer.ID())
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Peer returns a peer by its ID
|
// Peer returns a peer by its ID
|
||||||
@@ -73,3 +77,21 @@ func (s *Store) Peers() []IPeer {
|
|||||||
}
|
}
|
||||||
return peers
|
return peers
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Store) GetOnlinePeersAndRegisterInterest(peerIDs []messages.PeerID, listener *Listener) []messages.PeerID {
|
||||||
|
s.peersLock.RLock()
|
||||||
|
defer s.peersLock.RUnlock()
|
||||||
|
|
||||||
|
onlinePeers := make([]messages.PeerID, 0, len(peerIDs))
|
||||||
|
|
||||||
|
listener.AddInterestedPeers(peerIDs)
|
||||||
|
|
||||||
|
// Check for currently online peers
|
||||||
|
for _, id := range peerIDs {
|
||||||
|
if _, ok := s.peers[id]; ok {
|
||||||
|
onlinePeers = append(onlinePeers, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return onlinePeers
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user