Compare commits

...

13 Commits

Author SHA1 Message Date
Viktor Liu
036a3020fe Batch wireguard update operations 2025-07-22 14:44:26 +02:00
Zoltan Papp
86c16cf651 [server, relay] Fix/relay race disconnection (#4174)
Avoid invalid disconnection notifications in case the closed race dials.
In this PR resolve multiple race condition questions. Easier to understand the fix based on commit by commit.

- Remove store dependency from notifier
- Enforce the notification orders
- Fix invalid disconnection notification
- Ensure the order of the events on the consumer side
2025-07-21 19:58:17 +02:00
Bethuel Mmbaga
a7af15c4fc [management] Fix group resource count mismatch in policy (#4182) 2025-07-21 15:26:06 +03:00
Viktor Liu
d6ed9c037e [client] Fix bind exclusion routes (#4154) 2025-07-21 12:13:21 +02:00
Ali Amer
40fdeda838 [client] add new filter-by-connection-type flag (#4010)
introduces a new flag --filter-by-connection-type to the status command.
It allows users to filter peers by connection type (P2P or Relayed) in both JSON and detailed views.

Input validation is added in parseFilters() to ensure proper usage, and --detail is auto-enabled if no output format is specified (consistent with other filters).
2025-07-21 11:55:17 +02:00
Zoltan Papp
f6e9d755e4 [client, relay] The openConn function no longer blocks the relayAddress function call (#4180)
The openConn function no longer blocks the relayAddress function call in manager layer
2025-07-21 09:46:53 +02:00
Maycon Santos
08fd460867 [management] Add validate flow response (#4172)
This PR adds a validate flow response feature to the management server by integrating an IntegratedValidator component. The main purpose is to enable validation of PKCE authorization flows through an integrated validator interface.

- Adds a new ValidateFlowResponse method to the IntegratedValidator interface
- Integrates the validator into the management server to validate PKCE authorization flows
- Updates dependency version for management-integrations
2025-07-18 12:18:52 +02:00
Pascal Fischer
4f74509d55 [management] fix index creation if exist on mysql (#4150) 2025-07-16 15:07:31 +02:00
Maycon Santos
58185ced16 [misc] add forum post and update sign pipeline (#4155)
use old git-town version
2025-07-16 14:10:28 +02:00
Pedro Maia Costa
e67f44f47c [client] fix test (#4156) 2025-07-16 12:09:38 +02:00
Zoltan Papp
b524f486e2 [client] Fix/nil relayed address (#4153)
Fix nil pointer in Relay conn address

Meanwhile, we create a relayed net.Conn struct instance, it is possible to set the relayedURL to nil.

panic: value method github.com/netbirdio/netbird/relay/client.RelayAddr.String called using nil *RelayAddr pointer

Fix relayed URL variable protection
Protect the channel closing
2025-07-16 00:00:18 +02:00
Zoltan Papp
0dab03252c [client, relay-server] Feature/relay notification (#4083)
- Clients now subscribe to peer status changes.
- The server manages and maintains these subscriptions.
- Replaced raw string peer IDs with a custom peer ID type for better type safety and clarity.
2025-07-15 10:43:42 +02:00
iisteev
e49bcc343d [client] Avoid parsing NB_NETSTACK_SKIP_PROXY if empty (#4145)
Signed-off-by: iisteev <isteevan.shetoo@is-info.fr>
2025-07-13 15:42:48 +02:00
106 changed files with 2602 additions and 896 deletions

View File

@@ -16,6 +16,6 @@ jobs:
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: git-town/action@v1 - uses: git-town/action@v1.2.1
with: with:
skip-single-stacks: true skip-single-stacks: true

View File

@@ -211,7 +211,11 @@ jobs:
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
arch: [ '386','amd64' ] include:
- arch: "386"
raceFlag: ""
- arch: "amd64"
raceFlag: "-race"
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
- name: Install Go - name: Install Go
@@ -251,9 +255,9 @@ jobs:
- name: Test - name: Test
run: | run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \ CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
go test \ go test ${{ matrix.raceFlag }} \
-exec 'sudo' \ -exec 'sudo' \
-timeout 10m ./signal/... -timeout 10m ./relay/...
test_signal: test_signal:
name: "Signal / Unit" name: "Signal / Unit"

View File

@@ -9,7 +9,7 @@ on:
pull_request: pull_request:
env: env:
SIGN_PIPE_VER: "v0.0.20" SIGN_PIPE_VER: "v0.0.21"
GORELEASER_VER: "v2.3.2" GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird" PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH" COPYRIGHT: "NetBird GmbH"
@@ -231,3 +231,17 @@ jobs:
ref: ${{ env.SIGN_PIPE_VER }} ref: ${{ env.SIGN_PIPE_VER }}
token: ${{ secrets.SIGN_GITHUB_TOKEN }} token: ${{ secrets.SIGN_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref }}", "skipRelease": false }' inputs: '{ "tag": "${{ github.ref }}", "skipRelease": false }'
post_on_forum:
runs-on: ubuntu-latest
continue-on-error: true
needs: [trigger_signer]
steps:
- uses: Codixer/discourse-topic-github-release-action@v2.0.1
with:
discourse-api-key: ${{ secrets.DISCOURSE_RELEASES_API_KEY }}
discourse-base-url: https://forum.netbird.io
discourse-author-username: NetBird
discourse-category: 17
discourse-tags:
releases

View File

@@ -307,7 +307,7 @@ func getStatusOutput(cmd *cobra.Command, anon bool) string {
cmd.PrintErrf("Failed to get status: %v\n", err) cmd.PrintErrf("Failed to get status: %v\n", err)
} else { } else {
statusOutputString = nbstatus.ParseToFullDetailSummary( statusOutputString = nbstatus.ParseToFullDetailSummary(
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil), nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, ""),
) )
} }
return statusOutputString return statusOutputString

View File

@@ -26,6 +26,7 @@ var (
statusFilter string statusFilter string
ipsFilterMap map[string]struct{} ipsFilterMap map[string]struct{}
prefixNamesFilterMap map[string]struct{} prefixNamesFilterMap map[string]struct{}
connectionTypeFilter string
) )
var statusCmd = &cobra.Command{ var statusCmd = &cobra.Command{
@@ -45,6 +46,7 @@ func init() {
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200") statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200")
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud") statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected") statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
statusCmd.PersistentFlags().StringVar(&connectionTypeFilter, "filter-by-connection-type", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P")
} }
func statusFunc(cmd *cobra.Command, args []string) error { func statusFunc(cmd *cobra.Command, args []string) error {
@@ -89,7 +91,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return nil return nil
} }
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap) var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter)
var statusOutputString string var statusOutputString string
switch { switch {
case detailFlag: case detailFlag:
@@ -156,6 +158,15 @@ func parseFilters() error {
enableDetailFlagWhenFilterFlag() enableDetailFlagWhenFilterFlag()
} }
switch strings.ToLower(connectionTypeFilter) {
case "", "p2p", "relayed":
if strings.ToLower(connectionTypeFilter) != "" {
enableDetailFlagWhenFilterFlag()
}
default:
return fmt.Errorf("wrong connection-type filter, should be one of P2P|Relayed, got: %s", connectionTypeFilter)
}
return nil return nil
} }

View File

@@ -109,7 +109,7 @@ func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc
} }
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager) secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil) mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &mgmt.MockIntegratedValidator{})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

338
client/iface/batcher.go Normal file
View File

@@ -0,0 +1,338 @@
package iface
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"os"
"strconv"
"sync"
"time"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
)
const (
// DefaultBatchFlushInterval is the default maximum time to wait before flushing batched operations
DefaultBatchFlushInterval = 300 * time.Millisecond
// DefaultBatchSizeThreshold is the default number of operations to trigger an immediate flush
DefaultBatchSizeThreshold = 100
// AllowedIPOpAdd represents an add operation
AllowedIPOpAdd = "add"
// AllowedIPOpRemove represents a remove operation
AllowedIPOpRemove = "remove"
EnvDisableWGBatching = "NB_DISABLE_WG_BATCHING"
EnvWGBatchFlushIntervalMS = "NB_WG_BATCH_FLUSH_INTERVAL_MS"
EnvWGBatchSizeThreshold = "NB_WG_BATCH_SIZE_THRESHOLD"
)
// AllowedIPOperation represents a pending allowed IP operation
type AllowedIPOperation struct {
PeerKey string
Prefix netip.Prefix
Operation string
}
// PeerUpdateOperation represents a pending peer update operation
type PeerUpdateOperation struct {
PeerKey string
AllowedIPs []netip.Prefix
KeepAlive time.Duration
Endpoint *net.UDPAddr
PreSharedKey *wgtypes.Key
}
// WGBatcher batches WireGuard configuration updates to reduce syscall overhead
type WGBatcher struct {
configurer device.WGConfigurer
mu sync.Mutex
allowedIPOps []AllowedIPOperation
peerUpdates map[string]*PeerUpdateOperation
flushTimer *time.Timer
flushChan chan struct{}
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
batchFlushInterval time.Duration
batchSizeThreshold int
}
// NewWGBatcher creates a new WireGuard operation batcher
func NewWGBatcher(configurer device.WGConfigurer) *WGBatcher {
if os.Getenv(EnvDisableWGBatching) != "" {
log.Infof("WireGuard allowed IP batching disabled via %s", EnvDisableWGBatching)
return nil
}
flushInterval := DefaultBatchFlushInterval
sizeThreshold := DefaultBatchSizeThreshold
if intervalMs := os.Getenv(EnvWGBatchFlushIntervalMS); intervalMs != "" {
if ms, err := strconv.Atoi(intervalMs); err == nil && ms > 0 {
flushInterval = time.Duration(ms) * time.Millisecond
log.Infof("WireGuard batch flush interval set to %v", flushInterval)
}
}
if threshold := os.Getenv(EnvWGBatchSizeThreshold); threshold != "" {
if size, err := strconv.Atoi(threshold); err == nil && size > 0 {
sizeThreshold = size
log.Infof("WireGuard batch size threshold set to %d", sizeThreshold)
}
}
log.Info("WireGuard allowed IP batching enabled")
ctx, cancel := context.WithCancel(context.Background())
b := &WGBatcher{
configurer: configurer,
peerUpdates: make(map[string]*PeerUpdateOperation),
flushChan: make(chan struct{}, 1),
ctx: ctx,
cancel: cancel,
batchFlushInterval: flushInterval,
batchSizeThreshold: sizeThreshold,
}
b.wg.Add(1)
go b.flushLoop()
return b
}
// Close stops the batcher and flushes any pending operations
func (b *WGBatcher) Close() error {
b.mu.Lock()
if b.flushTimer != nil {
b.flushTimer.Stop()
}
b.mu.Unlock()
b.cancel()
if err := b.Flush(); err != nil {
log.Errorf("failed to flush pending operations on close: %v", err)
}
b.wg.Wait()
return nil
}
// UpdatePeer batches a peer update operation
func (b *WGBatcher) UpdatePeer(peerKey string, allowedIPs []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
b.mu.Lock()
defer b.mu.Unlock()
b.peerUpdates[peerKey] = &PeerUpdateOperation{
PeerKey: peerKey,
AllowedIPs: allowedIPs,
KeepAlive: keepAlive,
Endpoint: endpoint,
PreSharedKey: preSharedKey,
}
b.scheduleFlush()
return nil
}
// AddAllowedIP batches an allowed IP addition
func (b *WGBatcher) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
b.mu.Lock()
defer b.mu.Unlock()
b.allowedIPOps = append(b.allowedIPOps, AllowedIPOperation{
PeerKey: peerKey,
Prefix: allowedIP,
Operation: AllowedIPOpAdd,
})
b.scheduleFlush()
return nil
}
// RemoveAllowedIP batches an allowed IP removal
func (b *WGBatcher) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error {
b.mu.Lock()
defer b.mu.Unlock()
b.allowedIPOps = append(b.allowedIPOps, AllowedIPOperation{
PeerKey: peerKey,
Prefix: allowedIP,
Operation: AllowedIPOpRemove,
})
b.scheduleFlush()
return nil
}
// Flush immediately processes all batched operations
func (b *WGBatcher) Flush() error {
b.mu.Lock()
if b.flushTimer != nil {
b.flushTimer.Stop()
b.flushTimer = nil
}
peerUpdates := b.peerUpdates
allowedIPOps := b.allowedIPOps
b.peerUpdates = make(map[string]*PeerUpdateOperation)
b.allowedIPOps = nil
b.mu.Unlock()
return b.processBatch(peerUpdates, allowedIPOps)
}
// scheduleFlush schedules a batch flush if not already scheduled
func (b *WGBatcher) scheduleFlush() {
shouldFlushNow := len(b.allowedIPOps)+len(b.peerUpdates) >= b.batchSizeThreshold
if shouldFlushNow {
select {
case b.flushChan <- struct{}{}:
default:
}
return
}
if b.flushTimer == nil {
b.flushTimer = time.AfterFunc(b.batchFlushInterval, func() {
select {
case b.flushChan <- struct{}{}:
default:
}
})
}
}
// flushLoop handles periodic flushing of batched operations
func (b *WGBatcher) flushLoop() {
defer b.wg.Done()
for {
select {
case <-b.flushChan:
if err := b.Flush(); err != nil {
log.Errorf("Error flushing WireGuard operations: %v", err)
}
case <-b.ctx.Done():
return
}
}
}
// processBatch processes a batch of operations
func (b *WGBatcher) processBatch(peerUpdates map[string]*PeerUpdateOperation, allowedIPOps []AllowedIPOperation) error {
if len(peerUpdates) == 0 && len(allowedIPOps) == 0 {
return nil
}
start := time.Now()
defer func() {
duration := time.Since(start)
log.Debugf("Processed batch of %d peer updates and %d allowed IP operations in %v",
len(peerUpdates), len(allowedIPOps), duration)
}()
var merr *multierror.Error
if err := b.processPeerUpdates(peerUpdates); err != nil {
merr = multierror.Append(merr, err)
}
if err := b.processAllowedIPOps(allowedIPOps); err != nil {
merr = multierror.Append(merr, err)
}
return nberrors.FormatErrorOrNil(merr)
}
// processPeerUpdates processes peer update operations
func (b *WGBatcher) processPeerUpdates(peerUpdates map[string]*PeerUpdateOperation) error {
var merr *multierror.Error
for _, update := range peerUpdates {
if err := b.configurer.UpdatePeer(
update.PeerKey,
update.AllowedIPs,
update.KeepAlive,
update.Endpoint,
update.PreSharedKey,
); err != nil {
merr = multierror.Append(merr, fmt.Errorf("update peer %s: %w", update.PeerKey, err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
// processAllowedIPOps processes allowed IP add/remove operations
func (b *WGBatcher) processAllowedIPOps(allowedIPOps []AllowedIPOperation) error {
peerChanges := b.groupAllowedIPChanges(allowedIPOps)
return b.applyAllowedIPChanges(peerChanges)
}
// groupAllowedIPChanges groups allowed IP operations by peer
func (b *WGBatcher) groupAllowedIPChanges(allowedIPOps []AllowedIPOperation) map[string]struct {
toAdd []netip.Prefix
toRemove []netip.Prefix
} {
peerChanges := make(map[string]struct {
toAdd []netip.Prefix
toRemove []netip.Prefix
})
for _, op := range allowedIPOps {
changes := peerChanges[op.PeerKey]
if op.Operation == AllowedIPOpAdd {
changes.toAdd = append(changes.toAdd, op.Prefix)
} else {
changes.toRemove = append(changes.toRemove, op.Prefix)
}
peerChanges[op.PeerKey] = changes
}
return peerChanges
}
// applyAllowedIPChanges applies allowed IP changes for each peer
func (b *WGBatcher) applyAllowedIPChanges(peerChanges map[string]struct {
toAdd []netip.Prefix
toRemove []netip.Prefix
}) error {
var merr *multierror.Error
for peerKey, changes := range peerChanges {
for _, prefix := range changes.toRemove {
if err := b.configurer.RemoveAllowedIP(peerKey, prefix); err != nil {
if errors.Is(err, configurer.ErrPeerNotFound) || errors.Is(err, configurer.ErrAllowedIPNotFound) {
log.Debugf("remove allowed IP %s for peer %s: %v", prefix, peerKey, err)
} else {
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s for peer %s: %w", prefix, peerKey, err))
}
}
}
for _, prefix := range changes.toAdd {
if err := b.configurer.AddAllowedIP(peerKey, prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s for peer %s: %w", prefix, peerKey, err))
}
}
}
return nberrors.FormatErrorOrNil(merr)
}

View File

@@ -0,0 +1,15 @@
package bind
import (
wireguard "golang.zx2c4.com/wireguard/conn"
nbnet "github.com/netbirdio/netbird/util/net"
)
// TODO: This is most likely obsolete since the control fns should be called by the wrapped udpconn (ice_bind.go)
func init() {
listener := nbnet.NewListener()
if listener.ListenConfig.Control != nil {
*wireguard.ControlFns = append(*wireguard.ControlFns, listener.ListenConfig.Control)
}
}

View File

@@ -1,12 +0,0 @@
package bind
import (
wireguard "golang.zx2c4.com/wireguard/conn"
nbnet "github.com/netbirdio/netbird/util/net"
)
func init() {
// ControlFns is not thread safe and should only be modified during init.
*wireguard.ControlFns = append(*wireguard.ControlFns, nbnet.ControlProtectSocket)
}

View File

@@ -16,6 +16,7 @@ import (
wgConn "golang.zx2c4.com/wireguard/conn" wgConn "golang.zx2c4.com/wireguard/conn"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
nbnet "github.com/netbirdio/netbird/util/net"
) )
type RecvMessage struct { type RecvMessage struct {
@@ -153,7 +154,7 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
s.udpMux = NewUniversalUDPMuxDefault( s.udpMux = NewUniversalUDPMuxDefault(
UniversalUDPMuxParams{ UniversalUDPMuxParams{
UDPConn: conn, UDPConn: nbnet.WrapUDPConn(conn),
Net: s.transportNet, Net: s.transportNet,
FilterFn: s.filterFn, FilterFn: s.filterFn,
WGAddress: s.address, WGAddress: s.address,

View File

@@ -296,14 +296,20 @@ func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
return return
} }
m.addressMapMu.Lock() var allAddresses []string
defer m.addressMapMu.Unlock()
for _, c := range removedConns { for _, c := range removedConns {
addresses := c.getAddresses() addresses := c.getAddresses()
for _, addr := range addresses { allAddresses = append(allAddresses, addresses...)
}
m.addressMapMu.Lock()
for _, addr := range allAddresses {
delete(m.addressMap, addr) delete(m.addressMap, addr)
} }
m.addressMapMu.Unlock()
for _, addr := range allAddresses {
m.notifyAddressRemoval(addr)
} }
} }
@@ -351,14 +357,13 @@ func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string)
} }
m.addressMapMu.Lock() m.addressMapMu.Lock()
defer m.addressMapMu.Unlock()
existing, ok := m.addressMap[addr] existing, ok := m.addressMap[addr]
if !ok { if !ok {
existing = []*udpMuxedConn{} existing = []*udpMuxedConn{}
} }
existing = append(existing, conn) existing = append(existing, conn)
m.addressMap[addr] = existing m.addressMap[addr] = existing
m.addressMapMu.Unlock()
log.Debugf("ICE: registered %s for %s", addr, conn.params.Key) log.Debugf("ICE: registered %s for %s", addr, conn.params.Key)
} }
@@ -386,12 +391,12 @@ func (m *UDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.Addr) erro
// If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one // If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one
// muxed connection - one for the SRFLX candidate and the other one for the HOST one. // muxed connection - one for the SRFLX candidate and the other one for the HOST one.
// We will then forward STUN packets to each of these connections. // We will then forward STUN packets to each of these connections.
m.addressMapMu.Lock() m.addressMapMu.RLock()
var destinationConnList []*udpMuxedConn var destinationConnList []*udpMuxedConn
if storedConns, ok := m.addressMap[addr.String()]; ok { if storedConns, ok := m.addressMap[addr.String()]; ok {
destinationConnList = append(destinationConnList, storedConns...) destinationConnList = append(destinationConnList, storedConns...)
} }
m.addressMapMu.Unlock() m.addressMapMu.RUnlock()
var isIPv6 bool var isIPv6 bool
if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil { if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil {

View File

@@ -0,0 +1,21 @@
//go:build !ios
package bind
import (
nbnet "github.com/netbirdio/netbird/util/net"
)
func (m *UDPMuxDefault) notifyAddressRemoval(addr string) {
wrapped, ok := m.params.UDPConn.(*UDPConn)
if !ok {
return
}
nbnetConn, ok := wrapped.GetPacketConn().(*nbnet.UDPConn)
if !ok {
return
}
nbnetConn.RemoveAddress(addr)
}

View File

@@ -0,0 +1,7 @@
//go:build ios
package bind
func (m *UDPMuxDefault) notifyAddressRemoval(addr string) {
// iOS doesn't support nbnet hooks, so this is a no-op
}

View File

@@ -62,7 +62,7 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
// wrap UDP connection, process server reflexive messages // wrap UDP connection, process server reflexive messages
// before they are passed to the UDPMux connection handler (connWorker) // before they are passed to the UDPMux connection handler (connWorker)
m.params.UDPConn = &udpConn{ m.params.UDPConn = &UDPConn{
PacketConn: params.UDPConn, PacketConn: params.UDPConn,
mux: m, mux: m,
logger: params.Logger, logger: params.Logger,
@@ -70,7 +70,6 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
address: params.WGAddress, address: params.WGAddress,
} }
// embed UDPMux
udpMuxParams := UDPMuxParams{ udpMuxParams := UDPMuxParams{
Logger: params.Logger, Logger: params.Logger,
UDPConn: m.params.UDPConn, UDPConn: m.params.UDPConn,
@@ -114,8 +113,8 @@ func (m *UniversalUDPMuxDefault) ReadFromConn(ctx context.Context) {
} }
} }
// udpConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets // UDPConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets
type udpConn struct { type UDPConn struct {
net.PacketConn net.PacketConn
mux *UniversalUDPMuxDefault mux *UniversalUDPMuxDefault
logger logging.LeveledLogger logger logging.LeveledLogger
@@ -125,7 +124,12 @@ type udpConn struct {
address wgaddr.Address address wgaddr.Address
} }
func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) { // GetPacketConn returns the underlying PacketConn
func (u *UDPConn) GetPacketConn() net.PacketConn {
return u.PacketConn
}
func (u *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
if u.filterFn == nil { if u.filterFn == nil {
return u.PacketConn.WriteTo(b, addr) return u.PacketConn.WriteTo(b, addr)
} }
@@ -137,21 +141,21 @@ func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) {
return u.handleUncachedAddress(b, addr) return u.handleUncachedAddress(b, addr)
} }
func (u *udpConn) handleCachedAddress(isRouted bool, b []byte, addr net.Addr) (int, error) { func (u *UDPConn) handleCachedAddress(isRouted bool, b []byte, addr net.Addr) (int, error) {
if isRouted { if isRouted {
return 0, fmt.Errorf("address %s is part of a routed network, refusing to write", addr) return 0, fmt.Errorf("address %s is part of a routed network, refusing to write", addr)
} }
return u.PacketConn.WriteTo(b, addr) return u.PacketConn.WriteTo(b, addr)
} }
func (u *udpConn) handleUncachedAddress(b []byte, addr net.Addr) (int, error) { func (u *UDPConn) handleUncachedAddress(b []byte, addr net.Addr) (int, error) {
if err := u.performFilterCheck(addr); err != nil { if err := u.performFilterCheck(addr); err != nil {
return 0, err return 0, err
} }
return u.PacketConn.WriteTo(b, addr) return u.PacketConn.WriteTo(b, addr)
} }
func (u *udpConn) performFilterCheck(addr net.Addr) error { func (u *UDPConn) performFilterCheck(addr net.Addr) error {
host, err := getHostFromAddr(addr) host, err := getHostFromAddr(addr)
if err != nil { if err != nil {
log.Errorf("Failed to get host from address %s: %v", addr, err) log.Errorf("Failed to get host from address %s: %v", addr, err)

View File

@@ -59,6 +59,7 @@ type WGIface struct {
mu sync.Mutex mu sync.Mutex
configurer device.WGConfigurer configurer device.WGConfigurer
batcher *WGBatcher
filter device.PacketFilter filter device.PacketFilter
wgProxyFactory wgProxyFactory wgProxyFactory wgProxyFactory
} }
@@ -128,6 +129,12 @@ func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAliv
} }
log.Debugf("updating interface %s peer %s, endpoint %s, allowedIPs %v", w.tun.DeviceName(), peerKey, endpoint, allowedIps) log.Debugf("updating interface %s peer %s, endpoint %s, allowedIPs %v", w.tun.DeviceName(), peerKey, endpoint, allowedIps)
if endpoint != nil && w.batcher != nil {
if err := w.batcher.Flush(); err != nil {
log.Warnf("failed to flush batched operations: %v", err)
}
}
return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey) return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
} }
@@ -152,6 +159,10 @@ func (w *WGIface) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
} }
log.Debugf("Adding allowed IP to interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP) log.Debugf("Adding allowed IP to interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP)
if w.batcher != nil {
return w.batcher.AddAllowedIP(peerKey, allowedIP)
}
return w.configurer.AddAllowedIP(peerKey, allowedIP) return w.configurer.AddAllowedIP(peerKey, allowedIP)
} }
@@ -164,6 +175,10 @@ func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error
} }
log.Debugf("Removing allowed IP from interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP) log.Debugf("Removing allowed IP from interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP)
if w.batcher != nil {
return w.batcher.RemoveAllowedIP(peerKey, allowedIP)
}
return w.configurer.RemoveAllowedIP(peerKey, allowedIP) return w.configurer.RemoveAllowedIP(peerKey, allowedIP)
} }
@@ -174,6 +189,12 @@ func (w *WGIface) Close() error {
var result *multierror.Error var result *multierror.Error
if w.batcher != nil {
if err := w.batcher.Close(); err != nil {
result = multierror.Append(result, fmt.Errorf("failed to close WireGuard batcher: %w", err))
}
}
if err := w.wgProxyFactory.Free(); err != nil { if err := w.wgProxyFactory.Free(); err != nil {
result = multierror.Append(result, fmt.Errorf("failed to free WireGuard proxy: %w", err)) result = multierror.Append(result, fmt.Errorf("failed to free WireGuard proxy: %w", err))
} }

View File

@@ -17,6 +17,7 @@ func (w *WGIface) Create() error {
} }
w.configurer = cfgr w.configurer = cfgr
w.batcher = NewWGBatcher(cfgr)
return nil return nil
} }

View File

@@ -1,8 +1,6 @@
package iface package iface
import ( import "fmt"
"fmt"
)
// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up. // CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up.
// Will reuse an existing one. // Will reuse an existing one.
@@ -15,6 +13,7 @@ func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []s
return err return err
} }
w.configurer = cfgr w.configurer = cfgr
w.batcher = NewWGBatcher(cfgr)
return nil return nil
} }

View File

@@ -29,6 +29,7 @@ func (w *WGIface) Create() error {
return err return err
} }
w.configurer = cfgr w.configurer = cfgr
w.batcher = NewWGBatcher(cfgr)
return nil return nil
} }

View File

@@ -41,10 +41,13 @@ func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
} }
t.tundev = nsTunDev t.tundev = nsTunDev
skipProxy, err := strconv.ParseBool(os.Getenv(EnvSkipProxy)) var skipProxy bool
if val := os.Getenv(EnvSkipProxy); val != "" {
skipProxy, err = strconv.ParseBool(val)
if err != nil { if err != nil {
log.Errorf("failed to parse %s: %s", EnvSkipProxy, err) log.Errorf("failed to parse %s: %s", EnvSkipProxy, err)
} }
}
if skipProxy { if skipProxy {
return nsTunDev, tunNet, nil return nsTunDev, tunNet, nil
} }

View File

