Compare commits

...

8 Commits

Author SHA1 Message Date
Viktor Liu
036a3020fe Batch wireguard update operations 2025-07-22 14:44:26 +02:00
Zoltan Papp
86c16cf651 [server, relay] Fix/relay race disconnection (#4174)
Avoid invalid disconnection notifications in case the closed race dials.
In this PR resolve multiple race condition questions. Easier to understand the fix based on commit by commit.

- Remove store dependency from notifier
- Enforce the notification orders
- Fix invalid disconnection notification
- Ensure the order of the events on the consumer side
2025-07-21 19:58:17 +02:00
Bethuel Mmbaga
a7af15c4fc [management] Fix group resource count mismatch in policy (#4182) 2025-07-21 15:26:06 +03:00
Viktor Liu
d6ed9c037e [client] Fix bind exclusion routes (#4154) 2025-07-21 12:13:21 +02:00
Ali Amer
40fdeda838 [client] add new filter-by-connection-type flag (#4010)
introduces a new flag --filter-by-connection-type to the status command.
It allows users to filter peers by connection type (P2P or Relayed) in both JSON and detailed views.

Input validation is added in parseFilters() to ensure proper usage, and --detail is auto-enabled if no output format is specified (consistent with other filters).
2025-07-21 11:55:17 +02:00
Zoltan Papp
f6e9d755e4 [client, relay] The openConn function no longer blocks the relayAddress function call (#4180)
The openConn function no longer blocks the relayAddress function call in manager layer
2025-07-21 09:46:53 +02:00
Maycon Santos
08fd460867 [management] Add validate flow response (#4172)
This PR adds a validate flow response feature to the management server by integrating an IntegratedValidator component. The main purpose is to enable validation of PKCE authorization flows through an integrated validator interface.

- Adds a new ValidateFlowResponse method to the IntegratedValidator interface
- Integrates the validator into the management server to validate PKCE authorization flows
- Updates dependency version for management-integrations
2025-07-18 12:18:52 +02:00
Pascal Fischer
4f74509d55 [management] fix index creation if exist on mysql (#4150) 2025-07-16 15:07:31 +02:00
75 changed files with 1059 additions and 367 deletions

View File

@@ -211,7 +211,11 @@ jobs:
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
arch: [ '386','amd64' ] include:
- arch: "386"
raceFlag: ""
- arch: "amd64"
raceFlag: "-race"
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
- name: Install Go - name: Install Go
@@ -251,9 +255,9 @@ jobs:
- name: Test - name: Test
run: | run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \ CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
go test \ go test ${{ matrix.raceFlag }} \
-exec 'sudo' \ -exec 'sudo' \
-timeout 10m ./signal/... -timeout 10m ./relay/...
test_signal: test_signal:
name: "Signal / Unit" name: "Signal / Unit"

View File

@@ -307,7 +307,7 @@ func getStatusOutput(cmd *cobra.Command, anon bool) string {
cmd.PrintErrf("Failed to get status: %v\n", err) cmd.PrintErrf("Failed to get status: %v\n", err)
} else { } else {
statusOutputString = nbstatus.ParseToFullDetailSummary( statusOutputString = nbstatus.ParseToFullDetailSummary(
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil), nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, ""),
) )
} }
return statusOutputString return statusOutputString

View File

@@ -26,6 +26,7 @@ var (
statusFilter string statusFilter string
ipsFilterMap map[string]struct{} ipsFilterMap map[string]struct{}
prefixNamesFilterMap map[string]struct{} prefixNamesFilterMap map[string]struct{}
connectionTypeFilter string
) )
var statusCmd = &cobra.Command{ var statusCmd = &cobra.Command{
@@ -45,6 +46,7 @@ func init() {
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200") statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200")
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud") statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected") statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
statusCmd.PersistentFlags().StringVar(&connectionTypeFilter, "filter-by-connection-type", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P")
} }
func statusFunc(cmd *cobra.Command, args []string) error { func statusFunc(cmd *cobra.Command, args []string) error {
@@ -89,7 +91,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return nil return nil
} }
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap) var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter)
var statusOutputString string var statusOutputString string
switch { switch {
case detailFlag: case detailFlag:
@@ -156,6 +158,15 @@ func parseFilters() error {
enableDetailFlagWhenFilterFlag() enableDetailFlagWhenFilterFlag()
} }
switch strings.ToLower(connectionTypeFilter) {
case "", "p2p", "relayed":
if strings.ToLower(connectionTypeFilter) != "" {
enableDetailFlagWhenFilterFlag()
}
default:
return fmt.Errorf("wrong connection-type filter, should be one of P2P|Relayed, got: %s", connectionTypeFilter)
}
return nil return nil
} }

View File

@@ -109,7 +109,7 @@ func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc
} }
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager) secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil) mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &mgmt.MockIntegratedValidator{})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

338
client/iface/batcher.go Normal file
View File

@@ -0,0 +1,338 @@
package iface
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"os"
"strconv"
"sync"
"time"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
)
const (
// DefaultBatchFlushInterval is the default maximum time to wait before flushing batched operations
DefaultBatchFlushInterval = 300 * time.Millisecond
// DefaultBatchSizeThreshold is the default number of operations to trigger an immediate flush
DefaultBatchSizeThreshold = 100
// AllowedIPOpAdd represents an add operation
AllowedIPOpAdd = "add"
// AllowedIPOpRemove represents a remove operation
AllowedIPOpRemove = "remove"
EnvDisableWGBatching = "NB_DISABLE_WG_BATCHING"
EnvWGBatchFlushIntervalMS = "NB_WG_BATCH_FLUSH_INTERVAL_MS"
EnvWGBatchSizeThreshold = "NB_WG_BATCH_SIZE_THRESHOLD"
)
// AllowedIPOperation represents a pending allowed IP operation
type AllowedIPOperation struct {
PeerKey string
Prefix netip.Prefix
Operation string
}
// PeerUpdateOperation represents a pending peer update operation
type PeerUpdateOperation struct {
PeerKey string
AllowedIPs []netip.Prefix
KeepAlive time.Duration
Endpoint *net.UDPAddr
PreSharedKey *wgtypes.Key
}
// WGBatcher batches WireGuard configuration updates to reduce syscall overhead
type WGBatcher struct {
configurer device.WGConfigurer
mu sync.Mutex
allowedIPOps []AllowedIPOperation
peerUpdates map[string]*PeerUpdateOperation
flushTimer *time.Timer
flushChan chan struct{}
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
batchFlushInterval time.Duration
batchSizeThreshold int
}
// NewWGBatcher creates a new WireGuard operation batcher
func NewWGBatcher(configurer device.WGConfigurer) *WGBatcher {
if os.Getenv(EnvDisableWGBatching) != "" {
log.Infof("WireGuard allowed IP batching disabled via %s", EnvDisableWGBatching)
return nil
}
flushInterval := DefaultBatchFlushInterval
sizeThreshold := DefaultBatchSizeThreshold
if intervalMs := os.Getenv(EnvWGBatchFlushIntervalMS); intervalMs != "" {
if ms, err := strconv.Atoi(intervalMs); err == nil && ms > 0 {
flushInterval = time.Duration(ms) * time.Millisecond
log.Infof("WireGuard batch flush interval set to %v", flushInterval)
}
}
if threshold := os.Getenv(EnvWGBatchSizeThreshold); threshold != "" {
if size, err := strconv.Atoi(threshold); err == nil && size > 0 {
sizeThreshold = size
log.Infof("WireGuard batch size threshold set to %d", sizeThreshold)
}
}
log.Info("WireGuard allowed IP batching enabled")
ctx, cancel := context.WithCancel(context.Background())
b := &WGBatcher{
configurer: configurer,
peerUpdates: make(map[string]*PeerUpdateOperation),
flushChan: make(chan struct{}, 1),
ctx: ctx,
cancel: cancel,
batchFlushInterval: flushInterval,
batchSizeThreshold: sizeThreshold,
}
b.wg.Add(1)
go b.flushLoop()
return b
}
// Close stops the batcher and flushes any pending operations
func (b *WGBatcher) Close() error {
b.mu.Lock()
if b.flushTimer != nil {
b.flushTimer.Stop()
}
b.mu.Unlock()
b.cancel()
if err := b.Flush(); err != nil {
log.Errorf("failed to flush pending operations on close: %v", err)
}
b.wg.Wait()
return nil
}
// UpdatePeer batches a peer update operation
func (b *WGBatcher) UpdatePeer(peerKey string, allowedIPs []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
b.mu.Lock()
defer b.mu.Unlock()
b.peerUpdates[peerKey] = &PeerUpdateOperation{
PeerKey: peerKey,
AllowedIPs: allowedIPs,
KeepAlive: keepAlive,
Endpoint: endpoint,
PreSharedKey: preSharedKey,
}
b.scheduleFlush()
return nil
}
// AddAllowedIP batches an allowed IP addition
func (b *WGBatcher) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
b.mu.Lock()
defer b.mu.Unlock()
b.allowedIPOps = append(b.allowedIPOps, AllowedIPOperation{
PeerKey: peerKey,
Prefix: allowedIP,
Operation: AllowedIPOpAdd,
})
b.scheduleFlush()
return nil
}
// RemoveAllowedIP batches an allowed IP removal
func (b *WGBatcher) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error {
b.mu.Lock()
defer b.mu.Unlock()
b.allowedIPOps = append(b.allowedIPOps, AllowedIPOperation{
PeerKey: peerKey,
Prefix: allowedIP,
Operation: AllowedIPOpRemove,
})
b.scheduleFlush()
return nil
}
// Flush immediately processes all batched operations
func (b *WGBatcher) Flush() error {
b.mu.Lock()
if b.flushTimer != nil {
b.flushTimer.Stop()
b.flushTimer = nil
}
peerUpdates := b.peerUpdates
allowedIPOps := b.allowedIPOps
b.peerUpdates = make(map[string]*PeerUpdateOperation)
b.allowedIPOps = nil
b.mu.Unlock()
return b.processBatch(peerUpdates, allowedIPOps)
}
// scheduleFlush schedules a batch flush if not already scheduled
func (b *WGBatcher) scheduleFlush() {
shouldFlushNow := len(b.allowedIPOps)+len(b.peerUpdates) >= b.batchSizeThreshold
if shouldFlushNow {
select {
case b.flushChan <- struct{}{}:
default:
}
return
}
if b.flushTimer == nil {
b.flushTimer = time.AfterFunc(b.batchFlushInterval, func() {
select {
case b.flushChan <- struct{}{}:
default:
}
})
}
}
// flushLoop handles periodic flushing of batched operations
func (b *WGBatcher) flushLoop() {
defer b.wg.Done()
for {
select {
case <-b.flushChan:
if err := b.Flush(); err != nil {
log.Errorf("Error flushing WireGuard operations: %v", err)
}
case <-b.ctx.Done():
return
}
}
}
// processBatch processes a batch of operations
func (b *WGBatcher) processBatch(peerUpdates map[string]*PeerUpdateOperation, allowedIPOps []AllowedIPOperation) error {
if len(peerUpdates) == 0 && len(allowedIPOps) == 0 {
return nil
}
start := time.Now()
defer func() {
duration := time.Since(start)
log.Debugf("Processed batch of %d peer updates and %d allowed IP operations in %v",
len(peerUpdates), len(allowedIPOps), duration)
}()
var merr *multierror.Error
if err := b.processPeerUpdates(peerUpdates); err != nil {
merr = multierror.Append(merr, err)
}
if err := b.processAllowedIPOps(allowedIPOps); err != nil {
merr = multierror.Append(merr, err)
}
return nberrors.FormatErrorOrNil(merr)
}
// processPeerUpdates processes peer update operations
func (b *WGBatcher) processPeerUpdates(peerUpdates map[string]*PeerUpdateOperation) error {
var merr *multierror.Error
for _, update := range peerUpdates {
if err := b.configurer.UpdatePeer(
update.PeerKey,
update.AllowedIPs,
update.KeepAlive,
update.Endpoint,
update.PreSharedKey,
); err != nil {
merr = multierror.Append(merr, fmt.Errorf("update peer %s: %w", update.PeerKey, err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
// processAllowedIPOps processes allowed IP add/remove operations
func (b *WGBatcher) processAllowedIPOps(allowedIPOps []AllowedIPOperation) error {
peerChanges := b.groupAllowedIPChanges(allowedIPOps)
return b.applyAllowedIPChanges(peerChanges)
}
// groupAllowedIPChanges groups allowed IP operations by peer
func (b *WGBatcher) groupAllowedIPChanges(allowedIPOps []AllowedIPOperation) map[string]struct {
toAdd []netip.Prefix
toRemove []netip.Prefix
} {
peerChanges := make(map[string]struct {
toAdd []netip.Prefix
toRemove []netip.Prefix
})
for _, op := range allowedIPOps {
changes := peerChanges[op.PeerKey]
if op.Operation == AllowedIPOpAdd {
changes.toAdd = append(changes.toAdd, op.Prefix)
} else {
changes.toRemove = append(changes.toRemove, op.Prefix)
}
peerChanges[op.PeerKey] = changes
}
return peerChanges
}
// applyAllowedIPChanges applies allowed IP changes for each peer
func (b *WGBatcher) applyAllowedIPChanges(peerChanges map[string]struct {
toAdd []netip.Prefix
toRemove []netip.Prefix
}) error {
var merr *multierror.Error
for peerKey, changes := range peerChanges {
for _, prefix := range changes.toRemove {
if err := b.configurer.RemoveAllowedIP(peerKey, prefix); err != nil {
if errors.Is(err, configurer.ErrPeerNotFound) || errors.Is(err, configurer.ErrAllowedIPNotFound) {
log.Debugf("remove allowed IP %s for peer %s: %v", prefix, peerKey, err)
} else {
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s for peer %s: %w", prefix, peerKey, err))
}
}
}
for _, prefix := range changes.toAdd {
if err := b.configurer.AddAllowedIP(peerKey, prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s for peer %s: %w", prefix, peerKey, err))
}
}
}
return nberrors.FormatErrorOrNil(merr)
}

View File

@@ -0,0 +1,15 @@
package bind
import (
wireguard "golang.zx2c4.com/wireguard/conn"
nbnet "github.com/netbirdio/netbird/util/net"
)
// TODO: This is most likely obsolete since the control fns should be called by the wrapped udpconn (ice_bind.go)
func init() {
listener := nbnet.NewListener()
if listener.ListenConfig.Control != nil {
*wireguard.ControlFns = append(*wireguard.ControlFns, listener.ListenConfig.Control)
}
}

View File

@@ -1,12 +0,0 @@
package bind
import (
wireguard "golang.zx2c4.com/wireguard/conn"
nbnet "github.com/netbirdio/netbird/util/net"
)
func init() {
// ControlFns is not thread safe and should only be modified during init.
*wireguard.ControlFns = append(*wireguard.ControlFns, nbnet.ControlProtectSocket)
}

View File

@@ -16,6 +16,7 @@ import (
wgConn "golang.zx2c4.com/wireguard/conn" wgConn "golang.zx2c4.com/wireguard/conn"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
nbnet "github.com/netbirdio/netbird/util/net"
) )
type RecvMessage struct { type RecvMessage struct {
@@ -153,7 +154,7 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
s.udpMux = NewUniversalUDPMuxDefault( s.udpMux = NewUniversalUDPMuxDefault(
UniversalUDPMuxParams{ UniversalUDPMuxParams{
UDPConn: conn, UDPConn: nbnet.WrapUDPConn(conn),
Net: s.transportNet, Net: s.transportNet,
FilterFn: s.filterFn, FilterFn: s.filterFn,
WGAddress: s.address, WGAddress: s.address,

View File

@@ -296,14 +296,20 @@ func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
return return
} }
m.addressMapMu.Lock() var allAddresses []string
defer m.addressMapMu.Unlock()
for _, c := range removedConns { for _, c := range removedConns {
addresses := c.getAddresses() addresses := c.getAddresses()
for _, addr := range addresses { allAddresses = append(allAddresses, addresses...)
delete(m.addressMap, addr) }
}
m.addressMapMu.Lock()
for _, addr := range allAddresses {
delete(m.addressMap, addr)
}
m.addressMapMu.Unlock()
for _, addr := range allAddresses {
m.notifyAddressRemoval(addr)
} }
} }
@@ -351,14 +357,13 @@ func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string)
} }
m.addressMapMu.Lock() m.addressMapMu.Lock()
defer m.addressMapMu.Unlock()
existing, ok := m.addressMap[addr] existing, ok := m.addressMap[addr]
if !ok { if !ok {
existing = []*udpMuxedConn{} existing = []*udpMuxedConn{}
} }
existing = append(existing, conn) existing = append(existing, conn)
m.addressMap[addr] = existing m.addressMap[addr] = existing
m.addressMapMu.Unlock()
log.Debugf("ICE: registered %s for %s", addr, conn.params.Key) log.Debugf("ICE: registered %s for %s", addr, conn.params.Key)
} }
@@ -386,12 +391,12 @@ func (m *UDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.Addr) erro
// If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one // If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one
// muxed connection - one for the SRFLX candidate and the other one for the HOST one. // muxed connection - one for the SRFLX candidate and the other one for the HOST one.
// We will then forward STUN packets to each of these connections. // We will then forward STUN packets to each of these connections.
m.addressMapMu.Lock() m.addressMapMu.RLock()
var destinationConnList []*udpMuxedConn var destinationConnList []*udpMuxedConn
if storedConns, ok := m.addressMap[addr.String()]; ok { if storedConns, ok := m.addressMap[addr.String()]; ok {
destinationConnList = append(destinationConnList, storedConns...) destinationConnList = append(destinationConnList, storedConns...)
} }
m.addressMapMu.Unlock() m.addressMapMu.RUnlock()
var isIPv6 bool var isIPv6 bool
if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil { if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil {

View File

@@ -0,0 +1,21 @@
//go:build !ios
package bind
import (
nbnet "github.com/netbirdio/netbird/util/net"
)
func (m *UDPMuxDefault) notifyAddressRemoval(addr string) {
wrapped, ok := m.params.UDPConn.(*UDPConn)
if !ok {
return
}
nbnetConn, ok := wrapped.GetPacketConn().(*nbnet.UDPConn)
if !ok {
return
}
nbnetConn.RemoveAddress(addr)
}

View File

@@ -0,0 +1,7 @@
//go:build ios
package bind
func (m *UDPMuxDefault) notifyAddressRemoval(addr string) {
// iOS doesn't support nbnet hooks, so this is a no-op
}

View File

@@ -62,7 +62,7 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
// wrap UDP connection, process server reflexive messages // wrap UDP connection, process server reflexive messages
// before they are passed to the UDPMux connection handler (connWorker) // before they are passed to the UDPMux connection handler (connWorker)
m.params.UDPConn = &udpConn{ m.params.UDPConn = &UDPConn{
PacketConn: params.UDPConn, PacketConn: params.UDPConn,
mux: m, mux: m,
logger: params.Logger, logger: params.Logger,
@@ -70,7 +70,6 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
address: params.WGAddress, address: params.WGAddress,
} }
// embed UDPMux
udpMuxParams := UDPMuxParams{ udpMuxParams := UDPMuxParams{
Logger: params.Logger, Logger: params.Logger,
UDPConn: m.params.UDPConn, UDPConn: m.params.UDPConn,
@@ -114,8 +113,8 @@ func (m *UniversalUDPMuxDefault) ReadFromConn(ctx context.Context) {
} }
} }
// udpConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets // UDPConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets
type udpConn struct { type UDPConn struct {
net.PacketConn net.PacketConn
mux *UniversalUDPMuxDefault mux *UniversalUDPMuxDefault
logger logging.LeveledLogger logger logging.LeveledLogger
@@ -125,7 +124,12 @@ type udpConn struct {
address wgaddr.Address address wgaddr.Address
} }
func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) { // GetPacketConn returns the underlying PacketConn
func (u *UDPConn) GetPacketConn() net.PacketConn {
return u.PacketConn
}
func (u *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
if u.filterFn == nil { if u.filterFn == nil {
return u.PacketConn.WriteTo(b, addr) return u.PacketConn.WriteTo(b, addr)
} }
@@ -137,21 +141,21 @@ func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) {
return u.handleUncachedAddress(b, addr) return u.handleUncachedAddress(b, addr)
} }
func (u *udpConn) handleCachedAddress(isRouted bool, b []byte, addr net.Addr) (int, error) { func (u *UDPConn) handleCachedAddress(isRouted bool, b []byte, addr net.Addr) (int, error) {
if isRouted { if isRouted {
return 0, fmt.Errorf("address %s is part of a routed network, refusing to write", addr) return 0, fmt.Errorf("address %s is part of a routed network, refusing to write", addr)
} }
return u.PacketConn.WriteTo(b, addr) return u.PacketConn.WriteTo(b, addr)
} }
func (u *udpConn) handleUncachedAddress(b []byte, addr net.Addr) (int, error) { func (u *UDPConn) handleUncachedAddress(b []byte, addr net.Addr) (int, error) {
if err := u.performFilterCheck(addr); err != nil { if err := u.performFilterCheck(addr); err != nil {
return 0, err return 0, err
} }
return u.PacketConn.WriteTo(b, addr) return u.PacketConn.WriteTo(b, addr)
} }
func (u *udpConn) performFilterCheck(addr net.Addr) error { func (u *UDPConn) performFilterCheck(addr net.Addr) error {
host, err := getHostFromAddr(addr) host, err := getHostFromAddr(addr)
if err != nil { if err != nil {
log.Errorf("Failed to get host from address %s: %v", addr, err) log.Errorf("Failed to get host from address %s: %v", addr, err)

View File

@@ -59,6 +59,7 @@ type WGIface struct {
mu sync.Mutex mu sync.Mutex
configurer device.WGConfigurer configurer device.WGConfigurer
batcher *WGBatcher
filter device.PacketFilter filter device.PacketFilter
wgProxyFactory wgProxyFactory wgProxyFactory wgProxyFactory
} }
@@ -128,6 +129,12 @@ func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAliv
} }
log.Debugf("updating interface %s peer %s, endpoint %s, allowedIPs %v", w.tun.DeviceName(), peerKey, endpoint, allowedIps) log.Debugf("updating interface %s peer %s, endpoint %s, allowedIPs %v", w.tun.DeviceName(), peerKey, endpoint, allowedIps)
if endpoint != nil && w.batcher != nil {
if err := w.batcher.Flush(); err != nil {
log.Warnf("failed to flush batched operations: %v", err)
}
}
return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey) return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
} }
@@ -152,6 +159,10 @@ func (w *WGIface) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
} }
log.Debugf("Adding allowed IP to interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP) log.Debugf("Adding allowed IP to interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP)
if w.batcher != nil {
return w.batcher.AddAllowedIP(peerKey, allowedIP)
}
return w.configurer.AddAllowedIP(peerKey, allowedIP) return w.configurer.AddAllowedIP(peerKey, allowedIP)
} }
@@ -164,6 +175,10 @@ func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error
} }
log.Debugf("Removing allowed IP from interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP) log.Debugf("Removing allowed IP from interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP)
if w.batcher != nil {
return w.batcher.RemoveAllowedIP(peerKey, allowedIP)
}
return w.configurer.RemoveAllowedIP(peerKey, allowedIP) return w.configurer.RemoveAllowedIP(peerKey, allowedIP)
} }
@@ -174,6 +189,12 @@ func (w *WGIface) Close() error {
var result *multierror.Error var result *multierror.Error
if w.batcher != nil {
if err := w.batcher.Close(); err != nil {
result = multierror.Append(result, fmt.Errorf("failed to close WireGuard batcher: %w", err))
}
}
if err := w.wgProxyFactory.Free(); err != nil { if err := w.wgProxyFactory.Free(); err != nil {
result = multierror.Append(result, fmt.Errorf("failed to free WireGuard proxy: %w", err)) result = multierror.Append(result, fmt.Errorf("failed to free WireGuard proxy: %w", err))
} }

View File

@@ -17,6 +17,7 @@ func (w *WGIface) Create() error {
} }
w.configurer = cfgr w.configurer = cfgr
w.batcher = NewWGBatcher(cfgr)
return nil return nil
} }

View File

@@ -1,8 +1,6 @@
package iface package iface
import ( import "fmt"
"fmt"
)
// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up. // CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up.
// Will reuse an existing one. // Will reuse an existing one.
@@ -15,6 +13,7 @@ func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []s
return err return err
} }
w.configurer = cfgr w.configurer = cfgr
w.batcher = NewWGBatcher(cfgr)
return nil return nil
} }

View File

@@ -29,6 +29,7 @@ func (w *WGIface) Create() error {
return err return err
} }
w.configurer = cfgr w.configurer = cfgr
w.batcher = NewWGBatcher(cfgr)
return nil return nil
} }

View File

@@ -61,7 +61,6 @@ import (
signal "github.com/netbirdio/netbird/signal/client" signal "github.com/netbirdio/netbird/signal/client"
sProto "github.com/netbirdio/netbird/signal/proto" sProto "github.com/netbirdio/netbird/signal/proto"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
nbnet "github.com/netbirdio/netbird/util/net"
) )
// PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer. // PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer.
@@ -138,9 +137,6 @@ type Engine struct {
connMgr *ConnMgr connMgr *ConnMgr
beforePeerHook nbnet.AddHookFunc
afterPeerHook nbnet.RemoveHookFunc
// rpManager is a Rosenpass manager // rpManager is a Rosenpass manager
rpManager *rosenpass.Manager rpManager *rosenpass.Manager
@@ -409,12 +405,8 @@ func (e *Engine) Start() error {
DisableClientRoutes: e.config.DisableClientRoutes, DisableClientRoutes: e.config.DisableClientRoutes,
DisableServerRoutes: e.config.DisableServerRoutes, DisableServerRoutes: e.config.DisableServerRoutes,
}) })
beforePeerHook, afterPeerHook, err := e.routeManager.Init() if err := e.routeManager.Init(); err != nil {
if err != nil {
log.Errorf("Failed to initialize route manager: %s", err) log.Errorf("Failed to initialize route manager: %s", err)
} else {
e.beforePeerHook = beforePeerHook
e.afterPeerHook = afterPeerHook
} }
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener) e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
@@ -1261,10 +1253,6 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
return fmt.Errorf("peer already exists: %s", peerKey) return fmt.Errorf("peer already exists: %s", peerKey)
} }
if e.beforePeerHook != nil && e.afterPeerHook != nil {
conn.AddBeforeAddPeerHook(e.beforePeerHook)
conn.AddAfterRemovePeerHook(e.afterPeerHook)
}
return nil return nil
} }

View File

@@ -400,7 +400,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
StatusRecorder: engine.statusRecorder, StatusRecorder: engine.statusRecorder,
RelayManager: relayMgr, RelayManager: relayMgr,
}) })
_, _, err = engine.routeManager.Init() err = engine.routeManager.Init()
require.NoError(t, err) require.NoError(t, err)
engine.dnsServer = &dns.MockServer{ engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil }, UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
@@ -1494,7 +1494,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
} }
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager) secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil) mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{})
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }

View File