@@ -12,6 +12,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
) )
type ProxyBind struct { type ProxyBind struct {
@@ -28,6 +29,17 @@ type ProxyBind struct {
pausedMu sync.Mutex pausedMu sync.Mutex
paused bool paused bool
isStarted bool isStarted bool
closeListener *listener.CloseListener
}
func NewProxyBind(bind *bind.ICEBind) *ProxyBind {
p := &ProxyBind{
Bind: bind,
closeListener: listener.NewCloseListener(),
}
return p
} }
// AddTurnConn adds a new connection to the bind. // AddTurnConn adds a new connection to the bind.
@@ -54,6 +66,10 @@ func (p *ProxyBind) EndpointAddr() *net.UDPAddr {
} }
} }
func (p *ProxyBind) SetDisconnectListener(disconnected func()) {
p.closeListener.SetCloseListener(disconnected)
}
func (p *ProxyBind) Work() { func (p *ProxyBind) Work() {
if p.remoteConn == nil { if p.remoteConn == nil {
return return
@@ -96,6 +112,9 @@ func (p *ProxyBind) close() error {
if p.closed { if p.closed {
return nil return nil
} }
p.closeListener.SetCloseListener(nil)
p.closed = true p.closed = true
p.cancel() p.cancel()
@@ -122,6 +141,7 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
if ctx.Err() != nil { if ctx.Err() != nil {
return return
} }
p.closeListener.Notify()
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err) log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
return return
} }

View File

@@ -11,6 +11,8 @@ import (
"sync" "sync"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
) )
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call // ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
@@ -26,6 +28,15 @@ type ProxyWrapper struct {
pausedMu sync.Mutex pausedMu sync.Mutex
paused bool paused bool
isStarted bool isStarted bool
closeListener *listener.CloseListener
}
func NewProxyWrapper(WgeBPFProxy *WGEBPFProxy) *ProxyWrapper {
return &ProxyWrapper{
WgeBPFProxy: WgeBPFProxy,
closeListener: listener.NewCloseListener(),
}
} }
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error { func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
@@ -43,6 +54,10 @@ func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
return p.wgEndpointAddr return p.wgEndpointAddr
} }
func (p *ProxyWrapper) SetDisconnectListener(disconnected func()) {
p.closeListener.SetCloseListener(disconnected)
}
func (p *ProxyWrapper) Work() { func (p *ProxyWrapper) Work() {
if p.remoteConn == nil { if p.remoteConn == nil {
return return
@@ -77,6 +92,8 @@ func (e *ProxyWrapper) CloseConn() error {
e.cancel() e.cancel()
e.closeListener.SetCloseListener(nil)
if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
return fmt.Errorf("failed to close remote conn: %w", err) return fmt.Errorf("failed to close remote conn: %w", err)
} }
@@ -117,6 +134,7 @@ func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, err
if ctx.Err() != nil { if ctx.Err() != nil {
return 0, ctx.Err() return 0, ctx.Err()
} }
p.closeListener.Notify()
if !errors.Is(err, io.EOF) { if !errors.Is(err, io.EOF) {
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err) log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err)
} }

View File

@@ -36,9 +36,8 @@ func (w *KernelFactory) GetProxy() Proxy {
return udpProxy.NewWGUDPProxy(w.wgPort) return udpProxy.NewWGUDPProxy(w.wgPort)
} }
return &ebpf.ProxyWrapper{ return ebpf.NewProxyWrapper(w.ebpfProxy)
WgeBPFProxy: w.ebpfProxy,
}
} }
func (w *KernelFactory) Free() error { func (w *KernelFactory) Free() error {

View File

@@ -20,9 +20,7 @@ func NewUSPFactory(iceBind *bind.ICEBind) *USPFactory {
} }
func (w *USPFactory) GetProxy() Proxy { func (w *USPFactory) GetProxy() Proxy {
return &proxyBind.ProxyBind{ return proxyBind.NewProxyBind(w.bind)
Bind: w.bind,
}
} }
func (w *USPFactory) Free() error { func (w *USPFactory) Free() error {

View File

@@ -0,0 +1,19 @@
package listener
type CloseListener struct {
listener func()
}
func NewCloseListener() *CloseListener {
return &CloseListener{}
}
func (c *CloseListener) SetCloseListener(listener func()) {
c.listener = listener
}
func (c *CloseListener) Notify() {
if c.listener != nil {
c.listener()
}
}

View File

@@ -12,4 +12,5 @@ type Proxy interface {
Work() // Work start or resume the proxy Work() // Work start or resume the proxy
Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works. Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works.
CloseConn() error CloseConn() error
SetDisconnectListener(disconnected func())
} }

View File

@@ -98,9 +98,7 @@ func TestProxyCloseByRemoteConn(t *testing.T) {
t.Errorf("failed to free ebpf proxy: %s", err) t.Errorf("failed to free ebpf proxy: %s", err)
} }
}() }()
proxyWrapper := &ebpf.ProxyWrapper{ proxyWrapper := ebpf.NewProxyWrapper(ebpfProxy)
WgeBPFProxy: ebpfProxy,
}
tests = append(tests, struct { tests = append(tests, struct {
name string name string

View File

@@ -12,6 +12,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
cerrors "github.com/netbirdio/netbird/client/errors" cerrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
) )
// WGUDPProxy proxies // WGUDPProxy proxies
@@ -28,6 +29,8 @@ type WGUDPProxy struct {
pausedMu sync.Mutex pausedMu sync.Mutex
paused bool paused bool
isStarted bool isStarted bool
closeListener *listener.CloseListener
} }
// NewWGUDPProxy instantiate a UDP based WireGuard proxy. This is not a thread safe implementation // NewWGUDPProxy instantiate a UDP based WireGuard proxy. This is not a thread safe implementation
@@ -35,6 +38,7 @@ func NewWGUDPProxy(wgPort int) *WGUDPProxy {
log.Debugf("Initializing new user space proxy with port %d", wgPort) log.Debugf("Initializing new user space proxy with port %d", wgPort)
p := &WGUDPProxy{ p := &WGUDPProxy{
localWGListenPort: wgPort, localWGListenPort: wgPort,
closeListener: listener.NewCloseListener(),
} }
return p return p
} }
@@ -67,6 +71,10 @@ func (p *WGUDPProxy) EndpointAddr() *net.UDPAddr {
return endpointUdpAddr return endpointUdpAddr
} }
func (p *WGUDPProxy) SetDisconnectListener(disconnected func()) {
p.closeListener.SetCloseListener(disconnected)
}
// Work starts the proxy or resumes it if it was paused // Work starts the proxy or resumes it if it was paused
func (p *WGUDPProxy) Work() { func (p *WGUDPProxy) Work() {
if p.remoteConn == nil { if p.remoteConn == nil {
@@ -111,6 +119,8 @@ func (p *WGUDPProxy) close() error {
if p.closed { if p.closed {
return nil return nil
} }
p.closeListener.SetCloseListener(nil)
p.closed = true p.closed = true
p.cancel() p.cancel()
@@ -141,6 +151,7 @@ func (p *WGUDPProxy) proxyToRemote(ctx context.Context) {
if ctx.Err() != nil { if ctx.Err() != nil {
return return
} }
p.closeListener.Notify()
log.Debugf("failed to read from wg interface conn: %s", err) log.Debugf("failed to read from wg interface conn: %s", err)
return return
} }

View File

@@ -61,7 +61,6 @@ import (
signal "github.com/netbirdio/netbird/signal/client" signal "github.com/netbirdio/netbird/signal/client"
sProto "github.com/netbirdio/netbird/signal/proto" sProto "github.com/netbirdio/netbird/signal/proto"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
nbnet "github.com/netbirdio/netbird/util/net"
) )
// PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer. // PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer.
@@ -138,9 +137,6 @@ type Engine struct {
connMgr *ConnMgr connMgr *ConnMgr
beforePeerHook nbnet.AddHookFunc
afterPeerHook nbnet.RemoveHookFunc
// rpManager is a Rosenpass manager // rpManager is a Rosenpass manager
rpManager *rosenpass.Manager rpManager *rosenpass.Manager
@@ -409,12 +405,8 @@ func (e *Engine) Start() error {
DisableClientRoutes: e.config.DisableClientRoutes, DisableClientRoutes: e.config.DisableClientRoutes,
DisableServerRoutes: e.config.DisableServerRoutes, DisableServerRoutes: e.config.DisableServerRoutes,
}) })
beforePeerHook, afterPeerHook, err := e.routeManager.Init() if err := e.routeManager.Init(); err != nil {
if err != nil {
log.Errorf("Failed to initialize route manager: %s", err) log.Errorf("Failed to initialize route manager: %s", err)
} else {
e.beforePeerHook = beforePeerHook
e.afterPeerHook = afterPeerHook
} }
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener) e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
@@ -1261,10 +1253,6 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
return fmt.Errorf("peer already exists: %s", peerKey) return fmt.Errorf("peer already exists: %s", peerKey)
} }
if e.beforePeerHook != nil && e.afterPeerHook != nil {
conn.AddBeforeAddPeerHook(e.beforePeerHook)
conn.AddAfterRemovePeerHook(e.afterPeerHook)
}
return nil return nil
} }

View File

@@ -400,7 +400,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
StatusRecorder: engine.statusRecorder, StatusRecorder: engine.statusRecorder,
RelayManager: relayMgr, RelayManager: relayMgr,
}) })
_, _, err = engine.routeManager.Init() err = engine.routeManager.Init()
require.NoError(t, err) require.NoError(t, err)
engine.dnsServer = &dns.MockServer{ engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil }, UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
@@ -1481,6 +1481,10 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()). GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
Return(&types.Settings{}, nil). Return(&types.Settings{}, nil).
AnyTimes() AnyTimes()
settingsMockManager.EXPECT().
GetExtraSettings(gomock.Any(), gomock.Any()).
Return(&types.ExtraSettings{}, nil).
AnyTimes()
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)
@@ -1490,7 +1494,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
} }
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager) secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil) mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{})
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }

View File

@@ -26,7 +26,6 @@ import (
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
relayClient "github.com/netbirdio/netbird/relay/client" relayClient "github.com/netbirdio/netbird/relay/client"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
nbnet "github.com/netbirdio/netbird/util/net"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group" semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
) )
@@ -106,10 +105,6 @@ type Conn struct {
workerRelay *WorkerRelay workerRelay *WorkerRelay
wgWatcherWg sync.WaitGroup wgWatcherWg sync.WaitGroup
connIDRelay nbnet.ConnectionID
connIDICE nbnet.ConnectionID
beforeAddPeerHooks []nbnet.AddHookFunc
afterRemovePeerHooks []nbnet.RemoveHookFunc
// used to store the remote Rosenpass key for Relayed connection in case of connection update from ice // used to store the remote Rosenpass key for Relayed connection in case of connection update from ice
rosenpassRemoteKey []byte rosenpassRemoteKey []byte
@@ -167,7 +162,7 @@ func (conn *Conn) Open(engineCtx context.Context) error {
conn.ctx, conn.ctxCancel = context.WithCancel(engineCtx) conn.ctx, conn.ctxCancel = context.WithCancel(engineCtx)
conn.workerRelay = NewWorkerRelay(conn.Log, isController(conn.config), conn.config, conn, conn.relayManager, conn.dumpState) conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager, conn.dumpState)
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally() relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally) workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
@@ -267,8 +262,6 @@ func (conn *Conn) Close(signalToRemote bool) {
conn.Log.Errorf("failed to remove wg endpoint: %v", err) conn.Log.Errorf("failed to remove wg endpoint: %v", err)
} }
conn.freeUpConnID()
if conn.evalStatus() == StatusConnected && conn.onDisconnected != nil { if conn.evalStatus() == StatusConnected && conn.onDisconnected != nil {
conn.onDisconnected(conn.config.WgConfig.RemoteKey) conn.onDisconnected(conn.config.WgConfig.RemoteKey)
} }
@@ -293,13 +286,6 @@ func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMa
conn.workerICE.OnRemoteCandidate(candidate, haRoutes) conn.workerICE.OnRemoteCandidate(candidate, haRoutes)
} }
func (conn *Conn) AddBeforeAddPeerHook(hook nbnet.AddHookFunc) {
conn.beforeAddPeerHooks = append(conn.beforeAddPeerHooks, hook)
}
func (conn *Conn) AddAfterRemovePeerHook(hook nbnet.RemoveHookFunc) {
conn.afterRemovePeerHooks = append(conn.afterRemovePeerHooks, hook)
}
// SetOnConnected sets a handler function to be triggered by Conn when a new connection to a remote peer established // SetOnConnected sets a handler function to be triggered by Conn when a new connection to a remote peer established
func (conn *Conn) SetOnConnected(handler func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)) { func (conn *Conn) SetOnConnected(handler func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)) {
conn.onConnected = handler conn.onConnected = handler
@@ -387,10 +373,6 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
ep = directEp ep = directEp
} }
if err := conn.runBeforeAddPeerHooks(ep.IP); err != nil {
conn.Log.Errorf("Before add peer hook failed: %v", err)
}
conn.workerRelay.DisableWgWatcher() conn.workerRelay.DisableWgWatcher()
// todo consider to run conn.wgWatcherWg.Wait() here // todo consider to run conn.wgWatcherWg.Wait() here
@@ -489,6 +471,8 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
conn.Log.Errorf("failed to add relayed net.Conn to local proxy: %v", err) conn.Log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
return return
} }
wgProxy.SetDisconnectListener(conn.onRelayDisconnected)
conn.dumpState.NewLocalProxy() conn.dumpState.NewLocalProxy()
conn.Log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String()) conn.Log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String())
@@ -501,10 +485,6 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
return return
} }
if err := conn.runBeforeAddPeerHooks(wgProxy.EndpointAddr().IP); err != nil {
conn.Log.Errorf("Before add peer hook failed: %v", err)
}
wgProxy.Work() wgProxy.Work()
if err := conn.configureWGEndpoint(wgProxy.EndpointAddr(), rci.rosenpassPubKey); err != nil { if err := conn.configureWGEndpoint(wgProxy.EndpointAddr(), rci.rosenpassPubKey); err != nil {
if err := wgProxy.CloseConn(); err != nil { if err := wgProxy.CloseConn(); err != nil {
@@ -705,36 +685,6 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) {
return true return true
} }
func (conn *Conn) runBeforeAddPeerHooks(ip net.IP) error {
conn.connIDICE = nbnet.GenerateConnID()
for _, hook := range conn.beforeAddPeerHooks {
if err := hook(conn.connIDICE, ip); err != nil {
return err
}
}
return nil
}
func (conn *Conn) freeUpConnID() {
if conn.connIDRelay != "" {
for _, hook := range conn.afterRemovePeerHooks {
if err := hook(conn.connIDRelay); err != nil {
conn.Log.Errorf("After remove peer hook failed: %v", err)
}
}
conn.connIDRelay = ""
}
if conn.connIDICE != "" {
for _, hook := range conn.afterRemovePeerHooks {
if err := hook(conn.connIDICE); err != nil {
conn.Log.Errorf("After remove peer hook failed: %v", err)
}
}
conn.connIDICE = ""
}
}
func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) { func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
conn.Log.Debugf("setup proxied WireGuard connection") conn.Log.Debugf("setup proxied WireGuard connection")
udpAddr := &net.UDPAddr{ udpAddr := &net.UDPAddr{

View File

@@ -19,11 +19,12 @@ type RelayConnInfo struct {
} }
type WorkerRelay struct { type WorkerRelay struct {
peerCtx context.Context
log *log.Entry log *log.Entry
isController bool isController bool
config ConnConfig config ConnConfig
conn *Conn conn *Conn
relayManager relayClient.ManagerService relayManager *relayClient.Manager
relayedConn net.Conn relayedConn net.Conn
relayLock sync.Mutex relayLock sync.Mutex
@@ -33,8 +34,9 @@ type WorkerRelay struct {
wgWatcher *WGWatcher wgWatcher *WGWatcher
} }
func NewWorkerRelay(log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager relayClient.ManagerService, stateDump *stateDump) *WorkerRelay { func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager *relayClient.Manager, stateDump *stateDump) *WorkerRelay {
r := &WorkerRelay{ r := &WorkerRelay{
peerCtx: ctx,
log: log, log: log,
isController: ctrl, isController: ctrl,
config: config, config: config,
@@ -62,7 +64,7 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
srv := w.preferredRelayServer(currentRelayAddress, remoteOfferAnswer.RelaySrvAddress) srv := w.preferredRelayServer(currentRelayAddress, remoteOfferAnswer.RelaySrvAddress)
relayedConn, err := w.relayManager.OpenConn(srv, w.config.Key) relayedConn, err := w.relayManager.OpenConn(w.peerCtx, srv, w.config.Key)
if err != nil { if err != nil {
if errors.Is(err, relayClient.ErrConnAlreadyExists) { if errors.Is(err, relayClient.ErrConnAlreadyExists) {
w.log.Debugf("handled offer by reusing existing relay connection") w.log.Debugf("handled offer by reusing existing relay connection")

View File

@@ -44,7 +44,7 @@ import (
// Manager is a route manager interface // Manager is a route manager interface
type Manager interface { type Manager interface {
Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) Init() error
UpdateRoutes(updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error UpdateRoutes(updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error
ClassifyRoutes(newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap) ClassifyRoutes(newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap)
TriggerSelection(route.HAMap) TriggerSelection(route.HAMap)
@@ -201,11 +201,11 @@ func (m *DefaultManager) setupRefCounters(useNoop bool) {
} }
// Init sets up the routing // Init sets up the routing
func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { func (m *DefaultManager) Init() error {
m.routeSelector = m.initSelector() m.routeSelector = m.initSelector()
if nbnet.CustomRoutingDisabled() || m.disableClientRoutes { if nbnet.CustomRoutingDisabled() || m.disableClientRoutes {
return nil, nil, nil return nil
} }
if err := m.sysOps.CleanupRouting(nil); err != nil { if err := m.sysOps.CleanupRouting(nil); err != nil {
@@ -219,13 +219,12 @@ func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
ips := resolveURLsToIPs(initialAddresses) ips := resolveURLsToIPs(initialAddresses)
beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips, m.stateManager) if err := m.sysOps.SetupRouting(ips, m.stateManager); err != nil {
if err != nil { return fmt.Errorf("setup routing: %w", err)
return nil, nil, fmt.Errorf("setup routing: %w", err)
} }
log.Info("Routing setup complete") log.Info("Routing setup complete")
return beforePeerHook, afterPeerHook, nil return nil
} }
func (m *DefaultManager) initSelector() *routeselector.RouteSelector { func (m *DefaultManager) initSelector() *routeselector.RouteSelector {

View File

@@ -430,7 +430,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
StatusRecorder: statusRecorder, StatusRecorder: statusRecorder,
}) })
_, _, err = routeManager.Init() err = routeManager.Init()
require.NoError(t, err, "should init route manager") require.NoError(t, err, "should init route manager")
defer routeManager.Stop(nil) defer routeManager.Stop(nil)

View File

@@ -9,7 +9,6 @@ import (
"github.com/netbirdio/netbird/client/internal/routeselector" "github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/util/net"
) )
// MockManager is the mock instance of a route manager // MockManager is the mock instance of a route manager
@@ -23,8 +22,8 @@ type MockManager struct {
StopFunc func(manager *statemanager.Manager) StopFunc func(manager *statemanager.Manager)
} }
func (m *MockManager) Init() (net.AddHookFunc, net.RemoveHookFunc, error) { func (m *MockManager) Init() error {
return nil, nil, nil return nil
} }
// InitialRouteRange mock implementation of InitialRouteRange from Manager interface // InitialRouteRange mock implementation of InitialRouteRange from Manager interface

View File

@@ -6,6 +6,7 @@ import (
"net/netip" "net/netip"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/routemanager/notifier" "github.com/netbirdio/netbird/client/internal/routemanager/notifier"
@@ -56,6 +57,10 @@ type SysOps struct {
// seq is an atomic counter for generating unique sequence numbers for route messages // seq is an atomic counter for generating unique sequence numbers for route messages
//nolint:unused // only used on BSD systems //nolint:unused // only used on BSD systems
seq atomic.Uint32 seq atomic.Uint32
localSubnetsCache []*net.IPNet
localSubnetsCacheMu sync.RWMutex
localSubnetsCacheTime time.Time
} }
func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps { func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps {

View File

@@ -10,11 +10,10 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net"
) )
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error {
return nil, nil, nil return nil
} }
func (r *SysOps) CleanupRouting(*statemanager.Manager) error { func (r *SysOps) CleanupRouting(*statemanager.Manager) error {

View File

@@ -10,6 +10,7 @@ import (
"net/netip" "net/netip"
"runtime" "runtime"
"strconv" "strconv"
"time"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
"github.com/libp2p/go-netroute" "github.com/libp2p/go-netroute"
@@ -24,6 +25,8 @@ import (
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
const localSubnetsCacheTTL = 15 * time.Minute
var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1) var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1)
var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1) var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1)
var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1) var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
@@ -31,7 +34,7 @@ var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
var ErrRoutingIsSeparate = errors.New("routing is separate") var ErrRoutingIsSeparate = errors.New("routing is separate")
func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemanager.Manager) error {
stateManager.RegisterState(&ShutdownState{}) stateManager.RegisterState(&ShutdownState{})
initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified()) initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified())
@@ -75,7 +78,10 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana
r.refCounter = refCounter r.refCounter = refCounter
return r.setupHooks(initAddresses, stateManager) if err := r.setupHooks(initAddresses, stateManager); err != nil {
return fmt.Errorf("setup hooks: %w", err)
}
return nil
} }
// updateState updates state on every change so it will be persisted regularly // updateState updates state on every change so it will be persisted regularly
@@ -128,18 +134,14 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf wgIface, init
return Nexthop{}, fmt.Errorf("get next hop: %w", err) return Nexthop{}, fmt.Errorf("get next hop: %w", err)
} }
log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop.IP, prefix, nexthop.IP) log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop.IP, prefix, nexthop.Intf)
exitNextHop := Nexthop{ exitNextHop := nexthop
IP: nexthop.IP,
Intf: nexthop.Intf,
}
vpnAddr := vpnIntf.Address().IP vpnAddr := vpnIntf.Address().IP
// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values // if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
if exitNextHop.IP == vpnAddr || exitNextHop.Intf != nil && exitNextHop.Intf.Name == vpnIntf.Name() { if exitNextHop.IP == vpnAddr || exitNextHop.Intf != nil && exitNextHop.Intf.Name == vpnIntf.Name() {
log.Debugf("Route for prefix %s is pointing to the VPN interface, using initial next hop %v", prefix, initialNextHop) log.Debugf("Route for prefix %s is pointing to the VPN interface, using initial next hop %v", prefix, initialNextHop)
exitNextHop = initialNextHop exitNextHop = initialNextHop
} }
@@ -152,12 +154,37 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf wgIface, init
} }
func (r *SysOps) isPrefixInLocalSubnets(prefix netip.Prefix) (bool, *net.IPNet) { func (r *SysOps) isPrefixInLocalSubnets(prefix netip.Prefix) (bool, *net.IPNet) {
r.localSubnetsCacheMu.RLock()
cacheAge := time.Since(r.localSubnetsCacheTime)
subnets := r.localSubnetsCache
r.localSubnetsCacheMu.RUnlock()
if cacheAge > localSubnetsCacheTTL || subnets == nil {
r.localSubnetsCacheMu.Lock()
if time.Since(r.localSubnetsCacheTime) > localSubnetsCacheTTL || r.localSubnetsCache == nil {
r.refreshLocalSubnetsCache()
}
subnets = r.localSubnetsCache
r.localSubnetsCacheMu.Unlock()
}
for _, subnet := range subnets {
if subnet.Contains(prefix.Addr().AsSlice()) {
return true, subnet
}
}
return false, nil
}
func (r *SysOps) refreshLocalSubnetsCache() {
localInterfaces, err := net.Interfaces() localInterfaces, err := net.Interfaces()
if err != nil { if err != nil {
log.Errorf("Failed to get local interfaces: %v", err) log.Errorf("Failed to get local interfaces: %v", err)
return false, nil return
} }
var newSubnets []*net.IPNet
for _, intf := range localInterfaces { for _, intf := range localInterfaces {
addrs, err := intf.Addrs() addrs, err := intf.Addrs()
if err != nil { if err != nil {
@@ -171,14 +198,12 @@ func (r *SysOps) isPrefixInLocalSubnets(prefix netip.Prefix) (bool, *net.IPNet)
log.Errorf("Failed to convert address to IPNet: %v", addr) log.Errorf("Failed to convert address to IPNet: %v", addr)
continue continue
} }
newSubnets = append(newSubnets, ipnet)
if ipnet.Contains(prefix.Addr().AsSlice()) {
return true, ipnet
}
} }
} }
return false, nil r.localSubnetsCache = newSubnets
r.localSubnetsCacheTime = time.Now()
} }
// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix // genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix
@@ -264,7 +289,7 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface)
return r.removeFromRouteTable(prefix, nextHop) return r.removeFromRouteTable(prefix, nextHop)
} }
func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) error {
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error { beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
prefix, err := util.GetPrefixFromIP(ip) prefix, err := util.GetPrefixFromIP(ip)
if err != nil { if err != nil {
@@ -289,9 +314,11 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
return nil return nil
} }
var merr *multierror.Error
for _, ip := range initAddresses { for _, ip := range initAddresses {
if err := beforeHook("init", ip); err != nil { if err := beforeHook("init", ip); err != nil {
log.Errorf("Failed to add route reference: %v", err) merr = multierror.Append(merr, fmt.Errorf("add initial route for %s: %w", ip, err))
} }
} }
@@ -300,11 +327,11 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
return ctx.Err() return ctx.Err()
} }
var result *multierror.Error var merr *multierror.Error
for _, ip := range resolvedIPs { for _, ip := range resolvedIPs {
result = multierror.Append(result, beforeHook(connID, ip.IP)) merr = multierror.Append(merr, beforeHook(connID, ip.IP))
} }
return nberrors.FormatErrorOrNil(result) return nberrors.FormatErrorOrNil(merr)
}) })
nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error { nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error {
@@ -319,7 +346,16 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
return afterHook(connID) return afterHook(connID)
}) })
return beforeHook, afterHook, nil nbnet.AddListenerAddressRemoveHook(func(connID nbnet.ConnectionID, prefix netip.Prefix) error {
if _, err := r.refCounter.Decrement(prefix); err != nil {
return fmt.Errorf("remove route reference: %w", err)
}
r.updateState(stateManager)
return nil
})
return nberrors.FormatErrorOrNil(merr)
} }
func GetNextHop(ip netip.Addr) (Nexthop, error) { func GetNextHop(ip netip.Addr) (Nexthop, error) {

View File

@@ -143,7 +143,7 @@ func TestAddVPNRoute(t *testing.T) {
wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n) wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n)
r := NewSysOps(wgInterface, nil) r := NewSysOps(wgInterface, nil)
_, _, err := r.SetupRouting(nil, nil) err := r.SetupRouting(nil, nil)
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(func() { t.Cleanup(func() {
assert.NoError(t, r.CleanupRouting(nil)) assert.NoError(t, r.CleanupRouting(nil))
@@ -341,7 +341,7 @@ func TestAddRouteToNonVPNIntf(t *testing.T) {
wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n) wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n)
r := NewSysOps(wgInterface, nil) r := NewSysOps(wgInterface, nil)
_, _, err := r.SetupRouting(nil, nil) err := r.SetupRouting(nil, nil)
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(func() { t.Cleanup(func() {
assert.NoError(t, r.CleanupRouting(nil)) assert.NoError(t, r.CleanupRouting(nil))
@@ -484,7 +484,7 @@ func setupTestEnv(t *testing.T) {
}) })
r := NewSysOps(wgInterface, nil) r := NewSysOps(wgInterface, nil)
_, _, err := r.SetupRouting(nil, nil) err := r.SetupRouting(nil, nil)
require.NoError(t, err, "setupRouting should not return err") require.NoError(t, err, "setupRouting should not return err")
t.Cleanup(func() { t.Cleanup(func() {
assert.NoError(t, r.CleanupRouting(nil)) assert.NoError(t, r.CleanupRouting(nil))

View File

@@ -10,14 +10,13 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net"
) )
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error {
r.mu.Lock() r.mu.Lock()
defer r.mu.Unlock() defer r.mu.Unlock()
r.prefixes = make(map[netip.Prefix]struct{}) r.prefixes = make(map[netip.Prefix]struct{})
return nil, nil, nil return nil
} }
func (r *SysOps) CleanupRouting(*statemanager.Manager) error { func (r *SysOps) CleanupRouting(*statemanager.Manager) error {

View File

@@ -72,7 +72,7 @@ func getSetupRules() []ruleParams {
// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table. // Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table.
// This table is where a default route or other specific routes received from the management server are configured, // This table is where a default route or other specific routes received from the management server are configured,
// enabling VPN connectivity. // enabling VPN connectivity.
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (_ nbnet.AddHookFunc, _ nbnet.RemoveHookFunc, err error) { func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (err error) {
if !nbnet.AdvancedRouting() { if !nbnet.AdvancedRouting() {
log.Infof("Using legacy routing setup") log.Infof("Using legacy routing setup")
return r.setupRefCounter(initAddresses, stateManager) return r.setupRefCounter(initAddresses, stateManager)
@@ -89,7 +89,7 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
rules := getSetupRules() rules := getSetupRules()
for _, rule := range rules { for _, rule := range rules {
if err := addRule(rule); err != nil { if err := addRule(rule); err != nil {
return nil, nil, fmt.Errorf("%s: %w", rule.description, err) return fmt.Errorf("%s: %w", rule.description, err)
} }
} }
@@ -104,7 +104,7 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
} }
originalSysctl = originalValues originalSysctl = originalValues
return nil, nil, nil return nil
} }
// CleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'. // CleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'.

View File

@@ -18,10 +18,9 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net"
) )
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) error {
return r.setupRefCounter(initAddresses, stateManager) return r.setupRefCounter(initAddresses, stateManager)
} }

View File

@@ -19,7 +19,6 @@ import (
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net"
) )
const InfiniteLifetime = 0xffffffff const InfiniteLifetime = 0xffffffff
@@ -137,7 +136,7 @@ const (
RouteDeleted RouteDeleted
) )
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) error {
return r.setupRefCounter(initAddresses, stateManager) return r.setupRefCounter(initAddresses, stateManager)
} }

View File

@@ -1330,6 +1330,13 @@ func (x *PeerState) GetRelayAddress() string {
return "" return ""
} }
func (x *PeerState) GetConnectionType() string {
if x.Relayed {
return "Relayed"
}
return "P2P"
}
// LocalPeerState contains the latest state of the local peer // LocalPeerState contains the latest state of the local peer
type LocalPeerState struct { type LocalPeerState struct {
state protoimpl.MessageState `protogen:"open.v1"` state protoimpl.MessageState `protogen:"open.v1"`

View File

@@ -212,7 +212,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
} }
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager) secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil) mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{})
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }

View File

@@ -100,7 +100,7 @@ type OutputOverview struct {
LazyConnectionEnabled bool `json:"lazyConnectionEnabled" yaml:"lazyConnectionEnabled"` LazyConnectionEnabled bool `json:"lazyConnectionEnabled" yaml:"lazyConnectionEnabled"`
} }
func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}) OutputOverview { func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}, connectionTypeFilter string) OutputOverview {
pbFullStatus := resp.GetFullStatus() pbFullStatus := resp.GetFullStatus()
managementState := pbFullStatus.GetManagementState() managementState := pbFullStatus.GetManagementState()
@@ -118,7 +118,7 @@ func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, status
} }
relayOverview := mapRelays(pbFullStatus.GetRelays()) relayOverview := mapRelays(pbFullStatus.GetRelays())
peersOverview := mapPeers(resp.GetFullStatus().GetPeers(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter) peersOverview := mapPeers(resp.GetFullStatus().GetPeers(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter)
overview := OutputOverview{ overview := OutputOverview{
Peers: peersOverview, Peers: peersOverview,
@@ -193,6 +193,7 @@ func mapPeers(
prefixNamesFilter []string, prefixNamesFilter []string,
prefixNamesFilterMap map[string]struct{}, prefixNamesFilterMap map[string]struct{},
ipsFilter map[string]struct{}, ipsFilter map[string]struct{},
connectionTypeFilter string,
) PeersStateOutput { ) PeersStateOutput {
var peersStateDetail []PeerStateDetailOutput var peersStateDetail []PeerStateDetailOutput
peersConnected := 0 peersConnected := 0
@@ -208,7 +209,7 @@ func mapPeers(
transferSent := int64(0) transferSent := int64(0)
isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String() isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String()
if skipDetailByFilters(pbPeerState, pbPeerState.ConnStatus, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter) { if skipDetailByFilters(pbPeerState, pbPeerState.ConnStatus, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter) {
continue continue
} }
if isPeerConnected { if isPeerConnected {
@@ -218,10 +219,7 @@ func mapPeers(
remoteICE = pbPeerState.GetRemoteIceCandidateType() remoteICE = pbPeerState.GetRemoteIceCandidateType()
localICEEndpoint = pbPeerState.GetLocalIceCandidateEndpoint() localICEEndpoint = pbPeerState.GetLocalIceCandidateEndpoint()
remoteICEEndpoint = pbPeerState.GetRemoteIceCandidateEndpoint() remoteICEEndpoint = pbPeerState.GetRemoteIceCandidateEndpoint()
connType = "P2P" connType = pbPeerState.GetConnectionType()
if pbPeerState.Relayed {
connType = "Relayed"
}
relayServerAddress = pbPeerState.GetRelayAddress() relayServerAddress = pbPeerState.GetRelayAddress()
lastHandshake = pbPeerState.GetLastWireguardHandshake().AsTime().Local() lastHandshake = pbPeerState.GetLastWireguardHandshake().AsTime().Local()
transferReceived = pbPeerState.GetBytesRx() transferReceived = pbPeerState.GetBytesRx()
@@ -542,10 +540,11 @@ func parsePeers(peers PeersStateOutput, rosenpassEnabled, rosenpassPermissive bo
return peersString return peersString
} }
func skipDetailByFilters(peerState *proto.PeerState, peerStatus string, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}) bool { func skipDetailByFilters(peerState *proto.PeerState, peerStatus string, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}, connectionTypeFilter string) bool {
statusEval := false statusEval := false
ipEval := false ipEval := false
nameEval := true nameEval := true
connectionTypeEval := false
if statusFilter != "" { if statusFilter != "" {
if !strings.EqualFold(peerStatus, statusFilter) { if !strings.EqualFold(peerStatus, statusFilter) {
@@ -570,8 +569,11 @@ func skipDetailByFilters(peerState *proto.PeerState, peerStatus string, statusFi
} else { } else {
nameEval = false nameEval = false
} }
if connectionTypeFilter != "" && !strings.EqualFold(peerState.GetConnectionType(), connectionTypeFilter) {
connectionTypeEval = true
}
return statusEval || ipEval || nameEval return statusEval || ipEval || nameEval || connectionTypeEval
} }
func toIEC(b int64) string { func toIEC(b int64) string {

View File

@@ -234,7 +234,7 @@ var overview = OutputOverview{
} }
func TestConversionFromFullStatusToOutputOverview(t *testing.T) { func TestConversionFromFullStatusToOutputOverview(t *testing.T) {
convertedResult := ConvertToStatusOutputOverview(resp, false, "", nil, nil, nil) convertedResult := ConvertToStatusOutputOverview(resp, false, "", nil, nil, nil, "")
assert.Equal(t, overview, convertedResult) assert.Equal(t, overview, convertedResult)
} }

View File

@@ -433,7 +433,7 @@ func (s *serviceClient) collectDebugData(
var postUpStatusOutput string var postUpStatusOutput string
if postUpStatus != nil { if postUpStatus != nil {
overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil) overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "")
postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview) postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
} }
headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339)) headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
@@ -450,7 +450,7 @@ func (s *serviceClient) collectDebugData(
var preDownStatusOutput string var preDownStatusOutput string
if preDownStatus != nil { if preDownStatus != nil {
overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil) overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "")
preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview) preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
} }
headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s", headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s",
@@ -581,7 +581,7 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa
var statusOutput string var statusOutput string
if statusResp != nil { if statusResp != nil {
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil) overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "")
statusOutput = nbstatus.ParseToFullDetailSummary(overview) statusOutput = nbstatus.ParseToFullDetailSummary(overview)
} }

2
go.mod
View File

@@ -63,7 +63,7 @@ require (
github.com/miekg/dns v1.1.59 github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0 github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20250612164546-6bd7e2338d65 github.com/netbirdio/management-integrations/integrations v0.0.0-20250718071730-f4d133556ff5
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250514131221-a464fd5f30cb github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250514131221-a464fd5f30cb
github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0 github.com/oschwald/maxminddb-golang v1.12.0

4
go.sum
View File

@@ -503,8 +503,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250612164546-6bd7e2338d65 h1:5OfYiLjpr4dbQYJI5ouZaylkVdi2KlErLFOwBeBo5Hw= github.com/netbirdio/management-integrations/integrations v0.0.0-20250718071730-f4d133556ff5 h1:Zfn8d83OVyELCdxgprcyXR3D8uqoxHtXE9PUxVXDx/w=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250612164546-6bd7e2338d65/go.mod h1:Gi9raplYzCCyh07Olw/DVfCJTFgpr1WCXJ/Q+8TSA9Q= github.com/netbirdio/management-integrations/integrations v0.0.0-20250718071730-f4d133556ff5/go.mod h1:Gi9raplYzCCyh07Olw/DVfCJTFgpr1WCXJ/Q+8TSA9Q=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250514131221-a464fd5f30cb h1:Cr6age+ePALqlSvtp7wc6lYY97XN7rkD1K4XEDmY+TU= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250514131221-a464fd5f30cb h1:Cr6age+ePALqlSvtp7wc6lYY97XN7rkD1K4XEDmY+TU=

View File

@@ -112,7 +112,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
} }
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager) secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil) mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, mgmt.MockIntegratedValidator{})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -292,7 +292,7 @@ var (
ephemeralManager.LoadInitialPeers(ctx) ephemeralManager.LoadInitialPeers(ctx)
gRPCAPIHandler := grpc.NewServer(gRPCOpts...) gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
srv, err := server.NewServer(ctx, config, accountManager, settingsManager, peersUpdateManager, secretsManager, appMetrics, ephemeralManager, authManager) srv, err := server.NewServer(ctx, config, accountManager, settingsManager, peersUpdateManager, secretsManager, appMetrics, ephemeralManager, authManager, integratedPeerValidator)
if err != nil { if err != nil {
return fmt.Errorf("failed creating gRPC API handler: %v", err) return fmt.Errorf("failed creating gRPC API handler: %v", err)
} }

View File

@@ -2887,7 +2887,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, error) {
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -219,7 +219,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
// return empty extra settings for expected calls to UpdateAccountPeers // return empty extra settings for expected calls to UpdateAccountPeers
settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), gomock.Any()).Return(&types.ExtraSettings{}, nil).AnyTimes() settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), gomock.Any()).Return(&types.ExtraSettings{}, nil).AnyTimes()
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
} }
func createDNSStore(t *testing.T) (store.Store, error) { func createDNSStore(t *testing.T) (store.Store, error) {

View File

@@ -19,6 +19,7 @@ import (
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config" integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
"github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/proto"
@@ -47,6 +48,7 @@ type GRPCServer struct {
ephemeralManager *EphemeralManager ephemeralManager *EphemeralManager
peerLocks sync.Map peerLocks sync.Map
authManager auth.Manager authManager auth.Manager
integratedPeerValidator integrated_validator.IntegratedValidator
} }
// NewServer creates a new Management server // NewServer creates a new Management server
@@ -60,6 +62,7 @@ func NewServer(
appMetrics telemetry.AppMetrics, appMetrics telemetry.AppMetrics,
ephemeralManager *EphemeralManager, ephemeralManager *EphemeralManager,
authManager auth.Manager, authManager auth.Manager,
integratedPeerValidator integrated_validator.IntegratedValidator,
) (*GRPCServer, error) { ) (*GRPCServer, error) {
key, err := wgtypes.GeneratePrivateKey() key, err := wgtypes.GeneratePrivateKey()
if err != nil { if err != nil {
@@ -87,6 +90,7 @@ func NewServer(
authManager: authManager, authManager: authManager,
appMetrics: appMetrics, appMetrics: appMetrics,
ephemeralManager: ephemeralManager, ephemeralManager: ephemeralManager,
integratedPeerValidator: integratedPeerValidator,
}, nil }, nil
} }
@@ -850,7 +854,7 @@ func (s *GRPCServer) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.En
return nil, status.Error(codes.NotFound, "no pkce authorization flow information available") return nil, status.Error(codes.NotFound, "no pkce authorization flow information available")
} }
flowInfoResp := &proto.PKCEAuthorizationFlow{ initInfoFlow := &proto.PKCEAuthorizationFlow{
ProviderConfig: &proto.ProviderConfig{ ProviderConfig: &proto.ProviderConfig{
Audience: s.config.PKCEAuthorizationFlow.ProviderConfig.Audience, Audience: s.config.PKCEAuthorizationFlow.ProviderConfig.Audience,
ClientID: s.config.PKCEAuthorizationFlow.ProviderConfig.ClientID, ClientID: s.config.PKCEAuthorizationFlow.ProviderConfig.ClientID,
@@ -865,6 +869,8 @@ func (s *GRPCServer) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.En
}, },
} }
flowInfoResp := s.integratedPeerValidator.ValidateFlowResponse(ctx, peerKey.String(), initInfoFlow)
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, flowInfoResp) encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, flowInfoResp)
if err != nil { if err != nil {
return nil, status.Error(codes.Internal, "failed to encrypt no pkce authorization flow information") return nil, status.Error(codes.Internal, "failed to encrypt no pkce authorization flow information")

View File

@@ -427,6 +427,7 @@ func toPolicyResponse(groups []*types.Group, policy *types.Policy) *api.Policy {
Id: group.ID, Id: group.ID,
Name: group.Name, Name: group.Name,
PeersCount: len(group.Peers), PeersCount: len(group.Peers),
ResourcesCount: len(group.Resources),
} }
destinations = append(destinations, minimum) destinations = append(destinations, minimum)
cache[gid] = minimum cache[gid] = minimum

View File

@@ -1,4 +1,5 @@
package testing_tools package testing_tools
import ( import (
"bytes" "bytes"
"context" "context"
@@ -132,7 +133,7 @@ func BuildApiBlackBoxWithDBState(t TB, sqlFile string, expectedPeerUpdate *serve
} }
geoMock := &geolocation.Mock{} geoMock := &geolocation.Mock{}
validatorMock := server.MocIntegratedValidator{} validatorMock := server.MockIntegratedValidator{}
proxyController := integrations.NewController(store) proxyController := integrations.NewController(store)
userManager := users.NewManager(store) userManager := users.NewManager(store)
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)

View File

@@ -6,6 +6,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
@@ -101,22 +102,23 @@ func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountI
return am.integratedPeerValidator.GetValidatedPeers(accountID, groups, peers, settings.Extra) return am.integratedPeerValidator.GetValidatedPeers(accountID, groups, peers, settings.Extra)
} }
type MocIntegratedValidator struct { type MockIntegratedValidator struct {
integrated_validator.IntegratedValidator
ValidatePeerFunc func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error) ValidatePeerFunc func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error)
} }
func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error { func (a MockIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error {
return nil return nil
} }
func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error) { func (a MockIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error) {
if a.ValidatePeerFunc != nil { if a.ValidatePeerFunc != nil {
return a.ValidatePeerFunc(context.Background(), update, peer, userID, accountID, dnsDomain, peersGroup, extraSettings) return a.ValidatePeerFunc(context.Background(), update, peer, userID, accountID, dnsDomain, peersGroup, extraSettings)
} }
return update, false, nil return update, false, nil
} }
func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *types.ExtraSettings) (map[string]struct{}, error) { func (a MockIntegratedValidator) GetValidatedPeers(accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *types.ExtraSettings) (map[string]struct{}, error) {
validatedPeers := make(map[string]struct{}) validatedPeers := make(map[string]struct{})
for _, peer := range peers { for _, peer := range peers {
validatedPeers[peer.ID] = struct{}{} validatedPeers[peer.ID] = struct{}{}
@@ -124,22 +126,22 @@ func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups []*ty
return validatedPeers, nil return validatedPeers, nil
} }
func (MocIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) *nbpeer.Peer { func (MockIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) *nbpeer.Peer {
return peer return peer
} }
func (MocIntegratedValidator) IsNotValidPeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error) { func (MockIntegratedValidator) IsNotValidPeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error) {
return false, false, nil return false, false, nil
} }
func (MocIntegratedValidator) PeerDeleted(_ context.Context, _, _ string) error { func (MockIntegratedValidator) PeerDeleted(_ context.Context, _, _ string) error {
return nil return nil
} }
func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string)) { func (MockIntegratedValidator) SetPeerInvalidationListener(func(accountID string)) {
// just a dummy // just a dummy
} }
func (MocIntegratedValidator) Stop(_ context.Context) { func (MockIntegratedValidator) Stop(_ context.Context) {
// just a dummy // just a dummy
} }

View File

@@ -3,6 +3,7 @@ package integrated_validator
import ( import (
"context" "context"
"github.com/netbirdio/netbird/management/proto"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
) )
@@ -17,4 +18,5 @@ type IntegratedValidator interface {
PeerDeleted(ctx context.Context, accountID, peerID string) error PeerDeleted(ctx context.Context, accountID, peerID string) error
SetPeerInvalidationListener(fn func(accountID string)) SetPeerInvalidationListener(fn func(accountID string))
Stop(ctx context.Context) Stop(ctx context.Context)
ValidateFlowResponse(ctx context.Context, peerKey string, flowResponse *proto.PKCEAuthorizationFlow) *proto.PKCEAuthorizationFlow
} }

View File

@@ -448,7 +448,7 @@ func startManagementForTest(t *testing.T, testFile string, config *types.Config)
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)
accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted", accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted",
eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil { if err != nil {
cleanup() cleanup()
@@ -458,7 +458,7 @@ func startManagementForTest(t *testing.T, testFile string, config *types.Config)
secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager) secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
ephemeralMgr := NewEphemeralManager(store, accountManager) ephemeralMgr := NewEphemeralManager(store, accountManager)
mgmtServer, err := NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, ephemeralMgr, nil) mgmtServer, err := NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, ephemeralMgr, nil, MockIntegratedValidator{})
if err != nil { if err != nil {
return nil, nil, "", cleanup, err return nil, nil, "", cleanup, err
} }

View File

@@ -206,7 +206,7 @@ func startServer(
eventStore, eventStore,
nil, nil,
false, false,
server.MocIntegratedValidator{}, server.MockIntegratedValidator{},
metrics, metrics,
port_forwarding.NewControllerMock(), port_forwarding.NewControllerMock(),
settingsMockManager, settingsMockManager,
@@ -227,6 +227,7 @@ func startServer(
nil, nil,
nil, nil,
nil, nil,
server.MockIntegratedValidator{},
) )
if err != nil { if err != nil {
t.Fatalf("failed creating management server: %v", err) t.Fatalf("failed creating management server: %v", err)

View File

@@ -283,7 +283,7 @@ func MigrateSetupKeyToHashedSetupKey[T any](ctx context.Context, db *gorm.DB) er
} }
} }
if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s", "peers", "setup_key")).Error; err != nil { if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s DROP COLUMN IF EXISTS %s", "peers", "setup_key")).Error; err != nil {
log.WithContext(ctx).Errorf("Failed to drop column %s: %v", "setup_key", err) log.WithContext(ctx).Errorf("Failed to drop column %s: %v", "setup_key", err)
} }
@@ -377,6 +377,11 @@ func DropIndex[T any](ctx context.Context, db *gorm.DB, indexName string) error
func CreateIndexIfNotExists[T any](ctx context.Context, db *gorm.DB, indexName string, columns ...string) error { func CreateIndexIfNotExists[T any](ctx context.Context, db *gorm.DB, indexName string, columns ...string) error {
var model T var model T
if !db.Migrator().HasTable(&model) {
log.WithContext(ctx).Debugf("table for %T does not exist, no migration needed", model)
return nil
}
stmt := &gorm.Statement{DB: db} stmt := &gorm.Statement{DB: db}
if err := stmt.Parse(&model); err != nil { if err := stmt.Parse(&model); err != nil {
return fmt.Errorf("failed to parse model schema: %w", err) return fmt.Errorf("failed to parse model schema: %w", err)
@@ -384,6 +389,11 @@ func CreateIndexIfNotExists[T any](ctx context.Context, db *gorm.DB, indexName s
tableName := stmt.Schema.Table tableName := stmt.Schema.Table
dialect := db.Dialector.Name() dialect := db.Dialector.Name()
if db.Migrator().HasIndex(&model, indexName) {
log.WithContext(ctx).Infof("index %s already exists on table %s", indexName, tableName)
return nil
}
var columnClause string var columnClause string
if dialect == "mysql" { if dialect == "mysql" {
var withLength []string var withLength []string

View File

@@ -4,16 +4,21 @@ import (
"context" "context"
"encoding/gob" "encoding/gob"
"net" "net"
"os"
"strings" "strings"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite" "gorm.io/driver/sqlite"
"gorm.io/gorm" "gorm.io/gorm"
"github.com/netbirdio/netbird/management/server/migration" "github.com/netbirdio/netbird/management/server/migration"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/testutil"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@@ -21,7 +26,41 @@ import (
func setupDatabase(t *testing.T) *gorm.DB { func setupDatabase(t *testing.T) *gorm.DB {
t.Helper() t.Helper()
db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{}) var db *gorm.DB
var err error
var dsn string
var cleanup func()
switch os.Getenv("NETBIRD_STORE_ENGINE") {
case "mysql":
cleanup, dsn, err = testutil.CreateMysqlTestContainer()
if err != nil {
t.Fatalf("Failed to create MySQL test container: %v", err)
}
if dsn == "" {
t.Fatal("MySQL connection string is empty, ensure the test container is running")
}
db, err = gorm.Open(mysql.Open(dsn+"?charset=utf8&parseTime=True&loc=Local"), &gorm.Config{})
case "postgres":
cleanup, dsn, err = testutil.CreatePostgresTestContainer()
if err != nil {
t.Fatalf("Failed to create PostgreSQL test container: %v", err)
}
if dsn == "" {
t.Fatalf("PostgreSQL connection string is empty, ensure the test container is running")
}
db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
case "sqlite":
db, err = gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{})
default:
db, err = gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{})
}
if cleanup != nil {
t.Cleanup(cleanup)
}
require.NoError(t, err, "Failed to open database") require.NoError(t, err, "Failed to open database")
return db return db
@@ -34,6 +73,7 @@ func TestMigrateFieldFromGobToJSON_EmptyDB(t *testing.T) {
} }
func TestMigrateFieldFromGobToJSON_WithGobData(t *testing.T) { func TestMigrateFieldFromGobToJSON_WithGobData(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", "sqlite")
db := setupDatabase(t) db := setupDatabase(t)
err := db.AutoMigrate(&types.Account{}, &route.Route{}) err := db.AutoMigrate(&types.Account{}, &route.Route{})
@@ -97,6 +137,7 @@ func TestMigrateNetIPFieldFromBlobToJSON_EmptyDB(t *testing.T) {
} }
func TestMigrateNetIPFieldFromBlobToJSON_WithBlobData(t *testing.T) { func TestMigrateNetIPFieldFromBlobToJSON_WithBlobData(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", "sqlite")
db := setupDatabase(t) db := setupDatabase(t)
err := db.AutoMigrate(&types.Account{}, &nbpeer.Peer{}) err := db.AutoMigrate(&types.Account{}, &nbpeer.Peer{})
@@ -117,12 +158,18 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithBlobData(t *testing.T) {
Peers []peer `gorm:"foreignKey:AccountID;references:id"` Peers []peer `gorm:"foreignKey:AccountID;references:id"`
} }
err = db.Save(&account{ a := &account{
Account: types.Account{Id: "123"}, Account: types.Account{Id: "123"},
Peers: []peer{ }
err = db.Save(a).Error
require.NoError(t, err, "Failed to insert account")
a.Peers = []peer{
{Location: location{ConnectionIP: net.IP{10, 0, 0, 1}}}, {Location: location{ConnectionIP: net.IP{10, 0, 0, 1}}},
}}, }
).Error
err = db.Save(a).Error
require.NoError(t, err, "Failed to insert blob data") require.NoError(t, err, "Failed to insert blob data")
var blobValue string var blobValue string
@@ -143,12 +190,18 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithJSONData(t *testing.T) {
err := db.AutoMigrate(&types.Account{}, &nbpeer.Peer{}) err := db.AutoMigrate(&types.Account{}, &nbpeer.Peer{})
require.NoError(t, err, "Failed to auto-migrate tables") require.NoError(t, err, "Failed to auto-migrate tables")
err = db.Save(&types.Account{ account := &types.Account{
Id: "1234", Id: "1234",
PeersG: []nbpeer.Peer{ }
{Location: nbpeer.Location{ConnectionIP: net.IP{10, 0, 0, 1}}},
}}, err = db.Save(account).Error
).Error require.NoError(t, err, "Failed to insert account")
account.PeersG = []nbpeer.Peer{
{AccountID: "1234", Location: nbpeer.Location{ConnectionIP: net.IP{10, 0, 0, 1}}},
}
err = db.Save(account).Error
require.NoError(t, err, "Failed to insert JSON data") require.NoError(t, err, "Failed to insert JSON data")
err = migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](context.Background(), db, "location_connection_ip", "") err = migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](context.Background(), db, "location_connection_ip", "")
@@ -162,12 +215,13 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithJSONData(t *testing.T) {
func TestMigrateSetupKeyToHashedSetupKey_ForPlainKey(t *testing.T) { func TestMigrateSetupKeyToHashedSetupKey_ForPlainKey(t *testing.T) {
db := setupDatabase(t) db := setupDatabase(t)
err := db.AutoMigrate(&types.SetupKey{}) err := db.AutoMigrate(&types.SetupKey{}, &nbpeer.Peer{})
require.NoError(t, err, "Failed to auto-migrate tables") require.NoError(t, err, "Failed to auto-migrate tables")
err = db.Save(&types.SetupKey{ err = db.Save(&types.SetupKey{
Id: "1", Id: "1",
Key: "EEFDAB47-C1A5-4472-8C05-71DE9A1E8382", Key: "EEFDAB47-C1A5-4472-8C05-71DE9A1E8382",
UpdatedAt: time.Now(),
}).Error }).Error
require.NoError(t, err, "Failed to insert setup key") require.NoError(t, err, "Failed to insert setup key")
@@ -192,6 +246,7 @@ func TestMigrateSetupKeyToHashedSetupKey_ForAlreadyMigratedKey_Case1(t *testing.
Id: "1", Id: "1",
Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=",
KeySecret: "EEFDA****", KeySecret: "EEFDA****",
UpdatedAt: time.Now(),
}).Error }).Error
require.NoError(t, err, "Failed to insert setup key") require.NoError(t, err, "Failed to insert setup key")
@@ -215,6 +270,7 @@ func TestMigrateSetupKeyToHashedSetupKey_ForAlreadyMigratedKey_Case2(t *testing.
err = db.Save(&types.SetupKey{ err = db.Save(&types.SetupKey{
Id: "1", Id: "1",
Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=",
UpdatedAt: time.Now(),
}).Error }).Error
require.NoError(t, err, "Failed to insert setup key") require.NoError(t, err, "Failed to insert setup key")
@@ -237,6 +293,7 @@ func TestDropIndex(t *testing.T) {
err = db.Save(&types.SetupKey{ err = db.Save(&types.SetupKey{
Id: "1", Id: "1",
Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=",
UpdatedAt: time.Now(),
}).Error }).Error
require.NoError(t, err, "Failed to insert setup key") require.NoError(t, err, "Failed to insert setup key")
@@ -249,3 +306,37 @@ func TestDropIndex(t *testing.T) {
exist = db.Migrator().HasIndex(&types.SetupKey{}, "idx_setup_keys_account_id") exist = db.Migrator().HasIndex(&types.SetupKey{}, "idx_setup_keys_account_id")
assert.False(t, exist, "Should not have the index") assert.False(t, exist, "Should not have the index")
} }
func TestCreateIndex(t *testing.T) {
db := setupDatabase(t)
err := db.AutoMigrate(&nbpeer.Peer{})
assert.NoError(t, err, "Failed to auto-migrate tables")
indexName := "idx_account_ip"
err = migration.CreateIndexIfNotExists[nbpeer.Peer](context.Background(), db, indexName, "account_id", "ip")
assert.NoError(t, err, "Migration should not fail to create index")
exist := db.Migrator().HasIndex(&nbpeer.Peer{}, indexName)
assert.True(t, exist, "Should have the index")
}
func TestCreateIndexIfExists(t *testing.T) {
db := setupDatabase(t)
err := db.AutoMigrate(&nbpeer.Peer{})
assert.NoError(t, err, "Failed to auto-migrate tables")
indexName := "idx_account_ip"
err = migration.CreateIndexIfNotExists[nbpeer.Peer](context.Background(), db, indexName, "account_id", "ip")
assert.NoError(t, err, "Migration should not fail to create index")
exist := db.Migrator().HasIndex(&nbpeer.Peer{}, indexName)
assert.True(t, exist, "Should have the index")
err = migration.CreateIndexIfNotExists[nbpeer.Peer](context.Background(), db, indexName, "account_id", "ip")
assert.NoError(t, err, "Create index should not fail if index exists")
exist = db.Migrator().HasIndex(&nbpeer.Peer{}, indexName)
assert.True(t, exist, "Should have the index")
}

View File

@@ -785,7 +785,7 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
AnyTimes() AnyTimes()
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
} }
func createNSStore(t *testing.T) (store.Store, error) { func createNSStore(t *testing.T) (store.Store, error) {

View File

@@ -1273,7 +1273,7 @@ func Test_RegisterPeerByUser(t *testing.T) {
settingsMockManager := settings.NewMockManager(ctrl) settingsMockManager := settings.NewMockManager(ctrl)
permissionsManager := permissions.NewManager(s) permissionsManager := permissions.NewManager(s)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err) assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -1353,7 +1353,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
AnyTimes() AnyTimes()
permissionsManager := permissions.NewManager(s) permissionsManager := permissions.NewManager(s)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err) assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -1496,7 +1496,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
permissionsManager := permissions.NewManager(s) permissionsManager := permissions.NewManager(s)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err) assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -1570,7 +1570,7 @@ func Test_LoginPeer(t *testing.T) {
AnyTimes() AnyTimes()
permissionsManager := permissions.NewManager(s) permissionsManager := permissions.NewManager(s)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err) assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -1848,7 +1848,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
return update, true, nil return update, true, nil
} }
manager.integratedPeerValidator = MocIntegratedValidator{ValidatePeerFunc: requireUpdateFunc} manager.integratedPeerValidator = MockIntegratedValidator{ValidatePeerFunc: requireUpdateFunc}
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
peerShouldReceiveUpdate(t, updMsg) peerShouldReceiveUpdate(t, updMsg)
@@ -1870,7 +1870,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
return update, false, nil return update, false, nil
} }
manager.integratedPeerValidator = MocIntegratedValidator{ValidatePeerFunc: requireNoUpdateFunc} manager.integratedPeerValidator = MockIntegratedValidator{ValidatePeerFunc: requireNoUpdateFunc}
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
peerShouldNotReceiveUpdate(t, updMsg) peerShouldNotReceiveUpdate(t, updMsg)

View File

@@ -1284,7 +1284,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) {
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
} }
func createRouterStore(t *testing.T) (store.Store, error) { func createRouterStore(t *testing.T) (store.Store, error) {

View File

@@ -852,7 +852,7 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
eventStore: &activity.InMemoryEventStore{}, eventStore: &activity.InMemoryEventStore{},
integratedPeerValidator: MocIntegratedValidator{}, integratedPeerValidator: MockIntegratedValidator{},
permissionsManager: permissionsManager, permissionsManager: permissionsManager,
} }