@@ -26,7 +26,6 @@ import (
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
relayClient "github.com/netbirdio/netbird/relay/client" relayClient "github.com/netbirdio/netbird/relay/client"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
nbnet "github.com/netbirdio/netbird/util/net"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group" semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
) )
@@ -106,10 +105,6 @@ type Conn struct {
workerRelay *WorkerRelay workerRelay *WorkerRelay
wgWatcherWg sync.WaitGroup wgWatcherWg sync.WaitGroup
connIDRelay nbnet.ConnectionID
connIDICE nbnet.ConnectionID
beforeAddPeerHooks []nbnet.AddHookFunc
afterRemovePeerHooks []nbnet.RemoveHookFunc
// used to store the remote Rosenpass key for Relayed connection in case of connection update from ice // used to store the remote Rosenpass key for Relayed connection in case of connection update from ice
rosenpassRemoteKey []byte rosenpassRemoteKey []byte
@@ -267,8 +262,6 @@ func (conn *Conn) Close(signalToRemote bool) {
conn.Log.Errorf("failed to remove wg endpoint: %v", err) conn.Log.Errorf("failed to remove wg endpoint: %v", err)
} }
conn.freeUpConnID()
if conn.evalStatus() == StatusConnected && conn.onDisconnected != nil { if conn.evalStatus() == StatusConnected && conn.onDisconnected != nil {
conn.onDisconnected(conn.config.WgConfig.RemoteKey) conn.onDisconnected(conn.config.WgConfig.RemoteKey)
} }
@@ -293,13 +286,6 @@ func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMa
conn.workerICE.OnRemoteCandidate(candidate, haRoutes) conn.workerICE.OnRemoteCandidate(candidate, haRoutes)
} }
func (conn *Conn) AddBeforeAddPeerHook(hook nbnet.AddHookFunc) {
conn.beforeAddPeerHooks = append(conn.beforeAddPeerHooks, hook)
}
func (conn *Conn) AddAfterRemovePeerHook(hook nbnet.RemoveHookFunc) {
conn.afterRemovePeerHooks = append(conn.afterRemovePeerHooks, hook)
}
// SetOnConnected sets a handler function to be triggered by Conn when a new connection to a remote peer established // SetOnConnected sets a handler function to be triggered by Conn when a new connection to a remote peer established
func (conn *Conn) SetOnConnected(handler func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)) { func (conn *Conn) SetOnConnected(handler func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)) {
conn.onConnected = handler conn.onConnected = handler
@@ -387,10 +373,6 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
ep = directEp ep = directEp
} }
if err := conn.runBeforeAddPeerHooks(ep.IP); err != nil {
conn.Log.Errorf("Before add peer hook failed: %v", err)
}
conn.workerRelay.DisableWgWatcher() conn.workerRelay.DisableWgWatcher()
// todo consider to run conn.wgWatcherWg.Wait() here // todo consider to run conn.wgWatcherWg.Wait() here
@@ -503,10 +485,6 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
return return
} }
if err := conn.runBeforeAddPeerHooks(wgProxy.EndpointAddr().IP); err != nil {
conn.Log.Errorf("Before add peer hook failed: %v", err)
}
wgProxy.Work() wgProxy.Work()
if err := conn.configureWGEndpoint(wgProxy.EndpointAddr(), rci.rosenpassPubKey); err != nil { if err := conn.configureWGEndpoint(wgProxy.EndpointAddr(), rci.rosenpassPubKey); err != nil {
if err := wgProxy.CloseConn(); err != nil { if err := wgProxy.CloseConn(); err != nil {
@@ -707,36 +685,6 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) {
return true return true
} }
func (conn *Conn) runBeforeAddPeerHooks(ip net.IP) error {
conn.connIDICE = nbnet.GenerateConnID()
for _, hook := range conn.beforeAddPeerHooks {
if err := hook(conn.connIDICE, ip); err != nil {
return err
}
}
return nil
}
func (conn *Conn) freeUpConnID() {
if conn.connIDRelay != "" {
for _, hook := range conn.afterRemovePeerHooks {
if err := hook(conn.connIDRelay); err != nil {
conn.Log.Errorf("After remove peer hook failed: %v", err)
}
}
conn.connIDRelay = ""
}
if conn.connIDICE != "" {
for _, hook := range conn.afterRemovePeerHooks {
if err := hook(conn.connIDICE); err != nil {
conn.Log.Errorf("After remove peer hook failed: %v", err)
}
}
conn.connIDICE = ""
}
}
func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) { func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
conn.Log.Debugf("setup proxied WireGuard connection") conn.Log.Debugf("setup proxied WireGuard connection")
udpAddr := &net.UDPAddr{ udpAddr := &net.UDPAddr{

View File

@@ -24,7 +24,7 @@ type WorkerRelay struct {
isController bool isController bool
config ConnConfig config ConnConfig
conn *Conn conn *Conn
relayManager relayClient.ManagerService relayManager *relayClient.Manager
relayedConn net.Conn relayedConn net.Conn
relayLock sync.Mutex relayLock sync.Mutex
@@ -34,7 +34,7 @@ type WorkerRelay struct {
wgWatcher *WGWatcher wgWatcher *WGWatcher
} }
func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager relayClient.ManagerService, stateDump *stateDump) *WorkerRelay { func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager *relayClient.Manager, stateDump *stateDump) *WorkerRelay {
r := &WorkerRelay{ r := &WorkerRelay{
peerCtx: ctx, peerCtx: ctx,
log: log, log: log,

View File

@@ -812,7 +812,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
} }
params := common.HandlerParams{ params := common.HandlerParams{
Route: &route.Route{Network: netip.MustParsePrefix("192.168.0.0/24")}, Route: &route.Route{Network: netip.MustParsePrefix("192.168.0.0/24")},
} }
// create new clientNetwork // create new clientNetwork
client := &Watcher{ client := &Watcher{

View File

@@ -44,7 +44,7 @@ import (
// Manager is a route manager interface // Manager is a route manager interface
type Manager interface { type Manager interface {
Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) Init() error
UpdateRoutes(updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error UpdateRoutes(updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error
ClassifyRoutes(newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap) ClassifyRoutes(newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap)
TriggerSelection(route.HAMap) TriggerSelection(route.HAMap)
@@ -201,11 +201,11 @@ func (m *DefaultManager) setupRefCounters(useNoop bool) {
} }
// Init sets up the routing // Init sets up the routing
func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { func (m *DefaultManager) Init() error {
m.routeSelector = m.initSelector() m.routeSelector = m.initSelector()
if nbnet.CustomRoutingDisabled() || m.disableClientRoutes { if nbnet.CustomRoutingDisabled() || m.disableClientRoutes {
return nil, nil, nil return nil
} }
if err := m.sysOps.CleanupRouting(nil); err != nil { if err := m.sysOps.CleanupRouting(nil); err != nil {
@@ -219,13 +219,12 @@ func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
ips := resolveURLsToIPs(initialAddresses) ips := resolveURLsToIPs(initialAddresses)
beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips, m.stateManager) if err := m.sysOps.SetupRouting(ips, m.stateManager); err != nil {
if err != nil { return fmt.Errorf("setup routing: %w", err)
return nil, nil, fmt.Errorf("setup routing: %w", err)
} }
log.Info("Routing setup complete") log.Info("Routing setup complete")
return beforePeerHook, afterPeerHook, nil return nil
} }
func (m *DefaultManager) initSelector() *routeselector.RouteSelector { func (m *DefaultManager) initSelector() *routeselector.RouteSelector {

View File

@@ -430,7 +430,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
StatusRecorder: statusRecorder, StatusRecorder: statusRecorder,
}) })
_, _, err = routeManager.Init() err = routeManager.Init()
require.NoError(t, err, "should init route manager") require.NoError(t, err, "should init route manager")
defer routeManager.Stop(nil) defer routeManager.Stop(nil)

View File

@@ -9,7 +9,6 @@ import (
"github.com/netbirdio/netbird/client/internal/routeselector" "github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/util/net"
) )
// MockManager is the mock instance of a route manager // MockManager is the mock instance of a route manager
@@ -23,8 +22,8 @@ type MockManager struct {
StopFunc func(manager *statemanager.Manager) StopFunc func(manager *statemanager.Manager)
} }
func (m *MockManager) Init() (net.AddHookFunc, net.RemoveHookFunc, error) { func (m *MockManager) Init() error {
return nil, nil, nil return nil
} }
// InitialRouteRange mock implementation of InitialRouteRange from Manager interface // InitialRouteRange mock implementation of InitialRouteRange from Manager interface

View File

@@ -6,6 +6,7 @@ import (
"net/netip" "net/netip"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/routemanager/notifier" "github.com/netbirdio/netbird/client/internal/routemanager/notifier"
@@ -56,6 +57,10 @@ type SysOps struct {
// seq is an atomic counter for generating unique sequence numbers for route messages // seq is an atomic counter for generating unique sequence numbers for route messages
//nolint:unused // only used on BSD systems //nolint:unused // only used on BSD systems
seq atomic.Uint32 seq atomic.Uint32
localSubnetsCache []*net.IPNet
localSubnetsCacheMu sync.RWMutex
localSubnetsCacheTime time.Time
} }
func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps { func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps {

View File

@@ -10,11 +10,10 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net"
) )
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error {
return nil, nil, nil return nil
} }
func (r *SysOps) CleanupRouting(*statemanager.Manager) error { func (r *SysOps) CleanupRouting(*statemanager.Manager) error {

View File

@@ -10,6 +10,7 @@ import (
"net/netip" "net/netip"
"runtime" "runtime"
"strconv" "strconv"
"time"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
"github.com/libp2p/go-netroute" "github.com/libp2p/go-netroute"
@@ -24,6 +25,8 @@ import (
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
const localSubnetsCacheTTL = 15 * time.Minute
var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1) var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1)
var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1) var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1)
var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1) var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
@@ -31,7 +34,7 @@ var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
var ErrRoutingIsSeparate = errors.New("routing is separate") var ErrRoutingIsSeparate = errors.New("routing is separate")
func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemanager.Manager) error {
stateManager.RegisterState(&ShutdownState{}) stateManager.RegisterState(&ShutdownState{})
initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified()) initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified())
@@ -75,7 +78,10 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana
r.refCounter = refCounter r.refCounter = refCounter
return r.setupHooks(initAddresses, stateManager) if err := r.setupHooks(initAddresses, stateManager); err != nil {
return fmt.Errorf("setup hooks: %w", err)
}
return nil
} }
// updateState updates state on every change so it will be persisted regularly // updateState updates state on every change so it will be persisted regularly
@@ -128,18 +134,14 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf wgIface, init
return Nexthop{}, fmt.Errorf("get next hop: %w", err) return Nexthop{}, fmt.Errorf("get next hop: %w", err)
} }
log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop.IP, prefix, nexthop.IP) log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop.IP, prefix, nexthop.Intf)
exitNextHop := Nexthop{ exitNextHop := nexthop
IP: nexthop.IP,
Intf: nexthop.Intf,
}
vpnAddr := vpnIntf.Address().IP vpnAddr := vpnIntf.Address().IP
// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values // if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
if exitNextHop.IP == vpnAddr || exitNextHop.Intf != nil && exitNextHop.Intf.Name == vpnIntf.Name() { if exitNextHop.IP == vpnAddr || exitNextHop.Intf != nil && exitNextHop.Intf.Name == vpnIntf.Name() {
log.Debugf("Route for prefix %s is pointing to the VPN interface, using initial next hop %v", prefix, initialNextHop) log.Debugf("Route for prefix %s is pointing to the VPN interface, using initial next hop %v", prefix, initialNextHop)
exitNextHop = initialNextHop exitNextHop = initialNextHop
} }
@@ -152,12 +154,37 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf wgIface, init
} }
func (r *SysOps) isPrefixInLocalSubnets(prefix netip.Prefix) (bool, *net.IPNet) { func (r *SysOps) isPrefixInLocalSubnets(prefix netip.Prefix) (bool, *net.IPNet) {
r.localSubnetsCacheMu.RLock()
cacheAge := time.Since(r.localSubnetsCacheTime)
subnets := r.localSubnetsCache
r.localSubnetsCacheMu.RUnlock()
if cacheAge > localSubnetsCacheTTL || subnets == nil {
r.localSubnetsCacheMu.Lock()
if time.Since(r.localSubnetsCacheTime) > localSubnetsCacheTTL || r.localSubnetsCache == nil {
r.refreshLocalSubnetsCache()
}
subnets = r.localSubnetsCache
r.localSubnetsCacheMu.Unlock()
}
for _, subnet := range subnets {
if subnet.Contains(prefix.Addr().AsSlice()) {
return true, subnet
}
}
return false, nil
}
func (r *SysOps) refreshLocalSubnetsCache() {
localInterfaces, err := net.Interfaces() localInterfaces, err := net.Interfaces()
if err != nil { if err != nil {
log.Errorf("Failed to get local interfaces: %v", err) log.Errorf("Failed to get local interfaces: %v", err)
return false, nil return
} }
var newSubnets []*net.IPNet
for _, intf := range localInterfaces { for _, intf := range localInterfaces {
addrs, err := intf.Addrs() addrs, err := intf.Addrs()
if err != nil { if err != nil {
@@ -171,14 +198,12 @@ func (r *SysOps) isPrefixInLocalSubnets(prefix netip.Prefix) (bool, *net.IPNet)
log.Errorf("Failed to convert address to IPNet: %v", addr) log.Errorf("Failed to convert address to IPNet: %v", addr)
continue continue
} }
newSubnets = append(newSubnets, ipnet)
if ipnet.Contains(prefix.Addr().AsSlice()) {
return true, ipnet
}
} }
} }
return false, nil r.localSubnetsCache = newSubnets
r.localSubnetsCacheTime = time.Now()
} }
// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix // genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix
@@ -264,7 +289,7 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface)
return r.removeFromRouteTable(prefix, nextHop) return r.removeFromRouteTable(prefix, nextHop)
} }
func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) error {
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error { beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
prefix, err := util.GetPrefixFromIP(ip) prefix, err := util.GetPrefixFromIP(ip)
if err != nil { if err != nil {
@@ -289,9 +314,11 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
return nil return nil
} }
var merr *multierror.Error
for _, ip := range initAddresses { for _, ip := range initAddresses {
if err := beforeHook("init", ip); err != nil { if err := beforeHook("init", ip); err != nil {
log.Errorf("Failed to add route reference: %v", err) merr = multierror.Append(merr, fmt.Errorf("add initial route for %s: %w", ip, err))
} }
} }
@@ -300,11 +327,11 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
return ctx.Err() return ctx.Err()
} }
var result *multierror.Error var merr *multierror.Error
for _, ip := range resolvedIPs { for _, ip := range resolvedIPs {
result = multierror.Append(result, beforeHook(connID, ip.IP)) merr = multierror.Append(merr, beforeHook(connID, ip.IP))
} }
return nberrors.FormatErrorOrNil(result) return nberrors.FormatErrorOrNil(merr)
}) })
nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error { nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error {
@@ -319,7 +346,16 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
return afterHook(connID) return afterHook(connID)
}) })
return beforeHook, afterHook, nil nbnet.AddListenerAddressRemoveHook(func(connID nbnet.ConnectionID, prefix netip.Prefix) error {
if _, err := r.refCounter.Decrement(prefix); err != nil {
return fmt.Errorf("remove route reference: %w", err)
}
r.updateState(stateManager)
return nil
})
return nberrors.FormatErrorOrNil(merr)
} }
func GetNextHop(ip netip.Addr) (Nexthop, error) { func GetNextHop(ip netip.Addr) (Nexthop, error) {

View File

@@ -143,7 +143,7 @@ func TestAddVPNRoute(t *testing.T) {
wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n) wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n)
r := NewSysOps(wgInterface, nil) r := NewSysOps(wgInterface, nil)
_, _, err := r.SetupRouting(nil, nil) err := r.SetupRouting(nil, nil)
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(func() { t.Cleanup(func() {
assert.NoError(t, r.CleanupRouting(nil)) assert.NoError(t, r.CleanupRouting(nil))
@@ -341,7 +341,7 @@ func TestAddRouteToNonVPNIntf(t *testing.T) {
wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n) wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n)
r := NewSysOps(wgInterface, nil) r := NewSysOps(wgInterface, nil)
_, _, err := r.SetupRouting(nil, nil) err := r.SetupRouting(nil, nil)
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(func() { t.Cleanup(func() {
assert.NoError(t, r.CleanupRouting(nil)) assert.NoError(t, r.CleanupRouting(nil))
@@ -484,7 +484,7 @@ func setupTestEnv(t *testing.T) {
}) })
r := NewSysOps(wgInterface, nil) r := NewSysOps(wgInterface, nil)
_, _, err := r.SetupRouting(nil, nil) err := r.SetupRouting(nil, nil)
require.NoError(t, err, "setupRouting should not return err") require.NoError(t, err, "setupRouting should not return err")
t.Cleanup(func() { t.Cleanup(func() {
assert.NoError(t, r.CleanupRouting(nil)) assert.NoError(t, r.CleanupRouting(nil))

View File

@@ -10,14 +10,13 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net"
) )
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error {
r.mu.Lock() r.mu.Lock()
defer r.mu.Unlock() defer r.mu.Unlock()
r.prefixes = make(map[netip.Prefix]struct{}) r.prefixes = make(map[netip.Prefix]struct{})
return nil, nil, nil return nil
} }
func (r *SysOps) CleanupRouting(*statemanager.Manager) error { func (r *SysOps) CleanupRouting(*statemanager.Manager) error {

View File

@@ -72,7 +72,7 @@ func getSetupRules() []ruleParams {
// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table. // Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table.
// This table is where a default route or other specific routes received from the management server are configured, // This table is where a default route or other specific routes received from the management server are configured,
// enabling VPN connectivity. // enabling VPN connectivity.
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (_ nbnet.AddHookFunc, _ nbnet.RemoveHookFunc, err error) { func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (err error) {
if !nbnet.AdvancedRouting() { if !nbnet.AdvancedRouting() {
log.Infof("Using legacy routing setup") log.Infof("Using legacy routing setup")
return r.setupRefCounter(initAddresses, stateManager) return r.setupRefCounter(initAddresses, stateManager)
@@ -89,7 +89,7 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
rules := getSetupRules() rules := getSetupRules()
for _, rule := range rules { for _, rule := range rules {
if err := addRule(rule); err != nil { if err := addRule(rule); err != nil {
return nil, nil, fmt.Errorf("%s: %w", rule.description, err) return fmt.Errorf("%s: %w", rule.description, err)
} }
} }
@@ -104,7 +104,7 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
} }
originalSysctl = originalValues originalSysctl = originalValues
return nil, nil, nil return nil
} }
// CleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'. // CleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'.

View File

@@ -18,10 +18,9 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net"
) )
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) error {
return r.setupRefCounter(initAddresses, stateManager) return r.setupRefCounter(initAddresses, stateManager)
} }

View File

@@ -19,7 +19,6 @@ import (
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net"
) )
const InfiniteLifetime = 0xffffffff const InfiniteLifetime = 0xffffffff
@@ -137,7 +136,7 @@ const (
RouteDeleted RouteDeleted
) )
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) error {
return r.setupRefCounter(initAddresses, stateManager) return r.setupRefCounter(initAddresses, stateManager)
} }

View File

@@ -1330,6 +1330,13 @@ func (x *PeerState) GetRelayAddress() string {
return "" return ""
} }
func (x *PeerState) GetConnectionType() string {
if x.Relayed {
return "Relayed"
}
return "P2P"
}
// LocalPeerState contains the latest state of the local peer // LocalPeerState contains the latest state of the local peer
type LocalPeerState struct { type LocalPeerState struct {
state protoimpl.MessageState `protogen:"open.v1"` state protoimpl.MessageState `protogen:"open.v1"`

View File

@@ -212,7 +212,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
} }
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager) secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil) mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{})
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }

View File

@@ -100,7 +100,7 @@ type OutputOverview struct {
LazyConnectionEnabled bool `json:"lazyConnectionEnabled" yaml:"lazyConnectionEnabled"` LazyConnectionEnabled bool `json:"lazyConnectionEnabled" yaml:"lazyConnectionEnabled"`
} }
func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}) OutputOverview { func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}, connectionTypeFilter string) OutputOverview {
pbFullStatus := resp.GetFullStatus() pbFullStatus := resp.GetFullStatus()
managementState := pbFullStatus.GetManagementState() managementState := pbFullStatus.GetManagementState()
@@ -118,7 +118,7 @@ func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, status
} }
relayOverview := mapRelays(pbFullStatus.GetRelays()) relayOverview := mapRelays(pbFullStatus.GetRelays())
peersOverview := mapPeers(resp.GetFullStatus().GetPeers(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter) peersOverview := mapPeers(resp.GetFullStatus().GetPeers(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter)
overview := OutputOverview{ overview := OutputOverview{
Peers: peersOverview, Peers: peersOverview,
@@ -193,6 +193,7 @@ func mapPeers(
prefixNamesFilter []string, prefixNamesFilter []string,
prefixNamesFilterMap map[string]struct{}, prefixNamesFilterMap map[string]struct{},
ipsFilter map[string]struct{}, ipsFilter map[string]struct{},
connectionTypeFilter string,
) PeersStateOutput { ) PeersStateOutput {
var peersStateDetail []PeerStateDetailOutput var peersStateDetail []PeerStateDetailOutput
peersConnected := 0 peersConnected := 0
@@ -208,7 +209,7 @@ func mapPeers(
transferSent := int64(0) transferSent := int64(0)
isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String() isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String()
if skipDetailByFilters(pbPeerState, pbPeerState.ConnStatus, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter) { if skipDetailByFilters(pbPeerState, pbPeerState.ConnStatus, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter) {
continue continue
} }
if isPeerConnected { if isPeerConnected {
@@ -218,10 +219,7 @@ func mapPeers(
remoteICE = pbPeerState.GetRemoteIceCandidateType() remoteICE = pbPeerState.GetRemoteIceCandidateType()
localICEEndpoint = pbPeerState.GetLocalIceCandidateEndpoint() localICEEndpoint = pbPeerState.GetLocalIceCandidateEndpoint()
remoteICEEndpoint = pbPeerState.GetRemoteIceCandidateEndpoint() remoteICEEndpoint = pbPeerState.GetRemoteIceCandidateEndpoint()
connType = "P2P" connType = pbPeerState.GetConnectionType()
if pbPeerState.Relayed {
connType = "Relayed"
}
relayServerAddress = pbPeerState.GetRelayAddress() relayServerAddress = pbPeerState.GetRelayAddress()
lastHandshake = pbPeerState.GetLastWireguardHandshake().AsTime().Local() lastHandshake = pbPeerState.GetLastWireguardHandshake().AsTime().Local()
transferReceived = pbPeerState.GetBytesRx() transferReceived = pbPeerState.GetBytesRx()
@@ -542,10 +540,11 @@ func parsePeers(peers PeersStateOutput, rosenpassEnabled, rosenpassPermissive bo
return peersString return peersString
} }
func skipDetailByFilters(peerState *proto.PeerState, peerStatus string, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}) bool { func skipDetailByFilters(peerState *proto.PeerState, peerStatus string, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}, connectionTypeFilter string) bool {
statusEval := false statusEval := false
ipEval := false ipEval := false
nameEval := true nameEval := true
connectionTypeEval := false
if statusFilter != "" { if statusFilter != "" {
if !strings.EqualFold(peerStatus, statusFilter) { if !strings.EqualFold(peerStatus, statusFilter) {
@@ -570,8 +569,11 @@ func skipDetailByFilters(peerState *proto.PeerState, peerStatus string, statusFi
} else { } else {
nameEval = false nameEval = false
} }
if connectionTypeFilter != "" && !strings.EqualFold(peerState.GetConnectionType(), connectionTypeFilter) {
connectionTypeEval = true
}
return statusEval || ipEval || nameEval return statusEval || ipEval || nameEval || connectionTypeEval
} }
func toIEC(b int64) string { func toIEC(b int64) string {

View File

@@ -234,7 +234,7 @@ var overview = OutputOverview{
} }
func TestConversionFromFullStatusToOutputOverview(t *testing.T) { func TestConversionFromFullStatusToOutputOverview(t *testing.T) {
convertedResult := ConvertToStatusOutputOverview(resp, false, "", nil, nil, nil) convertedResult := ConvertToStatusOutputOverview(resp, false, "", nil, nil, nil, "")
assert.Equal(t, overview, convertedResult) assert.Equal(t, overview, convertedResult)
} }

View File

@@ -433,7 +433,7 @@ func (s *serviceClient) collectDebugData(
var postUpStatusOutput string var postUpStatusOutput string
if postUpStatus != nil { if postUpStatus != nil {
overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil) overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "")
postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview) postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
} }
headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339)) headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
@@ -450,7 +450,7 @@ func (s *serviceClient) collectDebugData(
var preDownStatusOutput string var preDownStatusOutput string
if preDownStatus != nil { if preDownStatus != nil {
overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil) overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "")
preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview) preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
} }
headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s", headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s",
@@ -581,7 +581,7 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa
var statusOutput string var statusOutput string
if statusResp != nil { if statusResp != nil {
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil) overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "")
statusOutput = nbstatus.ParseToFullDetailSummary(overview) statusOutput = nbstatus.ParseToFullDetailSummary(overview)
} }

2
go.mod
View File

@@ -63,7 +63,7 @@ require (
github.com/miekg/dns v1.1.59 github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0 github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20250612164546-6bd7e2338d65 github.com/netbirdio/management-integrations/integrations v0.0.0-20250718071730-f4d133556ff5
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250514131221-a464fd5f30cb github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250514131221-a464fd5f30cb
github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0 github.com/oschwald/maxminddb-golang v1.12.0

4
go.sum
View File

@@ -503,8 +503,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250612164546-6bd7e2338d65 h1:5OfYiLjpr4dbQYJI5ouZaylkVdi2KlErLFOwBeBo5Hw= github.com/netbirdio/management-integrations/integrations v0.0.0-20250718071730-f4d133556ff5 h1:Zfn8d83OVyELCdxgprcyXR3D8uqoxHtXE9PUxVXDx/w=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250612164546-6bd7e2338d65/go.mod h1:Gi9raplYzCCyh07Olw/DVfCJTFgpr1WCXJ/Q+8TSA9Q= github.com/netbirdio/management-integrations/integrations v0.0.0-20250718071730-f4d133556ff5/go.mod h1:Gi9raplYzCCyh07Olw/DVfCJTFgpr1WCXJ/Q+8TSA9Q=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250514131221-a464fd5f30cb h1:Cr6age+ePALqlSvtp7wc6lYY97XN7rkD1K4XEDmY+TU= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250514131221-a464fd5f30cb h1:Cr6age+ePALqlSvtp7wc6lYY97XN7rkD1K4XEDmY+TU=

View File

@@ -112,7 +112,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
} }
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager) secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil) mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, mgmt.MockIntegratedValidator{})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -292,7 +292,7 @@ var (
ephemeralManager.LoadInitialPeers(ctx) ephemeralManager.LoadInitialPeers(ctx)
gRPCAPIHandler := grpc.NewServer(gRPCOpts...) gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
srv, err := server.NewServer(ctx, config, accountManager, settingsManager, peersUpdateManager, secretsManager, appMetrics, ephemeralManager, authManager) srv, err := server.NewServer(ctx, config, accountManager, settingsManager, peersUpdateManager, secretsManager, appMetrics, ephemeralManager, authManager, integratedPeerValidator)
if err != nil { if err != nil {
return fmt.Errorf("failed creating gRPC API handler: %v", err) return fmt.Errorf("failed creating gRPC API handler: %v", err)
} }

View File

@@ -2887,7 +2887,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, error) {
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -219,7 +219,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
// return empty extra settings for expected calls to UpdateAccountPeers // return empty extra settings for expected calls to UpdateAccountPeers
settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), gomock.Any()).Return(&types.ExtraSettings{}, nil).AnyTimes() settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), gomock.Any()).Return(&types.ExtraSettings{}, nil).AnyTimes()
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
} }
func createDNSStore(t *testing.T) (store.Store, error) { func createDNSStore(t *testing.T) (store.Store, error) {

View File

@@ -19,6 +19,7 @@ import (
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config" integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
"github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/proto"
@@ -40,13 +41,14 @@ type GRPCServer struct {
settingsManager settings.Manager settingsManager settings.Manager
wgKey wgtypes.Key wgKey wgtypes.Key
proto.UnimplementedManagementServiceServer proto.UnimplementedManagementServiceServer
peersUpdateManager *PeersUpdateManager peersUpdateManager *PeersUpdateManager
config *types.Config config *types.Config
secretsManager SecretsManager secretsManager SecretsManager
appMetrics telemetry.AppMetrics appMetrics telemetry.AppMetrics
ephemeralManager *EphemeralManager ephemeralManager *EphemeralManager
peerLocks sync.Map peerLocks sync.Map
authManager auth.Manager authManager auth.Manager
integratedPeerValidator integrated_validator.IntegratedValidator
} }
// NewServer creates a new Management server // NewServer creates a new Management server
@@ -60,6 +62,7 @@ func NewServer(
appMetrics telemetry.AppMetrics, appMetrics telemetry.AppMetrics,
ephemeralManager *EphemeralManager, ephemeralManager *EphemeralManager,
authManager auth.Manager, authManager auth.Manager,
integratedPeerValidator integrated_validator.IntegratedValidator,
) (*GRPCServer, error) { ) (*GRPCServer, error) {
key, err := wgtypes.GeneratePrivateKey() key, err := wgtypes.GeneratePrivateKey()
if err != nil { if err != nil {
@@ -79,14 +82,15 @@ func NewServer(
return &GRPCServer{ return &GRPCServer{
wgKey: key, wgKey: key,
// peerKey -> event channel // peerKey -> event channel
peersUpdateManager: peersUpdateManager, peersUpdateManager: peersUpdateManager,
accountManager: accountManager, accountManager: accountManager,
settingsManager: settingsManager, settingsManager: settingsManager,
config: config, config: config,
secretsManager: secretsManager, secretsManager: secretsManager,
authManager: authManager, authManager: authManager,
appMetrics: appMetrics, appMetrics: appMetrics,
ephemeralManager: ephemeralManager, ephemeralManager: ephemeralManager,
integratedPeerValidator: integratedPeerValidator,
}, nil }, nil
} }
@@ -850,7 +854,7 @@ func (s *GRPCServer) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.En
return nil, status.Error(codes.NotFound, "no pkce authorization flow information available") return nil, status.Error(codes.NotFound, "no pkce authorization flow information available")
} }
flowInfoResp := &proto.PKCEAuthorizationFlow{ initInfoFlow := &proto.PKCEAuthorizationFlow{
ProviderConfig: &proto.ProviderConfig{ ProviderConfig: &proto.ProviderConfig{
Audience: s.config.PKCEAuthorizationFlow.ProviderConfig.Audience, Audience: s.config.PKCEAuthorizationFlow.ProviderConfig.Audience,
ClientID: s.config.PKCEAuthorizationFlow.ProviderConfig.ClientID, ClientID: s.config.PKCEAuthorizationFlow.ProviderConfig.ClientID,
@@ -865,6 +869,8 @@ func (s *GRPCServer) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.En
}, },
} }
flowInfoResp := s.integratedPeerValidator.ValidateFlowResponse(ctx, peerKey.String(), initInfoFlow)
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, flowInfoResp) encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, flowInfoResp)
if err != nil { if err != nil {
return nil, status.Error(codes.Internal, "failed to encrypt no pkce authorization flow information") return nil, status.Error(codes.Internal, "failed to encrypt no pkce authorization flow information")

View File

@@ -424,9 +424,10 @@ func toPolicyResponse(groups []*types.Group, policy *types.Policy) *api.Policy {
} }
if group, ok := groupsMap[gid]; ok { if group, ok := groupsMap[gid]; ok {
minimum := api.GroupMinimum{ minimum := api.GroupMinimum{
Id: group.ID, Id: group.ID,
Name: group.Name, Name: group.Name,
PeersCount: len(group.Peers), PeersCount: len(group.Peers),
ResourcesCount: len(group.Resources),
} }
destinations = append(destinations, minimum) destinations = append(destinations, minimum)
cache[gid] = minimum cache[gid] = minimum

View File

@@ -1,4 +1,5 @@
package testing_tools package testing_tools
import ( import (
"bytes" "bytes"
"context" "context"
@@ -132,7 +133,7 @@ func BuildApiBlackBoxWithDBState(t TB, sqlFile string, expectedPeerUpdate *serve
} }
geoMock := &geolocation.Mock{} geoMock := &geolocation.Mock{}
validatorMock := server.MocIntegratedValidator{} validatorMock := server.MockIntegratedValidator{}
proxyController := integrations.NewController(store) proxyController := integrations.NewController(store)
userManager := users.NewManager(store) userManager := users.NewManager(store)
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)

View File

@@ -6,6 +6,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
@@ -101,22 +102,23 @@ func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountI
return am.integratedPeerValidator.GetValidatedPeers(accountID, groups, peers, settings.Extra) return am.integratedPeerValidator.GetValidatedPeers(accountID, groups, peers, settings.Extra)
} }
type MocIntegratedValidator struct { type MockIntegratedValidator struct {
integrated_validator.IntegratedValidator
ValidatePeerFunc func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error) ValidatePeerFunc func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error)
} }
func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error { func (a MockIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error {
return nil return nil
} }
func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error) { func (a MockIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error) {
if a.ValidatePeerFunc != nil { if a.ValidatePeerFunc != nil {
return a.ValidatePeerFunc(context.Background(), update, peer, userID, accountID, dnsDomain, peersGroup, extraSettings) return a.ValidatePeerFunc(context.Background(), update, peer, userID, accountID, dnsDomain, peersGroup, extraSettings)
} }
return update, false, nil return update, false, nil
} }
func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *types.ExtraSettings) (map[string]struct{}, error) { func (a MockIntegratedValidator) GetValidatedPeers(accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *types.ExtraSettings) (map[string]struct{}, error) {
validatedPeers := make(map[string]struct{}) validatedPeers := make(map[string]struct{})
for _, peer := range peers { for _, peer := range peers {
validatedPeers[peer.ID] = struct{}{} validatedPeers[peer.ID] = struct{}{}
@@ -124,22 +126,22 @@ func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups []*ty
return validatedPeers, nil return validatedPeers, nil
} }
func (MocIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) *nbpeer.Peer { func (MockIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) *nbpeer.Peer {
return peer return peer
} }
func (MocIntegratedValidator) IsNotValidPeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error) { func (MockIntegratedValidator) IsNotValidPeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error) {
return false, false, nil return false, false, nil
} }
func (MocIntegratedValidator) PeerDeleted(_ context.Context, _, _ string) error { func (MockIntegratedValidator) PeerDeleted(_ context.Context, _, _ string) error {
return nil return nil
} }
func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string)) { func (MockIntegratedValidator) SetPeerInvalidationListener(func(accountID string)) {
// just a dummy // just a dummy
} }
func (MocIntegratedValidator) Stop(_ context.Context) { func (MockIntegratedValidator) Stop(_ context.Context) {
// just a dummy // just a dummy
} }