View File

@@ -7,13 +7,6 @@ import (
authv2 "github.com/netbirdio/netbird/relay/auth/hmac/v2" authv2 "github.com/netbirdio/netbird/relay/auth/hmac/v2"
) )
// Validator is an interface that defines the Validate method.
type Validator interface {
Validate(any) error
// Deprecated: Use Validate instead.
ValidateHelloMsgType(any) error
}
type TimedHMACValidator struct { type TimedHMACValidator struct {
authenticatorV2 *authv2.Validator authenticatorV2 *authv2.Validator
authenticator *auth.TimedHMACValidator authenticator *auth.TimedHMACValidator

View File

@@ -124,15 +124,14 @@ func (cc *connContainer) close() {
// While the Connect is in progress, the OpenConn function will block until the connection is established with relay server. // While the Connect is in progress, the OpenConn function will block until the connection is established with relay server.
type Client struct { type Client struct {
log *log.Entry log *log.Entry
parentCtx context.Context
connectionURL string connectionURL string
authTokenStore *auth.TokenStore authTokenStore *auth.TokenStore
hashedID []byte hashedID messages.PeerID
bufPool *sync.Pool bufPool *sync.Pool
relayConn net.Conn relayConn net.Conn
conns map[string]*connContainer conns map[messages.PeerID]*connContainer
serviceIsRunning bool serviceIsRunning bool
mu sync.Mutex // protect serviceIsRunning and conns mu sync.Mutex // protect serviceIsRunning and conns
readLoopMutex sync.Mutex readLoopMutex sync.Mutex
@@ -142,14 +141,17 @@ type Client struct {
onDisconnectListener func(string) onDisconnectListener func(string)
listenerMutex sync.Mutex listenerMutex sync.Mutex
stateSubscription *PeersStateSubscription
} }
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect // NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client { func NewClient(serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client {
hashedID, hashedStringId := messages.HashID(peerID) hashedID := messages.HashID(peerID)
relayLog := log.WithFields(log.Fields{"relay": serverURL})
c := &Client{ c := &Client{
log: log.WithFields(log.Fields{"relay": serverURL}), log: relayLog,
parentCtx: ctx,
connectionURL: serverURL, connectionURL: serverURL,
authTokenStore: authTokenStore, authTokenStore: authTokenStore,
hashedID: hashedID, hashedID: hashedID,
@@ -159,14 +161,15 @@ func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.Token
return &buf return &buf
}, },
}, },
conns: make(map[string]*connContainer), conns: make(map[messages.PeerID]*connContainer),
} }
c.log.Infof("create new relay connection: local peerID: %s, local peer hashedID: %s", peerID, hashedStringId)
c.log.Infof("create new relay connection: local peerID: %s, local peer hashedID: %s", peerID, hashedID)
return c return c
} }
// Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs. // Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs.
func (c *Client) Connect() error { func (c *Client) Connect(ctx context.Context) error {
c.log.Infof("connecting to relay server") c.log.Infof("connecting to relay server")
c.readLoopMutex.Lock() c.readLoopMutex.Lock()
defer c.readLoopMutex.Unlock() defer c.readLoopMutex.Unlock()
@@ -178,17 +181,27 @@ func (c *Client) Connect() error {
return nil return nil
} }
if err := c.connect(); err != nil { instanceURL, err := c.connect(ctx)
if err != nil {
return err return err
} }
c.muInstanceURL.Lock()
c.instanceURL = instanceURL
c.muInstanceURL.Unlock()
c.log = c.log.WithField("relay", c.instanceURL.String()) c.stateSubscription = NewPeersStateSubscription(c.log, c.relayConn, c.closeConnsByPeerID)
c.log = c.log.WithField("relay", instanceURL.String())
c.log.Infof("relay connection established") c.log.Infof("relay connection established")
c.serviceIsRunning = true c.serviceIsRunning = true
internallyStoppedFlag := newInternalStopFlag()
hc := healthcheck.NewReceiver(c.log)
go c.listenForStopEvents(ctx, hc, c.relayConn, internallyStoppedFlag)
c.wgReadLoop.Add(1) c.wgReadLoop.Add(1)
go c.readLoop(c.relayConn) go c.readLoop(hc, c.relayConn, internallyStoppedFlag)
return nil return nil
} }
@@ -196,26 +209,50 @@ func (c *Client) Connect() error {
// OpenConn create a new net.Conn for the destination peer ID. In case if the connection is in progress // OpenConn create a new net.Conn for the destination peer ID. In case if the connection is in progress
// to the relay server, the function will block until the connection is established or timed out. Otherwise, // to the relay server, the function will block until the connection is established or timed out. Otherwise,
// it will return immediately. // it will return immediately.
// It block until the server confirm the peer is online.
// todo: what should happen if call with the same peerID with multiple times? // todo: what should happen if call with the same peerID with multiple times?
func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) { func (c *Client) OpenConn(ctx context.Context, dstPeerID string) (net.Conn, error) {
c.mu.Lock() peerID := messages.HashID(dstPeerID)
defer c.mu.Unlock()
c.mu.Lock()
if !c.serviceIsRunning { if !c.serviceIsRunning {
c.mu.Unlock()
return nil, fmt.Errorf("relay connection is not established")
}
_, ok := c.conns[peerID]
if ok {
c.mu.Unlock()
return nil, ErrConnAlreadyExists
}
c.mu.Unlock()
if err := c.stateSubscription.WaitToBeOnlineAndSubscribe(ctx, peerID); err != nil {
c.log.Errorf("peer not available: %s, %s", peerID, err)
return nil, err
}
c.log.Infof("remote peer is available, prepare the relayed connection: %s", peerID)
msgChannel := make(chan Msg, 100)
c.mu.Lock()
if !c.serviceIsRunning {
c.mu.Unlock()
return nil, fmt.Errorf("relay connection is not established") return nil, fmt.Errorf("relay connection is not established")
} }
hashedID, hashedStringID := messages.HashID(dstPeerID) c.muInstanceURL.Lock()
_, ok := c.conns[hashedStringID] instanceURL := c.instanceURL
c.muInstanceURL.Unlock()
conn := NewConn(c, peerID, msgChannel, instanceURL)
_, ok = c.conns[peerID]
if ok { if ok {
c.mu.Unlock()
_ = conn.Close()
return nil, ErrConnAlreadyExists return nil, ErrConnAlreadyExists
} }
c.conns[peerID] = newConnContainer(c.log, conn, msgChannel)
c.log.Infof("open connection to peer: %s", hashedStringID) c.mu.Unlock()
msgChannel := make(chan Msg, 100)
conn := NewConn(c, hashedID, hashedStringID, msgChannel, c.instanceURL)
c.conns[hashedStringID] = newConnContainer(c.log, conn, msgChannel)
return conn, nil return conn, nil
} }
@@ -254,76 +291,70 @@ func (c *Client) Close() error {
return c.close(true) return c.close(true)
} }
func (c *Client) connect() error { func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
rd := dialer.NewRaceDial(c.log, c.connectionURL, quic.Dialer{}, ws.Dialer{}) rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, quic.Dialer{}, ws.Dialer{})
conn, err := rd.Dial() conn, err := rd.Dial()
if err != nil { if err != nil {
return err return nil, err
} }
c.relayConn = conn c.relayConn = conn
if err = c.handShake(); err != nil { instanceURL, err := c.handShake(ctx)
if err != nil {
cErr := conn.Close() cErr := conn.Close()
if cErr != nil { if cErr != nil {
c.log.Errorf("failed to close connection: %s", cErr) c.log.Errorf("failed to close connection: %s", cErr)
} }
return err return nil, err
} }
return nil return instanceURL, nil
} }
func (c *Client) handShake() error { func (c *Client) handShake(ctx context.Context) (*RelayAddr, error) {
msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary()) msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary())
if err != nil { if err != nil {
c.log.Errorf("failed to marshal auth message: %s", err) c.log.Errorf("failed to marshal auth message: %s", err)
return err return nil, err
} }
_, err = c.relayConn.Write(msg) _, err = c.relayConn.Write(msg)
if err != nil { if err != nil {
c.log.Errorf("failed to send auth message: %s", err) c.log.Errorf("failed to send auth message: %s", err)
return err return nil, err
} }
buf := make([]byte, messages.MaxHandshakeRespSize) buf := make([]byte, messages.MaxHandshakeRespSize)
n, err := c.readWithTimeout(buf) n, err := c.readWithTimeout(ctx, buf)
if err != nil { if err != nil {
c.log.Errorf("failed to read auth response: %s", err) c.log.Errorf("failed to read auth response: %s", err)
return err return nil, err
} }
_, err = messages.ValidateVersion(buf[:n]) _, err = messages.ValidateVersion(buf[:n])
if err != nil { if err != nil {
return fmt.Errorf("validate version: %w", err) return nil, fmt.Errorf("validate version: %w", err)
} }
msgType, err := messages.DetermineServerMessageType(buf[:n]) msgType, err := messages.DetermineServerMessageType(buf[:n])
if err != nil { if err != nil {
c.log.Errorf("failed to determine message type: %s", err) c.log.Errorf("failed to determine message type: %s", err)
return err return nil, err
} }
if msgType != messages.MsgTypeAuthResponse { if msgType != messages.MsgTypeAuthResponse {
c.log.Errorf("unexpected message type: %s", msgType) c.log.Errorf("unexpected message type: %s", msgType)
return fmt.Errorf("unexpected message type") return nil, fmt.Errorf("unexpected message type")
} }
addr, err := messages.UnmarshalAuthResponse(buf[:n]) addr, err := messages.UnmarshalAuthResponse(buf[:n])
if err != nil { if err != nil {
return err return nil, err
} }
c.muInstanceURL.Lock() return &RelayAddr{addr: addr}, nil
c.instanceURL = &RelayAddr{addr: addr}
c.muInstanceURL.Unlock()
return nil
} }
func (c *Client) readLoop(relayConn net.Conn) { func (c *Client) readLoop(hc *healthcheck.Receiver, relayConn net.Conn, internallyStoppedFlag *internalStopFlag) {
internallyStoppedFlag := newInternalStopFlag()
hc := healthcheck.NewReceiver(c.log)
go c.listenForStopEvents(hc, relayConn, internallyStoppedFlag)
var ( var (
errExit error errExit error
n int n int
@@ -366,10 +397,7 @@ func (c *Client) readLoop(relayConn net.Conn) {
hc.Stop() hc.Stop()
c.muInstanceURL.Lock() c.stateSubscription.Cleanup()
c.instanceURL = nil
c.muInstanceURL.Unlock()
c.wgReadLoop.Done() c.wgReadLoop.Done()
_ = c.close(false) _ = c.close(false)
c.notifyDisconnected() c.notifyDisconnected()
@@ -382,6 +410,14 @@ func (c *Client) handleMsg(msgType messages.MsgType, buf []byte, bufPtr *[]byte,
c.bufPool.Put(bufPtr) c.bufPool.Put(bufPtr)
case messages.MsgTypeTransport: case messages.MsgTypeTransport:
return c.handleTransportMsg(buf, bufPtr, internallyStoppedFlag) return c.handleTransportMsg(buf, bufPtr, internallyStoppedFlag)
case messages.MsgTypePeersOnline:
c.handlePeersOnlineMsg(buf)
c.bufPool.Put(bufPtr)
return true
case messages.MsgTypePeersWentOffline:
c.handlePeersWentOfflineMsg(buf)
c.bufPool.Put(bufPtr)
return true
case messages.MsgTypeClose: case messages.MsgTypeClose:
c.log.Debugf("relay connection close by server") c.log.Debugf("relay connection close by server")
c.bufPool.Put(bufPtr) c.bufPool.Put(bufPtr)
@@ -413,18 +449,16 @@ func (c *Client) handleTransportMsg(buf []byte, bufPtr *[]byte, internallyStoppe
return true return true
} }
stringID := messages.HashIDToString(peerID)
c.mu.Lock() c.mu.Lock()
if !c.serviceIsRunning { if !c.serviceIsRunning {
c.mu.Unlock() c.mu.Unlock()
c.bufPool.Put(bufPtr) c.bufPool.Put(bufPtr)
return false return false
} }
container, ok := c.conns[stringID] container, ok := c.conns[*peerID]
c.mu.Unlock() c.mu.Unlock()
if !ok { if !ok {
c.log.Errorf("peer not found: %s", stringID) c.log.Errorf("peer not found: %s", peerID.String())
c.bufPool.Put(bufPtr) c.bufPool.Put(bufPtr)
return true return true
} }
@@ -437,9 +471,9 @@ func (c *Client) handleTransportMsg(buf []byte, bufPtr *[]byte, internallyStoppe
return true return true
} }
func (c *Client) writeTo(connReference *Conn, id string, dstID []byte, payload []byte) (int, error) { func (c *Client) writeTo(connReference *Conn, dstID messages.PeerID, payload []byte) (int, error) {
c.mu.Lock() c.mu.Lock()
conn, ok := c.conns[id] conn, ok := c.conns[dstID]
c.mu.Unlock() c.mu.Unlock()
if !ok { if !ok {
return 0, net.ErrClosed return 0, net.ErrClosed
@@ -464,7 +498,7 @@ func (c *Client) writeTo(connReference *Conn, id string, dstID []byte, payload [
return len(payload), err return len(payload), err
} }
func (c *Client) listenForStopEvents(hc *healthcheck.Receiver, conn net.Conn, internalStopFlag *internalStopFlag) { func (c *Client) listenForStopEvents(ctx context.Context, hc *healthcheck.Receiver, conn net.Conn, internalStopFlag *internalStopFlag) {
for { for {
select { select {
case _, ok := <-hc.OnTimeout: case _, ok := <-hc.OnTimeout:
@@ -478,7 +512,7 @@ func (c *Client) listenForStopEvents(hc *healthcheck.Receiver, conn net.Conn, in
c.log.Warnf("failed to close connection: %s", err) c.log.Warnf("failed to close connection: %s", err)
} }
return return
case <-c.parentCtx.Done(): case <-ctx.Done():
err := c.close(true) err := c.close(true)
if err != nil { if err != nil {
c.log.Errorf("failed to teardown connection: %s", err) c.log.Errorf("failed to teardown connection: %s", err)
@@ -492,10 +526,31 @@ func (c *Client) closeAllConns() {
for _, container := range c.conns { for _, container := range c.conns {
container.close() container.close()
} }
c.conns = make(map[string]*connContainer) c.conns = make(map[messages.PeerID]*connContainer)
} }
func (c *Client) closeConn(connReference *Conn, id string) error { func (c *Client) closeConnsByPeerID(peerIDs []messages.PeerID) {
c.mu.Lock()
defer c.mu.Unlock()
for _, peerID := range peerIDs {
container, ok := c.conns[peerID]
if !ok {
c.log.Warnf("can not close connection, peer not found: %s", peerID)
continue
}
container.log.Infof("remote peer has been disconnected, free up connection: %s", peerID)
container.close()
delete(c.conns, peerID)
}
if err := c.stateSubscription.UnsubscribeStateChange(peerIDs); err != nil {
c.log.Errorf("failed to unsubscribe from peer state change: %s, %s", peerIDs, err)
}
}
func (c *Client) closeConn(connReference *Conn, id messages.PeerID) error {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
@@ -507,6 +562,11 @@ func (c *Client) closeConn(connReference *Conn, id string) error {
if container.conn != connReference { if container.conn != connReference {
return fmt.Errorf("conn reference mismatch") return fmt.Errorf("conn reference mismatch")
} }
if err := c.stateSubscription.UnsubscribeStateChange([]messages.PeerID{id}); err != nil {
container.log.Errorf("failed to unsubscribe from peer state change: %s", err)
}
c.log.Infof("free up connection to peer: %s", id) c.log.Infof("free up connection to peer: %s", id)
delete(c.conns, id) delete(c.conns, id)
container.close() container.close()
@@ -525,8 +585,12 @@ func (c *Client) close(gracefullyExit bool) error {
c.log.Warn("relay connection was already marked as not running") c.log.Warn("relay connection was already marked as not running")
return nil return nil
} }
c.serviceIsRunning = false c.serviceIsRunning = false
c.muInstanceURL.Lock()
c.instanceURL = nil
c.muInstanceURL.Unlock()
c.log.Infof("closing all peer connections") c.log.Infof("closing all peer connections")
c.closeAllConns() c.closeAllConns()
if gracefullyExit { if gracefullyExit {
@@ -559,8 +623,8 @@ func (c *Client) writeCloseMsg() {
} }
} }
func (c *Client) readWithTimeout(buf []byte) (int, error) { func (c *Client) readWithTimeout(ctx context.Context, buf []byte) (int, error) {
ctx, cancel := context.WithTimeout(c.parentCtx, serverResponseTimeout) ctx, cancel := context.WithTimeout(ctx, serverResponseTimeout)
defer cancel() defer cancel()
readDone := make(chan struct{}) readDone := make(chan struct{})
@@ -581,3 +645,21 @@ func (c *Client) readWithTimeout(buf []byte) (int, error) {
return n, err return n, err
} }
} }
func (c *Client) handlePeersOnlineMsg(buf []byte) {
peersID, err := messages.UnmarshalPeersOnlineMsg(buf)
if err != nil {
c.log.Errorf("failed to unmarshal peers online msg: %s", err)
return
}
c.stateSubscription.OnPeersOnline(peersID)
}
func (c *Client) handlePeersWentOfflineMsg(buf []byte) {
peersID, err := messages.UnMarshalPeersWentOffline(buf)
if err != nil {
c.log.Errorf("failed to unmarshal peers went offline msg: %s", err)
return
}
c.stateSubscription.OnPeersWentOffline(peersID)
}

View File

@@ -18,14 +18,19 @@ import (
) )
var ( var (
av = &allow.Auth{}
hmacTokenStore = &hmac.TokenStore{} hmacTokenStore = &hmac.TokenStore{}
serverListenAddr = "127.0.0.1:1234" serverListenAddr = "127.0.0.1:1234"
serverURL = "rel://127.0.0.1:1234" serverURL = "rel://127.0.0.1:1234"
serverCfg = server.Config{
Meter: otel.Meter(""),
ExposedAddress: serverURL,
TLSSupport: false,
AuthValidator: &allow.Auth{},
}
) )
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
_ = util.InitLog("error", "console") _ = util.InitLog("debug", "console")
code := m.Run() code := m.Run()
os.Exit(code) os.Exit(code)
} }
@@ -33,7 +38,7 @@ func TestMain(m *testing.M) {
func TestClient(t *testing.T) { func TestClient(t *testing.T) {
ctx := context.Background() ctx := context.Background()
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) srv, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@@ -58,37 +63,37 @@ func TestClient(t *testing.T) {
t.Fatalf("failed to start server: %s", err) t.Fatalf("failed to start server: %s", err)
} }
t.Log("alice connecting to server") t.Log("alice connecting to server")
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect() err = clientAlice.Connect(ctx)
if err != nil { if err != nil {
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
} }
defer clientAlice.Close() defer clientAlice.Close()
t.Log("placeholder connecting to server") t.Log("placeholder connecting to server")
clientPlaceHolder := NewClient(ctx, serverURL, hmacTokenStore, "clientPlaceHolder") clientPlaceHolder := NewClient(serverURL, hmacTokenStore, "clientPlaceHolder")
err = clientPlaceHolder.Connect() err = clientPlaceHolder.Connect(ctx)
if err != nil { if err != nil {
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
} }
defer clientPlaceHolder.Close() defer clientPlaceHolder.Close()
t.Log("Bob connecting to server") t.Log("Bob connecting to server")
clientBob := NewClient(ctx, serverURL, hmacTokenStore, "bob") clientBob := NewClient(serverURL, hmacTokenStore, "bob")
err = clientBob.Connect() err = clientBob.Connect(ctx)
if err != nil { if err != nil {
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
} }
defer clientBob.Close() defer clientBob.Close()
t.Log("Alice open connection to Bob") t.Log("Alice open connection to Bob")
connAliceToBob, err := clientAlice.OpenConn("bob") connAliceToBob, err := clientAlice.OpenConn(ctx, "bob")
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }
t.Log("Bob open connection to Alice") t.Log("Bob open connection to Alice")
connBobToAlice, err := clientBob.OpenConn("alice") connBobToAlice, err := clientBob.OpenConn(ctx, "alice")
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }
@@ -115,7 +120,7 @@ func TestClient(t *testing.T) {
func TestRegistration(t *testing.T) { func TestRegistration(t *testing.T) {
ctx := context.Background() ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr} srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) srv, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@@ -132,8 +137,8 @@ func TestRegistration(t *testing.T) {
t.Fatalf("failed to start server: %s", err) t.Fatalf("failed to start server: %s", err)
} }
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect() err = clientAlice.Connect(ctx)
if err != nil { if err != nil {
_ = srv.Shutdown(ctx) _ = srv.Shutdown(ctx)
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
@@ -172,8 +177,8 @@ func TestRegistrationTimeout(t *testing.T) {
_ = fakeTCPListener.Close() _ = fakeTCPListener.Close()
}(fakeTCPListener) }(fakeTCPListener)
clientAlice := NewClient(ctx, "127.0.0.1:1234", hmacTokenStore, "alice") clientAlice := NewClient("127.0.0.1:1234", hmacTokenStore, "alice")
err = clientAlice.Connect() err = clientAlice.Connect(ctx)
if err == nil { if err == nil {
t.Errorf("failed to connect to server: %s", err) t.Errorf("failed to connect to server: %s", err)
} }
@@ -189,7 +194,7 @@ func TestEcho(t *testing.T) {
idAlice := "alice" idAlice := "alice"
idBob := "bob" idBob := "bob"
srvCfg := server.ListenerConfig{Address: serverListenAddr} srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) srv, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@@ -213,8 +218,8 @@ func TestEcho(t *testing.T) {
t.Fatalf("failed to start server: %s", err) t.Fatalf("failed to start server: %s", err)
} }
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, idAlice) clientAlice := NewClient(serverURL, hmacTokenStore, idAlice)
err = clientAlice.Connect() err = clientAlice.Connect(ctx)
if err != nil { if err != nil {
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
} }
@@ -225,8 +230,8 @@ func TestEcho(t *testing.T) {
} }
}() }()
clientBob := NewClient(ctx, serverURL, hmacTokenStore, idBob) clientBob := NewClient(serverURL, hmacTokenStore, idBob)
err = clientBob.Connect() err = clientBob.Connect(ctx)
if err != nil { if err != nil {
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
} }
@@ -237,12 +242,12 @@ func TestEcho(t *testing.T) {
} }
}() }()
connAliceToBob, err := clientAlice.OpenConn(idBob) connAliceToBob, err := clientAlice.OpenConn(ctx, idBob)
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }
connBobToAlice, err := clientBob.OpenConn(idAlice) connBobToAlice, err := clientBob.OpenConn(ctx, idAlice)
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }
@@ -278,7 +283,7 @@ func TestBindToUnavailabePeer(t *testing.T) {
ctx := context.Background() ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr} srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) srv, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@@ -303,14 +308,14 @@ func TestBindToUnavailabePeer(t *testing.T) {
t.Fatalf("failed to start server: %s", err) t.Fatalf("failed to start server: %s", err)
} }
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect() err = clientAlice.Connect(ctx)
if err != nil { if err != nil {
t.Errorf("failed to connect to server: %s", err) t.Errorf("failed to connect to server: %s", err)
} }
_, err = clientAlice.OpenConn("bob") _, err = clientAlice.OpenConn(ctx, "bob")
if err != nil { if err == nil {
t.Errorf("failed to bind channel: %s", err) t.Errorf("expected error when binding to unavailable peer, got nil")
} }
log.Infof("closing client") log.Infof("closing client")
@@ -324,7 +329,7 @@ func TestBindReconnect(t *testing.T) {
ctx := context.Background() ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr} srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) srv, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@@ -349,24 +354,24 @@ func TestBindReconnect(t *testing.T) {
t.Fatalf("failed to start server: %s", err) t.Fatalf("failed to start server: %s", err)
} }
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect() err = clientAlice.Connect(ctx)
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
clientBob := NewClient(serverURL, hmacTokenStore, "bob")
err = clientBob.Connect(ctx)
if err != nil { if err != nil {
t.Errorf("failed to connect to server: %s", err) t.Errorf("failed to connect to server: %s", err)
} }
_, err = clientAlice.OpenConn("bob") _, err = clientAlice.OpenConn(ctx, "bob")
if err != nil { if err != nil {
t.Errorf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }
clientBob := NewClient(ctx, serverURL, hmacTokenStore, "bob") chBob, err := clientBob.OpenConn(ctx, "alice")
err = clientBob.Connect()
if err != nil {
t.Errorf("failed to connect to server: %s", err)
}
chBob, err := clientBob.OpenConn("alice")
if err != nil { if err != nil {
t.Errorf("failed to bind channel: %s", err) t.Errorf("failed to bind channel: %s", err)
} }
@@ -377,18 +382,28 @@ func TestBindReconnect(t *testing.T) {
t.Errorf("failed to close client: %s", err) t.Errorf("failed to close client: %s", err)
} }
clientAlice = NewClient(ctx, serverURL, hmacTokenStore, "alice") clientAlice = NewClient(serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect() err = clientAlice.Connect(ctx)
if err != nil { if err != nil {
t.Errorf("failed to connect to server: %s", err) t.Errorf("failed to connect to server: %s", err)
} }
chAlice, err := clientAlice.OpenConn("bob") chAlice, err := clientAlice.OpenConn(ctx, "bob")
if err != nil { if err != nil {
t.Errorf("failed to bind channel: %s", err) t.Errorf("failed to bind channel: %s", err)
} }
testString := "hello alice, I am bob" testString := "hello alice, I am bob"
_, err = chBob.Write([]byte(testString))
if err == nil {
t.Errorf("expected error when writing to channel, got nil")
}
chBob, err = clientBob.OpenConn(ctx, "alice")
if err != nil {
t.Errorf("failed to bind channel: %s", err)
}
_, err = chBob.Write([]byte(testString)) _, err = chBob.Write([]byte(testString))
if err != nil { if err != nil {
t.Errorf("failed to write to channel: %s", err) t.Errorf("failed to write to channel: %s", err)
@@ -415,7 +430,7 @@ func TestCloseConn(t *testing.T) {
ctx := context.Background() ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr} srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) srv, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@@ -440,13 +455,19 @@ func TestCloseConn(t *testing.T) {
t.Fatalf("failed to start server: %s", err) t.Fatalf("failed to start server: %s", err)
} }
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") bob := NewClient(serverURL, hmacTokenStore, "bob")
err = clientAlice.Connect() err = bob.Connect(ctx)
if err != nil { if err != nil {
t.Errorf("failed to connect to server: %s", err) t.Errorf("failed to connect to server: %s", err)
} }
conn, err := clientAlice.OpenConn("bob") clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect(ctx)
if err != nil {
t.Errorf("failed to connect to server: %s", err)
}
conn, err := clientAlice.OpenConn(ctx, "bob")
if err != nil { if err != nil {
t.Errorf("failed to bind channel: %s", err) t.Errorf("failed to bind channel: %s", err)
} }
@@ -472,7 +493,7 @@ func TestCloseRelayConn(t *testing.T) {
ctx := context.Background() ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr} srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) srv, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@@ -496,13 +517,19 @@ func TestCloseRelayConn(t *testing.T) {
t.Fatalf("failed to start server: %s", err) t.Fatalf("failed to start server: %s", err)
} }
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") bob := NewClient(serverURL, hmacTokenStore, "bob")
err = clientAlice.Connect() err = bob.Connect(ctx)
if err != nil { if err != nil {
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
} }
conn, err := clientAlice.OpenConn("bob") clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect(ctx)
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
conn, err := clientAlice.OpenConn(ctx, "bob")
if err != nil { if err != nil {
t.Errorf("failed to bind channel: %s", err) t.Errorf("failed to bind channel: %s", err)
} }
@@ -514,7 +541,7 @@ func TestCloseRelayConn(t *testing.T) {
t.Errorf("unexpected reading from closed connection") t.Errorf("unexpected reading from closed connection")
} }
_, err = clientAlice.OpenConn("bob") _, err = clientAlice.OpenConn(ctx, "bob")
if err == nil { if err == nil {
t.Errorf("unexpected opening connection to closed server") t.Errorf("unexpected opening connection to closed server")
} }
@@ -524,7 +551,7 @@ func TestCloseByServer(t *testing.T) {
ctx := context.Background() ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr} srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv1, err := server.NewServer(otel.Meter(""), serverURL, false, av) srv1, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@@ -544,11 +571,15 @@ func TestCloseByServer(t *testing.T) {
idAlice := "alice" idAlice := "alice"
log.Debugf("connect by alice") log.Debugf("connect by alice")
relayClient := NewClient(ctx, serverURL, hmacTokenStore, idAlice) relayClient := NewClient(serverURL, hmacTokenStore, idAlice)
err = relayClient.Connect() if err = relayClient.Connect(ctx); err != nil {
if err != nil {
log.Fatalf("failed to connect to server: %s", err) log.Fatalf("failed to connect to server: %s", err)
} }
defer func() {
if err := relayClient.Close(); err != nil {
log.Errorf("failed to close client: %s", err)
}
}()
disconnected := make(chan struct{}) disconnected := make(chan struct{})
relayClient.SetOnDisconnectListener(func(_ string) { relayClient.SetOnDisconnectListener(func(_ string) {
@@ -564,10 +595,10 @@ func TestCloseByServer(t *testing.T) {
select { select {
case <-disconnected: case <-disconnected:
case <-time.After(3 * time.Second): case <-time.After(3 * time.Second):
log.Fatalf("timeout waiting for client to disconnect") log.Errorf("timeout waiting for client to disconnect")
} }
_, err = relayClient.OpenConn("bob") _, err = relayClient.OpenConn(ctx, "bob")
if err == nil { if err == nil {
t.Errorf("unexpected opening connection to closed server") t.Errorf("unexpected opening connection to closed server")
} }
@@ -577,7 +608,7 @@ func TestCloseByClient(t *testing.T) {
ctx := context.Background() ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr} srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) srv, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@@ -596,8 +627,8 @@ func TestCloseByClient(t *testing.T) {
idAlice := "alice" idAlice := "alice"
log.Debugf("connect by alice") log.Debugf("connect by alice")
relayClient := NewClient(ctx, serverURL, hmacTokenStore, idAlice) relayClient := NewClient(serverURL, hmacTokenStore, idAlice)
err = relayClient.Connect() err = relayClient.Connect(ctx)
if err != nil { if err != nil {
log.Fatalf("failed to connect to server: %s", err) log.Fatalf("failed to connect to server: %s", err)
} }
@@ -607,7 +638,7 @@ func TestCloseByClient(t *testing.T) {
t.Errorf("failed to close client: %s", err) t.Errorf("failed to close client: %s", err)
} }
_, err = relayClient.OpenConn("bob") _, err = relayClient.OpenConn(ctx, "bob")
if err == nil { if err == nil {
t.Errorf("unexpected opening connection to closed server") t.Errorf("unexpected opening connection to closed server")
} }
@@ -623,7 +654,7 @@ func TestCloseNotDrainedChannel(t *testing.T) {
idAlice := "alice" idAlice := "alice"
idBob := "bob" idBob := "bob"
srvCfg := server.ListenerConfig{Address: serverListenAddr} srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) srv, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@@ -647,8 +678,8 @@ func TestCloseNotDrainedChannel(t *testing.T) {
t.Fatalf("failed to start server: %s", err) t.Fatalf("failed to start server: %s", err)
} }
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, idAlice) clientAlice := NewClient(serverURL, hmacTokenStore, idAlice)
err = clientAlice.Connect() err = clientAlice.Connect(ctx)
if err != nil { if err != nil {
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
} }
@@ -659,8 +690,8 @@ func TestCloseNotDrainedChannel(t *testing.T) {
} }
}() }()
clientBob := NewClient(ctx, serverURL, hmacTokenStore, idBob) clientBob := NewClient(serverURL, hmacTokenStore, idBob)
err = clientBob.Connect() err = clientBob.Connect(ctx)
if err != nil { if err != nil {
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
} }
@@ -671,12 +702,12 @@ func TestCloseNotDrainedChannel(t *testing.T) {
} }
}() }()
connAliceToBob, err := clientAlice.OpenConn(idBob) connAliceToBob, err := clientAlice.OpenConn(ctx, idBob)
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }
connBobToAlice, err := clientBob.OpenConn(idAlice) connBobToAlice, err := clientBob.OpenConn(ctx, idAlice)
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }

View File

@@ -3,13 +3,14 @@ package client
import ( import (
"net" "net"
"time" "time"
"github.com/netbirdio/netbird/relay/messages"
) )
// Conn represent a connection to a relayed remote peer. // Conn represent a connection to a relayed remote peer.
type Conn struct { type Conn struct {
client *Client client *Client
dstID []byte dstID messages.PeerID
dstStringID string
messageChan chan Msg messageChan chan Msg
instanceURL *RelayAddr instanceURL *RelayAddr
} }
@@ -17,14 +18,12 @@ type Conn struct {
// NewConn creates a new connection to a relayed remote peer. // NewConn creates a new connection to a relayed remote peer.
// client: the client instance, it used to send messages to the destination peer // client: the client instance, it used to send messages to the destination peer
// dstID: the destination peer ID // dstID: the destination peer ID
// dstStringID: the destination peer ID in string format
// messageChan: the channel where the messages will be received // messageChan: the channel where the messages will be received
// instanceURL: the relay instance URL, it used to get the proper server instance address for the remote peer // instanceURL: the relay instance URL, it used to get the proper server instance address for the remote peer
func NewConn(client *Client, dstID []byte, dstStringID string, messageChan chan Msg, instanceURL *RelayAddr) *Conn { func NewConn(client *Client, dstID messages.PeerID, messageChan chan Msg, instanceURL *RelayAddr) *Conn {
c := &Conn{ c := &Conn{
client: client, client: client,
dstID: dstID, dstID: dstID,
dstStringID: dstStringID,
messageChan: messageChan, messageChan: messageChan,
instanceURL: instanceURL, instanceURL: instanceURL,
} }
@@ -33,7 +32,7 @@ func NewConn(client *Client, dstID []byte, dstStringID string, messageChan chan
} }
func (c *Conn) Write(p []byte) (n int, err error) { func (c *Conn) Write(p []byte) (n int, err error) {
return c.client.writeTo(c, c.dstStringID, c.dstID, p) return c.client.writeTo(c, c.dstID, p)
} }
func (c *Conn) Read(b []byte) (n int, err error) { func (c *Conn) Read(b []byte) (n int, err error) {
@@ -48,7 +47,7 @@ func (c *Conn) Read(b []byte) (n int, err error) {
} }
func (c *Conn) Close() error { func (c *Conn) Close() error {
return c.client.closeConn(c, c.dstStringID) return c.client.closeConn(c, c.dstID)
} }
func (c *Conn) LocalAddr() net.Addr { func (c *Conn) LocalAddr() net.Addr {

View File

@@ -9,8 +9,8 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
var ( const (
connectionTimeout = 30 * time.Second DefaultConnectionTimeout = 30 * time.Second
) )
type DialeFn interface { type DialeFn interface {
@@ -28,13 +28,15 @@ type RaceDial struct {
log *log.Entry log *log.Entry
serverURL string serverURL string
dialerFns []DialeFn dialerFns []DialeFn
connectionTimeout time.Duration
} }
func NewRaceDial(log *log.Entry, serverURL string, dialerFns ...DialeFn) *RaceDial { func NewRaceDial(log *log.Entry, connectionTimeout time.Duration, serverURL string, dialerFns ...DialeFn) *RaceDial {
return &RaceDial{ return &RaceDial{
log: log, log: log,
serverURL: serverURL, serverURL: serverURL,
dialerFns: dialerFns, dialerFns: dialerFns,
connectionTimeout: connectionTimeout,
} }
} }
@@ -58,7 +60,7 @@ func (r *RaceDial) Dial() (net.Conn, error) {
} }
func (r *RaceDial) dial(dfn DialeFn, abortCtx context.Context, connChan chan dialResult) { func (r *RaceDial) dial(dfn DialeFn, abortCtx context.Context, connChan chan dialResult) {
ctx, cancel := context.WithTimeout(abortCtx, connectionTimeout) ctx, cancel := context.WithTimeout(abortCtx, r.connectionTimeout)
defer cancel() defer cancel()
r.log.Infof("dialing Relay server via %s", dfn.Protocol()) r.log.Infof("dialing Relay server via %s", dfn.Protocol())

View File

@@ -77,7 +77,7 @@ func TestRaceDialEmptyDialers(t *testing.T) {
logger := logrus.NewEntry(logrus.New()) logger := logrus.NewEntry(logrus.New())
serverURL := "test.server.com" serverURL := "test.server.com"
rd := NewRaceDial(logger, serverURL) rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL)
conn, err := rd.Dial() conn, err := rd.Dial()
if err == nil { if err == nil {
t.Errorf("Expected an error with empty dialers, got nil") t.Errorf("Expected an error with empty dialers, got nil")
@@ -103,7 +103,7 @@ func TestRaceDialSingleSuccessfulDialer(t *testing.T) {
protocolStr: proto, protocolStr: proto,
} }
rd := NewRaceDial(logger, serverURL, mockDialer) rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer)
conn, err := rd.Dial() conn, err := rd.Dial()
if err != nil { if err != nil {
t.Errorf("Expected no error, got %v", err) t.Errorf("Expected no error, got %v", err)
@@ -136,7 +136,7 @@ func TestRaceDialMultipleDialersWithOneSuccess(t *testing.T) {
protocolStr: "proto2", protocolStr: "proto2",
} }
rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2) rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2)
conn, err := rd.Dial() conn, err := rd.Dial()
if err != nil { if err != nil {
t.Errorf("Expected no error, got %v", err) t.Errorf("Expected no error, got %v", err)
@@ -144,13 +144,13 @@ func TestRaceDialMultipleDialersWithOneSuccess(t *testing.T) {
if conn.RemoteAddr().Network() != proto2 { if conn.RemoteAddr().Network() != proto2 {
t.Errorf("Expected connection with protocol %s, got %s", proto2, conn.RemoteAddr().Network()) t.Errorf("Expected connection with protocol %s, got %s", proto2, conn.RemoteAddr().Network())
} }
_ = conn.Close()
} }
func TestRaceDialTimeout(t *testing.T) { func TestRaceDialTimeout(t *testing.T) {
logger := logrus.NewEntry(logrus.New()) logger := logrus.NewEntry(logrus.New())
serverURL := "test.server.com" serverURL := "test.server.com"
connectionTimeout = 3 * time.Second
mockDialer := &MockDialer{ mockDialer := &MockDialer{
dialFunc: func(ctx context.Context, address string) (net.Conn, error) { dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
<-ctx.Done() <-ctx.Done()
@@ -159,7 +159,7 @@ func TestRaceDialTimeout(t *testing.T) {
protocolStr: "proto1", protocolStr: "proto1",
} }
rd := NewRaceDial(logger, serverURL, mockDialer) rd := NewRaceDial(logger, 3*time.Second, serverURL, mockDialer)
conn, err := rd.Dial() conn, err := rd.Dial()
if err == nil { if err == nil {
t.Errorf("Expected an error, got nil") t.Errorf("Expected an error, got nil")
@@ -187,7 +187,7 @@ func TestRaceDialAllDialersFail(t *testing.T) {
protocolStr: "protocol2", protocolStr: "protocol2",
} }
rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2) rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2)
conn, err := rd.Dial() conn, err := rd.Dial()
if err == nil { if err == nil {
t.Errorf("Expected an error, got nil") t.Errorf("Expected an error, got nil")
@@ -229,7 +229,7 @@ func TestRaceDialFirstSuccessfulDialerWins(t *testing.T) {
protocolStr: proto2, protocolStr: proto2,
} }
rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2) rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2)
conn, err := rd.Dial() conn, err := rd.Dial()
if err != nil { if err != nil {
t.Errorf("Expected no error, got %v", err) t.Errorf("Expected no error, got %v", err)

View File

@@ -8,7 +8,8 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
var ( const (
// TODO: make it configurable, the manager should validate all configurable parameters
reconnectingTimeout = 60 * time.Second reconnectingTimeout = 60 * time.Second
) )
@@ -80,7 +81,7 @@ func (g *Guard) tryToQuickReconnect(parentCtx context.Context, rc *Client) bool
log.Infof("try to reconnect to Relay server: %s", rc.connectionURL) log.Infof("try to reconnect to Relay server: %s", rc.connectionURL)
if err := rc.Connect(); err != nil { if err := rc.Connect(parentCtx); err != nil {
log.Errorf("failed to reconnect to relay server: %s", err) log.Errorf("failed to reconnect to relay server: %s", err)
return false return false
} }

View File

@@ -39,17 +39,6 @@ func NewRelayTrack() *RelayTrack {
type OnServerCloseListener func() type OnServerCloseListener func()
// ManagerService is the interface for the relay manager.
type ManagerService interface {
Serve() error
OpenConn(serverAddress, peerKey string) (net.Conn, error)
AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error
RelayInstanceAddress() (string, error)
ServerURLs() []string
HasRelayAddress() bool
UpdateToken(token *relayAuth.Token) error
}
// Manager is a manager for the relay client instances. It establishes one persistent connection to the given relay URL // Manager is a manager for the relay client instances. It establishes one persistent connection to the given relay URL
// and automatically reconnect to them in case disconnection. // and automatically reconnect to them in case disconnection.
// The manager also manage temporary relay connection. If a client wants to communicate with a client on a // The manager also manage temporary relay connection. If a client wants to communicate with a client on a
@@ -65,7 +54,7 @@ type Manager struct {
relayClient *Client relayClient *Client
// the guard logic can overwrite the relayClient variable, this mutex protect the usage of the variable // the guard logic can overwrite the relayClient variable, this mutex protect the usage of the variable
relayClientMu sync.Mutex relayClientMu sync.RWMutex
reconnectGuard *Guard reconnectGuard *Guard
relayClients map[string]*RelayTrack relayClients map[string]*RelayTrack
@@ -123,9 +112,9 @@ func (m *Manager) Serve() error {
// OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be // OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be
// established via the relay server. If the peer is on a different relay server, the manager will establish a new // established via the relay server. If the peer is on a different relay server, the manager will establish a new
// connection to the relay server. It returns back with a net.Conn what represent the remote peer connection. // connection to the relay server. It returns back with a net.Conn what represent the remote peer connection.
func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) { func (m *Manager) OpenConn(ctx context.Context, serverAddress, peerKey string) (net.Conn, error) {
m.relayClientMu.Lock() m.relayClientMu.RLock()
defer m.relayClientMu.Unlock() defer m.relayClientMu.RUnlock()
if m.relayClient == nil { if m.relayClient == nil {
return nil, ErrRelayClientNotConnected return nil, ErrRelayClientNotConnected
@@ -141,10 +130,10 @@ func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) {
) )
if !foreign { if !foreign {
log.Debugf("open peer connection via permanent server: %s", peerKey) log.Debugf("open peer connection via permanent server: %s", peerKey)
netConn, err = m.relayClient.OpenConn(peerKey) netConn, err = m.relayClient.OpenConn(ctx, peerKey)
} else { } else {
log.Debugf("open peer connection via foreign server: %s", serverAddress) log.Debugf("open peer connection via foreign server: %s", serverAddress)
netConn, err = m.openConnVia(serverAddress, peerKey) netConn, err = m.openConnVia(ctx, serverAddress, peerKey)
} }
if err != nil { if err != nil {
return nil, err return nil, err
@@ -155,8 +144,8 @@ func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) {
// Ready returns true if the home Relay client is connected to the relay server. // Ready returns true if the home Relay client is connected to the relay server.
func (m *Manager) Ready() bool { func (m *Manager) Ready() bool {
m.relayClientMu.Lock() m.relayClientMu.RLock()
defer m.relayClientMu.Unlock() defer m.relayClientMu.RUnlock()
if m.relayClient == nil { if m.relayClient == nil {
return false return false
@@ -174,8 +163,8 @@ func (m *Manager) SetOnReconnectedListener(f func()) {
// AddCloseListener adds a listener to the given server instance address. The listener will be called if the connection // AddCloseListener adds a listener to the given server instance address. The listener will be called if the connection
// closed. // closed.
func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error { func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error {
m.relayClientMu.Lock() m.relayClientMu.RLock()
defer m.relayClientMu.Unlock() defer m.relayClientMu.RUnlock()
if m.relayClient == nil { if m.relayClient == nil {
return ErrRelayClientNotConnected return ErrRelayClientNotConnected
@@ -199,8 +188,8 @@ func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServ
// RelayInstanceAddress returns the address of the permanent relay server. It could change if the network connection is // RelayInstanceAddress returns the address of the permanent relay server. It could change if the network connection is
// lost. This address will be sent to the target peer to choose the common relay server for the communication. // lost. This address will be sent to the target peer to choose the common relay server for the communication.
func (m *Manager) RelayInstanceAddress() (string, error) { func (m *Manager) RelayInstanceAddress() (string, error) {
m.relayClientMu.Lock() m.relayClientMu.RLock()
defer m.relayClientMu.Unlock() defer m.relayClientMu.RUnlock()
if m.relayClient == nil { if m.relayClient == nil {
return "", ErrRelayClientNotConnected return "", ErrRelayClientNotConnected
@@ -229,7 +218,7 @@ func (m *Manager) UpdateToken(token *relayAuth.Token) error {
return m.tokenStore.UpdateToken(token) return m.tokenStore.UpdateToken(token)
} }
func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) { func (m *Manager) openConnVia(ctx context.Context, serverAddress, peerKey string) (net.Conn, error) {
// check if already has a connection to the desired relay server // check if already has a connection to the desired relay server
m.relayClientsMutex.RLock() m.relayClientsMutex.RLock()
rt, ok := m.relayClients[serverAddress] rt, ok := m.relayClients[serverAddress]
@@ -240,7 +229,7 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) {
if rt.err != nil { if rt.err != nil {
return nil, rt.err return nil, rt.err
} }
return rt.relayClient.OpenConn(peerKey) return rt.relayClient.OpenConn(ctx, peerKey)
} }
m.relayClientsMutex.RUnlock() m.relayClientsMutex.RUnlock()
@@ -255,7 +244,7 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) {
if rt.err != nil { if rt.err != nil {
return nil, rt.err return nil, rt.err
} }
return rt.relayClient.OpenConn(peerKey) return rt.relayClient.OpenConn(ctx, peerKey)
} }
// create a new relay client and store it in the relayClients map // create a new relay client and store it in the relayClients map
@@ -264,8 +253,8 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) {
m.relayClients[serverAddress] = rt m.relayClients[serverAddress] = rt
m.relayClientsMutex.Unlock() m.relayClientsMutex.Unlock()
relayClient := NewClient(m.ctx, serverAddress, m.tokenStore, m.peerID) relayClient := NewClient(serverAddress, m.tokenStore, m.peerID)
err := relayClient.Connect() err := relayClient.Connect(m.ctx)
if err != nil { if err != nil {
rt.err = err rt.err = err
rt.Unlock() rt.Unlock()
@@ -279,7 +268,7 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) {
rt.relayClient = relayClient rt.relayClient = relayClient
rt.Unlock() rt.Unlock()
conn, err := relayClient.OpenConn(peerKey) conn, err := relayClient.OpenConn(ctx, peerKey)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -300,7 +289,9 @@ func (m *Manager) onServerConnected() {
func (m *Manager) onServerDisconnected(serverAddress string) { func (m *Manager) onServerDisconnected(serverAddress string) {
m.relayClientMu.Lock() m.relayClientMu.Lock()
if serverAddress == m.relayClient.connectionURL { if serverAddress == m.relayClient.connectionURL {
go m.reconnectGuard.StartReconnectTrys(m.ctx, m.relayClient) go func(client *Client) {
m.reconnectGuard.StartReconnectTrys(m.ctx, client)
}(m.relayClient)
} }
m.relayClientMu.Unlock() m.relayClientMu.Unlock()

View File

@@ -8,11 +8,14 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel" "go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/relay/auth/allow"
"github.com/netbirdio/netbird/relay/server" "github.com/netbirdio/netbird/relay/server"
) )
func TestEmptyURL(t *testing.T) { func TestEmptyURL(t *testing.T) {
mgr := NewManager(context.Background(), nil, "alice") ctx, cancel := context.WithCancel(context.Background())
defer cancel()
mgr := NewManager(ctx, nil, "alice")
err := mgr.Serve() err := mgr.Serve()
if err == nil { if err == nil {
t.Errorf("expected error, got nil") t.Errorf("expected error, got nil")
@@ -22,16 +25,22 @@ func TestEmptyURL(t *testing.T) {
func TestForeignConn(t *testing.T) { func TestForeignConn(t *testing.T) {
ctx := context.Background() ctx := context.Background()
srvCfg1 := server.ListenerConfig{ lstCfg1 := server.ListenerConfig{
Address: "localhost:1234", Address: "localhost:1234",
} }
srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av)
srv1, err := server.NewServer(server.Config{
Meter: otel.Meter(""),
ExposedAddress: lstCfg1.Address,
TLSSupport: false,
AuthValidator: &allow.Auth{},
})
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
errChan := make(chan error, 1) errChan := make(chan error, 1)
go func() { go func() {
err := srv1.Listen(srvCfg1) err := srv1.Listen(lstCfg1)
if err != nil { if err != nil {
errChan <- err errChan <- err
} }
@@ -51,7 +60,12 @@ func TestForeignConn(t *testing.T) {
srvCfg2 := server.ListenerConfig{ srvCfg2 := server.ListenerConfig{
Address: "localhost:2234", Address: "localhost:2234",
} }
srv2, err := server.NewServer(otel.Meter(""), srvCfg2.Address, false, av) srv2, err := server.NewServer(server.Config{
Meter: otel.Meter(""),
ExposedAddress: srvCfg2.Address,
TLSSupport: false,
AuthValidator: &allow.Auth{},
})
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@@ -74,32 +88,26 @@ func TestForeignConn(t *testing.T) {
t.Fatalf("failed to start server: %s", err) t.Fatalf("failed to start server: %s", err)
} }
idAlice := "alice"
log.Debugf("connect by alice")
mCtx, cancel := context.WithCancel(ctx) mCtx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
clientAlice := NewManager(mCtx, toURL(srvCfg1), idAlice) clientAlice := NewManager(mCtx, toURL(lstCfg1), "alice")
err = clientAlice.Serve() if err := clientAlice.Serve(); err != nil {
if err != nil {
t.Fatalf("failed to serve manager: %s", err) t.Fatalf("failed to serve manager: %s", err)
} }
idBob := "bob" clientBob := NewManager(mCtx, toURL(srvCfg2), "bob")
log.Debugf("connect by bob") if err := clientBob.Serve(); err != nil {
clientBob := NewManager(mCtx, toURL(srvCfg2), idBob)
err = clientBob.Serve()
if err != nil {
t.Fatalf("failed to serve manager: %s", err) t.Fatalf("failed to serve manager: %s", err)
} }
bobsSrvAddr, err := clientBob.RelayInstanceAddress() bobsSrvAddr, err := clientBob.RelayInstanceAddress()
if err != nil { if err != nil {
t.Fatalf("failed to get relay address: %s", err) t.Fatalf("failed to get relay address: %s", err)
} }
connAliceToBob, err := clientAlice.OpenConn(bobsSrvAddr, idBob) connAliceToBob, err := clientAlice.OpenConn(ctx, bobsSrvAddr, "bob")
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }
connBobToAlice, err := clientBob.OpenConn(bobsSrvAddr, idAlice) connBobToAlice, err := clientBob.OpenConn(ctx, bobsSrvAddr, "alice")
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }
@@ -137,7 +145,7 @@ func TestForeginConnClose(t *testing.T) {
srvCfg1 := server.ListenerConfig{ srvCfg1 := server.ListenerConfig{
Address: "localhost:1234", Address: "localhost:1234",
} }
srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av) srv1, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@@ -163,7 +171,7 @@ func TestForeginConnClose(t *testing.T) {
srvCfg2 := server.ListenerConfig{ srvCfg2 := server.ListenerConfig{
Address: "localhost:2234", Address: "localhost:2234",
} }
srv2, err := server.NewServer(otel.Meter(""), srvCfg2.Address, false, av) srv2, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@@ -186,16 +194,20 @@ func TestForeginConnClose(t *testing.T) {
t.Fatalf("failed to start server: %s", err) t.Fatalf("failed to start server: %s", err)
} }
idAlice := "alice"
log.Debugf("connect by alice")
mCtx, cancel := context.WithCancel(ctx) mCtx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
mgr := NewManager(mCtx, toURL(srvCfg1), idAlice)
mgrBob := NewManager(mCtx, toURL(srvCfg2), "bob")
if err := mgrBob.Serve(); err != nil {
t.Fatalf("failed to serve manager: %s", err)
}
mgr := NewManager(mCtx, toURL(srvCfg1), "alice")
err = mgr.Serve() err = mgr.Serve()
if err != nil { if err != nil {
t.Fatalf("failed to serve manager: %s", err) t.Fatalf("failed to serve manager: %s", err)
} }
conn, err := mgr.OpenConn(toURL(srvCfg2)[0], "anotherpeer") conn, err := mgr.OpenConn(ctx, toURL(srvCfg2)[0], "bob")
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }
@@ -206,29 +218,29 @@ func TestForeginConnClose(t *testing.T) {
} }
} }
func TestForeginAutoClose(t *testing.T) { func TestForeignAutoClose(t *testing.T) {
ctx := context.Background() ctx := context.Background()
relayCleanupInterval = 1 * time.Second relayCleanupInterval = 1 * time.Second
keepUnusedServerTime = 2 * time.Second
srvCfg1 := server.ListenerConfig{ srvCfg1 := server.ListenerConfig{
Address: "localhost:1234", Address: "localhost:1234",
} }
srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av) srv1, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
errChan := make(chan error, 1) errChan := make(chan error, 1)
go func() { go func() {
t.Log("binding server 1.") t.Log("binding server 1.")
err := srv1.Listen(srvCfg1) if err := srv1.Listen(srvCfg1); err != nil {
if err != nil {
errChan <- err errChan <- err
} }
}() }()
defer func() { defer func() {
t.Logf("closing server 1.") t.Logf("closing server 1.")
err := srv1.Shutdown(ctx) if err := srv1.Shutdown(ctx); err != nil {
if err != nil {
t.Errorf("failed to close server: %s", err) t.Errorf("failed to close server: %s", err)
} }
t.Logf("server 1. closed") t.Logf("server 1. closed")
@@ -241,7 +253,7 @@ func TestForeginAutoClose(t *testing.T) {
srvCfg2 := server.ListenerConfig{ srvCfg2 := server.ListenerConfig{
Address: "localhost:2234", Address: "localhost:2234",
} }
srv2, err := server.NewServer(otel.Meter(""), srvCfg2.Address, false, av) srv2, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@@ -276,23 +288,35 @@ func TestForeginAutoClose(t *testing.T) {
t.Fatalf("failed to serve manager: %s", err) t.Fatalf("failed to serve manager: %s", err)
} }
// Set up a disconnect listener to track when foreign server disconnects
foreignServerURL := toURL(srvCfg2)[0]
disconnected := make(chan struct{})
onDisconnect := func() {
select {
case disconnected <- struct{}{}:
default:
}
}
t.Log("open connection to another peer") t.Log("open connection to another peer")
conn, err := mgr.OpenConn(toURL(srvCfg2)[0], "anotherpeer") if _, err = mgr.OpenConn(ctx, foreignServerURL, "anotherpeer"); err == nil {
if err != nil { t.Fatalf("should have failed to open connection to another peer")
t.Fatalf("failed to bind channel: %s", err)
} }
t.Log("close conn") // Add the disconnect listener after the connection attempt
err = conn.Close() if err := mgr.AddCloseListener(foreignServerURL, onDisconnect); err != nil {
if err != nil { t.Logf("failed to add close listener (expected if connection failed): %s", err)
t.Fatalf("failed to close connection: %s", err)
} }
timeout := relayCleanupInterval + keepUnusedServerTime + 1*time.Second // Wait for cleanup to happen
timeout := relayCleanupInterval + keepUnusedServerTime + 2*time.Second
t.Logf("waiting for relay cleanup: %s", timeout) t.Logf("waiting for relay cleanup: %s", timeout)
time.Sleep(timeout)
if len(mgr.relayClients) != 0 { select {
t.Errorf("expected 0, got %d", len(mgr.relayClients)) case <-disconnected:
t.Log("foreign relay connection cleaned up successfully")
case <-time.After(timeout):
t.Log("timeout waiting for cleanup - this might be expected if connection never established")
} }
t.Logf("closing manager") t.Logf("closing manager")
@@ -300,19 +324,17 @@ func TestForeginAutoClose(t *testing.T) {
func TestAutoReconnect(t *testing.T) { func TestAutoReconnect(t *testing.T) {
ctx := context.Background() ctx := context.Background()
reconnectingTimeout = 2 * time.Second
srvCfg := server.ListenerConfig{ srvCfg := server.ListenerConfig{
Address: "localhost:1234", Address: "localhost:1234",
} }
srv, err := server.NewServer(otel.Meter(""), srvCfg.Address, false, av) srv, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
errChan := make(chan error, 1) errChan := make(chan error, 1)
go func() { go func() {
err := srv.Listen(srvCfg) if err := srv.Listen(srvCfg); err != nil {
if err != nil {
errChan <- err errChan <- err
} }
}() }()
@@ -330,6 +352,13 @@ func TestAutoReconnect(t *testing.T) {
mCtx, cancel := context.WithCancel(ctx) mCtx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
clientBob := NewManager(mCtx, toURL(srvCfg), "bob")
err = clientBob.Serve()
if err != nil {
t.Fatalf("failed to serve manager: %s", err)
}
clientAlice := NewManager(mCtx, toURL(srvCfg), "alice") clientAlice := NewManager(mCtx, toURL(srvCfg), "alice")
err = clientAlice.Serve() err = clientAlice.Serve()
if err != nil { if err != nil {
@@ -339,7 +368,7 @@ func TestAutoReconnect(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("failed to get relay address: %s", err) t.Errorf("failed to get relay address: %s", err)
} }
conn, err := clientAlice.OpenConn(ra, "bob") conn, err := clientAlice.OpenConn(ctx, ra, "bob")
if err != nil { if err != nil {
t.Errorf("failed to bind channel: %s", err) t.Errorf("failed to bind channel: %s", err)
} }
@@ -357,7 +386,7 @@ func TestAutoReconnect(t *testing.T) {
time.Sleep(reconnectingTimeout + 1*time.Second) time.Sleep(reconnectingTimeout + 1*time.Second)
log.Infof("reopent the connection") log.Infof("reopent the connection")
_, err = clientAlice.OpenConn(ra, "bob") _, err = clientAlice.OpenConn(ctx, ra, "bob")
if err != nil { if err != nil {
t.Errorf("failed to open channel: %s", err) t.Errorf("failed to open channel: %s", err)
} }
@@ -366,24 +395,27 @@ func TestAutoReconnect(t *testing.T) {
func TestNotifierDoubleAdd(t *testing.T) { func TestNotifierDoubleAdd(t *testing.T) {
ctx := context.Background() ctx := context.Background()
srvCfg1 := server.ListenerConfig{ listenerCfg1 := server.ListenerConfig{
Address: "localhost:1234", Address: "localhost:1234",
} }
srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av) srv, err := server.NewServer(server.Config{
Meter: otel.Meter(""),
ExposedAddress: listenerCfg1.Address,
TLSSupport: false,
AuthValidator: &allow.Auth{},
})
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
errChan := make(chan error, 1) errChan := make(chan error, 1)
go func() { go func() {
err := srv1.Listen(srvCfg1) if err := srv.Listen(listenerCfg1); err != nil {
if err != nil {
errChan <- err errChan <- err
} }
}() }()
defer func() { defer func() {
err := srv1.Shutdown(ctx) if err := srv.Shutdown(ctx); err != nil {
if err != nil {
t.Errorf("failed to close server: %s", err) t.Errorf("failed to close server: %s", err)
} }
}() }()
@@ -392,17 +424,21 @@ func TestNotifierDoubleAdd(t *testing.T) {
t.Fatalf("failed to start server: %s", err) t.Fatalf("failed to start server: %s", err)
} }
idAlice := "alice"
log.Debugf("connect by alice") log.Debugf("connect by alice")
mCtx, cancel := context.WithCancel(ctx) mCtx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
clientAlice := NewManager(mCtx, toURL(srvCfg1), idAlice)
err = clientAlice.Serve() clientBob := NewManager(mCtx, toURL(listenerCfg1), "bob")
if err != nil { if err = clientBob.Serve(); err != nil {
t.Fatalf("failed to serve manager: %s", err) t.Fatalf("failed to serve manager: %s", err)
} }
conn1, err := clientAlice.OpenConn(clientAlice.ServerURLs()[0], "idBob") clientAlice := NewManager(mCtx, toURL(listenerCfg1), "alice")
if err = clientAlice.Serve(); err != nil {
t.Fatalf("failed to serve manager: %s", err)
}
conn1, err := clientAlice.OpenConn(ctx, clientAlice.ServerURLs()[0], "bob")
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }

View File

@@ -0,0 +1,191 @@
package client
import (
"context"
"errors"
"fmt"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/messages"
)
const (
OpenConnectionTimeout = 30 * time.Second
)
type relayedConnWriter interface {
Write(p []byte) (n int, err error)
}
// PeersStateSubscription manages subscriptions to peer state changes (online/offline)
// over a relay connection. It allows tracking peers' availability and handling offline
// events via a callback. We get online notification from the server only once.
type PeersStateSubscription struct {
log *log.Entry
relayConn relayedConnWriter
offlineCallback func(peerIDs []messages.PeerID)
listenForOfflinePeers map[messages.PeerID]struct{}
waitingPeers map[messages.PeerID]chan struct{}
mu sync.Mutex // Mutex to protect access to waitingPeers and listenForOfflinePeers
}
func NewPeersStateSubscription(log *log.Entry, relayConn relayedConnWriter, offlineCallback func(peerIDs []messages.PeerID)) *PeersStateSubscription {
return &PeersStateSubscription{
log: log,
relayConn: relayConn,
offlineCallback: offlineCallback,
listenForOfflinePeers: make(map[messages.PeerID]struct{}),
waitingPeers: make(map[messages.PeerID]chan struct{}),
}
}
// OnPeersOnline should be called when a notification is received that certain peers have come online.
// It checks if any of the peers are being waited on and signals their availability.
func (s *PeersStateSubscription) OnPeersOnline(peersID []messages.PeerID) {
s.mu.Lock()
defer s.mu.Unlock()
for _, peerID := range peersID {
waitCh, ok := s.waitingPeers[peerID]
if !ok {
// If meanwhile the peer was unsubscribed, we don't need to signal it
continue
}
waitCh <- struct{}{}
delete(s.waitingPeers, peerID)
close(waitCh)
}
}
func (s *PeersStateSubscription) OnPeersWentOffline(peersID []messages.PeerID) {
s.mu.Lock()
relevantPeers := make([]messages.PeerID, 0, len(peersID))
for _, peerID := range peersID {
if _, ok := s.listenForOfflinePeers[peerID]; ok {
relevantPeers = append(relevantPeers, peerID)
}
}
s.mu.Unlock()
if len(relevantPeers) > 0 {
s.offlineCallback(relevantPeers)
}
}
// WaitToBeOnlineAndSubscribe waits for a specific peer to come online and subscribes to its state changes.
func (s *PeersStateSubscription) WaitToBeOnlineAndSubscribe(ctx context.Context, peerID messages.PeerID) error {
// Check if already waiting for this peer
s.mu.Lock()
if _, exists := s.waitingPeers[peerID]; exists {
s.mu.Unlock()
return errors.New("already waiting for peer to come online")
}
// Create a channel to wait for the peer to come online
waitCh := make(chan struct{}, 1)
s.waitingPeers[peerID] = waitCh
s.listenForOfflinePeers[peerID] = struct{}{}
s.mu.Unlock()
if err := s.subscribeStateChange(peerID); err != nil {
s.log.Errorf("failed to subscribe to peer state: %s", err)
s.mu.Lock()
if ch, exists := s.waitingPeers[peerID]; exists && ch == waitCh {
close(waitCh)
delete(s.waitingPeers, peerID)
delete(s.listenForOfflinePeers, peerID)
}
s.mu.Unlock()
return err
}
// Wait for peer to come online or context to be cancelled
timeoutCtx, cancel := context.WithTimeout(ctx, OpenConnectionTimeout)
defer cancel()
select {
case _, ok := <-waitCh:
if !ok {
return fmt.Errorf("wait for peer to come online has been cancelled")
}
s.log.Debugf("peer %s is now online", peerID)
return nil
case <-timeoutCtx.Done():
s.log.Debugf("context timed out while waiting for peer %s to come online", peerID)
if err := s.unsubscribeStateChange([]messages.PeerID{peerID}); err != nil {
s.log.Errorf("failed to unsubscribe from peer state: %s", err)
}
s.mu.Lock()
if ch, exists := s.waitingPeers[peerID]; exists && ch == waitCh {
close(waitCh)
delete(s.waitingPeers, peerID)
delete(s.listenForOfflinePeers, peerID)
}
s.mu.Unlock()
return timeoutCtx.Err()
}
}
func (s *PeersStateSubscription) UnsubscribeStateChange(peerIDs []messages.PeerID) error {
msgErr := s.unsubscribeStateChange(peerIDs)
s.mu.Lock()
for _, peerID := range peerIDs {
if wch, ok := s.waitingPeers[peerID]; ok {
close(wch)
delete(s.waitingPeers, peerID)
}
delete(s.listenForOfflinePeers, peerID)
}
s.mu.Unlock()
return msgErr
}
func (s *PeersStateSubscription) Cleanup() {
s.mu.Lock()
defer s.mu.Unlock()
for _, waitCh := range s.waitingPeers {
close(waitCh)
}
s.waitingPeers = make(map[messages.PeerID]chan struct{})
s.listenForOfflinePeers = make(map[messages.PeerID]struct{})
}
func (s *PeersStateSubscription) subscribeStateChange(peerID messages.PeerID) error {
msgs, err := messages.MarshalSubPeerStateMsg([]messages.PeerID{peerID})
if err != nil {
return err
}
for _, msg := range msgs {
if _, err := s.relayConn.Write(msg); err != nil {
return err
}
}
return nil
}
func (s *PeersStateSubscription) unsubscribeStateChange(peerIDs []messages.PeerID) error {
msgs, err := messages.MarshalUnsubPeerStateMsg(peerIDs)
if err != nil {
return err
}
var connWriteErr error
for _, msg := range msgs {
if _, err := s.relayConn.Write(msg); err != nil {
connWriteErr = err
}
}
return connWriteErr
}

View File

@@ -0,0 +1,99 @@
package client
import (
"bytes"
"context"
"testing"
"time"
"github.com/netbirdio/netbird/relay/messages"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type mockRelayedConn struct {
}
func (m *mockRelayedConn) Write(p []byte) (n int, err error) {
return len(p), nil
}
func TestWaitToBeOnlineAndSubscribe_Success(t *testing.T) {
peerID := messages.HashID("peer1")
mockConn := &mockRelayedConn{}
logger := logrus.New()
logger.SetOutput(&bytes.Buffer{}) // discard log output
sub := NewPeersStateSubscription(logrus.NewEntry(logger), mockConn, nil)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Launch wait in background
go func() {
time.Sleep(100 * time.Millisecond)
sub.OnPeersOnline([]messages.PeerID{peerID})
}()
err := sub.WaitToBeOnlineAndSubscribe(ctx, peerID)
assert.NoError(t, err)
}
func TestWaitToBeOnlineAndSubscribe_Timeout(t *testing.T) {
peerID := messages.HashID("peer2")
mockConn := &mockRelayedConn{}
logger := logrus.New()
logger.SetOutput(&bytes.Buffer{})
sub := NewPeersStateSubscription(logrus.NewEntry(logger), mockConn, nil)
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
err := sub.WaitToBeOnlineAndSubscribe(ctx, peerID)
assert.Error(t, err)
assert.Equal(t, context.DeadlineExceeded, err)
}
func TestWaitToBeOnlineAndSubscribe_Duplicate(t *testing.T) {
peerID := messages.HashID("peer3")
mockConn := &mockRelayedConn{}
logger := logrus.New()
logger.SetOutput(&bytes.Buffer{})
sub := NewPeersStateSubscription(logrus.NewEntry(logger), mockConn, nil)
ctx := context.Background()
go func() {
_ = sub.WaitToBeOnlineAndSubscribe(ctx, peerID)
}()
time.Sleep(100 * time.Millisecond)
err := sub.WaitToBeOnlineAndSubscribe(ctx, peerID)
require.Error(t, err)
assert.Contains(t, err.Error(), "already waiting")
}
func TestUnsubscribeStateChange(t *testing.T) {
peerID := messages.HashID("peer4")
mockConn := &mockRelayedConn{}
logger := logrus.New()
logger.SetOutput(&bytes.Buffer{})
sub := NewPeersStateSubscription(logrus.NewEntry(logger), mockConn, nil)
doneChan := make(chan struct{})
go func() {
_ = sub.WaitToBeOnlineAndSubscribe(context.Background(), peerID)
close(doneChan)
}()
time.Sleep(100 * time.Millisecond)
err := sub.UnsubscribeStateChange([]messages.PeerID{peerID})
assert.NoError(t, err)
select {
case <-doneChan:
case <-time.After(200 * time.Millisecond):
// Expected timeout, meaning the subscription was successfully unsubscribed
t.Errorf("timeout")
}
}

View File

@@ -70,8 +70,8 @@ func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) {
func (sp *ServerPicker) startConnection(ctx context.Context, resultChan chan connResult, url string) { func (sp *ServerPicker) startConnection(ctx context.Context, resultChan chan connResult, url string) {
log.Infof("try to connecting to relay server: %s", url) log.Infof("try to connecting to relay server: %s", url)
relayClient := NewClient(ctx, url, sp.TokenStore, sp.PeerID) relayClient := NewClient(url, sp.TokenStore, sp.PeerID)
err := relayClient.Connect() err := relayClient.Connect(ctx)
resultChan <- connResult{ resultChan <- connResult{
RelayClient: relayClient, RelayClient: relayClient,
Url: url, Url: url,

View File

@@ -141,7 +141,14 @@ func execute(cmd *cobra.Command, args []string) error {
hashedSecret := sha256.Sum256([]byte(cobraConfig.AuthSecret)) hashedSecret := sha256.Sum256([]byte(cobraConfig.AuthSecret))
authenticator := auth.NewTimedHMACValidator(hashedSecret[:], 24*time.Hour) authenticator := auth.NewTimedHMACValidator(hashedSecret[:], 24*time.Hour)
srv, err := server.NewServer(metricsServer.Meter, cobraConfig.ExposedAddress, tlsSupport, authenticator) cfg := server.Config{
Meter: metricsServer.Meter,
ExposedAddress: cobraConfig.ExposedAddress,
AuthValidator: authenticator,
TLSSupport: tlsSupport,
}
srv, err := server.NewServer(cfg)
if err != nil { if err != nil {
log.Debugf("failed to create relay server: %v", err) log.Debugf("failed to create relay server: %v", err)
return fmt.Errorf("failed to create relay server: %v", err) return fmt.Errorf("failed to create relay server: %v", err)

View File

@@ -4,38 +4,76 @@ import (
"context" "context"
"fmt" "fmt"
"os" "os"
"sync"
"testing" "testing"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
// Mutex to protect global variable access in tests
var testMutex sync.Mutex
func TestNewReceiver(t *testing.T) { func TestNewReceiver(t *testing.T) {
testMutex.Lock()
originalTimeout := heartbeatTimeout
heartbeatTimeout = 5 * time.Second heartbeatTimeout = 5 * time.Second
testMutex.Unlock()
defer func() {
testMutex.Lock()
heartbeatTimeout = originalTimeout
testMutex.Unlock()
}()
r := NewReceiver(log.WithContext(context.Background())) r := NewReceiver(log.WithContext(context.Background()))
defer r.Stop()
select { select {
case <-r.OnTimeout: case <-r.OnTimeout:
t.Error("unexpected timeout") t.Error("unexpected timeout")
case <-time.After(1 * time.Second): case <-time.After(1 * time.Second):
// Test passes if no timeout received
} }
} }
func TestNewReceiverNotReceive(t *testing.T) { func TestNewReceiverNotReceive(t *testing.T) {
testMutex.Lock()
originalTimeout := heartbeatTimeout
heartbeatTimeout = 1 * time.Second heartbeatTimeout = 1 * time.Second
testMutex.Unlock()
defer func() {
testMutex.Lock()
heartbeatTimeout = originalTimeout
testMutex.Unlock()
}()
r := NewReceiver(log.WithContext(context.Background())) r := NewReceiver(log.WithContext(context.Background()))
defer r.Stop()
select { select {
case <-r.OnTimeout: case <-r.OnTimeout:
// Test passes if timeout is received
case <-time.After(2 * time.Second): case <-time.After(2 * time.Second):
t.Error("timeout not received") t.Error("timeout not received")
} }
} }
func TestNewReceiverAck(t *testing.T) { func TestNewReceiverAck(t *testing.T) {
testMutex.Lock()
originalTimeout := heartbeatTimeout
heartbeatTimeout = 2 * time.Second heartbeatTimeout = 2 * time.Second
testMutex.Unlock()
defer func() {
testMutex.Lock()
heartbeatTimeout = originalTimeout
testMutex.Unlock()
}()
r := NewReceiver(log.WithContext(context.Background())) r := NewReceiver(log.WithContext(context.Background()))
defer r.Stop()
r.Heartbeat() r.Heartbeat()
@@ -59,13 +97,18 @@ func TestReceiverHealthCheckAttemptThreshold(t *testing.T) {
for _, tc := range testsCases { for _, tc := range testsCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
testMutex.Lock()
originalInterval := healthCheckInterval originalInterval := healthCheckInterval
originalTimeout := heartbeatTimeout originalTimeout := heartbeatTimeout
healthCheckInterval = 1 * time.Second healthCheckInterval = 1 * time.Second
heartbeatTimeout = healthCheckInterval + 500*time.Millisecond heartbeatTimeout = healthCheckInterval + 500*time.Millisecond
testMutex.Unlock()
defer func() { defer func() {
testMutex.Lock()
healthCheckInterval = originalInterval healthCheckInterval = originalInterval
heartbeatTimeout = originalTimeout heartbeatTimeout = originalTimeout
testMutex.Unlock()
}() }()
//nolint:tenv //nolint:tenv
os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold)) os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold))

View File

@@ -135,7 +135,11 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) {
defer cancel() defer cancel()
sender := NewSender(log.WithField("test_name", tc.name)) sender := NewSender(log.WithField("test_name", tc.name))
go sender.StartHealthCheck(ctx) senderExit := make(chan struct{})
go func() {
sender.StartHealthCheck(ctx)
close(senderExit)
}()
go func() { go func() {
responded := false responded := false
@@ -169,6 +173,11 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) {
t.Fatalf("should have timed out before %s", testTimeout) t.Fatalf("should have timed out before %s", testTimeout)
} }
select {
case <-senderExit:
case <-time.After(2 * time.Second):
t.Fatalf("sender did not exit in time")
}
}) })
} }

View File

@@ -8,24 +8,24 @@ import (
const ( const (
prefixLength = 4 prefixLength = 4
IDSize = prefixLength + sha256.Size peerIDSize = prefixLength + sha256.Size
) )
var ( var (
prefix = []byte("sha-") // 4 bytes prefix = []byte("sha-") // 4 bytes
) )
// HashID generates a sha256 hash from the peerID and returns the hash and the human-readable string type PeerID [peerIDSize]byte
func HashID(peerID string) ([]byte, string) {
idHash := sha256.Sum256([]byte(peerID)) func (p PeerID) String() string {
idHashString := string(prefix) + base64.StdEncoding.EncodeToString(idHash[:]) return fmt.Sprintf("%s%s", p[:prefixLength], base64.StdEncoding.EncodeToString(p[prefixLength:]))
var prefixedHash []byte
prefixedHash = append(prefixedHash, prefix...)
prefixedHash = append(prefixedHash, idHash[:]...)
return prefixedHash, idHashString
} }
// HashIDToString converts a hash to a human-readable string // HashID generates a sha256 hash from the peerID and returns the hash and the human-readable string
func HashIDToString(idHash []byte) string { func HashID(peerID string) PeerID {
return fmt.Sprintf("%s%s", idHash[:prefixLength], base64.StdEncoding.EncodeToString(idHash[prefixLength:])) idHash := sha256.Sum256([]byte(peerID))
var prefixedHash [peerIDSize]byte
copy(prefixedHash[:prefixLength], prefix)
copy(prefixedHash[prefixLength:], idHash[:])
return prefixedHash
} }

View File

@@ -1,13 +0,0 @@
package messages
import (
"testing"
)
func TestHashID(t *testing.T) {
hashedID, hashedStringId := HashID("alice")
enc := HashIDToString(hashedID)
if enc != hashedStringId {
t.Errorf("expected %s, got %s", hashedStringId, enc)
}
}

View File

@@ -9,20 +9,27 @@ import (
const ( const (
MaxHandshakeSize = 212 MaxHandshakeSize = 212
MaxHandshakeRespSize = 8192 MaxHandshakeRespSize = 8192
MaxMessageSize = 8820
CurrentProtocolVersion = 1 CurrentProtocolVersion = 1
MsgTypeUnknown MsgType = 0 MsgTypeUnknown MsgType = 0
// Deprecated: Use MsgTypeAuth instead. // Deprecated: Use MsgTypeAuth instead.
MsgTypeHello MsgType = 1 MsgTypeHello = 1
// Deprecated: Use MsgTypeAuthResponse instead. // Deprecated: Use MsgTypeAuthResponse instead.
MsgTypeHelloResponse MsgType = 2 MsgTypeHelloResponse = 2
MsgTypeTransport MsgType = 3 MsgTypeTransport = 3
MsgTypeClose MsgType = 4 MsgTypeClose = 4
MsgTypeHealthCheck MsgType = 5 MsgTypeHealthCheck = 5
MsgTypeAuth = 6 MsgTypeAuth = 6
MsgTypeAuthResponse = 7 MsgTypeAuthResponse = 7
// Peers state messages
MsgTypeSubscribePeerState = 8
MsgTypeUnsubscribePeerState = 9
MsgTypePeersOnline = 10
MsgTypePeersWentOffline = 11
// base size of the message // base size of the message
sizeOfVersionByte = 1 sizeOfVersionByte = 1
sizeOfMsgType = 1 sizeOfMsgType = 1
@@ -30,17 +37,17 @@ const (
// auth message // auth message
sizeOfMagicByte = 4 sizeOfMagicByte = 4
headerSizeAuth = sizeOfMagicByte + IDSize headerSizeAuth = sizeOfMagicByte + peerIDSize
offsetMagicByte = sizeOfProtoHeader offsetMagicByte = sizeOfProtoHeader
offsetAuthPeerID = sizeOfProtoHeader + sizeOfMagicByte offsetAuthPeerID = sizeOfProtoHeader + sizeOfMagicByte
headerTotalSizeAuth = sizeOfProtoHeader + headerSizeAuth headerTotalSizeAuth = sizeOfProtoHeader + headerSizeAuth
// hello message // hello message
headerSizeHello = sizeOfMagicByte + IDSize headerSizeHello = sizeOfMagicByte + peerIDSize
headerSizeHelloResp = 0 headerSizeHelloResp = 0
// transport // transport
headerSizeTransport = IDSize headerSizeTransport = peerIDSize
offsetTransportID = sizeOfProtoHeader offsetTransportID = sizeOfProtoHeader
headerTotalSizeTransport = sizeOfProtoHeader + headerSizeTransport headerTotalSizeTransport = sizeOfProtoHeader + headerSizeTransport
) )
@@ -72,6 +79,14 @@ func (m MsgType) String() string {
return "close" return "close"
case MsgTypeHealthCheck: case MsgTypeHealthCheck:
return "health check" return "health check"
case MsgTypeSubscribePeerState:
return "subscribe peer state"
case MsgTypeUnsubscribePeerState:
return "unsubscribe peer state"
case MsgTypePeersOnline:
return "peers online"
case MsgTypePeersWentOffline:
return "peers went offline"
default: default:
return "unknown" return "unknown"
} }
@@ -102,7 +117,9 @@ func DetermineClientMessageType(msg []byte) (MsgType, error) {
MsgTypeAuth, MsgTypeAuth,
MsgTypeTransport, MsgTypeTransport,
MsgTypeClose, MsgTypeClose,
MsgTypeHealthCheck: MsgTypeHealthCheck,
MsgTypeSubscribePeerState,
MsgTypeUnsubscribePeerState:
return msgType, nil return msgType, nil
default: default:
return MsgTypeUnknown, fmt.Errorf("invalid msg type %d", msgType) return MsgTypeUnknown, fmt.Errorf("invalid msg type %d", msgType)
@@ -122,7 +139,9 @@ func DetermineServerMessageType(msg []byte) (MsgType, error) {
MsgTypeAuthResponse, MsgTypeAuthResponse,
MsgTypeTransport, MsgTypeTransport,
MsgTypeClose, MsgTypeClose,
MsgTypeHealthCheck: MsgTypeHealthCheck,
MsgTypePeersOnline,
MsgTypePeersWentOffline:
return msgType, nil return msgType, nil
default: default:
return MsgTypeUnknown, fmt.Errorf("invalid msg type %d", msgType) return MsgTypeUnknown, fmt.Errorf("invalid msg type %d", msgType)
@@ -135,11 +154,7 @@ func DetermineServerMessageType(msg []byte) (MsgType, error) {
// message is used to authenticate the client with the server. The authentication is done using an HMAC method. // message is used to authenticate the client with the server. The authentication is done using an HMAC method.
// The protocol does not limit to use HMAC, it can be any other method. If the authentication failed the server will // The protocol does not limit to use HMAC, it can be any other method. If the authentication failed the server will
// close the network connection without any response. // close the network connection without any response.
func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) { func MarshalHelloMsg(peerID PeerID, additions []byte) ([]byte, error) {
if len(peerID) != IDSize {
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
}
msg := make([]byte, sizeOfProtoHeader+sizeOfMagicByte, sizeOfProtoHeader+headerSizeHello+len(additions)) msg := make([]byte, sizeOfProtoHeader+sizeOfMagicByte, sizeOfProtoHeader+headerSizeHello+len(additions))
msg[0] = byte(CurrentProtocolVersion) msg[0] = byte(CurrentProtocolVersion)
@@ -147,7 +162,7 @@ func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) {
copy(msg[sizeOfProtoHeader:sizeOfProtoHeader+sizeOfMagicByte], magicHeader) copy(msg[sizeOfProtoHeader:sizeOfProtoHeader+sizeOfMagicByte], magicHeader)
msg = append(msg, peerID...) msg = append(msg, peerID[:]...)
msg = append(msg, additions...) msg = append(msg, additions...)
return msg, nil return msg, nil
@@ -156,7 +171,7 @@ func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) {
// Deprecated: Use UnmarshalAuthMsg instead. // Deprecated: Use UnmarshalAuthMsg instead.
// UnmarshalHelloMsg extracts peerID and the additional data from the hello message. The Additional data is used to // UnmarshalHelloMsg extracts peerID and the additional data from the hello message. The Additional data is used to
// authenticate the client with the server. // authenticate the client with the server.
func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) { func UnmarshalHelloMsg(msg []byte) (*PeerID, []byte, error) {
if len(msg) < sizeOfProtoHeader+headerSizeHello { if len(msg) < sizeOfProtoHeader+headerSizeHello {
return nil, nil, ErrInvalidMessageLength return nil, nil, ErrInvalidMessageLength
} }
@@ -164,7 +179,9 @@ func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) {
return nil, nil, errors.New("invalid magic header") return nil, nil, errors.New("invalid magic header")
} }
return msg[sizeOfProtoHeader+sizeOfMagicByte : sizeOfProtoHeader+headerSizeHello], msg[headerSizeHello:], nil peerID := PeerID(msg[sizeOfProtoHeader+sizeOfMagicByte : sizeOfProtoHeader+headerSizeHello])
return &peerID, msg[headerSizeHello:], nil
} }
// Deprecated: Use MarshalAuthResponse instead. // Deprecated: Use MarshalAuthResponse instead.
@@ -197,34 +214,33 @@ func UnmarshalHelloResponse(msg []byte) ([]byte, error) {
// message is used to authenticate the client with the server. The authentication is done using an HMAC method. // message is used to authenticate the client with the server. The authentication is done using an HMAC method.
// The protocol does not limit to use HMAC, it can be any other method. If the authentication failed the server will // The protocol does not limit to use HMAC, it can be any other method. If the authentication failed the server will
// close the network connection without any response. // close the network connection without any response.
func MarshalAuthMsg(peerID []byte, authPayload []byte) ([]byte, error) { func MarshalAuthMsg(peerID PeerID, authPayload []byte) ([]byte, error) {
if len(peerID) != IDSize { if headerTotalSizeAuth+len(authPayload) > MaxHandshakeSize {
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID)) return nil, fmt.Errorf("too large auth payload")
} }
msg := make([]byte, sizeOfProtoHeader+sizeOfMagicByte, headerTotalSizeAuth+len(authPayload)) msg := make([]byte, headerTotalSizeAuth+len(authPayload))
msg[0] = byte(CurrentProtocolVersion) msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeAuth) msg[1] = byte(MsgTypeAuth)
copy(msg[sizeOfProtoHeader:], magicHeader) copy(msg[sizeOfProtoHeader:], magicHeader)
copy(msg[offsetAuthPeerID:], peerID[:])
msg = append(msg, peerID...) copy(msg[headerTotalSizeAuth:], authPayload)
msg = append(msg, authPayload...)
return msg, nil return msg, nil
} }
// UnmarshalAuthMsg extracts peerID and the auth payload from the message // UnmarshalAuthMsg extracts peerID and the auth payload from the message
func UnmarshalAuthMsg(msg []byte) ([]byte, []byte, error) { func UnmarshalAuthMsg(msg []byte) (*PeerID, []byte, error) {
if len(msg) < headerTotalSizeAuth { if len(msg) < headerTotalSizeAuth {
return nil, nil, ErrInvalidMessageLength return nil, nil, ErrInvalidMessageLength
} }
// Validate the magic header
if !bytes.Equal(msg[offsetMagicByte:offsetMagicByte+sizeOfMagicByte], magicHeader) { if !bytes.Equal(msg[offsetMagicByte:offsetMagicByte+sizeOfMagicByte], magicHeader) {
return nil, nil, errors.New("invalid magic header") return nil, nil, errors.New("invalid magic header")
} }
return msg[offsetAuthPeerID:headerTotalSizeAuth], msg[headerTotalSizeAuth:], nil peerID := PeerID(msg[offsetAuthPeerID:headerTotalSizeAuth])
return &peerID, msg[headerTotalSizeAuth:], nil
} }
// MarshalAuthResponse creates a response message to the auth. // MarshalAuthResponse creates a response message to the auth.
@@ -268,45 +284,48 @@ func MarshalCloseMsg() []byte {
// MarshalTransportMsg creates a transport message. // MarshalTransportMsg creates a transport message.
// The transport message is used to exchange data between peers. The message contains the data to be exchanged and the // The transport message is used to exchange data between peers. The message contains the data to be exchanged and the
// destination peer hashed ID. // destination peer hashed ID.
func MarshalTransportMsg(peerID, payload []byte) ([]byte, error) { func MarshalTransportMsg(peerID PeerID, payload []byte) ([]byte, error) {
if len(peerID) != IDSize { // todo validate size
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID)) msg := make([]byte, headerTotalSizeTransport+len(payload))
}
msg := make([]byte, headerTotalSizeTransport, headerTotalSizeTransport+len(payload))
msg[0] = byte(CurrentProtocolVersion) msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeTransport) msg[1] = byte(MsgTypeTransport)
copy(msg[sizeOfProtoHeader:], peerID) copy(msg[sizeOfProtoHeader:], peerID[:])
msg = append(msg, payload...) copy(msg[sizeOfProtoHeader+peerIDSize:], payload)
return msg, nil return msg, nil
} }
// UnmarshalTransportMsg extracts the peerID and the payload from the transport message. // UnmarshalTransportMsg extracts the peerID and the payload from the transport message.
func UnmarshalTransportMsg(buf []byte) ([]byte, []byte, error) { func UnmarshalTransportMsg(buf []byte) (*PeerID, []byte, error) {
if len(buf) < headerTotalSizeTransport { if len(buf) < headerTotalSizeTransport {
return nil, nil, ErrInvalidMessageLength return nil, nil, ErrInvalidMessageLength
} }
return buf[offsetTransportID:headerTotalSizeTransport], buf[headerTotalSizeTransport:], nil const offsetEnd = offsetTransportID + peerIDSize
var peerID PeerID
copy(peerID[:], buf[offsetTransportID:offsetEnd])
return &peerID, buf[headerTotalSizeTransport:], nil
} }
// UnmarshalTransportID extracts the peerID from the transport message. // UnmarshalTransportID extracts the peerID from the transport message.
func UnmarshalTransportID(buf []byte) ([]byte, error) { func UnmarshalTransportID(buf []byte) (*PeerID, error) {
if len(buf) < headerTotalSizeTransport { if len(buf) < headerTotalSizeTransport {
return nil, ErrInvalidMessageLength return nil, ErrInvalidMessageLength
} }
return buf[offsetTransportID:headerTotalSizeTransport], nil
const offsetEnd = offsetTransportID + peerIDSize
var id PeerID
copy(id[:], buf[offsetTransportID:offsetEnd])
return &id, nil
} }
// UpdateTransportMsg updates the peerID in the transport message. // UpdateTransportMsg updates the peerID in the transport message.
// With this function the server can reuse the given byte slice to update the peerID in the transport message. So do // With this function the server can reuse the given byte slice to update the peerID in the transport message. So do
// need to allocate a new byte slice. // need to allocate a new byte slice.
func UpdateTransportMsg(msg []byte, peerID []byte) error { func UpdateTransportMsg(msg []byte, peerID PeerID) error {
if len(msg) < offsetTransportID+len(peerID) { if len(msg) < offsetTransportID+peerIDSize {
return ErrInvalidMessageLength return ErrInvalidMessageLength
} }
copy(msg[offsetTransportID:], peerID) copy(msg[offsetTransportID:], peerID[:])
return nil return nil
} }

View File

@@ -5,7 +5,7 @@ import (
) )
func TestMarshalHelloMsg(t *testing.T) { func TestMarshalHelloMsg(t *testing.T) {
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+") peerID := HashID("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
msg, err := MarshalHelloMsg(peerID, nil) msg, err := MarshalHelloMsg(peerID, nil)
if err != nil { if err != nil {
t.Fatalf("error: %v", err) t.Fatalf("error: %v", err)
@@ -24,13 +24,13 @@ func TestMarshalHelloMsg(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("error: %v", err) t.Fatalf("error: %v", err)
} }
if string(receivedPeerID) != string(peerID) { if receivedPeerID.String() != peerID.String() {
t.Errorf("expected %s, got %s", peerID, receivedPeerID) t.Errorf("expected %s, got %s", peerID, receivedPeerID)
} }
} }
func TestMarshalAuthMsg(t *testing.T) { func TestMarshalAuthMsg(t *testing.T) {
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+") peerID := HashID("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
msg, err := MarshalAuthMsg(peerID, []byte{}) msg, err := MarshalAuthMsg(peerID, []byte{})
if err != nil { if err != nil {
t.Fatalf("error: %v", err) t.Fatalf("error: %v", err)
@@ -49,7 +49,7 @@ func TestMarshalAuthMsg(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("error: %v", err) t.Fatalf("error: %v", err)
} }
if string(receivedPeerID) != string(peerID) { if receivedPeerID.String() != peerID.String() {
t.Errorf("expected %s, got %s", peerID, receivedPeerID) t.Errorf("expected %s, got %s", peerID, receivedPeerID)
} }
} }
@@ -80,7 +80,7 @@ func TestMarshalAuthResponse(t *testing.T) {
} }
func TestMarshalTransportMsg(t *testing.T) { func TestMarshalTransportMsg(t *testing.T) {
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+") peerID := HashID("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
payload := []byte("payload") payload := []byte("payload")
msg, err := MarshalTransportMsg(peerID, payload) msg, err := MarshalTransportMsg(peerID, payload)
if err != nil { if err != nil {
@@ -101,7 +101,7 @@ func TestMarshalTransportMsg(t *testing.T) {
t.Fatalf("failed to unmarshal transport id: %v", err) t.Fatalf("failed to unmarshal transport id: %v", err)
} }
if string(uPeerID) != string(peerID) { if uPeerID.String() != peerID.String() {
t.Errorf("expected %s, got %s", peerID, uPeerID) t.Errorf("expected %s, got %s", peerID, uPeerID)
} }
@@ -110,8 +110,8 @@ func TestMarshalTransportMsg(t *testing.T) {
t.Fatalf("error: %v", err) t.Fatalf("error: %v", err)
} }
if string(id) != string(peerID) { if id.String() != peerID.String() {
t.Errorf("expected %s, got %s", peerID, id) t.Errorf("expected: '%s', got: '%s'", peerID, id)
} }
if string(respPayload) != string(payload) { if string(respPayload) != string(payload) {

View File

@@ -0,0 +1,92 @@
package messages
import (
"fmt"
)
func MarshalSubPeerStateMsg(ids []PeerID) ([][]byte, error) {
return marshalPeerIDs(ids, byte(MsgTypeSubscribePeerState))
}
func UnmarshalSubPeerStateMsg(buf []byte) ([]PeerID, error) {
return unmarshalPeerIDs(buf)
}
func MarshalUnsubPeerStateMsg(ids []PeerID) ([][]byte, error) {
return marshalPeerIDs(ids, byte(MsgTypeUnsubscribePeerState))
}
func UnmarshalUnsubPeerStateMsg(buf []byte) ([]PeerID, error) {
return unmarshalPeerIDs(buf)
}
func MarshalPeersOnline(ids []PeerID) ([][]byte, error) {
return marshalPeerIDs(ids, byte(MsgTypePeersOnline))
}
func UnmarshalPeersOnlineMsg(buf []byte) ([]PeerID, error) {
return unmarshalPeerIDs(buf)
}
func MarshalPeersWentOffline(ids []PeerID) ([][]byte, error) {
return marshalPeerIDs(ids, byte(MsgTypePeersWentOffline))
}
func UnMarshalPeersWentOffline(buf []byte) ([]PeerID, error) {
return unmarshalPeerIDs(buf)
}
// marshalPeerIDs is a generic function to marshal peer IDs with a specific message type
func marshalPeerIDs(ids []PeerID, msgType byte) ([][]byte, error) {
if len(ids) == 0 {
return nil, fmt.Errorf("no list of peer ids provided")
}
const maxPeersPerMessage = (MaxMessageSize - sizeOfProtoHeader) / peerIDSize
var messages [][]byte
for i := 0; i < len(ids); i += maxPeersPerMessage {
end := i + maxPeersPerMessage
if end > len(ids) {
end = len(ids)
}
chunk := ids[i:end]
totalSize := sizeOfProtoHeader + len(chunk)*peerIDSize
buf := make([]byte, totalSize)
buf[0] = byte(CurrentProtocolVersion)
buf[1] = msgType
offset := sizeOfProtoHeader
for _, id := range chunk {
copy(buf[offset:], id[:])
offset += peerIDSize
}
messages = append(messages, buf)
}
return messages, nil
}
// unmarshalPeerIDs is a generic function to unmarshal peer IDs from a buffer
func unmarshalPeerIDs(buf []byte) ([]PeerID, error) {
if len(buf) < sizeOfProtoHeader {
return nil, fmt.Errorf("invalid message format")
}
if (len(buf)-sizeOfProtoHeader)%peerIDSize != 0 {
return nil, fmt.Errorf("invalid peer list size: %d", len(buf)-sizeOfProtoHeader)
}
numIDs := (len(buf) - sizeOfProtoHeader) / peerIDSize
ids := make([]PeerID, numIDs)
offset := sizeOfProtoHeader
for i := 0; i < numIDs; i++ {
copy(ids[i][:], buf[offset:offset+peerIDSize])
offset += peerIDSize
}
return ids, nil
}

View File

@@ -0,0 +1,144 @@
package messages
import (
"bytes"
"testing"
)
const (
testPeerCount = 10
)
// Helper function to generate test PeerIDs
func generateTestPeerIDs(n int) []PeerID {
ids := make([]PeerID, n)
for i := 0; i < n; i++ {
for j := 0; j < peerIDSize; j++ {
ids[i][j] = byte(i + j)
}
}
return ids
}
// Helper function to compare slices of PeerID
func peerIDEqual(a, b []PeerID) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if !bytes.Equal(a[i][:], b[i][:]) {
return false
}
}
return true
}
func TestMarshalUnmarshalSubPeerState(t *testing.T) {
ids := generateTestPeerIDs(testPeerCount)
msgs, err := MarshalSubPeerStateMsg(ids)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
var allIDs []PeerID
for _, msg := range msgs {
decoded, err := UnmarshalSubPeerStateMsg(msg)
if err != nil {
t.Fatalf("unmarshal failed: %v", err)
}
allIDs = append(allIDs, decoded...)
}
if !peerIDEqual(ids, allIDs) {
t.Errorf("expected %v, got %v", ids, allIDs)
}
}
func TestMarshalSubPeerState_EmptyInput(t *testing.T) {
_, err := MarshalSubPeerStateMsg([]PeerID{})
if err == nil {
t.Errorf("expected error for empty input")
}
}
func TestUnmarshalSubPeerState_Invalid(t *testing.T) {
// Too short
_, err := UnmarshalSubPeerStateMsg([]byte{1})
if err == nil {
t.Errorf("expected error for short input")
}
// Misaligned length
buf := make([]byte, sizeOfProtoHeader+1)
_, err = UnmarshalSubPeerStateMsg(buf)
if err == nil {
t.Errorf("expected error for misaligned input")
}
}
func TestMarshalUnmarshalPeersOnline(t *testing.T) {
ids := generateTestPeerIDs(testPeerCount)
msgs, err := MarshalPeersOnline(ids)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
var allIDs []PeerID
for _, msg := range msgs {
decoded, err := UnmarshalPeersOnlineMsg(msg)
if err != nil {
t.Fatalf("unmarshal failed: %v", err)
}
allIDs = append(allIDs, decoded...)
}
if !peerIDEqual(ids, allIDs) {
t.Errorf("expected %v, got %v", ids, allIDs)
}
}
func TestMarshalPeersOnline_EmptyInput(t *testing.T) {
_, err := MarshalPeersOnline([]PeerID{})
if err == nil {
t.Errorf("expected error for empty input")
}
}
func TestUnmarshalPeersOnline_Invalid(t *testing.T) {
_, err := UnmarshalPeersOnlineMsg([]byte{1})
if err == nil {
t.Errorf("expected error for short input")
}
}
func TestMarshalUnmarshalPeersWentOffline(t *testing.T) {
ids := generateTestPeerIDs(testPeerCount)
msgs, err := MarshalPeersWentOffline(ids)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
var allIDs []PeerID
for _, msg := range msgs {
// MarshalPeersWentOffline shares no unmarshal function, so reuse PeersOnline
decoded, err := UnmarshalPeersOnlineMsg(msg)
if err != nil {
t.Fatalf("unmarshal failed: %v", err)
}
allIDs = append(allIDs, decoded...)
}
if !peerIDEqual(ids, allIDs) {
t.Errorf("expected %v, got %v", ids, allIDs)
}
}
func TestMarshalPeersWentOffline_EmptyInput(t *testing.T) {
_, err := MarshalPeersWentOffline([]PeerID{})
if err == nil {
t.Errorf("expected error for empty input")
}
}

View File

@@ -20,7 +20,7 @@ type Metrics struct {
TransferBytesRecv metric.Int64Counter TransferBytesRecv metric.Int64Counter
AuthenticationTime metric.Float64Histogram AuthenticationTime metric.Float64Histogram
PeerStoreTime metric.Float64Histogram PeerStoreTime metric.Float64Histogram
peerReconnections metric.Int64Counter
peers metric.Int64UpDownCounter peers metric.Int64UpDownCounter
peerActivityChan chan string peerActivityChan chan string
peerLastActive map[string]time.Time peerLastActive map[string]time.Time
@@ -80,6 +80,13 @@ func NewMetrics(ctx context.Context, meter metric.Meter) (*Metrics, error) {
return nil, err return nil, err
} }
peerReconnections, err := meter.Int64Counter("relay_peer_reconnections_total",
metric.WithDescription("Total number of times peers have reconnected and closed old connections"),
)
if err != nil {
return nil, err
}
m := &Metrics{ m := &Metrics{
Meter: meter, Meter: meter,
TransferBytesSent: bytesSent, TransferBytesSent: bytesSent,
@@ -87,6 +94,7 @@ func NewMetrics(ctx context.Context, meter metric.Meter) (*Metrics, error) {
AuthenticationTime: authTime, AuthenticationTime: authTime,
PeerStoreTime: peerStoreTime, PeerStoreTime: peerStoreTime,
peers: peers, peers: peers,
peerReconnections: peerReconnections,
ctx: ctx, ctx: ctx,
peerActivityChan: make(chan string, 10), peerActivityChan: make(chan string, 10),
@@ -138,6 +146,10 @@ func (m *Metrics) PeerDisconnected(id string) {
delete(m.peerLastActive, id) delete(m.peerLastActive, id)
} }
func (m *Metrics) RecordPeerReconnection() {
m.peerReconnections.Add(m.ctx, 1)
}
// PeerActivity increases the active connections // PeerActivity increases the active connections
func (m *Metrics) PeerActivity(peerID string) { func (m *Metrics) PeerActivity(peerID string) {
select { select {

View File

@@ -6,7 +6,6 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/auth"
"github.com/netbirdio/netbird/relay/messages" "github.com/netbirdio/netbird/relay/messages"
//nolint:staticcheck //nolint:staticcheck
"github.com/netbirdio/netbird/relay/messages/address" "github.com/netbirdio/netbird/relay/messages/address"
@@ -14,6 +13,12 @@ import (
authmsg "github.com/netbirdio/netbird/relay/messages/auth" authmsg "github.com/netbirdio/netbird/relay/messages/auth"
) )
type Validator interface {
Validate(any) error
// Deprecated: Use Validate instead.
ValidateHelloMsgType(any) error
}
// preparedMsg contains the marshalled success response messages // preparedMsg contains the marshalled success response messages
type preparedMsg struct { type preparedMsg struct {
responseHelloMsg []byte responseHelloMsg []byte
@@ -54,14 +59,14 @@ func marshalResponseHelloMsg(instanceURL string) ([]byte, error) {
type handshake struct { type handshake struct {
conn net.Conn conn net.Conn
validator auth.Validator validator Validator
preparedMsg *preparedMsg preparedMsg *preparedMsg
handshakeMethodAuth bool handshakeMethodAuth bool
peerID string peerID *messages.PeerID
} }
func (h *handshake) handshakeReceive() ([]byte, error) { func (h *handshake) handshakeReceive() (*messages.PeerID, error) {
buf := make([]byte, messages.MaxHandshakeSize) buf := make([]byte, messages.MaxHandshakeSize)
n, err := h.conn.Read(buf) n, err := h.conn.Read(buf)
if err != nil { if err != nil {
@@ -80,17 +85,14 @@ func (h *handshake) handshakeReceive() ([]byte, error) {
return nil, fmt.Errorf("determine message type from %s: %w", h.conn.RemoteAddr(), err) return nil, fmt.Errorf("determine message type from %s: %w", h.conn.RemoteAddr(), err)
} }
var ( var peerID *messages.PeerID
bytePeerID []byte
peerID string
)
switch msgType { switch msgType {
//nolint:staticcheck //nolint:staticcheck
case messages.MsgTypeHello: case messages.MsgTypeHello:
bytePeerID, peerID, err = h.handleHelloMsg(buf) peerID, err = h.handleHelloMsg(buf)
case messages.MsgTypeAuth: case messages.MsgTypeAuth:
h.handshakeMethodAuth = true h.handshakeMethodAuth = true
bytePeerID, peerID, err = h.handleAuthMsg(buf) peerID, err = h.handleAuthMsg(buf)
default: default:
return nil, fmt.Errorf("invalid message type %d from %s", msgType, h.conn.RemoteAddr()) return nil, fmt.Errorf("invalid message type %d from %s", msgType, h.conn.RemoteAddr())
} }
@@ -98,7 +100,7 @@ func (h *handshake) handshakeReceive() ([]byte, error) {
return nil, err return nil, err
} }
h.peerID = peerID h.peerID = peerID
return bytePeerID, nil return peerID, nil
} }
func (h *handshake) handshakeResponse() error { func (h *handshake) handshakeResponse() error {
@@ -116,40 +118,37 @@ func (h *handshake) handshakeResponse() error {
return nil return nil
} }
func (h *handshake) handleHelloMsg(buf []byte) ([]byte, string, error) { func (h *handshake) handleHelloMsg(buf []byte) (*messages.PeerID, error) {
//nolint:staticcheck //nolint:staticcheck
rawPeerID, authData, err := messages.UnmarshalHelloMsg(buf) peerID, authData, err := messages.UnmarshalHelloMsg(buf)
if err != nil { if err != nil {
return nil, "", fmt.Errorf("unmarshal hello message: %w", err) return nil, fmt.Errorf("unmarshal hello message: %w", err)
} }
peerID := messages.HashIDToString(rawPeerID)
log.Warnf("peer %s (%s) is using deprecated initial message type", peerID, h.conn.RemoteAddr()) log.Warnf("peer %s (%s) is using deprecated initial message type", peerID, h.conn.RemoteAddr())
authMsg, err := authmsg.UnmarshalMsg(authData) authMsg, err := authmsg.UnmarshalMsg(authData)
if err != nil { if err != nil {
return nil, "", fmt.Errorf("unmarshal auth message: %w", err) return nil, fmt.Errorf("unmarshal auth message: %w", err)
} }
//nolint:staticcheck //nolint:staticcheck
if err := h.validator.ValidateHelloMsgType(authMsg.AdditionalData); err != nil { if err := h.validator.ValidateHelloMsgType(authMsg.AdditionalData); err != nil {
return nil, "", fmt.Errorf("validate %s (%s): %w", peerID, h.conn.RemoteAddr(), err) return nil, fmt.Errorf("validate %s (%s): %w", peerID, h.conn.RemoteAddr(), err)
} }
return rawPeerID, peerID, nil return peerID, nil
} }
func (h *handshake) handleAuthMsg(buf []byte) ([]byte, string, error) { func (h *handshake) handleAuthMsg(buf []byte) (*messages.PeerID, error) {
rawPeerID, authPayload, err := messages.UnmarshalAuthMsg(buf) rawPeerID, authPayload, err := messages.UnmarshalAuthMsg(buf)
if err != nil { if err != nil {
return nil, "", fmt.Errorf("unmarshal hello message: %w", err) return nil, fmt.Errorf("unmarshal hello message: %w", err)
} }
peerID := messages.HashIDToString(rawPeerID)
if err := h.validator.Validate(authPayload); err != nil { if err := h.validator.Validate(authPayload); err != nil {
return nil, "", fmt.Errorf("validate %s (%s): %w", peerID, h.conn.RemoteAddr(), err) return nil, fmt.Errorf("validate %s (%s): %w", rawPeerID.String(), h.conn.RemoteAddr(), err)
} }
return rawPeerID, peerID, nil return rawPeerID, nil
} }

View File

@@ -18,12 +18,9 @@ type Listener struct {
TLSConfig *tls.Config TLSConfig *tls.Config
listener *quic.Listener listener *quic.Listener
acceptFn func(conn net.Conn)
} }
func (l *Listener) Listen(acceptFn func(conn net.Conn)) error { func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
l.acceptFn = acceptFn
quicCfg := &quic.Config{ quicCfg := &quic.Config{
EnableDatagrams: true, EnableDatagrams: true,
InitialPacketSize: 1452, InitialPacketSize: 1452,
@@ -49,7 +46,7 @@ func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
log.Infof("QUIC client connected from: %s", session.RemoteAddr()) log.Infof("QUIC client connected from: %s", session.RemoteAddr())
conn := NewConn(session) conn := NewConn(session)
l.acceptFn(conn) acceptFn(conn)
} }
} }

View File

@@ -12,10 +12,11 @@ import (
"github.com/netbirdio/netbird/relay/healthcheck" "github.com/netbirdio/netbird/relay/healthcheck"
"github.com/netbirdio/netbird/relay/messages" "github.com/netbirdio/netbird/relay/messages"
"github.com/netbirdio/netbird/relay/metrics" "github.com/netbirdio/netbird/relay/metrics"
"github.com/netbirdio/netbird/relay/server/store"
) )
const ( const (
bufferSize = 8820 bufferSize = messages.MaxMessageSize
errCloseConn = "failed to close connection to peer: %s" errCloseConn = "failed to close connection to peer: %s"
) )
@@ -24,31 +25,40 @@ const (
type Peer struct { type Peer struct {
metrics *metrics.Metrics metrics *metrics.Metrics
log *log.Entry log *log.Entry
idS string id messages.PeerID
idB []byte
conn net.Conn conn net.Conn
connMu sync.RWMutex connMu sync.RWMutex
store *Store store *store.Store
notifier *store.PeerNotifier
peersListener *store.Listener
// between the online peer collection step and the notification sending should not be sent offline notifications from another thread
notificationMutex sync.Mutex
} }
// NewPeer creates a new Peer instance and prepare custom logging // NewPeer creates a new Peer instance and prepare custom logging
func NewPeer(metrics *metrics.Metrics, id []byte, conn net.Conn, store *Store) *Peer { func NewPeer(metrics *metrics.Metrics, id messages.PeerID, conn net.Conn, store *store.Store, notifier *store.PeerNotifier) *Peer {
stringID := messages.HashIDToString(id) p := &Peer{
return &Peer{
metrics: metrics, metrics: metrics,
log: log.WithField("peer_id", stringID), log: log.WithField("peer_id", id.String()),
idS: stringID, id: id,
idB: id,
conn: conn, conn: conn,
store: store, store: store,
notifier: notifier,
} }
return p
} }
// Work reads data from the connection // Work reads data from the connection
// It manages the protocol (healthcheck, transport, close). Read the message and determine the message type and handle // It manages the protocol (healthcheck, transport, close). Read the message and determine the message type and handle
// the message accordingly. // the message accordingly.
func (p *Peer) Work() { func (p *Peer) Work() {
p.peersListener = p.notifier.NewListener(p.sendPeersOnline, p.sendPeersWentOffline)
defer func() { defer func() {
p.notifier.RemoveListener(p.peersListener)
if err := p.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { if err := p.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
p.log.Errorf(errCloseConn, err) p.log.Errorf(errCloseConn, err)
} }
@@ -94,6 +104,10 @@ func (p *Peer) Work() {
} }
} }
func (p *Peer) ID() messages.PeerID {
return p.id
}
func (p *Peer) handleMsgType(ctx context.Context, msgType messages.MsgType, hc *healthcheck.Sender, n int, msg []byte) { func (p *Peer) handleMsgType(ctx context.Context, msgType messages.MsgType, hc *healthcheck.Sender, n int, msg []byte) {
switch msgType { switch msgType {
case messages.MsgTypeHealthCheck: case messages.MsgTypeHealthCheck:
@@ -107,6 +121,10 @@ func (p *Peer) handleMsgType(ctx context.Context, msgType messages.MsgType, hc *
if err := p.conn.Close(); err != nil { if err := p.conn.Close(); err != nil {
log.Errorf(errCloseConn, err) log.Errorf(errCloseConn, err)
} }
case messages.MsgTypeSubscribePeerState:
p.handleSubscribePeerState(msg)
case messages.MsgTypeUnsubscribePeerState:
p.handleUnsubscribePeerState(msg)
default: default:
p.log.Warnf("received unexpected message type: %s", msgType) p.log.Warnf("received unexpected message type: %s", msgType)
} }
@@ -145,7 +163,7 @@ func (p *Peer) Close() {
// String returns the peer ID // String returns the peer ID
func (p *Peer) String() string { func (p *Peer) String() string {
return p.idS return p.id.String()
} }
func (p *Peer) writeWithTimeout(ctx context.Context, buf []byte) error { func (p *Peer) writeWithTimeout(ctx context.Context, buf []byte) error {
@@ -197,14 +215,14 @@ func (p *Peer) handleTransportMsg(msg []byte) {
return return
} }
stringPeerID := messages.HashIDToString(peerID) item, ok := p.store.Peer(*peerID)
dp, ok := p.store.Peer(stringPeerID)
if !ok { if !ok {
p.log.Debugf("peer not found: %s", stringPeerID) p.log.Debugf("peer not found: %s", peerID)
return return
} }
dp := item.(*Peer)
err = messages.UpdateTransportMsg(msg, p.idB) err = messages.UpdateTransportMsg(msg, p.id)
if err != nil { if err != nil {
p.log.Errorf("failed to update transport message: %s", err) p.log.Errorf("failed to update transport message: %s", err)
return return
@@ -217,3 +235,66 @@ func (p *Peer) handleTransportMsg(msg []byte) {
} }
p.metrics.TransferBytesSent.Add(context.Background(), int64(n)) p.metrics.TransferBytesSent.Add(context.Background(), int64(n))
} }
func (p *Peer) handleSubscribePeerState(msg []byte) {
peerIDs, err := messages.UnmarshalSubPeerStateMsg(msg)
if err != nil {
p.log.Errorf("failed to unmarshal open connection message: %s", err)
return
}
p.log.Debugf("received subscription message for %d peers", len(peerIDs))
// collect online peers to response back to the caller
p.notificationMutex.Lock()
defer p.notificationMutex.Unlock()
onlinePeers := p.store.GetOnlinePeersAndRegisterInterest(peerIDs, p.peersListener)
if len(onlinePeers) == 0 {
return
}
p.log.Debugf("response with %d online peers", len(onlinePeers))
p.sendPeersOnline(onlinePeers)
}
func (p *Peer) handleUnsubscribePeerState(msg []byte) {
peerIDs, err := messages.UnmarshalUnsubPeerStateMsg(msg)
if err != nil {
p.log.Errorf("failed to unmarshal open connection message: %s", err)
return
}
p.peersListener.RemoveInterestedPeer(peerIDs)
}
func (p *Peer) sendPeersOnline(peers []messages.PeerID) {
msgs, err := messages.MarshalPeersOnline(peers)
if err != nil {
p.log.Errorf("failed to marshal peer location message: %s", err)
return
}
for n, msg := range msgs {
if _, err := p.Write(msg); err != nil {
p.log.Errorf("failed to write %d. peers offline message: %s", n, err)
}
}
}
func (p *Peer) sendPeersWentOffline(peers []messages.PeerID) {
p.notificationMutex.Lock()
defer p.notificationMutex.Unlock()
msgs, err := messages.MarshalPeersWentOffline(peers)
if err != nil {
p.log.Errorf("failed to marshal peer location message: %s", err)
return
}
for n, msg := range msgs {
if _, err := p.Write(msg); err != nil {
p.log.Errorf("failed to write %d. peers offline message: %s", n, err)
}
}
}

View File

@@ -4,26 +4,55 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"net/url"
"strings"
"sync" "sync"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/metric" "go.opentelemetry.io/otel/metric"
"github.com/netbirdio/netbird/relay/auth"
//nolint:staticcheck //nolint:staticcheck
"github.com/netbirdio/netbird/relay/metrics" "github.com/netbirdio/netbird/relay/metrics"
"github.com/netbirdio/netbird/relay/server/store"
) )
type Config struct {
Meter metric.Meter
ExposedAddress string
TLSSupport bool
AuthValidator Validator
instanceURL string
}
func (c *Config) validate() error {
if c.Meter == nil {
c.Meter = otel.Meter("")
}
if c.ExposedAddress == "" {
return fmt.Errorf("exposed address is required")
}
instanceURL, err := getInstanceURL(c.ExposedAddress, c.TLSSupport)
if err != nil {
return fmt.Errorf("invalid url: %v", err)
}
c.instanceURL = instanceURL
if c.AuthValidator == nil {
return fmt.Errorf("auth validator is required")
}
return nil
}
// Relay represents the relay server // Relay represents the relay server
type Relay struct { type Relay struct {
metrics *metrics.Metrics metrics *metrics.Metrics
metricsCancel context.CancelFunc metricsCancel context.CancelFunc
validator auth.Validator validator Validator
store *Store store *store.Store
notifier *store.PeerNotifier
instanceURL string instanceURL string
preparedMsg *preparedMsg preparedMsg *preparedMsg
@@ -31,24 +60,27 @@ type Relay struct {
closeMu sync.RWMutex closeMu sync.RWMutex
} }
// NewRelay creates a new Relay instance // NewRelay creates and returns a new Relay instance.
// //
// Parameters: // Parameters:
// meter: An instance of metric.Meter from the go.opentelemetry.io/otel/metric package. It is used to create and manage //
// metrics for the relay server. // config: A Config struct that holds the configuration needed to initialize the relay server.
// exposedAddress: A string representing the address that the relay server is exposed on. The client will use this // - Meter: A metric.Meter used for emitting metrics. If not set, a default no-op meter will be used.
// address as the relay server's instance URL. // - ExposedAddress: The external address clients use to reach this relay. Required.
// tlsSupport: A boolean indicating whether the relay server supports TLS (Transport Layer Security) or not. The // - TLSSupport: A boolean indicating if the relay uses TLS. Affects the generated instance URL.
// instance URL depends on this value. // - AuthValidator: A Validator implementation used to authenticate peers. Required.
// validator: An instance of auth.Validator from the auth package. It is used to validate the authentication of the
// peers.
// //
// Returns: // Returns:
// A pointer to a Relay instance and an error. If the Relay instance is successfully created, the error is nil. //
// Otherwise, the error contains the details of what went wrong. // A pointer to a Relay instance and an error. If initialization is successful, the error will be nil;
func NewRelay(meter metric.Meter, exposedAddress string, tlsSupport bool, validator auth.Validator) (*Relay, error) { // otherwise, it will contain the reason the relay could not be created (e.g., invalid configuration).
func NewRelay(config Config) (*Relay, error) {
if err := config.validate(); err != nil {
return nil, fmt.Errorf("invalid config: %v", err)
}
ctx, metricsCancel := context.WithCancel(context.Background()) ctx, metricsCancel := context.WithCancel(context.Background())
m, err := metrics.NewMetrics(ctx, meter) m, err := metrics.NewMetrics(ctx, config.Meter)
if err != nil { if err != nil {
metricsCancel() metricsCancel()
return nil, fmt.Errorf("creating app metrics: %v", err) return nil, fmt.Errorf("creating app metrics: %v", err)
@@ -57,14 +89,10 @@ func NewRelay(meter metric.Meter, exposedAddress string, tlsSupport bool, valida
r := &Relay{ r := &Relay{
metrics: m, metrics: m,
metricsCancel: metricsCancel, metricsCancel: metricsCancel,
validator: validator, validator: config.AuthValidator,
store: NewStore(), instanceURL: config.instanceURL,
} store: store.NewStore(),
notifier: store.NewPeerNotifier(),
r.instanceURL, err = getInstanceURL(exposedAddress, tlsSupport)
if err != nil {
metricsCancel()
return nil, fmt.Errorf("get instance URL: %v", err)
} }
r.preparedMsg, err = newPreparedMsg(r.instanceURL) r.preparedMsg, err = newPreparedMsg(r.instanceURL)
@@ -76,32 +104,6 @@ func NewRelay(meter metric.Meter, exposedAddress string, tlsSupport bool, valida
return r, nil return r, nil
} }
// getInstanceURL checks if user supplied a URL scheme otherwise adds to the
// provided address according to TLS definition and parses the address before returning it
func getInstanceURL(exposedAddress string, tlsSupported bool) (string, error) {
addr := exposedAddress
split := strings.Split(exposedAddress, "://")
switch {
case len(split) == 1 && tlsSupported:
addr = "rels://" + exposedAddress
case len(split) == 1 && !tlsSupported:
addr = "rel://" + exposedAddress
case len(split) > 2:
return "", fmt.Errorf("invalid exposed address: %s", exposedAddress)
}
parsedURL, err := url.ParseRequestURI(addr)
if err != nil {
return "", fmt.Errorf("invalid exposed address: %v", err)
}
if parsedURL.Scheme != "rel" && parsedURL.Scheme != "rels" {
return "", fmt.Errorf("invalid scheme: %s", parsedURL.Scheme)
}
return parsedURL.String(), nil
}
// Accept start to handle a new peer connection // Accept start to handle a new peer connection
func (r *Relay) Accept(conn net.Conn) { func (r *Relay) Accept(conn net.Conn) {
acceptTime := time.Now() acceptTime := time.Now()
@@ -125,15 +127,21 @@ func (r *Relay) Accept(conn net.Conn) {
return return
} }
peer := NewPeer(r.metrics, peerID, conn, r.store) peer := NewPeer(r.metrics, *peerID, conn, r.store, r.notifier)
peer.log.Infof("peer connected from: %s", conn.RemoteAddr()) peer.log.Infof("peer connected from: %s", conn.RemoteAddr())
storeTime := time.Now() storeTime := time.Now()
r.store.AddPeer(peer) if isReconnection := r.store.AddPeer(peer); isReconnection {
r.metrics.RecordPeerReconnection()
}
r.notifier.PeerCameOnline(peer.ID())
r.metrics.RecordPeerStoreTime(time.Since(storeTime)) r.metrics.RecordPeerStoreTime(time.Since(storeTime))
r.metrics.PeerConnected(peer.String()) r.metrics.PeerConnected(peer.String())
go func() { go func() {
peer.Work() peer.Work()
r.store.DeletePeer(peer) if deleted := r.store.DeletePeer(peer); deleted {
r.notifier.PeerWentOffline(peer.ID())
}
peer.log.Debugf("relay connection closed") peer.log.Debugf("relay connection closed")
r.metrics.PeerDisconnected(peer.String()) r.metrics.PeerDisconnected(peer.String())
}() }()
@@ -154,12 +162,12 @@ func (r *Relay) Shutdown(ctx context.Context) {
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
peers := r.store.Peers() peers := r.store.Peers()
for _, peer := range peers { for _, v := range peers {
wg.Add(1) wg.Add(1)
go func(p *Peer) { go func(p *Peer) {
p.CloseGracefully(ctx) p.CloseGracefully(ctx)
wg.Done() wg.Done()
}(peer) }(v.(*Peer))
} }
wg.Wait() wg.Wait()
r.metricsCancel() r.metricsCancel()

View File

@@ -6,15 +6,12 @@ import (
"sync" "sync"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/metric"
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/relay/auth"
"github.com/netbirdio/netbird/relay/server/listener" "github.com/netbirdio/netbird/relay/server/listener"
"github.com/netbirdio/netbird/relay/server/listener/quic" "github.com/netbirdio/netbird/relay/server/listener/quic"
"github.com/netbirdio/netbird/relay/server/listener/ws" "github.com/netbirdio/netbird/relay/server/listener/ws"
quictls "github.com/netbirdio/netbird/relay/tls" quictls "github.com/netbirdio/netbird/relay/tls"
log "github.com/sirupsen/logrus"
) )
// ListenerConfig is the configuration for the listener. // ListenerConfig is the configuration for the listener.
@@ -33,13 +30,22 @@ type Server struct {
listeners []listener.Listener listeners []listener.Listener
} }
// NewServer creates a new relay server instance. // NewServer creates and returns a new relay server instance.
// meter: the OpenTelemetry meter //
// exposedAddress: this address will be used as the instance URL. It should be a domain:port format. // Parameters:
// tlsSupport: if true, the server will support TLS //
// authValidator: the auth validator to use for the server // config: A Config struct containing the necessary configuration:
func NewServer(meter metric.Meter, exposedAddress string, tlsSupport bool, authValidator auth.Validator) (*Server, error) { // - Meter: An OpenTelemetry metric.Meter used for recording metrics. If nil, a default no-op meter is used.
relay, err := NewRelay(meter, exposedAddress, tlsSupport, authValidator) // - ExposedAddress: The public address (in domain:port format) used as the server's instance URL. Required.
// - TLSSupport: A boolean indicating whether TLS is enabled for the server.
// - AuthValidator: A Validator used to authenticate peers. Required.
//
// Returns:
//
// A pointer to a Server instance and an error. If the configuration is valid and initialization succeeds,
// the returned error will be nil. Otherwise, the error will describe the problem.
func NewServer(config Config) (*Server, error) {
relay, err := NewRelay(config)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -1,68 +0,0 @@
package server
import (
"sync"
)
// Store is a thread-safe store of peers
// It is used to store the peers that are connected to the relay server
type Store struct {
peers map[string]*Peer // consider to use [32]byte as key. The Peer(id string) would be faster
peersLock sync.RWMutex
}
// NewStore creates a new Store instance
func NewStore() *Store {
return &Store{
peers: make(map[string]*Peer),
}
}
// AddPeer adds a peer to the store
func (s *Store) AddPeer(peer *Peer) {
s.peersLock.Lock()
defer s.peersLock.Unlock()
odlPeer, ok := s.peers[peer.String()]
if ok {
odlPeer.Close()
}
s.peers[peer.String()] = peer
}
// DeletePeer deletes a peer from the store
func (s *Store) DeletePeer(peer *Peer) {
s.peersLock.Lock()
defer s.peersLock.Unlock()
dp, ok := s.peers[peer.String()]
if !ok {
return
}
if dp != peer {
return
}
delete(s.peers, peer.String())
}
// Peer returns a peer by its ID
func (s *Store) Peer(id string) (*Peer, bool) {
s.peersLock.RLock()
defer s.peersLock.RUnlock()
p, ok := s.peers[id]
return p, ok
}
// Peers returns all the peers in the store
func (s *Store) Peers() []*Peer {
s.peersLock.RLock()
defer s.peersLock.RUnlock()
peers := make([]*Peer, 0, len(s.peers))
for _, p := range s.peers {
peers = append(peers, p)
}
return peers
}

View File

@@ -0,0 +1,122 @@
package store
import (
"context"
"sync"
"github.com/netbirdio/netbird/relay/messages"
)
type event struct {
peerID messages.PeerID
online bool
}
type Listener struct {
ctx context.Context
eventChan chan *event
interestedPeersForOffline map[messages.PeerID]struct{}
interestedPeersForOnline map[messages.PeerID]struct{}
mu sync.RWMutex
}
func newListener(ctx context.Context) *Listener {
l := &Listener{
ctx: ctx,
// important to use a single channel for offline and online events because with it we can ensure all events
// will be processed in the order they were sent
eventChan: make(chan *event, 244), //244 is the message size limit in the relay protocol
interestedPeersForOffline: make(map[messages.PeerID]struct{}),
interestedPeersForOnline: make(map[messages.PeerID]struct{}),
}
return l
}
func (l *Listener) AddInterestedPeers(peerIDs []messages.PeerID) {
l.mu.Lock()
defer l.mu.Unlock()
for _, id := range peerIDs {
l.interestedPeersForOnline[id] = struct{}{}
l.interestedPeersForOffline[id] = struct{}{}
}
}
func (l *Listener) RemoveInterestedPeer(peerIDs []messages.PeerID) {
l.mu.Lock()
defer l.mu.Unlock()
for _, id := range peerIDs {
delete(l.interestedPeersForOffline, id)
delete(l.interestedPeersForOnline, id)
}
}
func (l *Listener) listenForEvents(onPeersComeOnline, onPeersWentOffline func([]messages.PeerID)) {
for {
select {
case <-l.ctx.Done():
return
case e := <-l.eventChan:
peersOffline := make([]messages.PeerID, 0)
peersOnline := make([]messages.PeerID, 0)
if e.online {
peersOnline = append(peersOnline, e.peerID)
} else {
peersOffline = append(peersOffline, e.peerID)
}
// Drain the channel to collect all events
for len(l.eventChan) > 0 {
e = <-l.eventChan
if e.online {
peersOnline = append(peersOnline, e.peerID)
} else {
peersOffline = append(peersOffline, e.peerID)
}
}
if len(peersOnline) > 0 {
onPeersComeOnline(peersOnline)
}
if len(peersOffline) > 0 {
onPeersWentOffline(peersOffline)
}
}
}
}
func (l *Listener) peerWentOffline(peerID messages.PeerID) {
l.mu.RLock()
defer l.mu.RUnlock()
if _, ok := l.interestedPeersForOffline[peerID]; ok {
select {
case l.eventChan <- &event{
peerID: peerID,
online: false,
}:
case <-l.ctx.Done():
}
}
}
func (l *Listener) peerComeOnline(peerID messages.PeerID) {
l.mu.Lock()
defer l.mu.Unlock()
if _, ok := l.interestedPeersForOnline[peerID]; ok {
select {
case l.eventChan <- &event{
peerID: peerID,
online: true,
}:
case <-l.ctx.Done():
}
delete(l.interestedPeersForOnline, peerID)
}
}

View File

@@ -0,0 +1,61 @@
package store
import (
"context"
"sync"
"github.com/netbirdio/netbird/relay/messages"
)
type PeerNotifier struct {
listeners map[*Listener]context.CancelFunc
listenersMutex sync.RWMutex
}
func NewPeerNotifier() *PeerNotifier {
pn := &PeerNotifier{
listeners: make(map[*Listener]context.CancelFunc),
}
return pn
}
func (pn *PeerNotifier) NewListener(onPeersComeOnline, onPeersWentOffline func([]messages.PeerID)) *Listener {
ctx, cancel := context.WithCancel(context.Background())
listener := newListener(ctx)
go listener.listenForEvents(onPeersComeOnline, onPeersWentOffline)
pn.listenersMutex.Lock()
pn.listeners[listener] = cancel
pn.listenersMutex.Unlock()
return listener
}
func (pn *PeerNotifier) RemoveListener(listener *Listener) {
pn.listenersMutex.Lock()
defer pn.listenersMutex.Unlock()
cancel, ok := pn.listeners[listener]
if !ok {
return
}
cancel()
delete(pn.listeners, listener)
}
func (pn *PeerNotifier) PeerWentOffline(peerID messages.PeerID) {
pn.listenersMutex.RLock()
defer pn.listenersMutex.RUnlock()
for listener := range pn.listeners {
listener.peerWentOffline(peerID)
}
}
func (pn *PeerNotifier) PeerCameOnline(peerID messages.PeerID) {
pn.listenersMutex.RLock()
defer pn.listenersMutex.RUnlock()
for listener := range pn.listeners {
listener.peerComeOnline(peerID)
}
}

View File

@@ -0,0 +1,97 @@
package store
import (
"sync"
"github.com/netbirdio/netbird/relay/messages"
)
type IPeer interface {
Close()
ID() messages.PeerID
}
// Store is a thread-safe store of peers
// It is used to store the peers that are connected to the relay server
type Store struct {
peers map[messages.PeerID]IPeer
peersLock sync.RWMutex
}
// NewStore creates a new Store instance
func NewStore() *Store {
return &Store{
peers: make(map[messages.PeerID]IPeer),
}
}
// AddPeer adds a peer to the store
// If the peer already exists, it will be replaced and the old peer will be closed
// Returns true if the peer was replaced, false if it was added for the first time.
func (s *Store) AddPeer(peer IPeer) bool {
s.peersLock.Lock()
defer s.peersLock.Unlock()
odlPeer, ok := s.peers[peer.ID()]
if ok {
odlPeer.Close()
}
s.peers[peer.ID()] = peer
return ok
}
// DeletePeer deletes a peer from the store
func (s *Store) DeletePeer(peer IPeer) bool {
s.peersLock.Lock()
defer s.peersLock.Unlock()
dp, ok := s.peers[peer.ID()]
if !ok {
return false
}
if dp != peer {
return false
}
delete(s.peers, peer.ID())
return true
}
// Peer returns a peer by its ID
func (s *Store) Peer(id messages.PeerID) (IPeer, bool) {
s.peersLock.RLock()
defer s.peersLock.RUnlock()
p, ok := s.peers[id]
return p, ok
}
// Peers returns all the peers in the store
func (s *Store) Peers() []IPeer {
s.peersLock.RLock()
defer s.peersLock.RUnlock()
peers := make([]IPeer, 0, len(s.peers))
for _, p := range s.peers {
peers = append(peers, p)
}
return peers
}
func (s *Store) GetOnlinePeersAndRegisterInterest(peerIDs []messages.PeerID, listener *Listener) []messages.PeerID {
s.peersLock.RLock()
defer s.peersLock.RUnlock()
onlinePeers := make([]messages.PeerID, 0, len(peerIDs))
listener.AddInterestedPeers(peerIDs)
// Check for currently online peers
for _, id := range peerIDs {
if _, ok := s.peers[id]; ok {
onlinePeers = append(onlinePeers, id)
}
}
return onlinePeers
}

View File

@@ -0,0 +1,49 @@
package store
import (
"testing"
"github.com/netbirdio/netbird/relay/messages"
)
type MocPeer struct {
id messages.PeerID
}
func (m *MocPeer) Close() {
}
func (m *MocPeer) ID() messages.PeerID {
return m.id
}
func TestStore_DeletePeer(t *testing.T) {
s := NewStore()
pID := messages.HashID("peer_one")
p := &MocPeer{id: pID}
s.AddPeer(p)
s.DeletePeer(p)
if _, ok := s.Peer(pID); ok {
t.Errorf("peer was not deleted")
}
}
func TestStore_DeleteDeprecatedPeer(t *testing.T) {
s := NewStore()
pID1 := messages.HashID("peer_one")
pID2 := messages.HashID("peer_one")
p1 := &MocPeer{id: pID1}
p2 := &MocPeer{id: pID2}
s.AddPeer(p1)
s.AddPeer(p2)
s.DeletePeer(p1)
if _, ok := s.Peer(pID2); !ok {
t.Errorf("second peer was deleted")
}
}

View File

@@ -1,85 +0,0 @@
package server
import (
"context"
"net"
"testing"
"time"
"go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/relay/metrics"
)
type mockConn struct {
}
func (m mockConn) Read(b []byte) (n int, err error) {
//TODO implement me
panic("implement me")
}
func (m mockConn) Write(b []byte) (n int, err error) {
//TODO implement me
panic("implement me")
}
func (m mockConn) Close() error {
return nil
}
func (m mockConn) LocalAddr() net.Addr {
//TODO implement me
panic("implement me")
}
func (m mockConn) RemoteAddr() net.Addr {
//TODO implement me
panic("implement me")
}
func (m mockConn) SetDeadline(t time.Time) error {
//TODO implement me
panic("implement me")
}
func (m mockConn) SetReadDeadline(t time.Time) error {
//TODO implement me
panic("implement me")
}
func (m mockConn) SetWriteDeadline(t time.Time) error {
//TODO implement me
panic("implement me")
}
func TestStore_DeletePeer(t *testing.T) {
s := NewStore()
m, _ := metrics.NewMetrics(context.Background(), otel.Meter(""))
p := NewPeer(m, []byte("peer_one"), nil, nil)
s.AddPeer(p)
s.DeletePeer(p)
if _, ok := s.Peer(p.String()); ok {
t.Errorf("peer was not deleted")
}
}
func TestStore_DeleteDeprecatedPeer(t *testing.T) {
s := NewStore()
m, _ := metrics.NewMetrics(context.Background(), otel.Meter(""))
conn := &mockConn{}
p1 := NewPeer(m, []byte("peer_id"), conn, nil)
p2 := NewPeer(m, []byte("peer_id"), conn, nil)
s.AddPeer(p1)
s.AddPeer(p2)
s.DeletePeer(p1)
if _, ok := s.Peer(p2.String()); !ok {
t.Errorf("second peer was deleted")
}
}

33
relay/server/url.go Normal file
View File

@@ -0,0 +1,33 @@
package server
import (
"fmt"
"net/url"
"strings"
)
// getInstanceURL checks if user supplied a URL scheme otherwise adds to the
// provided address according to TLS definition and parses the address before returning it
func getInstanceURL(exposedAddress string, tlsSupported bool) (string, error) {
addr := exposedAddress
split := strings.Split(exposedAddress, "://")
switch {
case len(split) == 1 && tlsSupported:
addr = "rels://" + exposedAddress
case len(split) == 1 && !tlsSupported:
addr = "rel://" + exposedAddress
case len(split) > 2:
return "", fmt.Errorf("invalid exposed address: %s", exposedAddress)
}
parsedURL, err := url.ParseRequestURI(addr)
if err != nil {
return "", fmt.Errorf("invalid exposed address: %v", err)
}
if parsedURL.Scheme != "rel" && parsedURL.Scheme != "rels" {
return "", fmt.Errorf("invalid scheme: %s", parsedURL.Scheme)
}
return parsedURL.String(), nil
}

Some files were not shown because too many files have changed in this diff Show More