View File

@@ -3,6 +3,7 @@ package integrated_validator
import ( import (
"context" "context"
"github.com/netbirdio/netbird/management/proto"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
) )
@@ -17,4 +18,5 @@ type IntegratedValidator interface {
PeerDeleted(ctx context.Context, accountID, peerID string) error PeerDeleted(ctx context.Context, accountID, peerID string) error
SetPeerInvalidationListener(fn func(accountID string)) SetPeerInvalidationListener(fn func(accountID string))
Stop(ctx context.Context) Stop(ctx context.Context)
ValidateFlowResponse(ctx context.Context, peerKey string, flowResponse *proto.PKCEAuthorizationFlow) *proto.PKCEAuthorizationFlow
} }

View File

@@ -448,7 +448,7 @@ func startManagementForTest(t *testing.T, testFile string, config *types.Config)
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)
accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted", accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted",
eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil { if err != nil {
cleanup() cleanup()
@@ -458,7 +458,7 @@ func startManagementForTest(t *testing.T, testFile string, config *types.Config)
secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager) secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
ephemeralMgr := NewEphemeralManager(store, accountManager) ephemeralMgr := NewEphemeralManager(store, accountManager)
mgmtServer, err := NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, ephemeralMgr, nil) mgmtServer, err := NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, ephemeralMgr, nil, MockIntegratedValidator{})
if err != nil { if err != nil {
return nil, nil, "", cleanup, err return nil, nil, "", cleanup, err
} }

View File

@@ -206,7 +206,7 @@ func startServer(
eventStore, eventStore,
nil, nil,
false, false,
server.MocIntegratedValidator{}, server.MockIntegratedValidator{},
metrics, metrics,
port_forwarding.NewControllerMock(), port_forwarding.NewControllerMock(),
settingsMockManager, settingsMockManager,
@@ -227,6 +227,7 @@ func startServer(
nil, nil,
nil, nil,
nil, nil,
server.MockIntegratedValidator{},
) )
if err != nil { if err != nil {
t.Fatalf("failed creating management server: %v", err) t.Fatalf("failed creating management server: %v", err)

View File

@@ -283,7 +283,7 @@ func MigrateSetupKeyToHashedSetupKey[T any](ctx context.Context, db *gorm.DB) er
} }
} }
if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s", "peers", "setup_key")).Error; err != nil { if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s DROP COLUMN IF EXISTS %s", "peers", "setup_key")).Error; err != nil {
log.WithContext(ctx).Errorf("Failed to drop column %s: %v", "setup_key", err) log.WithContext(ctx).Errorf("Failed to drop column %s: %v", "setup_key", err)
} }
@@ -377,6 +377,11 @@ func DropIndex[T any](ctx context.Context, db *gorm.DB, indexName string) error
func CreateIndexIfNotExists[T any](ctx context.Context, db *gorm.DB, indexName string, columns ...string) error { func CreateIndexIfNotExists[T any](ctx context.Context, db *gorm.DB, indexName string, columns ...string) error {
var model T var model T
if !db.Migrator().HasTable(&model) {
log.WithContext(ctx).Debugf("table for %T does not exist, no migration needed", model)
return nil
}
stmt := &gorm.Statement{DB: db} stmt := &gorm.Statement{DB: db}
if err := stmt.Parse(&model); err != nil { if err := stmt.Parse(&model); err != nil {
return fmt.Errorf("failed to parse model schema: %w", err) return fmt.Errorf("failed to parse model schema: %w", err)
@@ -384,6 +389,11 @@ func CreateIndexIfNotExists[T any](ctx context.Context, db *gorm.DB, indexName s
tableName := stmt.Schema.Table tableName := stmt.Schema.Table
dialect := db.Dialector.Name() dialect := db.Dialector.Name()
if db.Migrator().HasIndex(&model, indexName) {
log.WithContext(ctx).Infof("index %s already exists on table %s", indexName, tableName)
return nil
}
var columnClause string var columnClause string
if dialect == "mysql" { if dialect == "mysql" {
var withLength []string var withLength []string

View File

@@ -4,16 +4,21 @@ import (
"context" "context"
"encoding/gob" "encoding/gob"
"net" "net"
"os"
"strings" "strings"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite" "gorm.io/driver/sqlite"
"gorm.io/gorm" "gorm.io/gorm"
"github.com/netbirdio/netbird/management/server/migration" "github.com/netbirdio/netbird/management/server/migration"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/testutil"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@@ -21,7 +26,41 @@ import (
func setupDatabase(t *testing.T) *gorm.DB { func setupDatabase(t *testing.T) *gorm.DB {
t.Helper() t.Helper()
db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{}) var db *gorm.DB
var err error
var dsn string
var cleanup func()
switch os.Getenv("NETBIRD_STORE_ENGINE") {
case "mysql":
cleanup, dsn, err = testutil.CreateMysqlTestContainer()
if err != nil {
t.Fatalf("Failed to create MySQL test container: %v", err)
}
if dsn == "" {
t.Fatal("MySQL connection string is empty, ensure the test container is running")
}
db, err = gorm.Open(mysql.Open(dsn+"?charset=utf8&parseTime=True&loc=Local"), &gorm.Config{})
case "postgres":
cleanup, dsn, err = testutil.CreatePostgresTestContainer()
if err != nil {
t.Fatalf("Failed to create PostgreSQL test container: %v", err)
}
if dsn == "" {
t.Fatalf("PostgreSQL connection string is empty, ensure the test container is running")
}
db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
case "sqlite":
db, err = gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{})
default:
db, err = gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{})
}
if cleanup != nil {
t.Cleanup(cleanup)
}
require.NoError(t, err, "Failed to open database") require.NoError(t, err, "Failed to open database")
return db return db
@@ -34,6 +73,7 @@ func TestMigrateFieldFromGobToJSON_EmptyDB(t *testing.T) {
} }
func TestMigrateFieldFromGobToJSON_WithGobData(t *testing.T) { func TestMigrateFieldFromGobToJSON_WithGobData(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", "sqlite")
db := setupDatabase(t) db := setupDatabase(t)
err := db.AutoMigrate(&types.Account{}, &route.Route{}) err := db.AutoMigrate(&types.Account{}, &route.Route{})
@@ -97,6 +137,7 @@ func TestMigrateNetIPFieldFromBlobToJSON_EmptyDB(t *testing.T) {
} }
func TestMigrateNetIPFieldFromBlobToJSON_WithBlobData(t *testing.T) { func TestMigrateNetIPFieldFromBlobToJSON_WithBlobData(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", "sqlite")
db := setupDatabase(t) db := setupDatabase(t)
err := db.AutoMigrate(&types.Account{}, &nbpeer.Peer{}) err := db.AutoMigrate(&types.Account{}, &nbpeer.Peer{})
@@ -117,12 +158,18 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithBlobData(t *testing.T) {
Peers []peer `gorm:"foreignKey:AccountID;references:id"` Peers []peer `gorm:"foreignKey:AccountID;references:id"`
} }
err = db.Save(&account{ a := &account{
Account: types.Account{Id: "123"}, Account: types.Account{Id: "123"},
Peers: []peer{ }
{Location: location{ConnectionIP: net.IP{10, 0, 0, 1}}},
}}, err = db.Save(a).Error
).Error require.NoError(t, err, "Failed to insert account")
a.Peers = []peer{
{Location: location{ConnectionIP: net.IP{10, 0, 0, 1}}},
}
err = db.Save(a).Error
require.NoError(t, err, "Failed to insert blob data") require.NoError(t, err, "Failed to insert blob data")
var blobValue string var blobValue string
@@ -143,12 +190,18 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithJSONData(t *testing.T) {
err := db.AutoMigrate(&types.Account{}, &nbpeer.Peer{}) err := db.AutoMigrate(&types.Account{}, &nbpeer.Peer{})
require.NoError(t, err, "Failed to auto-migrate tables") require.NoError(t, err, "Failed to auto-migrate tables")
err = db.Save(&types.Account{ account := &types.Account{
Id: "1234", Id: "1234",
PeersG: []nbpeer.Peer{ }
{Location: nbpeer.Location{ConnectionIP: net.IP{10, 0, 0, 1}}},
}}, err = db.Save(account).Error
).Error require.NoError(t, err, "Failed to insert account")
account.PeersG = []nbpeer.Peer{
{AccountID: "1234", Location: nbpeer.Location{ConnectionIP: net.IP{10, 0, 0, 1}}},
}
err = db.Save(account).Error
require.NoError(t, err, "Failed to insert JSON data") require.NoError(t, err, "Failed to insert JSON data")
err = migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](context.Background(), db, "location_connection_ip", "") err = migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](context.Background(), db, "location_connection_ip", "")
@@ -162,12 +215,13 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithJSONData(t *testing.T) {
func TestMigrateSetupKeyToHashedSetupKey_ForPlainKey(t *testing.T) { func TestMigrateSetupKeyToHashedSetupKey_ForPlainKey(t *testing.T) {
db := setupDatabase(t) db := setupDatabase(t)
err := db.AutoMigrate(&types.SetupKey{}) err := db.AutoMigrate(&types.SetupKey{}, &nbpeer.Peer{})
require.NoError(t, err, "Failed to auto-migrate tables") require.NoError(t, err, "Failed to auto-migrate tables")
err = db.Save(&types.SetupKey{ err = db.Save(&types.SetupKey{
Id: "1", Id: "1",
Key: "EEFDAB47-C1A5-4472-8C05-71DE9A1E8382", Key: "EEFDAB47-C1A5-4472-8C05-71DE9A1E8382",
UpdatedAt: time.Now(),
}).Error }).Error
require.NoError(t, err, "Failed to insert setup key") require.NoError(t, err, "Failed to insert setup key")
@@ -192,6 +246,7 @@ func TestMigrateSetupKeyToHashedSetupKey_ForAlreadyMigratedKey_Case1(t *testing.
Id: "1", Id: "1",
Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=",
KeySecret: "EEFDA****", KeySecret: "EEFDA****",
UpdatedAt: time.Now(),
}).Error }).Error
require.NoError(t, err, "Failed to insert setup key") require.NoError(t, err, "Failed to insert setup key")
@@ -213,8 +268,9 @@ func TestMigrateSetupKeyToHashedSetupKey_ForAlreadyMigratedKey_Case2(t *testing.
require.NoError(t, err, "Failed to auto-migrate tables") require.NoError(t, err, "Failed to auto-migrate tables")
err = db.Save(&types.SetupKey{ err = db.Save(&types.SetupKey{
Id: "1", Id: "1",
Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=",
UpdatedAt: time.Now(),
}).Error }).Error
require.NoError(t, err, "Failed to insert setup key") require.NoError(t, err, "Failed to insert setup key")
@@ -235,8 +291,9 @@ func TestDropIndex(t *testing.T) {
require.NoError(t, err, "Failed to auto-migrate tables") require.NoError(t, err, "Failed to auto-migrate tables")
err = db.Save(&types.SetupKey{ err = db.Save(&types.SetupKey{
Id: "1", Id: "1",
Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=",
UpdatedAt: time.Now(),
}).Error }).Error
require.NoError(t, err, "Failed to insert setup key") require.NoError(t, err, "Failed to insert setup key")
@@ -249,3 +306,37 @@ func TestDropIndex(t *testing.T) {
exist = db.Migrator().HasIndex(&types.SetupKey{}, "idx_setup_keys_account_id") exist = db.Migrator().HasIndex(&types.SetupKey{}, "idx_setup_keys_account_id")
assert.False(t, exist, "Should not have the index") assert.False(t, exist, "Should not have the index")
} }
func TestCreateIndex(t *testing.T) {
db := setupDatabase(t)
err := db.AutoMigrate(&nbpeer.Peer{})
assert.NoError(t, err, "Failed to auto-migrate tables")
indexName := "idx_account_ip"
err = migration.CreateIndexIfNotExists[nbpeer.Peer](context.Background(), db, indexName, "account_id", "ip")
assert.NoError(t, err, "Migration should not fail to create index")
exist := db.Migrator().HasIndex(&nbpeer.Peer{}, indexName)
assert.True(t, exist, "Should have the index")
}
func TestCreateIndexIfExists(t *testing.T) {
db := setupDatabase(t)
err := db.AutoMigrate(&nbpeer.Peer{})
assert.NoError(t, err, "Failed to auto-migrate tables")
indexName := "idx_account_ip"
err = migration.CreateIndexIfNotExists[nbpeer.Peer](context.Background(), db, indexName, "account_id", "ip")
assert.NoError(t, err, "Migration should not fail to create index")
exist := db.Migrator().HasIndex(&nbpeer.Peer{}, indexName)
assert.True(t, exist, "Should have the index")
err = migration.CreateIndexIfNotExists[nbpeer.Peer](context.Background(), db, indexName, "account_id", "ip")
assert.NoError(t, err, "Create index should not fail if index exists")
exist = db.Migrator().HasIndex(&nbpeer.Peer{}, indexName)
assert.True(t, exist, "Should have the index")
}

View File

@@ -785,7 +785,7 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
AnyTimes() AnyTimes()
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
} }
func createNSStore(t *testing.T) (store.Store, error) { func createNSStore(t *testing.T) (store.Store, error) {

View File

@@ -1273,7 +1273,7 @@ func Test_RegisterPeerByUser(t *testing.T) {
settingsMockManager := settings.NewMockManager(ctrl) settingsMockManager := settings.NewMockManager(ctrl)
permissionsManager := permissions.NewManager(s) permissionsManager := permissions.NewManager(s)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err) assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -1353,7 +1353,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
AnyTimes() AnyTimes()
permissionsManager := permissions.NewManager(s) permissionsManager := permissions.NewManager(s)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err) assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -1496,7 +1496,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
permissionsManager := permissions.NewManager(s) permissionsManager := permissions.NewManager(s)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err) assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -1570,7 +1570,7 @@ func Test_LoginPeer(t *testing.T) {
AnyTimes() AnyTimes()
permissionsManager := permissions.NewManager(s) permissionsManager := permissions.NewManager(s)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err) assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -1848,7 +1848,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
return update, true, nil return update, true, nil
} }
manager.integratedPeerValidator = MocIntegratedValidator{ValidatePeerFunc: requireUpdateFunc} manager.integratedPeerValidator = MockIntegratedValidator{ValidatePeerFunc: requireUpdateFunc}
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
peerShouldReceiveUpdate(t, updMsg) peerShouldReceiveUpdate(t, updMsg)
@@ -1870,7 +1870,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
return update, false, nil return update, false, nil
} }
manager.integratedPeerValidator = MocIntegratedValidator{ValidatePeerFunc: requireNoUpdateFunc} manager.integratedPeerValidator = MockIntegratedValidator{ValidatePeerFunc: requireNoUpdateFunc}
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
peerShouldNotReceiveUpdate(t, updMsg) peerShouldNotReceiveUpdate(t, updMsg)

View File

@@ -1284,7 +1284,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) {
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
} }
func createRouterStore(t *testing.T) (store.Store, error) { func createRouterStore(t *testing.T) (store.Store, error) {

View File

@@ -852,7 +852,7 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
eventStore: &activity.InMemoryEventStore{}, eventStore: &activity.InMemoryEventStore{},
integratedPeerValidator: MocIntegratedValidator{}, integratedPeerValidator: MockIntegratedValidator{},
permissionsManager: permissionsManager, permissionsManager: permissionsManager,
} }

View File

@@ -292,7 +292,7 @@ func (c *Client) Close() error {
} }
func (c *Client) connect(ctx context.Context) (*RelayAddr, error) { func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
rd := dialer.NewRaceDial(c.log, c.connectionURL, quic.Dialer{}, ws.Dialer{}) rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, quic.Dialer{}, ws.Dialer{})
conn, err := rd.Dial() conn, err := rd.Dial()
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -572,10 +572,14 @@ func TestCloseByServer(t *testing.T) {
idAlice := "alice" idAlice := "alice"
log.Debugf("connect by alice") log.Debugf("connect by alice")
relayClient := NewClient(serverURL, hmacTokenStore, idAlice) relayClient := NewClient(serverURL, hmacTokenStore, idAlice)
err = relayClient.Connect(ctx) if err = relayClient.Connect(ctx); err != nil {
if err != nil {
log.Fatalf("failed to connect to server: %s", err) log.Fatalf("failed to connect to server: %s", err)
} }
defer func() {
if err := relayClient.Close(); err != nil {
log.Errorf("failed to close client: %s", err)
}
}()
disconnected := make(chan struct{}) disconnected := make(chan struct{})
relayClient.SetOnDisconnectListener(func(_ string) { relayClient.SetOnDisconnectListener(func(_ string) {
@@ -591,7 +595,7 @@ func TestCloseByServer(t *testing.T) {
select { select {
case <-disconnected: case <-disconnected:
case <-time.After(3 * time.Second): case <-time.After(3 * time.Second):
log.Fatalf("timeout waiting for client to disconnect") log.Errorf("timeout waiting for client to disconnect")
} }
_, err = relayClient.OpenConn(ctx, "bob") _, err = relayClient.OpenConn(ctx, "bob")

View File

@@ -9,8 +9,8 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
var ( const (
connectionTimeout = 30 * time.Second DefaultConnectionTimeout = 30 * time.Second
) )
type DialeFn interface { type DialeFn interface {
@@ -25,16 +25,18 @@ type dialResult struct {
} }
type RaceDial struct { type RaceDial struct {
log *log.Entry log *log.Entry
serverURL string serverURL string
dialerFns []DialeFn dialerFns []DialeFn
connectionTimeout time.Duration
} }
func NewRaceDial(log *log.Entry, serverURL string, dialerFns ...DialeFn) *RaceDial { func NewRaceDial(log *log.Entry, connectionTimeout time.Duration, serverURL string, dialerFns ...DialeFn) *RaceDial {
return &RaceDial{ return &RaceDial{
log: log, log: log,
serverURL: serverURL, serverURL: serverURL,
dialerFns: dialerFns, dialerFns: dialerFns,
connectionTimeout: connectionTimeout,
} }
} }
@@ -58,7 +60,7 @@ func (r *RaceDial) Dial() (net.Conn, error) {
} }
func (r *RaceDial) dial(dfn DialeFn, abortCtx context.Context, connChan chan dialResult) { func (r *RaceDial) dial(dfn DialeFn, abortCtx context.Context, connChan chan dialResult) {
ctx, cancel := context.WithTimeout(abortCtx, connectionTimeout) ctx, cancel := context.WithTimeout(abortCtx, r.connectionTimeout)
defer cancel() defer cancel()
r.log.Infof("dialing Relay server via %s", dfn.Protocol()) r.log.Infof("dialing Relay server via %s", dfn.Protocol())

View File

@@ -77,7 +77,7 @@ func TestRaceDialEmptyDialers(t *testing.T) {
logger := logrus.NewEntry(logrus.New()) logger := logrus.NewEntry(logrus.New())
serverURL := "test.server.com" serverURL := "test.server.com"
rd := NewRaceDial(logger, serverURL) rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL)
conn, err := rd.Dial() conn, err := rd.Dial()
if err == nil { if err == nil {
t.Errorf("Expected an error with empty dialers, got nil") t.Errorf("Expected an error with empty dialers, got nil")
@@ -103,7 +103,7 @@ func TestRaceDialSingleSuccessfulDialer(t *testing.T) {
protocolStr: proto, protocolStr: proto,
} }
rd := NewRaceDial(logger, serverURL, mockDialer) rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer)
conn, err := rd.Dial() conn, err := rd.Dial()
if err != nil { if err != nil {
t.Errorf("Expected no error, got %v", err) t.Errorf("Expected no error, got %v", err)
@@ -136,7 +136,7 @@ func TestRaceDialMultipleDialersWithOneSuccess(t *testing.T) {
protocolStr: "proto2", protocolStr: "proto2",
} }
rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2) rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2)
conn, err := rd.Dial() conn, err := rd.Dial()
if err != nil { if err != nil {
t.Errorf("Expected no error, got %v", err) t.Errorf("Expected no error, got %v", err)
@@ -144,13 +144,13 @@ func TestRaceDialMultipleDialersWithOneSuccess(t *testing.T) {
if conn.RemoteAddr().Network() != proto2 { if conn.RemoteAddr().Network() != proto2 {
t.Errorf("Expected connection with protocol %s, got %s", proto2, conn.RemoteAddr().Network()) t.Errorf("Expected connection with protocol %s, got %s", proto2, conn.RemoteAddr().Network())
} }
_ = conn.Close()
} }
func TestRaceDialTimeout(t *testing.T) { func TestRaceDialTimeout(t *testing.T) {
logger := logrus.NewEntry(logrus.New()) logger := logrus.NewEntry(logrus.New())
serverURL := "test.server.com" serverURL := "test.server.com"
connectionTimeout = 3 * time.Second
mockDialer := &MockDialer{ mockDialer := &MockDialer{
dialFunc: func(ctx context.Context, address string) (net.Conn, error) { dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
<-ctx.Done() <-ctx.Done()
@@ -159,7 +159,7 @@ func TestRaceDialTimeout(t *testing.T) {
protocolStr: "proto1", protocolStr: "proto1",
} }
rd := NewRaceDial(logger, serverURL, mockDialer) rd := NewRaceDial(logger, 3*time.Second, serverURL, mockDialer)
conn, err := rd.Dial() conn, err := rd.Dial()
if err == nil { if err == nil {
t.Errorf("Expected an error, got nil") t.Errorf("Expected an error, got nil")
@@ -187,7 +187,7 @@ func TestRaceDialAllDialersFail(t *testing.T) {
protocolStr: "protocol2", protocolStr: "protocol2",
} }
rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2) rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2)
conn, err := rd.Dial() conn, err := rd.Dial()
if err == nil { if err == nil {
t.Errorf("Expected an error, got nil") t.Errorf("Expected an error, got nil")
@@ -229,7 +229,7 @@ func TestRaceDialFirstSuccessfulDialerWins(t *testing.T) {
protocolStr: proto2, protocolStr: proto2,
} }
rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2) rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2)
conn, err := rd.Dial() conn, err := rd.Dial()
if err != nil { if err != nil {
t.Errorf("Expected no error, got %v", err) t.Errorf("Expected no error, got %v", err)

View File

@@ -8,7 +8,8 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
var ( const (
// TODO: make it configurable, the manager should validate all configurable parameters
reconnectingTimeout = 60 * time.Second reconnectingTimeout = 60 * time.Second
) )

View File

@@ -39,17 +39,6 @@ func NewRelayTrack() *RelayTrack {
type OnServerCloseListener func() type OnServerCloseListener func()
// ManagerService is the interface for the relay manager.
type ManagerService interface {
Serve() error
OpenConn(ctx context.Context, serverAddress, peerKey string) (net.Conn, error)
AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error
RelayInstanceAddress() (string, error)
ServerURLs() []string
HasRelayAddress() bool
UpdateToken(token *relayAuth.Token) error
}
// Manager is a manager for the relay client instances. It establishes one persistent connection to the given relay URL // Manager is a manager for the relay client instances. It establishes one persistent connection to the given relay URL
// and automatically reconnect to them in case disconnection. // and automatically reconnect to them in case disconnection.
// The manager also manage temporary relay connection. If a client wants to communicate with a client on a // The manager also manage temporary relay connection. If a client wants to communicate with a client on a
@@ -65,7 +54,7 @@ type Manager struct {
relayClient *Client relayClient *Client
// the guard logic can overwrite the relayClient variable, this mutex protect the usage of the variable // the guard logic can overwrite the relayClient variable, this mutex protect the usage of the variable
relayClientMu sync.Mutex relayClientMu sync.RWMutex
reconnectGuard *Guard reconnectGuard *Guard
relayClients map[string]*RelayTrack relayClients map[string]*RelayTrack
@@ -124,8 +113,8 @@ func (m *Manager) Serve() error {
// established via the relay server. If the peer is on a different relay server, the manager will establish a new // established via the relay server. If the peer is on a different relay server, the manager will establish a new
// connection to the relay server. It returns back with a net.Conn what represent the remote peer connection. // connection to the relay server. It returns back with a net.Conn what represent the remote peer connection.
func (m *Manager) OpenConn(ctx context.Context, serverAddress, peerKey string) (net.Conn, error) { func (m *Manager) OpenConn(ctx context.Context, serverAddress, peerKey string) (net.Conn, error) {
m.relayClientMu.Lock() m.relayClientMu.RLock()
defer m.relayClientMu.Unlock() defer m.relayClientMu.RUnlock()
if m.relayClient == nil { if m.relayClient == nil {
return nil, ErrRelayClientNotConnected return nil, ErrRelayClientNotConnected
@@ -155,8 +144,8 @@ func (m *Manager) OpenConn(ctx context.Context, serverAddress, peerKey string) (
// Ready returns true if the home Relay client is connected to the relay server. // Ready returns true if the home Relay client is connected to the relay server.
func (m *Manager) Ready() bool { func (m *Manager) Ready() bool {
m.relayClientMu.Lock() m.relayClientMu.RLock()
defer m.relayClientMu.Unlock() defer m.relayClientMu.RUnlock()
if m.relayClient == nil { if m.relayClient == nil {
return false return false
@@ -174,8 +163,8 @@ func (m *Manager) SetOnReconnectedListener(f func()) {
// AddCloseListener adds a listener to the given server instance address. The listener will be called if the connection // AddCloseListener adds a listener to the given server instance address. The listener will be called if the connection
// closed. // closed.
func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error { func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error {
m.relayClientMu.Lock() m.relayClientMu.RLock()
defer m.relayClientMu.Unlock() defer m.relayClientMu.RUnlock()
if m.relayClient == nil { if m.relayClient == nil {
return ErrRelayClientNotConnected return ErrRelayClientNotConnected
@@ -199,8 +188,8 @@ func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServ
// RelayInstanceAddress returns the address of the permanent relay server. It could change if the network connection is // RelayInstanceAddress returns the address of the permanent relay server. It could change if the network connection is
// lost. This address will be sent to the target peer to choose the common relay server for the communication. // lost. This address will be sent to the target peer to choose the common relay server for the communication.
func (m *Manager) RelayInstanceAddress() (string, error) { func (m *Manager) RelayInstanceAddress() (string, error) {
m.relayClientMu.Lock() m.relayClientMu.RLock()
defer m.relayClientMu.Unlock() defer m.relayClientMu.RUnlock()
if m.relayClient == nil { if m.relayClient == nil {
return "", ErrRelayClientNotConnected return "", ErrRelayClientNotConnected
@@ -300,7 +289,9 @@ func (m *Manager) onServerConnected() {
func (m *Manager) onServerDisconnected(serverAddress string) { func (m *Manager) onServerDisconnected(serverAddress string) {
m.relayClientMu.Lock() m.relayClientMu.Lock()
if serverAddress == m.relayClient.connectionURL { if serverAddress == m.relayClient.connectionURL {
go m.reconnectGuard.StartReconnectTrys(m.ctx, m.relayClient) go func(client *Client) {
m.reconnectGuard.StartReconnectTrys(m.ctx, client)
}(m.relayClient)
} }
m.relayClientMu.Unlock() m.relayClientMu.Unlock()

View File

@@ -13,7 +13,9 @@ import (
) )
func TestEmptyURL(t *testing.T) { func TestEmptyURL(t *testing.T) {
mgr := NewManager(context.Background(), nil, "alice") ctx, cancel := context.WithCancel(context.Background())
defer cancel()
mgr := NewManager(ctx, nil, "alice")
err := mgr.Serve() err := mgr.Serve()
if err == nil { if err == nil {
t.Errorf("expected error, got nil") t.Errorf("expected error, got nil")
@@ -216,9 +218,11 @@ func TestForeginConnClose(t *testing.T) {
} }
} }
func TestForeginAutoClose(t *testing.T) { func TestForeignAutoClose(t *testing.T) {
ctx := context.Background() ctx := context.Background()
relayCleanupInterval = 1 * time.Second relayCleanupInterval = 1 * time.Second
keepUnusedServerTime = 2 * time.Second
srvCfg1 := server.ListenerConfig{ srvCfg1 := server.ListenerConfig{
Address: "localhost:1234", Address: "localhost:1234",
} }
@@ -284,16 +288,35 @@ func TestForeginAutoClose(t *testing.T) {
t.Fatalf("failed to serve manager: %s", err) t.Fatalf("failed to serve manager: %s", err)
} }
// Set up a disconnect listener to track when foreign server disconnects
foreignServerURL := toURL(srvCfg2)[0]
disconnected := make(chan struct{})
onDisconnect := func() {
select {
case disconnected <- struct{}{}:
default:
}
}
t.Log("open connection to another peer") t.Log("open connection to another peer")
if _, err = mgr.OpenConn(ctx, toURL(srvCfg2)[0], "anotherpeer"); err == nil { if _, err = mgr.OpenConn(ctx, foreignServerURL, "anotherpeer"); err == nil {
t.Fatalf("should have failed to open connection to another peer") t.Fatalf("should have failed to open connection to another peer")
} }
timeout := relayCleanupInterval + keepUnusedServerTime + 1*time.Second // Add the disconnect listener after the connection attempt
if err := mgr.AddCloseListener(foreignServerURL, onDisconnect); err != nil {
t.Logf("failed to add close listener (expected if connection failed): %s", err)
}
// Wait for cleanup to happen
timeout := relayCleanupInterval + keepUnusedServerTime + 2*time.Second
t.Logf("waiting for relay cleanup: %s", timeout) t.Logf("waiting for relay cleanup: %s", timeout)
time.Sleep(timeout)
if len(mgr.relayClients) != 0 { select {
t.Errorf("expected 0, got %d", len(mgr.relayClients)) case <-disconnected:
t.Log("foreign relay connection cleaned up successfully")
case <-time.After(timeout):
t.Log("timeout waiting for cleanup - this might be expected if connection never established")
} }
t.Logf("closing manager") t.Logf("closing manager")
@@ -301,7 +324,6 @@ func TestForeginAutoClose(t *testing.T) {
func TestAutoReconnect(t *testing.T) { func TestAutoReconnect(t *testing.T) {
ctx := context.Background() ctx := context.Background()
reconnectingTimeout = 2 * time.Second
srvCfg := server.ListenerConfig{ srvCfg := server.ListenerConfig{
Address: "localhost:1234", Address: "localhost:1234",
@@ -312,8 +334,7 @@ func TestAutoReconnect(t *testing.T) {
} }
errChan := make(chan error, 1) errChan := make(chan error, 1)
go func() { go func() {
err := srv.Listen(srvCfg) if err := srv.Listen(srvCfg); err != nil {
if err != nil {
errChan <- err errChan <- err
} }
}() }()

View File

@@ -4,38 +4,76 @@ import (
"context" "context"
"fmt" "fmt"
"os" "os"
"sync"
"testing" "testing"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
// Mutex to protect global variable access in tests
var testMutex sync.Mutex
func TestNewReceiver(t *testing.T) { func TestNewReceiver(t *testing.T) {
testMutex.Lock()
originalTimeout := heartbeatTimeout
heartbeatTimeout = 5 * time.Second heartbeatTimeout = 5 * time.Second
testMutex.Unlock()
defer func() {
testMutex.Lock()
heartbeatTimeout = originalTimeout
testMutex.Unlock()
}()
r := NewReceiver(log.WithContext(context.Background())) r := NewReceiver(log.WithContext(context.Background()))
defer r.Stop()
select { select {
case <-r.OnTimeout: case <-r.OnTimeout:
t.Error("unexpected timeout") t.Error("unexpected timeout")
case <-time.After(1 * time.Second): case <-time.After(1 * time.Second):
// Test passes if no timeout received
} }
} }
func TestNewReceiverNotReceive(t *testing.T) { func TestNewReceiverNotReceive(t *testing.T) {
testMutex.Lock()
originalTimeout := heartbeatTimeout
heartbeatTimeout = 1 * time.Second heartbeatTimeout = 1 * time.Second
testMutex.Unlock()
defer func() {
testMutex.Lock()
heartbeatTimeout = originalTimeout
testMutex.Unlock()
}()
r := NewReceiver(log.WithContext(context.Background())) r := NewReceiver(log.WithContext(context.Background()))
defer r.Stop()
select { select {
case <-r.OnTimeout: case <-r.OnTimeout:
// Test passes if timeout is received
case <-time.After(2 * time.Second): case <-time.After(2 * time.Second):
t.Error("timeout not received") t.Error("timeout not received")
} }
} }
func TestNewReceiverAck(t *testing.T) { func TestNewReceiverAck(t *testing.T) {
testMutex.Lock()
originalTimeout := heartbeatTimeout
heartbeatTimeout = 2 * time.Second heartbeatTimeout = 2 * time.Second
testMutex.Unlock()
defer func() {
testMutex.Lock()
heartbeatTimeout = originalTimeout
testMutex.Unlock()
}()
r := NewReceiver(log.WithContext(context.Background())) r := NewReceiver(log.WithContext(context.Background()))
defer r.Stop()
r.Heartbeat() r.Heartbeat()
@@ -59,13 +97,18 @@ func TestReceiverHealthCheckAttemptThreshold(t *testing.T) {
for _, tc := range testsCases { for _, tc := range testsCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
testMutex.Lock()
originalInterval := healthCheckInterval originalInterval := healthCheckInterval
originalTimeout := heartbeatTimeout originalTimeout := heartbeatTimeout
healthCheckInterval = 1 * time.Second healthCheckInterval = 1 * time.Second
heartbeatTimeout = healthCheckInterval + 500*time.Millisecond heartbeatTimeout = healthCheckInterval + 500*time.Millisecond
testMutex.Unlock()
defer func() { defer func() {
testMutex.Lock()
healthCheckInterval = originalInterval healthCheckInterval = originalInterval
heartbeatTimeout = originalTimeout heartbeatTimeout = originalTimeout
testMutex.Unlock()
}() }()
//nolint:tenv //nolint:tenv
os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold)) os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold))

View File

@@ -135,7 +135,11 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) {
defer cancel() defer cancel()
sender := NewSender(log.WithField("test_name", tc.name)) sender := NewSender(log.WithField("test_name", tc.name))
go sender.StartHealthCheck(ctx) senderExit := make(chan struct{})
go func() {
sender.StartHealthCheck(ctx)
close(senderExit)
}()
go func() { go func() {
responded := false responded := false
@@ -169,6 +173,11 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) {
t.Fatalf("should have timed out before %s", testTimeout) t.Fatalf("should have timed out before %s", testTimeout)
} }
select {
case <-senderExit:
case <-time.After(2 * time.Second):
t.Fatalf("sender did not exit in time")
}
}) })
} }

View File

@@ -20,12 +20,12 @@ type Metrics struct {
TransferBytesRecv metric.Int64Counter TransferBytesRecv metric.Int64Counter
AuthenticationTime metric.Float64Histogram AuthenticationTime metric.Float64Histogram
PeerStoreTime metric.Float64Histogram PeerStoreTime metric.Float64Histogram
peerReconnections metric.Int64Counter
peers metric.Int64UpDownCounter peers metric.Int64UpDownCounter
peerActivityChan chan string peerActivityChan chan string
peerLastActive map[string]time.Time peerLastActive map[string]time.Time
mutexActivity sync.Mutex mutexActivity sync.Mutex
ctx context.Context ctx context.Context
} }
func NewMetrics(ctx context.Context, meter metric.Meter) (*Metrics, error) { func NewMetrics(ctx context.Context, meter metric.Meter) (*Metrics, error) {
@@ -80,6 +80,13 @@ func NewMetrics(ctx context.Context, meter metric.Meter) (*Metrics, error) {
return nil, err return nil, err
} }
peerReconnections, err := meter.Int64Counter("relay_peer_reconnections_total",
metric.WithDescription("Total number of times peers have reconnected and closed old connections"),
)
if err != nil {
return nil, err
}
m := &Metrics{ m := &Metrics{
Meter: meter, Meter: meter,
TransferBytesSent: bytesSent, TransferBytesSent: bytesSent,
@@ -87,6 +94,7 @@ func NewMetrics(ctx context.Context, meter metric.Meter) (*Metrics, error) {
AuthenticationTime: authTime, AuthenticationTime: authTime,
PeerStoreTime: peerStoreTime, PeerStoreTime: peerStoreTime,
peers: peers, peers: peers,
peerReconnections: peerReconnections,
ctx: ctx, ctx: ctx,
peerActivityChan: make(chan string, 10), peerActivityChan: make(chan string, 10),
@@ -138,6 +146,10 @@ func (m *Metrics) PeerDisconnected(id string) {
delete(m.peerLastActive, id) delete(m.peerLastActive, id)
} }
func (m *Metrics) RecordPeerReconnection() {
m.peerReconnections.Add(m.ctx, 1)
}
// PeerActivity increases the active connections // PeerActivity increases the active connections
func (m *Metrics) PeerActivity(peerID string) { func (m *Metrics) PeerActivity(peerID string) {
select { select {

View File

@@ -18,12 +18,9 @@ type Listener struct {
TLSConfig *tls.Config TLSConfig *tls.Config
listener *quic.Listener listener *quic.Listener
acceptFn func(conn net.Conn)
} }
func (l *Listener) Listen(acceptFn func(conn net.Conn)) error { func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
l.acceptFn = acceptFn
quicCfg := &quic.Config{ quicCfg := &quic.Config{
EnableDatagrams: true, EnableDatagrams: true,
InitialPacketSize: 1452, InitialPacketSize: 1452,
@@ -49,7 +46,7 @@ func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
log.Infof("QUIC client connected from: %s", session.RemoteAddr()) log.Infof("QUIC client connected from: %s", session.RemoteAddr())
conn := NewConn(session) conn := NewConn(session)
l.acceptFn(conn) acceptFn(conn)
} }
} }

View File

@@ -32,6 +32,9 @@ type Peer struct {
notifier *store.PeerNotifier notifier *store.PeerNotifier
peersListener *store.Listener peersListener *store.Listener
// between the online peer collection step and the notification sending should not be sent offline notifications from another thread
notificationMutex sync.Mutex
} }
// NewPeer creates a new Peer instance and prepare custom logging // NewPeer creates a new Peer instance and prepare custom logging
@@ -241,10 +244,16 @@ func (p *Peer) handleSubscribePeerState(msg []byte) {
} }
p.log.Debugf("received subscription message for %d peers", len(peerIDs)) p.log.Debugf("received subscription message for %d peers", len(peerIDs))
onlinePeers := p.peersListener.AddInterestedPeers(peerIDs)
// collect online peers to response back to the caller
p.notificationMutex.Lock()
defer p.notificationMutex.Unlock()
onlinePeers := p.store.GetOnlinePeersAndRegisterInterest(peerIDs, p.peersListener)
if len(onlinePeers) == 0 { if len(onlinePeers) == 0 {
return return
} }
p.log.Debugf("response with %d online peers", len(onlinePeers)) p.log.Debugf("response with %d online peers", len(onlinePeers))
p.sendPeersOnline(onlinePeers) p.sendPeersOnline(onlinePeers)
} }
@@ -274,6 +283,9 @@ func (p *Peer) sendPeersOnline(peers []messages.PeerID) {
} }
func (p *Peer) sendPeersWentOffline(peers []messages.PeerID) { func (p *Peer) sendPeersWentOffline(peers []messages.PeerID) {
p.notificationMutex.Lock()
defer p.notificationMutex.Unlock()
msgs, err := messages.MarshalPeersWentOffline(peers) msgs, err := messages.MarshalPeersWentOffline(peers)
if err != nil { if err != nil {
p.log.Errorf("failed to marshal peer location message: %s", err) p.log.Errorf("failed to marshal peer location message: %s", err)

View File

@@ -86,14 +86,13 @@ func NewRelay(config Config) (*Relay, error) {
return nil, fmt.Errorf("creating app metrics: %v", err) return nil, fmt.Errorf("creating app metrics: %v", err)
} }
peerStore := store.NewStore()
r := &Relay{ r := &Relay{
metrics: m, metrics: m,
metricsCancel: metricsCancel, metricsCancel: metricsCancel,
validator: config.AuthValidator, validator: config.AuthValidator,
instanceURL: config.instanceURL, instanceURL: config.instanceURL,
store: peerStore, store: store.NewStore(),
notifier: store.NewPeerNotifier(peerStore), notifier: store.NewPeerNotifier(),
} }
r.preparedMsg, err = newPreparedMsg(r.instanceURL) r.preparedMsg, err = newPreparedMsg(r.instanceURL)
@@ -131,15 +130,18 @@ func (r *Relay) Accept(conn net.Conn) {
peer := NewPeer(r.metrics, *peerID, conn, r.store, r.notifier) peer := NewPeer(r.metrics, *peerID, conn, r.store, r.notifier)
peer.log.Infof("peer connected from: %s", conn.RemoteAddr()) peer.log.Infof("peer connected from: %s", conn.RemoteAddr())
storeTime := time.Now() storeTime := time.Now()
r.store.AddPeer(peer) if isReconnection := r.store.AddPeer(peer); isReconnection {
r.metrics.RecordPeerReconnection()
}
r.notifier.PeerCameOnline(peer.ID()) r.notifier.PeerCameOnline(peer.ID())
r.metrics.RecordPeerStoreTime(time.Since(storeTime)) r.metrics.RecordPeerStoreTime(time.Since(storeTime))
r.metrics.PeerConnected(peer.String()) r.metrics.PeerConnected(peer.String())
go func() { go func() {
peer.Work() peer.Work()
r.notifier.PeerWentOffline(peer.ID()) if deleted := r.store.DeletePeer(peer); deleted {
r.store.DeletePeer(peer) r.notifier.PeerWentOffline(peer.ID())
}
peer.log.Debugf("relay connection closed") peer.log.Debugf("relay connection closed")
r.metrics.PeerDisconnected(peer.String()) r.metrics.PeerDisconnected(peer.String())
}() }()

View File

@@ -7,24 +7,27 @@ import (
"github.com/netbirdio/netbird/relay/messages" "github.com/netbirdio/netbird/relay/messages"
) )
type Listener struct { type event struct {
ctx context.Context peerID messages.PeerID
store *Store online bool
}
onlineChan chan messages.PeerID type Listener struct {
offlineChan chan messages.PeerID ctx context.Context
eventChan chan *event
interestedPeersForOffline map[messages.PeerID]struct{} interestedPeersForOffline map[messages.PeerID]struct{}
interestedPeersForOnline map[messages.PeerID]struct{} interestedPeersForOnline map[messages.PeerID]struct{}
mu sync.RWMutex mu sync.RWMutex
} }
func newListener(ctx context.Context, store *Store) *Listener { func newListener(ctx context.Context) *Listener {
l := &Listener{ l := &Listener{
ctx: ctx, ctx: ctx,
store: store,
onlineChan: make(chan messages.PeerID, 244), //244 is the message size limit in the relay protocol // important to use a single channel for offline and online events because with it we can ensure all events
offlineChan: make(chan messages.PeerID, 244), //244 is the message size limit in the relay protocol // will be processed in the order they were sent
eventChan: make(chan *event, 244), //244 is the message size limit in the relay protocol
interestedPeersForOffline: make(map[messages.PeerID]struct{}), interestedPeersForOffline: make(map[messages.PeerID]struct{}),
interestedPeersForOnline: make(map[messages.PeerID]struct{}), interestedPeersForOnline: make(map[messages.PeerID]struct{}),
} }
@@ -32,8 +35,7 @@ func newListener(ctx context.Context, store *Store) *Listener {
return l return l
} }
func (l *Listener) AddInterestedPeers(peerIDs []messages.PeerID) []messages.PeerID { func (l *Listener) AddInterestedPeers(peerIDs []messages.PeerID) {
availablePeers := make([]messages.PeerID, 0)
l.mu.Lock() l.mu.Lock()
defer l.mu.Unlock() defer l.mu.Unlock()
@@ -41,17 +43,6 @@ func (l *Listener) AddInterestedPeers(peerIDs []messages.PeerID) []messages.Peer
l.interestedPeersForOnline[id] = struct{}{} l.interestedPeersForOnline[id] = struct{}{}
l.interestedPeersForOffline[id] = struct{}{} l.interestedPeersForOffline[id] = struct{}{}
} }
// collect online peers to response back to the caller
for _, id := range peerIDs {
_, ok := l.store.Peer(id)
if !ok {
continue
}
availablePeers = append(availablePeers, id)
}
return availablePeers
} }
func (l *Listener) RemoveInterestedPeer(peerIDs []messages.PeerID) { func (l *Listener) RemoveInterestedPeer(peerIDs []messages.PeerID) {
@@ -61,7 +52,6 @@ func (l *Listener) RemoveInterestedPeer(peerIDs []messages.PeerID) {
for _, id := range peerIDs { for _, id := range peerIDs {
delete(l.interestedPeersForOffline, id) delete(l.interestedPeersForOffline, id)
delete(l.interestedPeersForOnline, id) delete(l.interestedPeersForOnline, id)
} }
} }
@@ -70,26 +60,31 @@ func (l *Listener) listenForEvents(onPeersComeOnline, onPeersWentOffline func([]
select { select {
case <-l.ctx.Done(): case <-l.ctx.Done():
return return
case pID := <-l.onlineChan: case e := <-l.eventChan:
peers := make([]messages.PeerID, 0) peersOffline := make([]messages.PeerID, 0)
peers = append(peers, pID) peersOnline := make([]messages.PeerID, 0)
if e.online {
for len(l.onlineChan) > 0 { peersOnline = append(peersOnline, e.peerID)
pID = <-l.onlineChan } else {
peers = append(peers, pID) peersOffline = append(peersOffline, e.peerID)
} }
onPeersComeOnline(peers) // Drain the channel to collect all events
case pID := <-l.offlineChan: for len(l.eventChan) > 0 {
peers := make([]messages.PeerID, 0) e = <-l.eventChan
peers = append(peers, pID) if e.online {
peersOnline = append(peersOnline, e.peerID)
for len(l.offlineChan) > 0 { } else {
pID = <-l.offlineChan peersOffline = append(peersOffline, e.peerID)
peers = append(peers, pID) }
} }
onPeersWentOffline(peers) if len(peersOnline) > 0 {
onPeersComeOnline(peersOnline)
}
if len(peersOffline) > 0 {
onPeersWentOffline(peersOffline)
}
} }
} }
} }
@@ -100,7 +95,10 @@ func (l *Listener) peerWentOffline(peerID messages.PeerID) {
if _, ok := l.interestedPeersForOffline[peerID]; ok { if _, ok := l.interestedPeersForOffline[peerID]; ok {
select { select {
case l.offlineChan <- peerID: case l.eventChan <- &event{
peerID: peerID,
online: false,
}:
case <-l.ctx.Done(): case <-l.ctx.Done():
} }
} }
@@ -112,9 +110,13 @@ func (l *Listener) peerComeOnline(peerID messages.PeerID) {
if _, ok := l.interestedPeersForOnline[peerID]; ok { if _, ok := l.interestedPeersForOnline[peerID]; ok {
select { select {
case l.onlineChan <- peerID: case l.eventChan <- &event{
peerID: peerID,
online: true,
}:
case <-l.ctx.Done(): case <-l.ctx.Done():
} }
delete(l.interestedPeersForOnline, peerID) delete(l.interestedPeersForOnline, peerID)
} }
} }

View File

@@ -8,15 +8,12 @@ import (
) )
type PeerNotifier struct { type PeerNotifier struct {
store *Store
listeners map[*Listener]context.CancelFunc listeners map[*Listener]context.CancelFunc
listenersMutex sync.RWMutex listenersMutex sync.RWMutex
} }
func NewPeerNotifier(store *Store) *PeerNotifier { func NewPeerNotifier() *PeerNotifier {
pn := &PeerNotifier{ pn := &PeerNotifier{
store: store,
listeners: make(map[*Listener]context.CancelFunc), listeners: make(map[*Listener]context.CancelFunc),
} }
return pn return pn
@@ -24,7 +21,7 @@ func NewPeerNotifier(store *Store) *PeerNotifier {
func (pn *PeerNotifier) NewListener(onPeersComeOnline, onPeersWentOffline func([]messages.PeerID)) *Listener { func (pn *PeerNotifier) NewListener(onPeersComeOnline, onPeersWentOffline func([]messages.PeerID)) *Listener {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
listener := newListener(ctx, pn.store) listener := newListener(ctx)
go listener.listenForEvents(onPeersComeOnline, onPeersWentOffline) go listener.listenForEvents(onPeersComeOnline, onPeersWentOffline)
pn.listenersMutex.Lock() pn.listenersMutex.Lock()

View File

@@ -26,7 +26,9 @@ func NewStore() *Store {
} }
// AddPeer adds a peer to the store // AddPeer adds a peer to the store
func (s *Store) AddPeer(peer IPeer) { // If the peer already exists, it will be replaced and the old peer will be closed
// Returns true if the peer was replaced, false if it was added for the first time.
func (s *Store) AddPeer(peer IPeer) bool {
s.peersLock.Lock() s.peersLock.Lock()
defer s.peersLock.Unlock() defer s.peersLock.Unlock()
odlPeer, ok := s.peers[peer.ID()] odlPeer, ok := s.peers[peer.ID()]
@@ -35,22 +37,24 @@ func (s *Store) AddPeer(peer IPeer) {
} }
s.peers[peer.ID()] = peer s.peers[peer.ID()] = peer
return ok
} }
// DeletePeer deletes a peer from the store // DeletePeer deletes a peer from the store
func (s *Store) DeletePeer(peer IPeer) { func (s *Store) DeletePeer(peer IPeer) bool {
s.peersLock.Lock() s.peersLock.Lock()
defer s.peersLock.Unlock() defer s.peersLock.Unlock()
dp, ok := s.peers[peer.ID()] dp, ok := s.peers[peer.ID()]
if !ok { if !ok {
return return false
} }
if dp != peer { if dp != peer {
return return false
} }
delete(s.peers, peer.ID()) delete(s.peers, peer.ID())
return true
} }
// Peer returns a peer by its ID // Peer returns a peer by its ID
@@ -73,3 +77,21 @@ func (s *Store) Peers() []IPeer {
} }
return peers return peers
} }
func (s *Store) GetOnlinePeersAndRegisterInterest(peerIDs []messages.PeerID, listener *Listener) []messages.PeerID {
s.peersLock.RLock()
defer s.peersLock.RUnlock()
onlinePeers := make([]messages.PeerID, 0, len(peerIDs))
listener.AddInterestedPeers(peerIDs)
// Check for currently online peers
for _, id := range peerIDs {
if _, ok := s.peers[id]; ok {
onlinePeers = append(onlinePeers, id)
}
}
return onlinePeers
}

View File

@@ -6,6 +6,7 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"net/netip"
"sync" "sync"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -17,11 +18,16 @@ type ListenerWriteHookFunc func(connID ConnectionID, ip *net.IPAddr, data []byte
// ListenerCloseHookFunc defines the function signature for close hooks for PacketConn. // ListenerCloseHookFunc defines the function signature for close hooks for PacketConn.
type ListenerCloseHookFunc func(connID ConnectionID, conn net.PacketConn) error type ListenerCloseHookFunc func(connID ConnectionID, conn net.PacketConn) error
// ListenerAddressRemoveHookFunc defines the function signature for hooks called when addresses are removed.
type ListenerAddressRemoveHookFunc func(connID ConnectionID, prefix netip.Prefix) error
var ( var (
listenerWriteHooksMutex sync.RWMutex listenerWriteHooksMutex sync.RWMutex
listenerWriteHooks []ListenerWriteHookFunc listenerWriteHooks []ListenerWriteHookFunc
listenerCloseHooksMutex sync.RWMutex listenerCloseHooksMutex sync.RWMutex
listenerCloseHooks []ListenerCloseHookFunc listenerCloseHooks []ListenerCloseHookFunc
listenerAddressRemoveHooksMutex sync.RWMutex
listenerAddressRemoveHooks []ListenerAddressRemoveHookFunc
) )
// AddListenerWriteHook allows adding a new write hook to be executed before a UDP packet is sent. // AddListenerWriteHook allows adding a new write hook to be executed before a UDP packet is sent.
@@ -38,7 +44,14 @@ func AddListenerCloseHook(hook ListenerCloseHookFunc) {
listenerCloseHooks = append(listenerCloseHooks, hook) listenerCloseHooks = append(listenerCloseHooks, hook)
} }
// RemoveListenerHooks removes all dialer hooks. // AddListenerAddressRemoveHook allows adding a new hook to be executed when an address is removed.
func AddListenerAddressRemoveHook(hook ListenerAddressRemoveHookFunc) {
listenerAddressRemoveHooksMutex.Lock()
defer listenerAddressRemoveHooksMutex.Unlock()
listenerAddressRemoveHooks = append(listenerAddressRemoveHooks, hook)
}
// RemoveListenerHooks removes all listener hooks.
func RemoveListenerHooks() { func RemoveListenerHooks() {
listenerWriteHooksMutex.Lock() listenerWriteHooksMutex.Lock()
defer listenerWriteHooksMutex.Unlock() defer listenerWriteHooksMutex.Unlock()
@@ -47,6 +60,10 @@ func RemoveListenerHooks() {
listenerCloseHooksMutex.Lock() listenerCloseHooksMutex.Lock()
defer listenerCloseHooksMutex.Unlock() defer listenerCloseHooksMutex.Unlock()
listenerCloseHooks = nil listenerCloseHooks = nil
listenerAddressRemoveHooksMutex.Lock()
defer listenerAddressRemoveHooksMutex.Unlock()
listenerAddressRemoveHooks = nil
} }
// ListenPacket listens on the network address and returns a PacketConn // ListenPacket listens on the network address and returns a PacketConn
@@ -61,6 +78,7 @@ func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address stri
return nil, fmt.Errorf("listen packet: %w", err) return nil, fmt.Errorf("listen packet: %w", err)
} }
connID := GenerateConnID() connID := GenerateConnID()
return &PacketConn{PacketConn: pc, ID: connID, seenAddrs: &sync.Map{}}, nil return &PacketConn{PacketConn: pc, ID: connID, seenAddrs: &sync.Map{}}, nil
} }
@@ -102,6 +120,45 @@ func (c *UDPConn) Close() error {
return closeConn(c.ID, c.UDPConn) return closeConn(c.ID, c.UDPConn)
} }
// WrapUDPConn wraps an existing *net.UDPConn with nbnet functionality
func WrapUDPConn(conn *net.UDPConn) *UDPConn {
return &UDPConn{
UDPConn: conn,
ID: GenerateConnID(),
seenAddrs: &sync.Map{},
}
}
// RemoveAddress removes an address from the seen cache and triggers removal hooks.
func (c *UDPConn) RemoveAddress(addr string) {
if _, exists := c.seenAddrs.LoadAndDelete(addr); !exists {
return
}
ipStr, _, err := net.SplitHostPort(addr)
if err != nil {
log.Errorf("Error splitting IP address and port: %v", err)
return
}
ipAddr, err := netip.ParseAddr(ipStr)
if err != nil {
log.Errorf("Error parsing IP address %s: %v", ipStr, err)
return
}
prefix := netip.PrefixFrom(ipAddr, ipAddr.BitLen())
listenerAddressRemoveHooksMutex.RLock()
defer listenerAddressRemoveHooksMutex.RUnlock()
for _, hook := range listenerAddressRemoveHooks {
if err := hook(c.ID, prefix); err != nil {
log.Errorf("Error executing listener address remove hook: %v", err)
}
}
}
func callWriteHooks(id ConnectionID, seenAddrs *sync.Map, b []byte, addr net.Addr) { func callWriteHooks(id ConnectionID, seenAddrs *sync.Map, b []byte, addr net.Addr) {
// Lookup the address in the seenAddrs map to avoid calling the hooks for every write // Lookup the address in the seenAddrs map to avoid calling the hooks for every write
if _, loaded := seenAddrs.LoadOrStore(addr.String(), true); !loaded { if _, loaded := seenAddrs.LoadOrStore(addr.String(), true); !loaded {

View File

@@ -0,0 +1,10 @@
package net
import (
"net"
)
// WrapUDPConn on iOS just returns the original connection since iOS handles its own networking
func WrapUDPConn(conn *net.UDPConn) *net.UDPConn {
return conn
}