mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-03 14:39:55 +00:00
Compare commits
32 Commits
trigger-pr
...
feature/cl
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7e276a40d9 | ||
|
|
3753bf7fc4 | ||
|
|
bec58b85b1 | ||
|
|
ca3e6d93d3 | ||
|
|
80abddb78a | ||
|
|
6981fdce7e | ||
|
|
08403f64aa | ||
|
|
391221a986 | ||
|
|
7bc85107eb | ||
|
|
3be16d19a0 | ||
|
|
af8f730bda | ||
|
|
c3f176f348 | ||
|
|
0119f3e9f4 | ||
|
|
1b96648d4d | ||
|
|
d2f9653cea | ||
|
|
194a986926 | ||
|
|
f7732557fa | ||
|
|
d488f58311 | ||
|
|
6fdc00ff41 | ||
|
|
b20d484972 | ||
|
|
8931293343 | ||
|
|
7b830d8f72 | ||
|
|
3a0cf230a1 | ||
|
|
138e728427 | ||
|
|
e7283a8198 | ||
|
|
08295e5116 | ||
|
|
cbfde79dd8 | ||
|
|
5169129029 | ||
|
|
5f02a48737 | ||
|
|
bb377a2885 | ||
|
|
e3a5c44d37 | ||
|
|
c5eb5ba1c6 |
@@ -60,8 +60,8 @@
|
||||
|
||||
https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2
|
||||
|
||||
### NetBird on Lawrence Systems (Video)
|
||||
[](https://www.youtube.com/watch?v=Kwrff6h0rEw)
|
||||
### Self-Host NetBird (Video)
|
||||
[](https://youtu.be/bZAgpT6nzaQ)
|
||||
|
||||
### Key features
|
||||
|
||||
|
||||
@@ -71,6 +71,8 @@ type Options struct {
|
||||
DisableClientRoutes bool
|
||||
// BlockInbound blocks all inbound connections from peers
|
||||
BlockInbound bool
|
||||
// WireguardPort is the port for the WireGuard interface. Use 0 for a random port.
|
||||
WireguardPort *int
|
||||
}
|
||||
|
||||
// validateCredentials checks that exactly one credential type is provided
|
||||
@@ -140,6 +142,7 @@ func New(opts Options) (*Client, error) {
|
||||
DisableServerRoutes: &t,
|
||||
DisableClientRoutes: &opts.DisableClientRoutes,
|
||||
BlockInbound: &opts.BlockInbound,
|
||||
WireguardPort: opts.WireguardPort,
|
||||
}
|
||||
if opts.ConfigPath != "" {
|
||||
config, err = profilemanager.UpdateOrCreateConfig(input)
|
||||
|
||||
@@ -3,12 +3,6 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
@@ -17,33 +11,7 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
m.outgoingRules = make(map[netip.Addr]RuleSet)
|
||||
m.incomingDenyRules = make(map[netip.Addr]RuleSet)
|
||||
m.incomingRules = make(map[netip.Addr]RuleSet)
|
||||
|
||||
if m.udpTracker != nil {
|
||||
m.udpTracker.Close()
|
||||
}
|
||||
|
||||
if m.icmpTracker != nil {
|
||||
m.icmpTracker.Close()
|
||||
}
|
||||
|
||||
if m.tcpTracker != nil {
|
||||
m.tcpTracker.Close()
|
||||
}
|
||||
|
||||
if fwder := m.forwarder.Load(); fwder != nil {
|
||||
fwder.Stop()
|
||||
}
|
||||
|
||||
if m.logger != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
if err := m.logger.Stop(ctx); err != nil {
|
||||
log.Errorf("failed to shutdown logger: %v", err)
|
||||
}
|
||||
}
|
||||
m.resetState()
|
||||
|
||||
if m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.Close(stateManager)
|
||||
|
||||
@@ -1,12 +1,9 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@@ -26,33 +23,7 @@ func (m *Manager) Close(*statemanager.Manager) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
m.outgoingRules = make(map[netip.Addr]RuleSet)
|
||||
m.incomingDenyRules = make(map[netip.Addr]RuleSet)
|
||||
m.incomingRules = make(map[netip.Addr]RuleSet)
|
||||
|
||||
if m.udpTracker != nil {
|
||||
m.udpTracker.Close()
|
||||
}
|
||||
|
||||
if m.icmpTracker != nil {
|
||||
m.icmpTracker.Close()
|
||||
}
|
||||
|
||||
if m.tcpTracker != nil {
|
||||
m.tcpTracker.Close()
|
||||
}
|
||||
|
||||
if fwder := m.forwarder.Load(); fwder != nil {
|
||||
fwder.Stop()
|
||||
}
|
||||
|
||||
if m.logger != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
if err := m.logger.Stop(ctx); err != nil {
|
||||
log.Errorf("failed to shutdown logger: %v", err)
|
||||
}
|
||||
}
|
||||
m.resetState()
|
||||
|
||||
if !isWindowsFirewallReachable() {
|
||||
return nil
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -12,11 +13,13 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||
@@ -24,6 +27,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
@@ -89,6 +93,7 @@ type Manager struct {
|
||||
incomingDenyRules map[netip.Addr]RuleSet
|
||||
incomingRules map[netip.Addr]RuleSet
|
||||
routeRules RouteRules
|
||||
routeRulesMap map[nbid.RuleID]*RouteRule
|
||||
decoders sync.Pool
|
||||
wgIface common.IFaceMapper
|
||||
nativeFirewall firewall.Manager
|
||||
@@ -229,6 +234,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
||||
flowLogger: flowLogger,
|
||||
netstack: netstack.IsEnabled(),
|
||||
localForwarding: enableLocalForwarding,
|
||||
routeRulesMap: make(map[nbid.RuleID]*RouteRule),
|
||||
dnatMappings: make(map[netip.Addr]netip.Addr),
|
||||
portDNATRules: []portDNATRule{},
|
||||
netstackServices: make(map[serviceKey]struct{}),
|
||||
@@ -480,11 +486,15 @@ func (m *Manager) addRouteFiltering(
|
||||
return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||
}
|
||||
|
||||
ruleID := uuid.New().String()
|
||||
ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
||||
|
||||
if existingRule, ok := m.routeRulesMap[ruleKey]; ok {
|
||||
return existingRule, nil
|
||||
}
|
||||
|
||||
rule := RouteRule{
|
||||
// TODO: consolidate these IDs
|
||||
id: ruleID,
|
||||
id: string(ruleKey),
|
||||
mgmtId: id,
|
||||
sources: sources,
|
||||
dstSet: destination.Set,
|
||||
@@ -499,6 +509,7 @@ func (m *Manager) addRouteFiltering(
|
||||
|
||||
m.routeRules = append(m.routeRules, &rule)
|
||||
m.routeRules.Sort()
|
||||
m.routeRulesMap[ruleKey] = &rule
|
||||
|
||||
return &rule, nil
|
||||
}
|
||||
@@ -515,15 +526,20 @@ func (m *Manager) deleteRouteRule(rule firewall.Rule) error {
|
||||
return m.nativeFirewall.DeleteRouteRule(rule)
|
||||
}
|
||||
|
||||
ruleID := rule.ID()
|
||||
ruleKey := nbid.RuleID(rule.ID())
|
||||
if _, ok := m.routeRulesMap[ruleKey]; !ok {
|
||||
return fmt.Errorf("route rule not found: %s", ruleKey)
|
||||
}
|
||||
|
||||
idx := slices.IndexFunc(m.routeRules, func(r *RouteRule) bool {
|
||||
return r.id == ruleID
|
||||
return r.id == string(ruleKey)
|
||||
})
|
||||
if idx < 0 {
|
||||
return fmt.Errorf("route rule not found: %s", ruleID)
|
||||
return fmt.Errorf("route rule not found in slice: %s", ruleKey)
|
||||
}
|
||||
|
||||
m.routeRules = slices.Delete(m.routeRules, idx, idx+1)
|
||||
delete(m.routeRulesMap, ruleKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -570,6 +586,40 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||
// Flush doesn't need to be implemented for this manager
|
||||
func (m *Manager) Flush() error { return nil }
|
||||
|
||||
// resetState clears all firewall rules and closes connection trackers.
|
||||
// Must be called with m.mutex held.
|
||||
func (m *Manager) resetState() {
|
||||
maps.Clear(m.outgoingRules)
|
||||
maps.Clear(m.incomingDenyRules)
|
||||
maps.Clear(m.incomingRules)
|
||||
maps.Clear(m.routeRulesMap)
|
||||
m.routeRules = m.routeRules[:0]
|
||||
|
||||
if m.udpTracker != nil {
|
||||
m.udpTracker.Close()
|
||||
}
|
||||
|
||||
if m.icmpTracker != nil {
|
||||
m.icmpTracker.Close()
|
||||
}
|
||||
|
||||
if m.tcpTracker != nil {
|
||||
m.tcpTracker.Close()
|
||||
}
|
||||
|
||||
if fwder := m.forwarder.Load(); fwder != nil {
|
||||
fwder.Stop()
|
||||
}
|
||||
|
||||
if m.logger != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
if err := m.logger.Stop(ctx); err != nil {
|
||||
log.Errorf("failed to shutdown logger: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
|
||||
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
|
||||
if m.nativeFirewall == nil {
|
||||
|
||||
376
client/firewall/uspfilter/filter_routeacl_test.go
Normal file
376
client/firewall/uspfilter/filter_routeacl_test.go
Normal file
@@ -0,0 +1,376 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/mocks"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
// TestAddRouteFilteringReturnsExistingRule verifies that adding the same route
|
||||
// filtering rule twice returns the same rule ID (idempotent behavior).
|
||||
func TestAddRouteFilteringReturnsExistingRule(t *testing.T) {
|
||||
manager := setupTestManager(t)
|
||||
|
||||
sources := []netip.Prefix{
|
||||
netip.MustParsePrefix("100.64.1.0/24"),
|
||||
netip.MustParsePrefix("100.64.2.0/24"),
|
||||
}
|
||||
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
||||
|
||||
// Add rule first time
|
||||
rule1, err := manager.AddRouteFiltering(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
destination,
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []uint16{443}},
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule1)
|
||||
|
||||
// Add the same rule again
|
||||
rule2, err := manager.AddRouteFiltering(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
destination,
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []uint16{443}},
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule2)
|
||||
|
||||
// These should be the same (idempotent) like nftables/iptables implementations
|
||||
assert.Equal(t, rule1.ID(), rule2.ID(),
|
||||
"Adding the same rule twice should return the same rule ID (idempotent)")
|
||||
|
||||
manager.mutex.RLock()
|
||||
ruleCount := len(manager.routeRules)
|
||||
manager.mutex.RUnlock()
|
||||
|
||||
assert.Equal(t, 2, ruleCount,
|
||||
"Should have exactly 2 rules (1 user rule + 1 block rule)")
|
||||
}
|
||||
|
||||
// TestAddRouteFilteringDifferentRulesGetDifferentIDs verifies that rules with
|
||||
// different parameters get distinct IDs.
|
||||
func TestAddRouteFilteringDifferentRulesGetDifferentIDs(t *testing.T) {
|
||||
manager := setupTestManager(t)
|
||||
|
||||
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
|
||||
|
||||
// Add first rule
|
||||
rule1, err := manager.AddRouteFiltering(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []uint16{443}},
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add different rule (different destination)
|
||||
rule2, err := manager.AddRouteFiltering(
|
||||
[]byte("policy-2"),
|
||||
sources,
|
||||
fw.Network{Prefix: netip.MustParsePrefix("192.168.2.0/24")}, // Different!
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []uint16{443}},
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEqual(t, rule1.ID(), rule2.ID(),
|
||||
"Different rules should have different IDs")
|
||||
|
||||
manager.mutex.RLock()
|
||||
ruleCount := len(manager.routeRules)
|
||||
manager.mutex.RUnlock()
|
||||
|
||||
assert.Equal(t, 3, ruleCount, "Should have 3 rules (2 user rules + 1 block rule)")
|
||||
}
|
||||
|
||||
// TestRouteRuleUpdateDoesNotCauseGap verifies that re-adding the same route
|
||||
// rule during a network map update does not disrupt existing traffic.
|
||||
func TestRouteRuleUpdateDoesNotCauseGap(t *testing.T) {
|
||||
manager := setupTestManager(t)
|
||||
|
||||
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
|
||||
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
||||
|
||||
rule1, err := manager.AddRouteFiltering(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
destination,
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
nil,
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.1.5")
|
||||
dstIP := netip.MustParseAddr("192.168.1.10")
|
||||
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443)
|
||||
require.True(t, pass, "Traffic should pass with rule in place")
|
||||
|
||||
// Re-add same rule (simulates network map update)
|
||||
rule2, err := manager.AddRouteFiltering(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
destination,
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
nil,
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Idempotent IDs mean rule1.ID() == rule2.ID(), so the ACL manager
|
||||
// won't delete rule1 during cleanup. If IDs differed, deleting rule1
|
||||
// would remove the only matching rule and cause a traffic gap.
|
||||
if rule1.ID() != rule2.ID() {
|
||||
err = manager.DeleteRouteRule(rule1)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
_, passAfter := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443)
|
||||
assert.True(t, passAfter,
|
||||
"Traffic should still pass after rule update - no gap should occur")
|
||||
}
|
||||
|
||||
// TestBlockInvalidRoutedIdempotent verifies that blockInvalidRouted creates
|
||||
// exactly one drop rule for the WireGuard network prefix, and calling it again
|
||||
// returns the same rule without duplicating.
|
||||
func TestBlockInvalidRoutedIdempotent(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
dev := mocks.NewMockDevice(ctrl)
|
||||
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
|
||||
|
||||
wgNet := netip.MustParsePrefix("100.64.0.1/16")
|
||||
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: wgNet.Addr(),
|
||||
Network: wgNet,
|
||||
}
|
||||
},
|
||||
GetDeviceFunc: func() *device.FilteredDevice {
|
||||
return &device.FilteredDevice{Device: dev}
|
||||
},
|
||||
GetWGDeviceFunc: func() *wgdevice.Device {
|
||||
return &wgdevice.Device{}
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
})
|
||||
|
||||
// Call blockInvalidRouted directly multiple times
|
||||
rule1, err := manager.blockInvalidRouted(ifaceMock)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule1)
|
||||
|
||||
rule2, err := manager.blockInvalidRouted(ifaceMock)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule2)
|
||||
|
||||
rule3, err := manager.blockInvalidRouted(ifaceMock)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule3)
|
||||
|
||||
// All should return the same rule
|
||||
assert.Equal(t, rule1.ID(), rule2.ID(), "Second call should return same rule")
|
||||
assert.Equal(t, rule2.ID(), rule3.ID(), "Third call should return same rule")
|
||||
|
||||
// Should have exactly 1 route rule
|
||||
manager.mutex.RLock()
|
||||
ruleCount := len(manager.routeRules)
|
||||
manager.mutex.RUnlock()
|
||||
|
||||
assert.Equal(t, 1, ruleCount, "Should have exactly 1 block rule after 3 calls")
|
||||
|
||||
// Verify the rule blocks traffic to the WG network
|
||||
srcIP := netip.MustParseAddr("10.0.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.50")
|
||||
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 80)
|
||||
assert.False(t, pass, "Block rule should deny traffic to WG prefix")
|
||||
}
|
||||
|
||||
// TestBlockRuleNotAccumulatedOnRepeatedEnableRouting verifies that calling
|
||||
// EnableRouting multiple times (as happens on each route update) does not
|
||||
// accumulate duplicate block rules in the routeRules slice.
|
||||
func TestBlockRuleNotAccumulatedOnRepeatedEnableRouting(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
dev := mocks.NewMockDevice(ctrl)
|
||||
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
|
||||
|
||||
wgNet := netip.MustParsePrefix("100.64.0.1/16")
|
||||
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: wgNet.Addr(),
|
||||
Network: wgNet,
|
||||
}
|
||||
},
|
||||
GetDeviceFunc: func() *device.FilteredDevice {
|
||||
return &device.FilteredDevice{Device: dev}
|
||||
},
|
||||
GetWGDeviceFunc: func() *wgdevice.Device {
|
||||
return &wgdevice.Device{}
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
})
|
||||
|
||||
// Call EnableRouting multiple times (simulating repeated route updates)
|
||||
for i := 0; i < 5; i++ {
|
||||
require.NoError(t, manager.EnableRouting())
|
||||
}
|
||||
|
||||
manager.mutex.RLock()
|
||||
ruleCount := len(manager.routeRules)
|
||||
manager.mutex.RUnlock()
|
||||
|
||||
assert.Equal(t, 1, ruleCount,
|
||||
"Repeated EnableRouting should not accumulate block rules")
|
||||
}
|
||||
|
||||
// TestRouteRuleCountStableAcrossUpdates verifies that adding the same route
|
||||
// rule multiple times does not create duplicate entries.
|
||||
func TestRouteRuleCountStableAcrossUpdates(t *testing.T) {
|
||||
manager := setupTestManager(t)
|
||||
|
||||
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
|
||||
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
||||
|
||||
// Simulate 5 network map updates with the same route rule
|
||||
for i := 0; i < 5; i++ {
|
||||
rule, err := manager.AddRouteFiltering(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
destination,
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []uint16{443}},
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule)
|
||||
}
|
||||
|
||||
manager.mutex.RLock()
|
||||
ruleCount := len(manager.routeRules)
|
||||
manager.mutex.RUnlock()
|
||||
|
||||
assert.Equal(t, 2, ruleCount,
|
||||
"Should have exactly 2 rules (1 user rule + 1 block rule) after 5 updates")
|
||||
}
|
||||
|
||||
// TestDeleteRouteRuleAfterIdempotentAdd verifies that deleting a route rule
|
||||
// after adding it multiple times works correctly.
|
||||
func TestDeleteRouteRuleAfterIdempotentAdd(t *testing.T) {
|
||||
manager := setupTestManager(t)
|
||||
|
||||
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
|
||||
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
||||
|
||||
// Add same rule twice
|
||||
rule1, err := manager.AddRouteFiltering(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
destination,
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
nil,
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
rule2, err := manager.AddRouteFiltering(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
destination,
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
nil,
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, rule1.ID(), rule2.ID(), "Should return same rule ID")
|
||||
|
||||
// Delete using first reference
|
||||
err = manager.DeleteRouteRule(rule1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify traffic no longer passes
|
||||
srcIP := netip.MustParseAddr("100.64.1.5")
|
||||
dstIP := netip.MustParseAddr("192.168.1.10")
|
||||
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443)
|
||||
assert.False(t, pass, "Traffic should not pass after rule deletion")
|
||||
}
|
||||
|
||||
func setupTestManager(t *testing.T) *Manager {
|
||||
t.Helper()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
dev := mocks.NewMockDevice(ctrl)
|
||||
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
|
||||
|
||||
wgNet := netip.MustParsePrefix("100.64.0.1/16")
|
||||
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: wgNet.Addr(),
|
||||
Network: wgNet,
|
||||
}
|
||||
},
|
||||
GetDeviceFunc: func() *device.FilteredDevice {
|
||||
return &device.FilteredDevice{Device: dev}
|
||||
},
|
||||
GetWGDeviceFunc: func() *wgdevice.Device {
|
||||
return &wgdevice.Device{}
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.EnableRouting())
|
||||
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
})
|
||||
|
||||
return manager
|
||||
}
|
||||
@@ -263,6 +263,158 @@ func TestAddUDPPacketHook(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestPeerRuleLifecycleDenyRules verifies that deny rules are correctly added
|
||||
// to the deny map and can be cleanly deleted without leaving orphans.
|
||||
func TestPeerRuleLifecycleDenyRules(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, m.Close(nil))
|
||||
}()
|
||||
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
addr := netip.MustParseAddr("192.168.1.1")
|
||||
|
||||
// Add multiple deny rules for different ports
|
||||
rule1, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
rule2, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||
&fw.Port{Values: []uint16{80}}, fw.ActionDrop, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
m.mutex.RLock()
|
||||
denyCount := len(m.incomingDenyRules[addr])
|
||||
m.mutex.RUnlock()
|
||||
require.Equal(t, 2, denyCount, "Should have exactly 2 deny rules")
|
||||
|
||||
// Delete the first deny rule
|
||||
err = m.DeletePeerRule(rule1[0])
|
||||
require.NoError(t, err)
|
||||
|
||||
m.mutex.RLock()
|
||||
denyCount = len(m.incomingDenyRules[addr])
|
||||
m.mutex.RUnlock()
|
||||
require.Equal(t, 1, denyCount, "Should have 1 deny rule after deleting first")
|
||||
|
||||
// Delete the second deny rule
|
||||
err = m.DeletePeerRule(rule2[0])
|
||||
require.NoError(t, err)
|
||||
|
||||
m.mutex.RLock()
|
||||
_, exists := m.incomingDenyRules[addr]
|
||||
m.mutex.RUnlock()
|
||||
require.False(t, exists, "Deny rules IP entry should be cleaned up when empty")
|
||||
}
|
||||
|
||||
// TestPeerRuleAddAndDeleteDontLeak verifies that repeatedly adding and deleting
|
||||
// peer rules (simulating network map updates) does not leak rules in the maps.
|
||||
func TestPeerRuleAddAndDeleteDontLeak(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, m.Close(nil))
|
||||
}()
|
||||
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
addr := netip.MustParseAddr("192.168.1.1")
|
||||
|
||||
// Simulate 10 network map updates: add rule, delete old, add new
|
||||
for i := 0; i < 10; i++ {
|
||||
// Add a deny rule
|
||||
rules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add an allow rule
|
||||
allowRules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||
&fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Delete them (simulating ACL manager cleanup)
|
||||
for _, r := range rules {
|
||||
require.NoError(t, m.DeletePeerRule(r))
|
||||
}
|
||||
for _, r := range allowRules {
|
||||
require.NoError(t, m.DeletePeerRule(r))
|
||||
}
|
||||
}
|
||||
|
||||
m.mutex.RLock()
|
||||
denyCount := len(m.incomingDenyRules[addr])
|
||||
allowCount := len(m.incomingRules[addr])
|
||||
m.mutex.RUnlock()
|
||||
|
||||
require.Equal(t, 0, denyCount, "No deny rules should remain after cleanup")
|
||||
require.Equal(t, 0, allowCount, "No allow rules should remain after cleanup")
|
||||
}
|
||||
|
||||
// TestMixedAllowDenyRulesSameIP verifies that allow and deny rules for the same
|
||||
// IP are stored in separate maps and don't interfere with each other.
|
||||
func TestMixedAllowDenyRulesSameIP(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, m.Close(nil))
|
||||
}()
|
||||
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
|
||||
// Add allow rule for port 80
|
||||
allowRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||
&fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add deny rule for port 22
|
||||
denyRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
addr := netip.MustParseAddr("192.168.1.1")
|
||||
m.mutex.RLock()
|
||||
allowCount := len(m.incomingRules[addr])
|
||||
denyCount := len(m.incomingDenyRules[addr])
|
||||
m.mutex.RUnlock()
|
||||
|
||||
require.Equal(t, 1, allowCount, "Should have 1 allow rule")
|
||||
require.Equal(t, 1, denyCount, "Should have 1 deny rule")
|
||||
|
||||
// Delete allow rule should not affect deny rule
|
||||
err = m.DeletePeerRule(allowRule[0])
|
||||
require.NoError(t, err)
|
||||
|
||||
m.mutex.RLock()
|
||||
denyCountAfter := len(m.incomingDenyRules[addr])
|
||||
m.mutex.RUnlock()
|
||||
|
||||
require.Equal(t, 1, denyCountAfter, "Deny rule should still exist after deleting allow rule")
|
||||
|
||||
// Delete deny rule
|
||||
err = m.DeletePeerRule(denyRule[0])
|
||||
require.NoError(t, err)
|
||||
|
||||
m.mutex.RLock()
|
||||
_, denyExists := m.incomingDenyRules[addr]
|
||||
_, allowExists := m.incomingRules[addr]
|
||||
m.mutex.RUnlock()
|
||||
|
||||
require.False(t, denyExists, "Deny rules should be empty")
|
||||
require.False(t, allowExists, "Allow rules should be empty")
|
||||
}
|
||||
|
||||
func TestManagerReset(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
@@ -228,6 +229,10 @@ func (w *WGIface) Close() error {
|
||||
result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err))
|
||||
}
|
||||
|
||||
if nbnetstack.IsEnabled() {
|
||||
return errors.FormatErrorOrNil(result)
|
||||
}
|
||||
|
||||
if err := w.waitUntilRemoved(); err != nil {
|
||||
log.Warnf("failed to remove WireGuard interface %s: %v", w.Name(), err)
|
||||
if err := w.Destroy(); err != nil {
|
||||
|
||||
@@ -189,6 +189,212 @@ func TestDefaultManagerStateless(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
// TestDenyRulesNotAccumulatedOnRepeatedApply verifies that applying the same
|
||||
// deny rules repeatedly does not accumulate duplicate rules in the uspfilter.
|
||||
// This tests the full ACL manager -> uspfilter integration.
|
||||
func TestDenyRulesNotAccumulatedOnRepeatedApply(t *testing.T) {
|
||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
|
||||
networkMap := &mgmProto.NetworkMap{
|
||||
FirewallRules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_DROP,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "22",
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_DROP,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "80",
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
},
|
||||
FirewallRulesIsEmpty: false,
|
||||
}
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||
IP: network.Addr(),
|
||||
Network: network,
|
||||
}).AnyTimes()
|
||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, fw.Close(nil))
|
||||
}()
|
||||
|
||||
acl := NewDefaultManager(fw)
|
||||
|
||||
// Apply the same rules 5 times (simulating repeated network map updates)
|
||||
for i := 0; i < 5; i++ {
|
||||
acl.ApplyFiltering(networkMap, false)
|
||||
}
|
||||
|
||||
// The ACL manager should track exactly 3 rule pairs (2 deny + 1 accept inbound)
|
||||
assert.Equal(t, 3, len(acl.peerRulesPairs),
|
||||
"Should have exactly 3 rule pairs after 5 identical updates")
|
||||
}
|
||||
|
||||
// TestDenyRulesCleanedUpOnRemoval verifies that deny rules are properly cleaned
|
||||
// up when they're removed from the network map in a subsequent update.
|
||||
func TestDenyRulesCleanedUpOnRemoval(t *testing.T) {
|
||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||
IP: network.Addr(),
|
||||
Network: network,
|
||||
}).AnyTimes()
|
||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, fw.Close(nil))
|
||||
}()
|
||||
|
||||
acl := NewDefaultManager(fw)
|
||||
|
||||
// First update: add deny and accept rules
|
||||
networkMap1 := &mgmProto.NetworkMap{
|
||||
FirewallRules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_DROP,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "22",
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
},
|
||||
FirewallRulesIsEmpty: false,
|
||||
}
|
||||
|
||||
acl.ApplyFiltering(networkMap1, false)
|
||||
assert.Equal(t, 2, len(acl.peerRulesPairs), "Should have 2 rules after first update")
|
||||
|
||||
// Second update: remove the deny rule, keep only accept
|
||||
networkMap2 := &mgmProto.NetworkMap{
|
||||
FirewallRules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
},
|
||||
FirewallRulesIsEmpty: false,
|
||||
}
|
||||
|
||||
acl.ApplyFiltering(networkMap2, false)
|
||||
assert.Equal(t, 1, len(acl.peerRulesPairs),
|
||||
"Should have 1 rule after removing deny rule")
|
||||
|
||||
// Third update: remove all rules
|
||||
networkMap3 := &mgmProto.NetworkMap{
|
||||
FirewallRules: []*mgmProto.FirewallRule{},
|
||||
FirewallRulesIsEmpty: true,
|
||||
}
|
||||
|
||||
acl.ApplyFiltering(networkMap3, false)
|
||||
assert.Equal(t, 0, len(acl.peerRulesPairs),
|
||||
"Should have 0 rules after removing all rules")
|
||||
}
|
||||
|
||||
// TestRuleUpdateChangingAction verifies that when a rule's action changes from
|
||||
// accept to deny (or vice versa), the old rule is properly removed and the new
|
||||
// one added without leaking.
|
||||
func TestRuleUpdateChangingAction(t *testing.T) {
|
||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||
IP: network.Addr(),
|
||||
Network: network,
|
||||
}).AnyTimes()
|
||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, fw.Close(nil))
|
||||
}()
|
||||
|
||||
acl := NewDefaultManager(fw)
|
||||
|
||||
// First update: accept rule
|
||||
networkMap := &mgmProto.NetworkMap{
|
||||
FirewallRules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "22",
|
||||
},
|
||||
},
|
||||
FirewallRulesIsEmpty: false,
|
||||
}
|
||||
acl.ApplyFiltering(networkMap, false)
|
||||
assert.Equal(t, 1, len(acl.peerRulesPairs))
|
||||
|
||||
// Second update: change to deny (same IP/port/proto, different action)
|
||||
networkMap.FirewallRules = []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_DROP,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "22",
|
||||
},
|
||||
}
|
||||
acl.ApplyFiltering(networkMap, false)
|
||||
|
||||
// Should still have exactly 1 rule (the old accept removed, new deny added)
|
||||
assert.Equal(t, 1, len(acl.peerRulesPairs),
|
||||
"Changing action should result in exactly 1 rule, not 2")
|
||||
}
|
||||
|
||||
func TestPortInfoEmpty(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
@@ -244,7 +245,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
localPeerState := peer.LocalPeerState{
|
||||
IP: loginResp.GetPeerConfig().GetAddress(),
|
||||
PubKey: myPrivateKey.PublicKey().String(),
|
||||
KernelInterface: device.WireGuardModuleIsLoaded(),
|
||||
KernelInterface: device.WireGuardModuleIsLoaded() && !netstack.IsEnabled(),
|
||||
FQDN: loginResp.GetPeerConfig().GetFqdn(),
|
||||
}
|
||||
c.statusRecorder.UpdateLocalPeerState(localPeerState)
|
||||
|
||||
@@ -219,6 +219,11 @@ const (
|
||||
darwinStdoutLogPath = "/var/log/netbird.err.log"
|
||||
)
|
||||
|
||||
// MetricsExporter is an interface for exporting metrics
|
||||
type MetricsExporter interface {
|
||||
Export(w io.Writer) error
|
||||
}
|
||||
|
||||
type BundleGenerator struct {
|
||||
anonymizer *anonymize.Anonymizer
|
||||
|
||||
@@ -229,6 +234,7 @@ type BundleGenerator struct {
|
||||
logPath string
|
||||
cpuProfile []byte
|
||||
refreshStatus func() // Optional callback to refresh status before bundle generation
|
||||
clientMetrics MetricsExporter
|
||||
|
||||
anonymize bool
|
||||
includeSystemInfo bool
|
||||
@@ -250,6 +256,7 @@ type GeneratorDependencies struct {
|
||||
LogPath string
|
||||
CPUProfile []byte
|
||||
RefreshStatus func() // Optional callback to refresh status before bundle generation
|
||||
ClientMetrics MetricsExporter
|
||||
}
|
||||
|
||||
func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGenerator {
|
||||
@@ -268,6 +275,7 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
|
||||
logPath: deps.LogPath,
|
||||
cpuProfile: deps.CPUProfile,
|
||||
refreshStatus: deps.RefreshStatus,
|
||||
clientMetrics: deps.ClientMetrics,
|
||||
|
||||
anonymize: cfg.Anonymize,
|
||||
includeSystemInfo: cfg.IncludeSystemInfo,
|
||||
@@ -351,6 +359,10 @@ func (g *BundleGenerator) createArchive() error {
|
||||
log.Errorf("failed to add corrupted state files to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addMetrics(); err != nil {
|
||||
log.Errorf("failed to add metrics to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addWgShow(); err != nil {
|
||||
log.Errorf("failed to add wg show output: %v", err)
|
||||
}
|
||||
@@ -744,6 +756,25 @@ func (g *BundleGenerator) addCorruptedStateFiles() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addMetrics() error {
|
||||
if g.clientMetrics == nil {
|
||||
log.Debugf("skipping metrics in debug bundle: no metrics collector")
|
||||
return nil
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := g.clientMetrics.Export(&buf); err != nil {
|
||||
return fmt.Errorf("export metrics: %w", err)
|
||||
}
|
||||
|
||||
if err := g.addFileToZip(&buf, "metrics.txt"); err != nil {
|
||||
return fmt.Errorf("add metrics file to zip: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("added metrics to debug bundle")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addLogfile() error {
|
||||
if g.logPath == "" {
|
||||
log.Debugf("skipping empty log file in debug bundle")
|
||||
|
||||
@@ -6,7 +6,9 @@ import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
@@ -27,6 +29,8 @@ import (
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
const envSkipDNSProbe = "NB_SKIP_DNS_PROBE"
|
||||
|
||||
// ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
|
||||
type ReadyListener interface {
|
||||
OnReady()
|
||||
@@ -439,6 +443,17 @@ func (s *DefaultServer) SearchDomains() []string {
|
||||
// ProbeAvailability tests each upstream group's servers for availability
|
||||
// and deactivates the group if no server responds
|
||||
func (s *DefaultServer) ProbeAvailability() {
|
||||
if val := os.Getenv(envSkipDNSProbe); val != "" {
|
||||
skipProbe, err := strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse %s: %v", envSkipDNSProbe, err)
|
||||
}
|
||||
if skipProbe {
|
||||
log.Infof("skipping DNS probe due to %s", envSkipDNSProbe)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for _, mux := range s.dnsMuxMap {
|
||||
wg.Add(1)
|
||||
|
||||
@@ -190,50 +190,75 @@ func (f *DNSForwarder) Close(ctx context.Context) error {
|
||||
return nberrors.FormatErrorOrNil(result)
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, query *dns.Msg) *dns.Msg {
|
||||
func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, query *dns.Msg, startTime time.Time) {
|
||||
if len(query.Question) == 0 {
|
||||
return nil
|
||||
return
|
||||
}
|
||||
question := query.Question[0]
|
||||
logger.Tracef("received DNS request for DNS forwarder: domain=%s type=%s class=%s",
|
||||
question.Name, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
|
||||
qname := strings.ToLower(question.Name)
|
||||
|
||||
domain := strings.ToLower(question.Name)
|
||||
logger.Tracef("question: domain=%s type=%s class=%s",
|
||||
qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
|
||||
|
||||
resp := query.SetReply(query)
|
||||
network := resutil.NetworkForQtype(question.Qtype)
|
||||
if network == "" {
|
||||
resp.Rcode = dns.RcodeNotImplemented
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
logger.Errorf("failed to write DNS response: %v", err)
|
||||
}
|
||||
return nil
|
||||
f.writeResponse(logger, w, resp, qname, startTime)
|
||||
return
|
||||
}
|
||||
|
||||
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, "."))
|
||||
// query doesn't match any configured domain
|
||||
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(qname, "."))
|
||||
if mostSpecificResId == "" {
|
||||
resp.Rcode = dns.RcodeRefused
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
logger.Errorf("failed to write DNS response: %v", err)
|
||||
}
|
||||
return nil
|
||||
f.writeResponse(logger, w, resp, qname, startTime)
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
|
||||
defer cancel()
|
||||
|
||||
result := resutil.LookupIP(ctx, f.resolver, network, domain, question.Qtype)
|
||||
result := resutil.LookupIP(ctx, f.resolver, network, qname, question.Qtype)
|
||||
if result.Err != nil {
|
||||
f.handleDNSError(ctx, logger, w, question, resp, domain, result)
|
||||
return nil
|
||||
f.handleDNSError(ctx, logger, w, question, resp, qname, result, startTime)
|
||||
return
|
||||
}
|
||||
|
||||
f.updateInternalState(result.IPs, mostSpecificResId, matchingEntries)
|
||||
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, result.IPs, f.ttl)...)
|
||||
f.cache.set(domain, question.Qtype, result.IPs)
|
||||
resp.Answer = append(resp.Answer, resutil.IPsToRRs(qname, result.IPs, f.ttl)...)
|
||||
f.cache.set(qname, question.Qtype, result.IPs)
|
||||
|
||||
return resp
|
||||
f.writeResponse(logger, w, resp, qname, startTime)
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) writeResponse(logger *log.Entry, w dns.ResponseWriter, resp *dns.Msg, qname string, startTime time.Time) {
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
logger.Errorf("failed to write DNS response: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
|
||||
qname, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
|
||||
}
|
||||
|
||||
// udpResponseWriter wraps a dns.ResponseWriter to handle UDP-specific truncation.
|
||||
type udpResponseWriter struct {
|
||||
dns.ResponseWriter
|
||||
query *dns.Msg
|
||||
}
|
||||
|
||||
func (u *udpResponseWriter) WriteMsg(resp *dns.Msg) error {
|
||||
opt := u.query.IsEdns0()
|
||||
maxSize := dns.MinMsgSize
|
||||
if opt != nil {
|
||||
maxSize = int(opt.UDPSize())
|
||||
}
|
||||
|
||||
if resp.Len() > maxSize {
|
||||
resp.Truncate(maxSize)
|
||||
}
|
||||
|
||||
return u.ResponseWriter.WriteMsg(resp)
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
|
||||
@@ -243,30 +268,7 @@ func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
|
||||
"dns_id": fmt.Sprintf("%04x", query.Id),
|
||||
})
|
||||
|
||||
resp := f.handleDNSQuery(logger, w, query)
|
||||
if resp == nil {
|
||||
return
|
||||
}
|
||||
|
||||
opt := query.IsEdns0()
|
||||
maxSize := dns.MinMsgSize
|
||||
if opt != nil {
|
||||
// client advertised a larger EDNS0 buffer
|
||||
maxSize = int(opt.UDPSize())
|
||||
}
|
||||
|
||||
// if our response is too big, truncate and set the TC bit
|
||||
if resp.Len() > maxSize {
|
||||
resp.Truncate(maxSize)
|
||||
}
|
||||
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
logger.Errorf("failed to write DNS response: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
|
||||
query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
|
||||
f.handleDNSQuery(logger, &udpResponseWriter{ResponseWriter: w, query: query}, query, startTime)
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
|
||||
@@ -276,18 +278,7 @@ func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
|
||||
"dns_id": fmt.Sprintf("%04x", query.Id),
|
||||
})
|
||||
|
||||
resp := f.handleDNSQuery(logger, w, query)
|
||||
if resp == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
logger.Errorf("failed to write DNS response: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
|
||||
query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
|
||||
f.handleDNSQuery(logger, w, query, startTime)
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) {
|
||||
@@ -334,6 +325,7 @@ func (f *DNSForwarder) handleDNSError(
|
||||
resp *dns.Msg,
|
||||
domain string,
|
||||
result resutil.LookupResult,
|
||||
startTime time.Time,
|
||||
) {
|
||||
qType := question.Qtype
|
||||
qTypeName := dns.TypeToString[qType]
|
||||
@@ -343,9 +335,7 @@ func (f *DNSForwarder) handleDNSError(
|
||||
// NotFound: cache negative result and respond
|
||||
if result.Rcode == dns.RcodeNameError || result.Rcode == dns.RcodeSuccess {
|
||||
f.cache.set(domain, question.Qtype, nil)
|
||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||
logger.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||
}
|
||||
f.writeResponse(logger, w, resp, domain, startTime)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -355,9 +345,7 @@ func (f *DNSForwarder) handleDNSError(
|
||||
logger.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
|
||||
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, ips, f.ttl)...)
|
||||
resp.Rcode = dns.RcodeSuccess
|
||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||
logger.Errorf("failed to write cached DNS response: %v", writeErr)
|
||||
}
|
||||
f.writeResponse(logger, w, resp, domain, startTime)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -365,9 +353,7 @@ func (f *DNSForwarder) handleDNSError(
|
||||
verifyResult := resutil.LookupIP(ctx, f.resolver, resutil.NetworkForQtype(qType), domain, qType)
|
||||
if verifyResult.Rcode == dns.RcodeNameError || verifyResult.Rcode == dns.RcodeSuccess {
|
||||
resp.Rcode = verifyResult.Rcode
|
||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||
logger.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||
}
|
||||
f.writeResponse(logger, w, resp, domain, startTime)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -375,15 +361,12 @@ func (f *DNSForwarder) handleDNSError(
|
||||
// No cache or verification failed. Log with or without the server field for more context.
|
||||
var dnsErr *net.DNSError
|
||||
if errors.As(result.Err, &dnsErr) && dnsErr.Server != "" {
|
||||
logger.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, result.Err)
|
||||
logger.Warnf("upstream failure: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, result.Err)
|
||||
} else {
|
||||
logger.Warnf(errResolveFailed, domain, result.Err)
|
||||
}
|
||||
|
||||
// Write final failure response.
|
||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||
logger.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||
}
|
||||
f.writeResponse(logger, w, resp, domain, startTime)
|
||||
}
|
||||
|
||||
// getMatchingEntries retrieves the resource IDs for a given domain.
|
||||
|
||||
@@ -318,8 +318,9 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
|
||||
query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA)
|
||||
|
||||
mockWriter := &test.MockResponseWriter{}
|
||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
||||
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||
|
||||
resp := mockWriter.GetLastResponse()
|
||||
if tt.shouldResolve {
|
||||
require.NotNil(t, resp, "Expected response for authorized domain")
|
||||
require.Equal(t, dns.RcodeSuccess, resp.Rcode, "Expected successful response")
|
||||
@@ -329,10 +330,9 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
|
||||
mockFirewall.AssertExpectations(t)
|
||||
mockResolver.AssertExpectations(t)
|
||||
} else {
|
||||
if resp != nil {
|
||||
assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess,
|
||||
"Unauthorized domain should not return successful answers")
|
||||
}
|
||||
require.NotNil(t, resp, "Expected response")
|
||||
assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess,
|
||||
"Unauthorized domain should not return successful answers")
|
||||
mockFirewall.AssertNotCalled(t, "UpdateSet")
|
||||
mockResolver.AssertNotCalled(t, "LookupNetIP")
|
||||
}
|
||||
@@ -466,14 +466,16 @@ func TestDNSForwarder_FirewallSetUpdates(t *testing.T) {
|
||||
dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA)
|
||||
|
||||
mockWriter := &test.MockResponseWriter{}
|
||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, dnsQuery)
|
||||
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, dnsQuery, time.Now())
|
||||
|
||||
// Verify response
|
||||
resp := mockWriter.GetLastResponse()
|
||||
if tt.shouldResolve {
|
||||
require.NotNil(t, resp, "Expected response for authorized domain")
|
||||
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||
require.NotEmpty(t, resp.Answer)
|
||||
} else if resp != nil {
|
||||
} else {
|
||||
require.NotNil(t, resp, "Expected response")
|
||||
assert.True(t, resp.Rcode == dns.RcodeRefused || len(resp.Answer) == 0,
|
||||
"Unauthorized domain should be refused or have no answers")
|
||||
}
|
||||
@@ -528,9 +530,10 @@ func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
|
||||
query.SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
mockWriter := &test.MockResponseWriter{}
|
||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
||||
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||
|
||||
// Verify response contains all IPs
|
||||
resp := mockWriter.GetLastResponse()
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||
require.Len(t, resp.Answer, 3, "Should have 3 answer records")
|
||||
@@ -605,7 +608,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
||||
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||
|
||||
// Check the response written to the writer
|
||||
require.NotNil(t, writtenResp, "Expected response to be written")
|
||||
@@ -675,7 +678,8 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
|
||||
q1 := &dns.Msg{}
|
||||
q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
||||
w1 := &test.MockResponseWriter{}
|
||||
resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1)
|
||||
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1, time.Now())
|
||||
resp1 := w1.GetLastResponse()
|
||||
require.NotNil(t, resp1)
|
||||
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
||||
require.Len(t, resp1.Answer, 1)
|
||||
@@ -683,13 +687,13 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
|
||||
// Second query: serve from cache after upstream failure
|
||||
q2 := &dns.Msg{}
|
||||
q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
||||
var writtenResp *dns.Msg
|
||||
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
|
||||
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2)
|
||||
w2 := &test.MockResponseWriter{}
|
||||
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2, time.Now())
|
||||
|
||||
require.NotNil(t, writtenResp, "expected response to be written")
|
||||
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
|
||||
require.Len(t, writtenResp.Answer, 1)
|
||||
resp2 := w2.GetLastResponse()
|
||||
require.NotNil(t, resp2, "expected response to be written")
|
||||
require.Equal(t, dns.RcodeSuccess, resp2.Rcode)
|
||||
require.Len(t, resp2.Answer, 1)
|
||||
|
||||
mockResolver.AssertExpectations(t)
|
||||
}
|
||||
@@ -715,7 +719,8 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
|
||||
q1 := &dns.Msg{}
|
||||
q1.SetQuestion(mixedQuery+".", dns.TypeA)
|
||||
w1 := &test.MockResponseWriter{}
|
||||
resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1)
|
||||
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1, time.Now())
|
||||
resp1 := w1.GetLastResponse()
|
||||
require.NotNil(t, resp1)
|
||||
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
||||
require.Len(t, resp1.Answer, 1)
|
||||
@@ -727,13 +732,13 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
|
||||
|
||||
q2 := &dns.Msg{}
|
||||
q2.SetQuestion("EXAMPLE.COM", dns.TypeA)
|
||||
var writtenResp *dns.Msg
|
||||
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
|
||||
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2)
|
||||
w2 := &test.MockResponseWriter{}
|
||||
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2, time.Now())
|
||||
|
||||
require.NotNil(t, writtenResp)
|
||||
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
|
||||
require.Len(t, writtenResp.Answer, 1)
|
||||
resp2 := w2.GetLastResponse()
|
||||
require.NotNil(t, resp2)
|
||||
require.Equal(t, dns.RcodeSuccess, resp2.Rcode)
|
||||
require.Len(t, resp2.Answer, 1)
|
||||
|
||||
mockResolver.AssertExpectations(t)
|
||||
}
|
||||
@@ -784,8 +789,9 @@ func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
|
||||
query.SetQuestion("smtp.mail.example.com.", dns.TypeA)
|
||||
|
||||
mockWriter := &test.MockResponseWriter{}
|
||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
||||
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||
|
||||
resp := mockWriter.GetLastResponse()
|
||||
require.NotNil(t, resp)
|
||||
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||
|
||||
@@ -897,26 +903,15 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) {
|
||||
query := &dns.Msg{}
|
||||
query.SetQuestion(dns.Fqdn("example.com"), tt.queryType)
|
||||
|
||||
var writtenResp *dns.Msg
|
||||
mockWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
writtenResp = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
mockWriter := &test.MockResponseWriter{}
|
||||
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||
|
||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
||||
|
||||
// If a response was returned, it means it should be written (happens in wrapper functions)
|
||||
if resp != nil && writtenResp == nil {
|
||||
writtenResp = resp
|
||||
}
|
||||
|
||||
require.NotNil(t, writtenResp, "Expected response to be written")
|
||||
assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description)
|
||||
resp := mockWriter.GetLastResponse()
|
||||
require.NotNil(t, resp, "Expected response to be written")
|
||||
assert.Equal(t, tt.expectedCode, resp.Rcode, tt.description)
|
||||
|
||||
if tt.expectNoAnswer {
|
||||
assert.Empty(t, writtenResp.Answer, "Response should have no answer records")
|
||||
assert.Empty(t, resp.Answer, "Response should have no answer records")
|
||||
}
|
||||
|
||||
mockResolver.AssertExpectations(t)
|
||||
@@ -931,15 +926,8 @@ func TestDNSForwarder_EmptyQuery(t *testing.T) {
|
||||
query := &dns.Msg{}
|
||||
// Don't set any question
|
||||
|
||||
writeCalled := false
|
||||
mockWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
writeCalled = true
|
||||
return nil
|
||||
},
|
||||
}
|
||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
||||
mockWriter := &test.MockResponseWriter{}
|
||||
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||
|
||||
assert.Nil(t, resp, "Should return nil for empty query")
|
||||
assert.False(t, writeCalled, "Should not write response for empty query")
|
||||
assert.Nil(t, mockWriter.GetLastResponse(), "Should not write response for empty query")
|
||||
}
|
||||
|
||||
@@ -36,6 +36,7 @@ import (
|
||||
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
||||
"github.com/netbirdio/netbird/client/internal/ingressgw"
|
||||
"github.com/netbirdio/netbird/client/internal/metrics"
|
||||
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
"github.com/netbirdio/netbird/client/internal/networkmonitor"
|
||||
@@ -44,7 +45,6 @@ import (
|
||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/proxy"
|
||||
"github.com/netbirdio/netbird/client/internal/relay"
|
||||
"github.com/netbirdio/netbird/client/internal/rosenpass"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
@@ -66,6 +66,7 @@ import (
|
||||
signal "github.com/netbirdio/netbird/shared/signal/client"
|
||||
sProto "github.com/netbirdio/netbird/shared/signal/proto"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
// PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer.
|
||||
@@ -141,11 +142,6 @@ type EngineConfig struct {
|
||||
ProfileConfig *profilemanager.Config
|
||||
|
||||
LogPath string
|
||||
|
||||
// ProxyConfig contains system proxy settings for macOS
|
||||
ProxyEnabled bool
|
||||
ProxyHost string
|
||||
ProxyPort int
|
||||
}
|
||||
|
||||
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
|
||||
@@ -227,11 +223,11 @@ type Engine struct {
|
||||
|
||||
probeStunTurn *relay.StunTurnProbe
|
||||
|
||||
// clientMetrics collects and pushes metrics
|
||||
clientMetrics *metrics.ClientMetrics
|
||||
|
||||
jobExecutor *jobexec.Executor
|
||||
jobExecutorWG sync.WaitGroup
|
||||
|
||||
// proxyManager manages system-wide browser proxy settings on macOS
|
||||
proxyManager *proxy.Manager
|
||||
}
|
||||
|
||||
// Peer is an instance of the Connection Peer
|
||||
@@ -257,6 +253,12 @@ func NewEngine(
|
||||
checks []*mgmProto.Checks,
|
||||
stateManager *statemanager.Manager,
|
||||
) *Engine {
|
||||
// Initialize metrics based on deployment type
|
||||
var deploymentType metrics.DeploymentType
|
||||
if mgmClient != nil {
|
||||
deploymentType = metrics.DetermineDeploymentType(mgmClient.GetServerURL())
|
||||
}
|
||||
|
||||
engine := &Engine{
|
||||
clientCtx: clientCtx,
|
||||
clientCancel: clientCancel,
|
||||
@@ -279,6 +281,10 @@ func NewEngine(
|
||||
jobExecutor: jobexec.NewExecutor(),
|
||||
}
|
||||
|
||||
engine.clientMetrics = metrics.NewClientMetrics(metrics.AgentInfo{
|
||||
DeploymentType: deploymentType,
|
||||
Version: version.NetbirdVersion()})
|
||||
|
||||
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
|
||||
return engine
|
||||
}
|
||||
@@ -322,12 +328,6 @@ func (e *Engine) Stop() error {
|
||||
e.updateManager.Stop()
|
||||
}
|
||||
|
||||
if e.proxyManager != nil {
|
||||
if err := e.proxyManager.DisableWebProxy(); err != nil {
|
||||
log.Warnf("failed to disable system proxy: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Info("cleaning up status recorder states")
|
||||
e.statusRecorder.ReplaceOfflinePeers([]peer.State{})
|
||||
e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{})
|
||||
@@ -463,10 +463,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
}
|
||||
e.stateManager.Start()
|
||||
|
||||
// Initialize proxy manager and register state for cleanup
|
||||
proxy.RegisterState(e.stateManager)
|
||||
e.proxyManager = proxy.NewManager(e.stateManager)
|
||||
|
||||
initialRoutes, dnsConfig, dnsFeatureFlag, err := e.readInitialSettings()
|
||||
if err != nil {
|
||||
e.close()
|
||||
@@ -847,6 +843,12 @@ func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdate
|
||||
}
|
||||
|
||||
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
started := time.Now()
|
||||
defer func() {
|
||||
duration := time.Since(started)
|
||||
log.Infof("sync finished in %s", duration)
|
||||
e.clientMetrics.RecordSyncDuration(e.ctx, duration)
|
||||
}()
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
@@ -1036,7 +1038,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
state := e.statusRecorder.GetLocalPeerState()
|
||||
state.IP = e.wgInterface.Address().String()
|
||||
state.PubKey = e.config.WgPrivateKey.PublicKey().String()
|
||||
state.KernelInterface = device.WireGuardModuleIsLoaded()
|
||||
state.KernelInterface = !e.wgInterface.IsUserspaceBind()
|
||||
state.FQDN = conf.GetFqdn()
|
||||
|
||||
e.statusRecorder.UpdateLocalPeerState(state)
|
||||
@@ -1092,6 +1094,7 @@ func (e *Engine) handleBundle(params *mgmProto.BundleParameters) (*mgmProto.JobR
|
||||
StatusRecorder: e.statusRecorder,
|
||||
SyncResponse: syncResponse,
|
||||
LogPath: e.config.LogPath,
|
||||
ClientMetrics: e.clientMetrics,
|
||||
RefreshStatus: func() {
|
||||
e.RunHealthProbes(true)
|
||||
},
|
||||
@@ -1331,9 +1334,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
// If no server of a server group responds this will disable the respective handler and retry later.
|
||||
e.dnsServer.ProbeAvailability()
|
||||
|
||||
// Update system proxy state based on routes after network map is fully applied
|
||||
e.updateSystemProxy(clientRoutes)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1550,12 +1550,13 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
|
||||
}
|
||||
|
||||
serviceDependencies := peer.ServiceDependencies{
|
||||
StatusRecorder: e.statusRecorder,
|
||||
Signaler: e.signaler,
|
||||
IFaceDiscover: e.mobileDep.IFaceDiscover,
|
||||
RelayManager: e.relayManager,
|
||||
SrWatcher: e.srWatcher,
|
||||
Semaphore: e.connSemaphore,
|
||||
StatusRecorder: e.statusRecorder,
|
||||
Signaler: e.signaler,
|
||||
IFaceDiscover: e.mobileDep.IFaceDiscover,
|
||||
RelayManager: e.relayManager,
|
||||
SrWatcher: e.srWatcher,
|
||||
Semaphore: e.connSemaphore,
|
||||
MetricsRecorder: e.clientMetrics,
|
||||
}
|
||||
peerConn, err := peer.NewConn(config, serviceDependencies)
|
||||
if err != nil {
|
||||
@@ -1841,6 +1842,11 @@ func (e *Engine) GetFirewallManager() firewallManager.Manager {
|
||||
return e.firewall
|
||||
}
|
||||
|
||||
// GetClientMetrics returns the client metrics
|
||||
func (e *Engine) GetClientMetrics() *metrics.ClientMetrics {
|
||||
return e.clientMetrics
|
||||
}
|
||||
|
||||
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
|
||||
iface, err := net.InterfaceByName(ifaceName)
|
||||
if err != nil {
|
||||
@@ -2325,26 +2331,6 @@ func createFile(path string) error {
|
||||
return file.Close()
|
||||
}
|
||||
|
||||
// updateSystemProxy triggers a proxy enable/disable cycle after the network map is updated.
|
||||
func (e *Engine) updateSystemProxy(clientRoutes route.HAMap) {
|
||||
if runtime.GOOS != "darwin" || e.proxyManager == nil {
|
||||
log.Errorf("not updating proxy")
|
||||
return
|
||||
}
|
||||
|
||||
if err := e.proxyManager.EnableWebProxy(e.config.ProxyHost, e.config.ProxyPort); err != nil {
|
||||
log.Errorf("enable system proxy: %v", err)
|
||||
return
|
||||
}
|
||||
log.Error("system proxy enabled after network map update")
|
||||
|
||||
if err := e.proxyManager.DisableWebProxy(); err != nil {
|
||||
log.Errorf("disable system proxy: %v", err)
|
||||
return
|
||||
}
|
||||
log.Error("system proxy disabled after network map update")
|
||||
}
|
||||
|
||||
func convertToOfferAnswer(msg *sProto.Message) (*peer.OfferAnswer, error) {
|
||||
remoteCred, err := signal.UnMarshalCredential(msg)
|
||||
if err != nil {
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
|
||||
@@ -94,6 +95,10 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
||||
|
||||
// updateSSHClientConfig updates the SSH client configuration with peer information
|
||||
func (e *Engine) updateSSHClientConfig(remotePeers []*mgmProto.RemotePeerConfig) error {
|
||||
if netstack.IsEnabled() {
|
||||
return nil
|
||||
}
|
||||
|
||||
peerInfo := e.extractPeerSSHInfo(remotePeers)
|
||||
if len(peerInfo) == 0 {
|
||||
log.Debug("no SSH-enabled peers found, skipping SSH config update")
|
||||
@@ -216,6 +221,10 @@ func (e *Engine) GetPeerSSHKey(peerAddress string) ([]byte, bool) {
|
||||
|
||||
// cleanupSSHConfig removes NetBird SSH client configuration on shutdown
|
||||
func (e *Engine) cleanupSSHConfig() {
|
||||
if netstack.IsEnabled() {
|
||||
return
|
||||
}
|
||||
|
||||
configMgr := sshconfig.New()
|
||||
|
||||
if err := configMgr.RemoveSSHClientConfig(); err != nil {
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
||||
@@ -74,12 +75,13 @@ func (m *Manager) createListener(peerCfg lazyconn.PeerConfig) (listener, error)
|
||||
return NewUDPListener(m.wgIface, peerCfg)
|
||||
}
|
||||
|
||||
// BindListener is only used on Windows and JS platforms:
|
||||
// BindListener is used on Windows, JS, and netstack platforms:
|
||||
// - JS: Cannot listen to UDP sockets
|
||||
// - Windows: IP_UNICAST_IF socket option forces packets out the interface the default
|
||||
// gateway points to, preventing them from reaching the loopback interface.
|
||||
// BindListener bypasses this by passing data directly through the bind.
|
||||
if runtime.GOOS != "windows" && runtime.GOOS != "js" {
|
||||
// - Netstack: Allows multiple instances on the same host without port conflicts.
|
||||
// BindListener bypasses these issues by passing data directly through the bind.
|
||||
if runtime.GOOS != "windows" && runtime.GOOS != "js" && !netstack.IsEnabled() {
|
||||
return NewUDPListener(m.wgIface, peerCfg)
|
||||
}
|
||||
|
||||
|
||||
17
client/internal/metrics/connection_type.go
Normal file
17
client/internal/metrics/connection_type.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package metrics
|
||||
|
||||
// ConnectionType represents the type of peer connection
|
||||
type ConnectionType string
|
||||
|
||||
const (
|
||||
// ConnectionTypeICE represents a direct peer-to-peer connection using ICE
|
||||
ConnectionTypeICE ConnectionType = "ice"
|
||||
|
||||
// ConnectionTypeRelay represents a relayed connection
|
||||
ConnectionTypeRelay ConnectionType = "relay"
|
||||
)
|
||||
|
||||
// String returns the string representation of the connection type
|
||||
func (c ConnectionType) String() string {
|
||||
return string(c)
|
||||
}
|
||||
46
client/internal/metrics/deployment_type.go
Normal file
46
client/internal/metrics/deployment_type.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package metrics
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// DeploymentType represents the type of NetBird deployment
|
||||
type DeploymentType int
|
||||
|
||||
const (
|
||||
// DeploymentTypeUnknown represents an unknown or uninitialized deployment type
|
||||
DeploymentTypeUnknown DeploymentType = iota
|
||||
|
||||
// DeploymentTypeCloud represents a cloud-hosted NetBird deployment
|
||||
DeploymentTypeCloud
|
||||
|
||||
// DeploymentTypeSelfHosted represents a self-hosted NetBird deployment
|
||||
DeploymentTypeSelfHosted
|
||||
)
|
||||
|
||||
// String returns the string representation of the deployment type
|
||||
func (d DeploymentType) String() string {
|
||||
switch d {
|
||||
case DeploymentTypeCloud:
|
||||
return "cloud"
|
||||
case DeploymentTypeSelfHosted:
|
||||
return "selfhosted"
|
||||
default:
|
||||
return "selfhosted"
|
||||
}
|
||||
}
|
||||
|
||||
// DetermineDeploymentType determines if the deployment is cloud or self-hosted
|
||||
// based on the management URL string
|
||||
func DetermineDeploymentType(managementURL string) DeploymentType {
|
||||
if managementURL == "" {
|
||||
return DeploymentTypeUnknown
|
||||
}
|
||||
|
||||
// Check for NetBird cloud API domain
|
||||
if strings.Contains(strings.ToLower(managementURL), "api.netbird.io") {
|
||||
return DeploymentTypeCloud
|
||||
}
|
||||
|
||||
return DeploymentTypeSelfHosted
|
||||
}
|
||||
77
client/internal/metrics/metrics.go
Normal file
77
client/internal/metrics/metrics.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package metrics
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"time"
|
||||
)
|
||||
|
||||
// AgentInfo holds static information about the agent
|
||||
type AgentInfo struct {
|
||||
DeploymentType DeploymentType
|
||||
Version string
|
||||
}
|
||||
|
||||
// metricsImplementation defines the internal interface for metrics implementations
|
||||
type metricsImplementation interface {
|
||||
// RecordConnectionStages records connection stage metrics from timestamps
|
||||
RecordConnectionStages(
|
||||
ctx context.Context,
|
||||
connectionType ConnectionType,
|
||||
isReconnection bool,
|
||||
timestamps ConnectionStageTimestamps,
|
||||
)
|
||||
|
||||
// RecordSyncDuration records how long it took to process a sync message
|
||||
RecordSyncDuration(ctx context.Context, duration time.Duration)
|
||||
|
||||
// Export exports metrics in Prometheus format
|
||||
Export(w io.Writer) error
|
||||
}
|
||||
|
||||
type ClientMetrics struct {
|
||||
impl metricsImplementation
|
||||
}
|
||||
|
||||
// ConnectionStageTimestamps holds timestamps for each connection stage
|
||||
type ConnectionStageTimestamps struct {
|
||||
Created time.Time
|
||||
SemaphoreAcquired time.Time
|
||||
Signaling time.Time // First signal sent (initial) or signal received (reconnection)
|
||||
ConnectionReady time.Time
|
||||
WgHandshakeSuccess time.Time
|
||||
}
|
||||
|
||||
// NewClientMetrics creates a new ClientMetrics instance
|
||||
func NewClientMetrics(agentInfo AgentInfo) *ClientMetrics {
|
||||
return &ClientMetrics{impl: newVictoriaMetrics(agentInfo)}
|
||||
}
|
||||
|
||||
// RecordConnectionStages calculates stage durations from timestamps and records them
|
||||
func (c *ClientMetrics) RecordConnectionStages(
|
||||
ctx context.Context,
|
||||
connectionType ConnectionType,
|
||||
isReconnection bool,
|
||||
timestamps ConnectionStageTimestamps,
|
||||
) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.impl.RecordConnectionStages(ctx, connectionType, isReconnection, timestamps)
|
||||
}
|
||||
|
||||
// RecordSyncDuration records the duration of sync message processing
|
||||
func (c *ClientMetrics) RecordSyncDuration(ctx context.Context, duration time.Duration) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.impl.RecordSyncDuration(ctx, duration)
|
||||
}
|
||||
|
||||
// Export exports metrics to the writer
|
||||
func (c *ClientMetrics) Export(w io.Writer) error {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
return c.impl.Export(w)
|
||||
}
|
||||
129
client/internal/metrics/victoria.go
Normal file
129
client/internal/metrics/victoria.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package metrics
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/VictoriaMetrics/metrics"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// victoriaMetrics is the VictoriaMetrics implementation of ClientMetrics
|
||||
type victoriaMetrics struct {
|
||||
// Static agent information applied to all metrics
|
||||
agentInfo AgentInfo
|
||||
|
||||
// Metrics set for managing all metrics
|
||||
set *metrics.Set
|
||||
}
|
||||
|
||||
func newVictoriaMetrics(agentInfo AgentInfo) metricsImplementation {
|
||||
return &victoriaMetrics{
|
||||
agentInfo: agentInfo,
|
||||
set: metrics.NewSet(),
|
||||
}
|
||||
}
|
||||
|
||||
// RecordConnectionStages records the duration of each connection stage from timestamps
|
||||
func (m *victoriaMetrics) RecordConnectionStages(
|
||||
ctx context.Context,
|
||||
connectionType ConnectionType,
|
||||
isReconnection bool,
|
||||
timestamps ConnectionStageTimestamps,
|
||||
) {
|
||||
// Calculate stage durations
|
||||
var creationToSemaphore, semaphoreToSignaling, signalingToConnection, connectionToHandshake, totalDuration float64
|
||||
|
||||
if !timestamps.Created.IsZero() && !timestamps.SemaphoreAcquired.IsZero() {
|
||||
creationToSemaphore = timestamps.SemaphoreAcquired.Sub(timestamps.Created).Seconds()
|
||||
}
|
||||
|
||||
if !timestamps.SemaphoreAcquired.IsZero() && !timestamps.Signaling.IsZero() {
|
||||
semaphoreToSignaling = timestamps.Signaling.Sub(timestamps.SemaphoreAcquired).Seconds()
|
||||
}
|
||||
|
||||
if !timestamps.Signaling.IsZero() && !timestamps.ConnectionReady.IsZero() {
|
||||
signalingToConnection = timestamps.ConnectionReady.Sub(timestamps.Signaling).Seconds()
|
||||
}
|
||||
|
||||
if !timestamps.ConnectionReady.IsZero() && !timestamps.WgHandshakeSuccess.IsZero() {
|
||||
connectionToHandshake = timestamps.WgHandshakeSuccess.Sub(timestamps.ConnectionReady).Seconds()
|
||||
}
|
||||
|
||||
// Calculate total duration:
|
||||
// For initial connections: Created → WgHandshakeSuccess
|
||||
// For reconnections: Signaling → WgHandshakeSuccess (since Created is not tracked)
|
||||
if !timestamps.Created.IsZero() && !timestamps.WgHandshakeSuccess.IsZero() {
|
||||
totalDuration = timestamps.WgHandshakeSuccess.Sub(timestamps.Created).Seconds()
|
||||
} else if !timestamps.Signaling.IsZero() && !timestamps.WgHandshakeSuccess.IsZero() {
|
||||
totalDuration = timestamps.WgHandshakeSuccess.Sub(timestamps.Signaling).Seconds()
|
||||
}
|
||||
|
||||
// Determine attempt type
|
||||
attemptType := "initial"
|
||||
if isReconnection {
|
||||
attemptType = "reconnection"
|
||||
}
|
||||
|
||||
connTypeStr := connectionType.String()
|
||||
|
||||
// Record observations using histograms
|
||||
m.set.GetOrCreateHistogram(
|
||||
m.getMetricName("netbird_peer_connection_stage_creation_to_semaphore", connTypeStr, attemptType),
|
||||
).Update(creationToSemaphore)
|
||||
|
||||
m.set.GetOrCreateHistogram(
|
||||
m.getMetricName("netbird_peer_connection_stage_semaphore_to_signaling", connTypeStr, attemptType),
|
||||
).Update(semaphoreToSignaling)
|
||||
|
||||
m.set.GetOrCreateHistogram(
|
||||
m.getMetricName("netbird_peer_connection_stage_signaling_to_connection", connTypeStr, attemptType),
|
||||
).Update(signalingToConnection)
|
||||
|
||||
m.set.GetOrCreateHistogram(
|
||||
m.getMetricName("netbird_peer_connection_stage_connection_to_handshake", connTypeStr, attemptType),
|
||||
).Update(connectionToHandshake)
|
||||
|
||||
m.set.GetOrCreateHistogram(
|
||||
m.getMetricName("netbird_peer_connection_total_creation_to_handshake", connTypeStr, attemptType),
|
||||
).Update(totalDuration)
|
||||
|
||||
log.Tracef("peer connection metrics [%s, %s, %s]: creation→semaphore: %.3fs, semaphore→signaling: %.3fs, signaling→connection: %.3fs, connection→handshake: %.3fs, total: %.3fs",
|
||||
m.agentInfo.DeploymentType.String(), connTypeStr, attemptType,
|
||||
creationToSemaphore, semaphoreToSignaling, signalingToConnection, connectionToHandshake,
|
||||
totalDuration)
|
||||
}
|
||||
|
||||
// getMetricName constructs a metric name with labels
|
||||
func (m *victoriaMetrics) getMetricName(baseName, connectionType, attemptType string) string {
|
||||
return fmt.Sprintf(`%s{deployment_type=%q,connection_type=%q,attempt_type=%q,version=%q}`,
|
||||
baseName,
|
||||
m.agentInfo.DeploymentType.String(),
|
||||
connectionType,
|
||||
attemptType,
|
||||
m.agentInfo.Version,
|
||||
)
|
||||
}
|
||||
|
||||
// RecordSyncDuration records the duration of sync message processing
|
||||
func (m *victoriaMetrics) RecordSyncDuration(ctx context.Context, duration time.Duration) {
|
||||
metricName := fmt.Sprintf(`netbird_sync_duration_seconds{deployment_type=%q,version=%q}`,
|
||||
m.agentInfo.DeploymentType.String(),
|
||||
m.agentInfo.Version,
|
||||
)
|
||||
|
||||
m.set.GetOrCreateHistogram(metricName).Update(duration.Seconds())
|
||||
}
|
||||
|
||||
// Export writes metrics in Prometheus text format with HELP comments
|
||||
func (m *victoriaMetrics) Export(w io.Writer) error {
|
||||
if m.set == nil {
|
||||
return fmt.Errorf("metrics set not initialized")
|
||||
}
|
||||
|
||||
// Write metrics in Prometheus format
|
||||
m.set.WritePrometheus(w)
|
||||
return nil
|
||||
}
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
"github.com/netbirdio/netbird/client/internal/metrics"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/conntype"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||
@@ -28,6 +29,16 @@ import (
|
||||
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
||||
)
|
||||
|
||||
// MetricsRecorder is an interface for recording peer connection metrics
|
||||
type MetricsRecorder interface {
|
||||
RecordConnectionStages(
|
||||
ctx context.Context,
|
||||
connectionType metrics.ConnectionType,
|
||||
isReconnection bool,
|
||||
timestamps metrics.ConnectionStageTimestamps,
|
||||
)
|
||||
}
|
||||
|
||||
type ServiceDependencies struct {
|
||||
StatusRecorder *Status
|
||||
Signaler *Signaler
|
||||
@@ -36,6 +47,7 @@ type ServiceDependencies struct {
|
||||
SrWatcher *guard.SRWatcher
|
||||
Semaphore *semaphoregroup.SemaphoreGroup
|
||||
PeerConnDispatcher *dispatcher.ConnectionDispatcher
|
||||
MetricsRecorder MetricsRecorder
|
||||
}
|
||||
|
||||
type WgConfig struct {
|
||||
@@ -119,6 +131,10 @@ type Conn struct {
|
||||
dumpState *stateDump
|
||||
|
||||
endpointUpdater *EndpointUpdater
|
||||
|
||||
// Connection stage timestamps for metrics
|
||||
metricsRecorder MetricsRecorder
|
||||
metricsStages MetricsStages
|
||||
}
|
||||
|
||||
// NewConn creates a new not opened Conn to the remote peer.
|
||||
@@ -144,9 +160,11 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
|
||||
statusICE: worker.NewAtomicStatus(),
|
||||
dumpState: dumpState,
|
||||
endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)),
|
||||
wgWatcher: NewWGWatcher(connLog, config.WgConfig.WgInterface, config.Key, dumpState),
|
||||
metricsRecorder: services.MetricsRecorder,
|
||||
}
|
||||
|
||||
conn.wgWatcher = NewWGWatcher(connLog, config.WgConfig.WgInterface, config.Key, dumpState)
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
@@ -154,6 +172,20 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
|
||||
// It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will
|
||||
// be used.
|
||||
func (conn *Conn) Open(engineCtx context.Context) error {
|
||||
conn.mu.Lock()
|
||||
if conn.opened {
|
||||
conn.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Record the start time - beginning of connection attempt
|
||||
conn.metricsStages = MetricsStages{}
|
||||
conn.metricsStages.RecordCreated()
|
||||
|
||||
conn.mu.Unlock()
|
||||
|
||||
// Semaphore.Add() blocks here until there's a free slot
|
||||
// todo create common semaphor logic for reconnection and connections too that can remote seats from semaphor on the fly
|
||||
if err := conn.semaphore.Add(engineCtx); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -166,6 +198,9 @@ func (conn *Conn) Open(engineCtx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Record when semaphore was acquired (after the wait)
|
||||
conn.metricsStages.RecordSemaphoreAcquired()
|
||||
|
||||
conn.ctx, conn.ctxCancel = context.WithCancel(engineCtx)
|
||||
|
||||
conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager)
|
||||
@@ -178,7 +213,7 @@ func (conn *Conn) Open(engineCtx context.Context) error {
|
||||
}
|
||||
conn.workerICE = workerICE
|
||||
|
||||
conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay)
|
||||
conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay, &conn.metricsStages)
|
||||
|
||||
conn.handshaker.AddRelayListener(conn.workerRelay.OnNewOffer)
|
||||
if !isForceRelayed() {
|
||||
@@ -350,7 +385,7 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
|
||||
if conn.currentConnPriority > priority {
|
||||
conn.Log.Infof("current connection priority (%s) is higher than the new one (%s), do not upgrade connection", conn.currentConnPriority, priority)
|
||||
conn.statusICE.SetConnected()
|
||||
conn.updateIceState(iceConnInfo)
|
||||
conn.updateIceState(iceConnInfo, time.Now())
|
||||
return
|
||||
}
|
||||
|
||||
@@ -390,7 +425,8 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
|
||||
}
|
||||
|
||||
conn.Log.Infof("configure WireGuard endpoint to: %s", ep.String())
|
||||
conn.enableWgWatcherIfNeeded()
|
||||
updateTime := time.Now()
|
||||
conn.enableWgWatcherIfNeeded(updateTime)
|
||||
|
||||
presharedKey := conn.presharedKey(iceConnInfo.RosenpassPubKey)
|
||||
if err = conn.endpointUpdater.ConfigureWGEndpoint(ep, presharedKey); err != nil {
|
||||
@@ -406,8 +442,8 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
|
||||
|
||||
conn.currentConnPriority = priority
|
||||
conn.statusICE.SetConnected()
|
||||
conn.updateIceState(iceConnInfo)
|
||||
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
|
||||
conn.updateIceState(iceConnInfo, updateTime)
|
||||
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr, updateTime)
|
||||
}
|
||||
|
||||
func (conn *Conn) onICEStateDisconnected() {
|
||||
@@ -455,6 +491,10 @@ func (conn *Conn) onICEStateDisconnected() {
|
||||
|
||||
conn.disableWgWatcherIfNeeded()
|
||||
|
||||
if conn.currentConnPriority == conntype.None {
|
||||
conn.metricsStages.Disconnected()
|
||||
}
|
||||
|
||||
peerState := State{
|
||||
PubKey: conn.config.Key,
|
||||
ConnStatus: conn.evalStatus(),
|
||||
@@ -495,14 +535,15 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
||||
conn.Log.Debugf("do not switch to relay because current priority is: %s", conn.currentConnPriority.String())
|
||||
conn.setRelayedProxy(wgProxy)
|
||||
conn.statusRelay.SetConnected()
|
||||
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
||||
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey, time.Now())
|
||||
return
|
||||
}
|
||||
|
||||
wgProxy.Work()
|
||||
presharedKey := conn.presharedKey(rci.rosenpassPubKey)
|
||||
|
||||
conn.enableWgWatcherIfNeeded()
|
||||
updateTime := time.Now()
|
||||
conn.enableWgWatcherIfNeeded(updateTime)
|
||||
|
||||
if err := conn.endpointUpdater.ConfigureWGEndpoint(wgProxy.EndpointAddr(), presharedKey); err != nil {
|
||||
if err := wgProxy.CloseConn(); err != nil {
|
||||
@@ -513,13 +554,14 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
||||
}
|
||||
|
||||
wgConfigWorkaround()
|
||||
|
||||
conn.rosenpassRemoteKey = rci.rosenpassPubKey
|
||||
conn.currentConnPriority = conntype.Relay
|
||||
conn.statusRelay.SetConnected()
|
||||
conn.setRelayedProxy(wgProxy)
|
||||
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
||||
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey, updateTime)
|
||||
conn.Log.Infof("start to communicate with peer via relay")
|
||||
conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr)
|
||||
conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr, updateTime)
|
||||
}
|
||||
|
||||
func (conn *Conn) onRelayDisconnected() {
|
||||
@@ -557,6 +599,10 @@ func (conn *Conn) handleRelayDisconnectedLocked() {
|
||||
|
||||
conn.disableWgWatcherIfNeeded()
|
||||
|
||||
if conn.currentConnPriority == conntype.None {
|
||||
conn.metricsStages.Disconnected()
|
||||
}
|
||||
|
||||
peerState := State{
|
||||
PubKey: conn.config.Key,
|
||||
ConnStatus: conn.evalStatus(),
|
||||
@@ -573,6 +619,7 @@ func (conn *Conn) onGuardEvent() {
|
||||
if err := conn.handshaker.SendOffer(); err != nil {
|
||||
conn.Log.Errorf("failed to send offer: %v", err)
|
||||
}
|
||||
conn.metricsStages.RecordSignaling()
|
||||
}
|
||||
|
||||
func (conn *Conn) onWGDisconnected() {
|
||||
@@ -597,10 +644,10 @@ func (conn *Conn) onWGDisconnected() {
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) updateRelayStatus(relayServerAddr string, rosenpassPubKey []byte) {
|
||||
func (conn *Conn) updateRelayStatus(relayServerAddr string, rosenpassPubKey []byte, updateTime time.Time) {
|
||||
peerState := State{
|
||||
PubKey: conn.config.Key,
|
||||
ConnStatusUpdate: time.Now(),
|
||||
ConnStatusUpdate: updateTime,
|
||||
ConnStatus: conn.evalStatus(),
|
||||
Relayed: conn.isRelayed(),
|
||||
RelayServerAddress: relayServerAddr,
|
||||
@@ -613,10 +660,10 @@ func (conn *Conn) updateRelayStatus(relayServerAddr string, rosenpassPubKey []by
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) updateIceState(iceConnInfo ICEConnInfo) {
|
||||
func (conn *Conn) updateIceState(iceConnInfo ICEConnInfo, updateTime time.Time) {
|
||||
peerState := State{
|
||||
PubKey: conn.config.Key,
|
||||
ConnStatusUpdate: time.Now(),
|
||||
ConnStatusUpdate: updateTime,
|
||||
ConnStatus: conn.evalStatus(),
|
||||
Relayed: iceConnInfo.Relayed,
|
||||
LocalIceCandidateType: iceConnInfo.LocalIceCandidateType,
|
||||
@@ -654,11 +701,13 @@ func (conn *Conn) setStatusToDisconnected() {
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) doOnConnected(remoteRosenpassPubKey []byte, remoteRosenpassAddr string) {
|
||||
func (conn *Conn) doOnConnected(remoteRosenpassPubKey []byte, remoteRosenpassAddr string, updateTime time.Time) {
|
||||
if runtime.GOOS == "ios" {
|
||||
runtime.GC()
|
||||
}
|
||||
|
||||
conn.metricsStages.RecordConnectionReady(updateTime)
|
||||
|
||||
if conn.onConnected != nil {
|
||||
conn.onConnected(conn.config.Key, remoteRosenpassPubKey, conn.config.WgConfig.AllowedIps[0].Addr().String(), remoteRosenpassAddr)
|
||||
}
|
||||
@@ -723,14 +772,14 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) {
|
||||
return true
|
||||
}
|
||||
|
||||
func (conn *Conn) enableWgWatcherIfNeeded() {
|
||||
func (conn *Conn) enableWgWatcherIfNeeded(enabledTime time.Time) {
|
||||
if !conn.wgWatcher.IsEnabled() {
|
||||
wgWatcherCtx, wgWatcherCancel := context.WithCancel(conn.ctx)
|
||||
conn.wgWatcherCancel = wgWatcherCancel
|
||||
conn.wgWatcherWg.Add(1)
|
||||
go func() {
|
||||
defer conn.wgWatcherWg.Done()
|
||||
conn.wgWatcher.EnableWgWatcher(wgWatcherCtx, conn.onWGDisconnected)
|
||||
conn.wgWatcher.EnableWgWatcher(wgWatcherCtx, enabledTime, conn.onWGDisconnected, conn.onWGHandshakeSuccess)
|
||||
}()
|
||||
}
|
||||
}
|
||||
@@ -794,6 +843,36 @@ func (conn *Conn) setRelayedProxy(proxy wgproxy.Proxy) {
|
||||
conn.wgProxyRelay = proxy
|
||||
}
|
||||
|
||||
// onWGHandshakeSuccess is called when the first WireGuard handshake is detected
|
||||
func (conn *Conn) onWGHandshakeSuccess(when time.Time) {
|
||||
conn.metricsStages.RecordWGHandshakeSuccess(when)
|
||||
conn.recordConnectionMetrics()
|
||||
}
|
||||
|
||||
// recordConnectionMetrics records connection stage timestamps as metrics
|
||||
func (conn *Conn) recordConnectionMetrics() {
|
||||
if conn.metricsRecorder == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Determine connection type based on current priority
|
||||
var connType metrics.ConnectionType
|
||||
switch conn.currentConnPriority {
|
||||
case conntype.Relay:
|
||||
connType = metrics.ConnectionTypeRelay
|
||||
default:
|
||||
connType = metrics.ConnectionTypeICE
|
||||
}
|
||||
|
||||
// Record metrics with timestamps - duration calculation happens in metrics package
|
||||
conn.metricsRecorder.RecordConnectionStages(
|
||||
context.Background(),
|
||||
connType,
|
||||
conn.metricsStages.IsReconnection(),
|
||||
conn.metricsStages.GetTimestamps(),
|
||||
)
|
||||
}
|
||||
|
||||
// AllowedIP returns the allowed IP of the remote peer
|
||||
func (conn *Conn) AllowedIP() netip.Addr {
|
||||
return conn.config.WgConfig.AllowedIps[0].Addr()
|
||||
|
||||
@@ -44,12 +44,13 @@ type OfferAnswer struct {
|
||||
}
|
||||
|
||||
type Handshaker struct {
|
||||
mu sync.Mutex
|
||||
log *log.Entry
|
||||
config ConnConfig
|
||||
signaler *Signaler
|
||||
ice *WorkerICE
|
||||
relay *WorkerRelay
|
||||
mu sync.Mutex
|
||||
log *log.Entry
|
||||
config ConnConfig
|
||||
signaler *Signaler
|
||||
ice *WorkerICE
|
||||
relay *WorkerRelay
|
||||
metricsStages *MetricsStages
|
||||
// relayListener is not blocking because the listener is using a goroutine to process the messages
|
||||
// and it will only keep the latest message if multiple offers are received in a short time
|
||||
// this is to avoid blocking the handshaker if the listener is doing some heavy processing
|
||||
@@ -64,13 +65,14 @@ type Handshaker struct {
|
||||
remoteAnswerCh chan OfferAnswer
|
||||
}
|
||||
|
||||
func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *WorkerICE, relay *WorkerRelay) *Handshaker {
|
||||
func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *WorkerICE, relay *WorkerRelay, metricsStages *MetricsStages) *Handshaker {
|
||||
return &Handshaker{
|
||||
log: log,
|
||||
config: config,
|
||||
signaler: signaler,
|
||||
ice: ice,
|
||||
relay: relay,
|
||||
metricsStages: metricsStages,
|
||||
remoteOffersCh: make(chan OfferAnswer),
|
||||
remoteAnswerCh: make(chan OfferAnswer),
|
||||
}
|
||||
@@ -89,6 +91,12 @@ func (h *Handshaker) Listen(ctx context.Context) {
|
||||
select {
|
||||
case remoteOfferAnswer := <-h.remoteOffersCh:
|
||||
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
|
||||
|
||||
// Record signaling received for reconnection attempts
|
||||
if h.metricsStages != nil {
|
||||
h.metricsStages.RecordSignalingReceived()
|
||||
}
|
||||
|
||||
if h.relayListener != nil {
|
||||
h.relayListener.Notify(&remoteOfferAnswer)
|
||||
}
|
||||
@@ -103,6 +111,12 @@ func (h *Handshaker) Listen(ctx context.Context) {
|
||||
}
|
||||
case remoteOfferAnswer := <-h.remoteAnswerCh:
|
||||
h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
|
||||
|
||||
// Record signaling received for reconnection attempts
|
||||
if h.metricsStages != nil {
|
||||
h.metricsStages.RecordSignalingReceived()
|
||||
}
|
||||
|
||||
if h.relayListener != nil {
|
||||
h.relayListener.Notify(&remoteOfferAnswer)
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package ice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -32,24 +33,6 @@ type ThreadSafeAgent struct {
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func (a *ThreadSafeAgent) Close() error {
|
||||
var err error
|
||||
a.once.Do(func() {
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- a.Agent.Close()
|
||||
}()
|
||||
|
||||
select {
|
||||
case err = <-done:
|
||||
case <-time.After(iceAgentCloseTimeout):
|
||||
log.Warnf("ICE agent close timed out after %v, proceeding with cleanup", iceAgentCloseTimeout)
|
||||
err = nil
|
||||
}
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) {
|
||||
iceKeepAlive := iceKeepAlive()
|
||||
iceDisconnectedTimeout := iceDisconnectedTimeout()
|
||||
@@ -93,9 +76,41 @@ func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, c
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if agent == nil {
|
||||
return nil, fmt.Errorf("ice.NewAgent returned nil agent without error")
|
||||
}
|
||||
|
||||
return &ThreadSafeAgent{Agent: agent}, nil
|
||||
}
|
||||
|
||||
func (a *ThreadSafeAgent) Close() error {
|
||||
var err error
|
||||
a.once.Do(func() {
|
||||
// Defensive check to prevent nil pointer dereference
|
||||
// This can happen during sleep/wake transitions or memory corruption scenarios
|
||||
// github.com/netbirdio/netbird/client/internal/peer/ice.(*ThreadSafeAgent).Close(0x40006883f0?)
|
||||
// [signal 0xc0000005 code=0x0 addr=0x0 pc=0x7ff7e73af83c]
|
||||
agent := a.Agent
|
||||
if agent == nil {
|
||||
log.Warnf("ICE agent is nil during close, skipping")
|
||||
return
|
||||
}
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- agent.Close()
|
||||
}()
|
||||
|
||||
select {
|
||||
case err = <-done:
|
||||
case <-time.After(iceAgentCloseTimeout):
|
||||
log.Warnf("ICE agent close timed out after %v, proceeding with cleanup", iceAgentCloseTimeout)
|
||||
err = nil
|
||||
}
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func GenerateICECredentials() (string, string, error) {
|
||||
ufrag, err := randutil.GenerateCryptoRandomString(lenUFrag, runesAlpha)
|
||||
if err != nil {
|
||||
|
||||
104
client/internal/peer/metrics_saver.go
Normal file
104
client/internal/peer/metrics_saver.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/metrics"
|
||||
)
|
||||
|
||||
type MetricsStages struct {
|
||||
isReconnectionAttempt bool // Track if current attempt is a reconnection
|
||||
stageTimestamps metrics.ConnectionStageTimestamps
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (s *MetricsStages) RecordCreated() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.stageTimestamps.Created = time.Now()
|
||||
}
|
||||
|
||||
func (s *MetricsStages) RecordSemaphoreAcquired() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.stageTimestamps.SemaphoreAcquired = time.Now()
|
||||
}
|
||||
|
||||
// RecordSignaling records the signaling timestamp when sending offers
|
||||
// For initial connections: records when we start sending
|
||||
// For reconnections: does nothing (we wait for RecordSignalingReceived)
|
||||
func (s *MetricsStages) RecordSignaling() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.isReconnectionAttempt {
|
||||
return
|
||||
}
|
||||
|
||||
if s.stageTimestamps.Signaling.IsZero() {
|
||||
s.stageTimestamps.Signaling = time.Now()
|
||||
}
|
||||
}
|
||||
|
||||
// RecordSignalingReceived records the signaling timestamp when receiving offers/answers
|
||||
// For reconnections: records when we receive the first signal
|
||||
// For initial connections: does nothing (already recorded in RecordSignaling)
|
||||
func (s *MetricsStages) RecordSignalingReceived() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Only record for reconnections when we receive a signal
|
||||
if s.isReconnectionAttempt && s.stageTimestamps.Signaling.IsZero() {
|
||||
s.stageTimestamps.Signaling = time.Now()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MetricsStages) RecordConnectionReady(when time.Time) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.stageTimestamps.ConnectionReady.IsZero() {
|
||||
s.stageTimestamps.ConnectionReady = when
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MetricsStages) RecordWGHandshakeSuccess(handshakeTime time.Time) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if !s.stageTimestamps.ConnectionReady.IsZero() {
|
||||
// WireGuard only reports handshake times with second precision, but ConnectionReady
|
||||
// is captured with microsecond precision. If handshake appears before ConnectionReady
|
||||
// due to truncation (e.g., handshake at 6.042s truncated to 6.000s), normalize to
|
||||
// ConnectionReady to avoid negative duration metrics.
|
||||
if handshakeTime.Before(s.stageTimestamps.ConnectionReady) {
|
||||
s.stageTimestamps.WgHandshakeSuccess = s.stageTimestamps.ConnectionReady
|
||||
} else {
|
||||
s.stageTimestamps.WgHandshakeSuccess = handshakeTime
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Disconnected sets the mode to reconnection. It is called only when both ICE and Relay have been disconnected at the same time.
|
||||
func (s *MetricsStages) Disconnected() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Reset all timestamps for reconnection
|
||||
// For reconnections, we only track from Signaling onwards
|
||||
// This avoids meaningless creation→semaphore and semaphore→signaling metrics
|
||||
s.stageTimestamps = metrics.ConnectionStageTimestamps{}
|
||||
s.isReconnectionAttempt = true
|
||||
}
|
||||
|
||||
func (s *MetricsStages) IsReconnection() bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.isReconnectionAttempt
|
||||
}
|
||||
|
||||
func (s *MetricsStages) GetTimestamps() metrics.ConnectionStageTimestamps {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.stageTimestamps
|
||||
}
|
||||
@@ -45,7 +45,7 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin
|
||||
|
||||
// EnableWgWatcher starts the WireGuard watcher. If it is already enabled, it will return immediately and do nothing.
|
||||
// The watcher runs until ctx is cancelled. Caller is responsible for context lifecycle management.
|
||||
func (w *WGWatcher) EnableWgWatcher(ctx context.Context, onDisconnectedFn func()) {
|
||||
func (w *WGWatcher) EnableWgWatcher(ctx context.Context, enabledTime time.Time, onDisconnectedFn func(), onHandshakeSuccessFn func(when time.Time)) {
|
||||
w.muEnabled.Lock()
|
||||
if w.enabled {
|
||||
w.muEnabled.Unlock()
|
||||
@@ -53,7 +53,6 @@ func (w *WGWatcher) EnableWgWatcher(ctx context.Context, onDisconnectedFn func()
|
||||
}
|
||||
|
||||
w.log.Debugf("enable WireGuard watcher")
|
||||
enabledTime := time.Now()
|
||||
w.enabled = true
|
||||
w.muEnabled.Unlock()
|
||||
|
||||
@@ -62,7 +61,7 @@ func (w *WGWatcher) EnableWgWatcher(ctx context.Context, onDisconnectedFn func()
|
||||
w.log.Warnf("failed to read initial wg stats: %v", err)
|
||||
}
|
||||
|
||||
w.periodicHandshakeCheck(ctx, onDisconnectedFn, enabledTime, initialHandshake)
|
||||
w.periodicHandshakeCheck(ctx, onDisconnectedFn, onHandshakeSuccessFn, enabledTime, initialHandshake)
|
||||
|
||||
w.muEnabled.Lock()
|
||||
w.enabled = false
|
||||
@@ -77,7 +76,7 @@ func (w *WGWatcher) IsEnabled() bool {
|
||||
}
|
||||
|
||||
// wgStateCheck help to check the state of the WireGuard handshake and relay connection
|
||||
func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn func(), enabledTime time.Time, initialHandshake time.Time) {
|
||||
func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn func(), onHandshakeSuccessFn func(when time.Time), enabledTime time.Time, initialHandshake time.Time) {
|
||||
w.log.Infof("WireGuard watcher started")
|
||||
|
||||
timer := time.NewTimer(wgHandshakeOvertime)
|
||||
@@ -96,6 +95,9 @@ func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn
|
||||
if lastHandshake.IsZero() {
|
||||
elapsed := calcElapsed(enabledTime, *handshake)
|
||||
w.log.Infof("first wg handshake detected within: %.2fsec, (%s)", elapsed, handshake)
|
||||
if onHandshakeSuccessFn != nil {
|
||||
onHandshakeSuccessFn(*handshake)
|
||||
}
|
||||
}
|
||||
|
||||
lastHandshake = *handshake
|
||||
|
||||
@@ -35,9 +35,11 @@ func TestWGWatcher_EnableWgWatcher(t *testing.T) {
|
||||
defer cancel()
|
||||
|
||||
onDisconnected := make(chan struct{}, 1)
|
||||
go watcher.EnableWgWatcher(ctx, func() {
|
||||
go watcher.EnableWgWatcher(ctx, time.Now(), func() {
|
||||
mlog.Infof("onDisconnectedFn")
|
||||
onDisconnected <- struct{}{}
|
||||
}, func(when time.Time) {
|
||||
mlog.Infof("onHandshakeSuccess: %v", when)
|
||||
})
|
||||
|
||||
// wait for initial reading
|
||||
@@ -64,7 +66,7 @@ func TestWGWatcher_ReEnable(t *testing.T) {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
watcher.EnableWgWatcher(ctx, func() {})
|
||||
watcher.EnableWgWatcher(ctx, time.Now(), func() {}, func(when time.Time) {})
|
||||
}()
|
||||
cancel()
|
||||
|
||||
@@ -75,9 +77,9 @@ func TestWGWatcher_ReEnable(t *testing.T) {
|
||||
defer cancel()
|
||||
|
||||
onDisconnected := make(chan struct{}, 1)
|
||||
go watcher.EnableWgWatcher(ctx, func() {
|
||||
go watcher.EnableWgWatcher(ctx, time.Now(), func() {
|
||||
onDisconnected <- struct{}{}
|
||||
})
|
||||
}, func(when time.Time) {})
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
mocWgIface.disconnect()
|
||||
|
||||
@@ -107,8 +107,10 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
||||
}
|
||||
w.log.Debugf("agent already exists, recreate the connection")
|
||||
w.agentDialerCancel()
|
||||
if err := w.agent.Close(); err != nil {
|
||||
w.log.Warnf("failed to close ICE agent: %s", err)
|
||||
if w.agent != nil {
|
||||
if err := w.agent.Close(); err != nil {
|
||||
w.log.Warnf("failed to close ICE agent: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
sessionID, err := NewICESessionID()
|
||||
|
||||
@@ -252,7 +252,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
}
|
||||
|
||||
if config.AdminURL == nil {
|
||||
log.Infof("using default Admin URL %s", DefaultManagementURL)
|
||||
log.Infof("using default Admin URL %s", DefaultAdminURL)
|
||||
config.AdminURL, err = parseURL("Admin URL", DefaultAdminURL)
|
||||
if err != nil {
|
||||
return false, err
|
||||
|
||||
@@ -1,262 +0,0 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
const networksetupPath = "/usr/sbin/networksetup"
|
||||
|
||||
// Manager handles system-wide proxy configuration on macOS.
|
||||
type Manager struct {
|
||||
mu sync.Mutex
|
||||
stateManager *statemanager.Manager
|
||||
modifiedServices []string
|
||||
enabled bool
|
||||
}
|
||||
|
||||
// NewManager creates a new proxy manager.
|
||||
func NewManager(stateManager *statemanager.Manager) *Manager {
|
||||
return &Manager{
|
||||
stateManager: stateManager,
|
||||
}
|
||||
}
|
||||
|
||||
// GetActiveNetworkServices returns the list of active network services.
|
||||
func GetActiveNetworkServices() ([]string, error) {
|
||||
cmd := exec.Command(networksetupPath, "-listallnetworkservices")
|
||||
out, err := cmd.Output()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list network services: %w", err)
|
||||
}
|
||||
|
||||
lines := strings.Split(string(out), "\n")
|
||||
var services []string
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" || strings.HasPrefix(line, "*") || strings.Contains(line, "asterisk") {
|
||||
continue
|
||||
}
|
||||
services = append(services, line)
|
||||
}
|
||||
return services, nil
|
||||
}
|
||||
|
||||
// EnableWebProxy enables web proxy for all active network services.
|
||||
func (m *Manager) EnableWebProxy(host string, port int) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.enabled {
|
||||
log.Debug("web proxy already enabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
services, err := GetActiveNetworkServices()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var modifiedServices []string
|
||||
for _, service := range services {
|
||||
if err := m.enableProxyForService(service, host, port); err != nil {
|
||||
log.Warnf("enable proxy for %s: %v", service, err)
|
||||
continue
|
||||
}
|
||||
modifiedServices = append(modifiedServices, service)
|
||||
}
|
||||
|
||||
m.modifiedServices = modifiedServices
|
||||
m.enabled = true
|
||||
m.updateState()
|
||||
|
||||
log.Infof("enabled web proxy on %d services -> %s:%d", len(modifiedServices), host, port)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) enableProxyForService(service, host string, port int) error {
|
||||
portStr := fmt.Sprintf("%d", port)
|
||||
|
||||
// Set web proxy (HTTP)
|
||||
cmd := exec.Command(networksetupPath, "-setwebproxy", service, host, portStr)
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("set web proxy: %w, output: %s", err, out)
|
||||
}
|
||||
|
||||
// Enable web proxy
|
||||
cmd = exec.Command(networksetupPath, "-setwebproxystate", service, "on")
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("enable web proxy state: %w, output: %s", err, out)
|
||||
}
|
||||
|
||||
// Set secure web proxy (HTTPS)
|
||||
cmd = exec.Command(networksetupPath, "-setsecurewebproxy", service, host, portStr)
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("set secure web proxy: %w, output: %s", err, out)
|
||||
}
|
||||
|
||||
// Enable secure web proxy
|
||||
cmd = exec.Command(networksetupPath, "-setsecurewebproxystate", service, "on")
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("enable secure web proxy state: %w, output: %s", err, out)
|
||||
}
|
||||
|
||||
log.Debugf("enabled proxy for service %s", service)
|
||||
return nil
|
||||
}
|
||||
|
||||
// DisableWebProxy disables web proxy for all modified network services.
|
||||
func (m *Manager) DisableWebProxy() error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if !m.enabled {
|
||||
log.Debug("web proxy already disabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
services := m.modifiedServices
|
||||
if len(services) == 0 {
|
||||
services, _ = GetActiveNetworkServices()
|
||||
}
|
||||
|
||||
for _, service := range services {
|
||||
if err := m.disableProxyForService(service); err != nil {
|
||||
log.Warnf("disable proxy for %s: %v", service, err)
|
||||
}
|
||||
}
|
||||
|
||||
m.modifiedServices = nil
|
||||
m.enabled = false
|
||||
m.updateState()
|
||||
|
||||
log.Info("disabled web proxy")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) disableProxyForService(service string) error {
|
||||
// Disable web proxy (HTTP)
|
||||
cmd := exec.Command(networksetupPath, "-setwebproxystate", service, "off")
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("disable web proxy: %w, output: %s", err, out)
|
||||
}
|
||||
|
||||
// Disable secure web proxy (HTTPS)
|
||||
cmd = exec.Command(networksetupPath, "-setsecurewebproxystate", service, "off")
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("disable secure web proxy: %w, output: %s", err, out)
|
||||
}
|
||||
|
||||
log.Debugf("disabled proxy for service %s", service)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetAutoproxyURL sets the automatic proxy configuration URL (PAC file).
|
||||
func (m *Manager) SetAutoproxyURL(pacURL string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
services, err := GetActiveNetworkServices()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var modifiedServices []string
|
||||
for _, service := range services {
|
||||
cmd := exec.Command(networksetupPath, "-setautoproxyurl", service, pacURL)
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
log.Warnf("set autoproxy for %s: %v, output: %s", service, err, out)
|
||||
continue
|
||||
}
|
||||
|
||||
cmd = exec.Command(networksetupPath, "-setautoproxystate", service, "on")
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
log.Warnf("enable autoproxy for %s: %v, output: %s", service, err, out)
|
||||
continue
|
||||
}
|
||||
|
||||
modifiedServices = append(modifiedServices, service)
|
||||
log.Debugf("set autoproxy URL for %s -> %s", service, pacURL)
|
||||
}
|
||||
|
||||
m.modifiedServices = modifiedServices
|
||||
m.enabled = true
|
||||
m.updateState()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DisableAutoproxy disables automatic proxy configuration.
|
||||
func (m *Manager) DisableAutoproxy() error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
services := m.modifiedServices
|
||||
if len(services) == 0 {
|
||||
services, _ = GetActiveNetworkServices()
|
||||
}
|
||||
|
||||
for _, service := range services {
|
||||
cmd := exec.Command(networksetupPath, "-setautoproxystate", service, "off")
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
log.Warnf("disable autoproxy for %s: %v, output: %s", service, err, out)
|
||||
}
|
||||
}
|
||||
|
||||
m.modifiedServices = nil
|
||||
m.enabled = false
|
||||
m.updateState()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsEnabled returns whether the proxy is currently enabled.
|
||||
func (m *Manager) IsEnabled() bool {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.enabled
|
||||
}
|
||||
|
||||
// Restore restores proxy settings from a previous state.
|
||||
func (m *Manager) Restore(services []string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
for _, service := range services {
|
||||
if err := m.disableProxyForService(service); err != nil {
|
||||
log.Warnf("restore proxy for %s: %v", service, err)
|
||||
}
|
||||
}
|
||||
|
||||
m.modifiedServices = nil
|
||||
m.enabled = false
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) updateState() {
|
||||
if m.stateManager == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if m.enabled && len(m.modifiedServices) > 0 {
|
||||
state := &ShutdownState{
|
||||
ModifiedServices: m.modifiedServices,
|
||||
}
|
||||
if err := m.stateManager.UpdateState(state); err != nil {
|
||||
log.Errorf("update proxy state: %v", err)
|
||||
}
|
||||
} else {
|
||||
if err := m.stateManager.DeleteState(&ShutdownState{}); err != nil {
|
||||
log.Debugf("delete proxy state: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,45 +0,0 @@
|
||||
//go:build !darwin || ios
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
// Manager is a no-op proxy manager for non-macOS platforms.
|
||||
type Manager struct{}
|
||||
|
||||
// NewManager creates a new proxy manager (no-op on non-macOS).
|
||||
func NewManager(_ *statemanager.Manager) *Manager {
|
||||
return &Manager{}
|
||||
}
|
||||
|
||||
// EnableWebProxy is a no-op on non-macOS platforms.
|
||||
func (m *Manager) EnableWebProxy(host string, port int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// DisableWebProxy is a no-op on non-macOS platforms.
|
||||
func (m *Manager) DisableWebProxy() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetAutoproxyURL is a no-op on non-macOS platforms.
|
||||
func (m *Manager) SetAutoproxyURL(pacURL string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// DisableAutoproxy is a no-op on non-macOS platforms.
|
||||
func (m *Manager) DisableAutoproxy() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsEnabled always returns false on non-macOS platforms.
|
||||
func (m *Manager) IsEnabled() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Restore is a no-op on non-macOS platforms.
|
||||
func (m *Manager) Restore(services []string) error {
|
||||
return nil
|
||||
}
|
||||
@@ -1,88 +0,0 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestGetActiveNetworkServices(t *testing.T) {
|
||||
services, err := GetActiveNetworkServices()
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, services, "should have at least one network service")
|
||||
|
||||
// Check that services don't contain invalid entries
|
||||
for _, service := range services {
|
||||
assert.NotEmpty(t, service)
|
||||
assert.NotContains(t, service, "*")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_EnableDisableWebProxy(t *testing.T) {
|
||||
// Skip this test in CI as it requires admin privileges
|
||||
if testing.Short() {
|
||||
t.Skip("skipping proxy test in short mode")
|
||||
}
|
||||
|
||||
m := NewManager(nil)
|
||||
assert.NotNil(t, m)
|
||||
assert.False(t, m.IsEnabled())
|
||||
|
||||
// This test would require admin privileges to actually enable the proxy
|
||||
// So we just test the basic state management
|
||||
}
|
||||
|
||||
func TestShutdownState_Name(t *testing.T) {
|
||||
state := &ShutdownState{}
|
||||
assert.Equal(t, "proxy_state", state.Name())
|
||||
}
|
||||
|
||||
func TestShutdownState_Cleanup_EmptyServices(t *testing.T) {
|
||||
state := &ShutdownState{
|
||||
ModifiedServices: []string{},
|
||||
}
|
||||
err := state.Cleanup()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestContains(t *testing.T) {
|
||||
tests := []struct {
|
||||
s string
|
||||
substr string
|
||||
want bool
|
||||
}{
|
||||
{"Enabled: Yes", "Enabled: Yes", true},
|
||||
{"Enabled: No", "Enabled: Yes", false},
|
||||
{"Server: 127.0.0.1\nEnabled: Yes\nPort: 8080", "Enabled: Yes", true},
|
||||
{"", "Enabled: Yes", false},
|
||||
{"Enabled: Yes", "", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.s+"_"+tt.substr, func(t *testing.T) {
|
||||
got := contains(tt.s, tt.substr)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsProxyEnabled(t *testing.T) {
|
||||
tests := []struct {
|
||||
output string
|
||||
want bool
|
||||
}{
|
||||
{"Enabled: Yes\nServer: 127.0.0.1\nPort: 8080", true},
|
||||
{"Enabled: No\nServer: \nPort: 0", false},
|
||||
{"Server: 127.0.0.1\nEnabled: Yes\nPort: 8080", true},
|
||||
{"", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.output, func(t *testing.T) {
|
||||
got := isProxyEnabled(tt.output)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,105 +0,0 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
// ShutdownState stores proxy state for cleanup on unclean shutdown.
|
||||
type ShutdownState struct {
|
||||
ModifiedServices []string `json:"modified_services"`
|
||||
}
|
||||
|
||||
// Name returns the state name for persistence.
|
||||
func (s *ShutdownState) Name() string {
|
||||
return "proxy_state"
|
||||
}
|
||||
|
||||
// Cleanup restores proxy settings after an unclean shutdown.
|
||||
func (s *ShutdownState) Cleanup() error {
|
||||
if len(s.ModifiedServices) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Infof("cleaning up proxy state for %d services", len(s.ModifiedServices))
|
||||
|
||||
for _, service := range s.ModifiedServices {
|
||||
// Disable web proxy (HTTP)
|
||||
cmd := exec.Command(networksetupPath, "-setwebproxystate", service, "off")
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
log.Warnf("cleanup web proxy for %s: %v, output: %s", service, err, out)
|
||||
}
|
||||
|
||||
// Disable secure web proxy (HTTPS)
|
||||
cmd = exec.Command(networksetupPath, "-setsecurewebproxystate", service, "off")
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
log.Warnf("cleanup secure web proxy for %s: %v, output: %s", service, err, out)
|
||||
}
|
||||
|
||||
// Disable autoproxy
|
||||
cmd = exec.Command(networksetupPath, "-setautoproxystate", service, "off")
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
log.Warnf("cleanup autoproxy for %s: %v, output: %s", service, err, out)
|
||||
}
|
||||
|
||||
log.Debugf("cleaned up proxy for service %s", service)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegisterState registers the proxy state with the state manager.
|
||||
func RegisterState(stateManager *statemanager.Manager) {
|
||||
if stateManager == nil {
|
||||
return
|
||||
}
|
||||
stateManager.RegisterState(&ShutdownState{})
|
||||
}
|
||||
|
||||
// GetProxyState returns the current proxy state from the command line.
|
||||
func GetProxyState(service string) (webProxy, secureProxy, autoProxy bool, err error) {
|
||||
// Check web proxy state
|
||||
cmd := exec.Command(networksetupPath, "-getwebproxy", service)
|
||||
out, err := cmd.Output()
|
||||
if err != nil {
|
||||
return false, false, false, fmt.Errorf("get web proxy: %w", err)
|
||||
}
|
||||
webProxy = isProxyEnabled(string(out))
|
||||
|
||||
// Check secure web proxy state
|
||||
cmd = exec.Command(networksetupPath, "-getsecurewebproxy", service)
|
||||
out, err = cmd.Output()
|
||||
if err != nil {
|
||||
return false, false, false, fmt.Errorf("get secure web proxy: %w", err)
|
||||
}
|
||||
secureProxy = isProxyEnabled(string(out))
|
||||
|
||||
// Check autoproxy state
|
||||
cmd = exec.Command(networksetupPath, "-getautoproxyurl", service)
|
||||
out, err = cmd.Output()
|
||||
if err != nil {
|
||||
return false, false, false, fmt.Errorf("get autoproxy: %w", err)
|
||||
}
|
||||
autoProxy = isProxyEnabled(string(out))
|
||||
|
||||
return webProxy, secureProxy, autoProxy, nil
|
||||
}
|
||||
|
||||
func isProxyEnabled(output string) bool {
|
||||
return !contains(output, "Enabled: No") && contains(output, "Enabled: Yes")
|
||||
}
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -1,24 +0,0 @@
|
||||
//go:build !darwin || ios
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
// ShutdownState is a no-op state for non-macOS platforms.
|
||||
type ShutdownState struct{}
|
||||
|
||||
// Name returns the state name.
|
||||
func (s *ShutdownState) Name() string {
|
||||
return "proxy_state"
|
||||
}
|
||||
|
||||
// Cleanup is a no-op on non-macOS platforms.
|
||||
func (s *ShutdownState) Cleanup() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegisterState is a no-op on non-macOS platforms.
|
||||
func RegisterState(stateManager *statemanager.Manager) {
|
||||
}
|
||||
@@ -173,12 +173,21 @@ func (m *DefaultManager) setupAndroidRoutes(config ManagerConfig) {
|
||||
}
|
||||
|
||||
func (m *DefaultManager) setupRefCounters(useNoop bool) {
|
||||
var once sync.Once
|
||||
var wgIface *net.Interface
|
||||
toInterface := func() *net.Interface {
|
||||
once.Do(func() {
|
||||
wgIface = m.wgInterface.ToInterface()
|
||||
})
|
||||
return wgIface
|
||||
}
|
||||
|
||||
m.routeRefCounter = refcounter.New(
|
||||
func(prefix netip.Prefix, _ struct{}) (struct{}, error) {
|
||||
return struct{}{}, m.sysOps.AddVPNRoute(prefix, m.wgInterface.ToInterface())
|
||||
return struct{}{}, m.sysOps.AddVPNRoute(prefix, toInterface())
|
||||
},
|
||||
func(prefix netip.Prefix, _ struct{}) error {
|
||||
return m.sysOps.RemoveVPNRoute(prefix, m.wgInterface.ToInterface())
|
||||
return m.sysOps.RemoveVPNRoute(prefix, toInterface())
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -4,16 +4,17 @@ package systemops
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// filterRoutesByFlags returns true if the route message should be ignored based on its flags.
|
||||
func filterRoutesByFlags(routeMessageFlags int) bool {
|
||||
if routeMessageFlags&syscall.RTF_UP == 0 {
|
||||
if routeMessageFlags&unix.RTF_UP == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE|syscall.RTF_WASCLONED) != 0 {
|
||||
if routeMessageFlags&(unix.RTF_REJECT|unix.RTF_BLACKHOLE|unix.RTF_WASCLONED) != 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -24,42 +25,51 @@ func filterRoutesByFlags(routeMessageFlags int) bool {
|
||||
func formatBSDFlags(flags int) string {
|
||||
var flagStrs []string
|
||||
|
||||
if flags&syscall.RTF_UP != 0 {
|
||||
if flags&unix.RTF_UP != 0 {
|
||||
flagStrs = append(flagStrs, "U")
|
||||
}
|
||||
if flags&syscall.RTF_GATEWAY != 0 {
|
||||
if flags&unix.RTF_GATEWAY != 0 {
|
||||
flagStrs = append(flagStrs, "G")
|
||||
}
|
||||
if flags&syscall.RTF_HOST != 0 {
|
||||
if flags&unix.RTF_HOST != 0 {
|
||||
flagStrs = append(flagStrs, "H")
|
||||
}
|
||||
if flags&syscall.RTF_REJECT != 0 {
|
||||
if flags&unix.RTF_REJECT != 0 {
|
||||
flagStrs = append(flagStrs, "R")
|
||||
}
|
||||
if flags&syscall.RTF_DYNAMIC != 0 {
|
||||
if flags&unix.RTF_DYNAMIC != 0 {
|
||||
flagStrs = append(flagStrs, "D")
|
||||
}
|
||||
if flags&syscall.RTF_MODIFIED != 0 {
|
||||
if flags&unix.RTF_MODIFIED != 0 {
|
||||
flagStrs = append(flagStrs, "M")
|
||||
}
|
||||
if flags&syscall.RTF_STATIC != 0 {
|
||||
if flags&unix.RTF_STATIC != 0 {
|
||||
flagStrs = append(flagStrs, "S")
|
||||
}
|
||||
if flags&syscall.RTF_LLINFO != 0 {
|
||||
if flags&unix.RTF_LLINFO != 0 {
|
||||
flagStrs = append(flagStrs, "L")
|
||||
}
|
||||
if flags&syscall.RTF_LOCAL != 0 {
|
||||
if flags&unix.RTF_LOCAL != 0 {
|
||||
flagStrs = append(flagStrs, "l")
|
||||
}
|
||||
if flags&syscall.RTF_BLACKHOLE != 0 {
|
||||
if flags&unix.RTF_BLACKHOLE != 0 {
|
||||
flagStrs = append(flagStrs, "B")
|
||||
}
|
||||
if flags&syscall.RTF_CLONING != 0 {
|
||||
if flags&unix.RTF_CLONING != 0 {
|
||||
flagStrs = append(flagStrs, "C")
|
||||
}
|
||||
if flags&syscall.RTF_WASCLONED != 0 {
|
||||
if flags&unix.RTF_WASCLONED != 0 {
|
||||
flagStrs = append(flagStrs, "W")
|
||||
}
|
||||
if flags&unix.RTF_PROTO1 != 0 {
|
||||
flagStrs = append(flagStrs, "1")
|
||||
}
|
||||
if flags&unix.RTF_PROTO2 != 0 {
|
||||
flagStrs = append(flagStrs, "2")
|
||||
}
|
||||
if flags&unix.RTF_PROTO3 != 0 {
|
||||
flagStrs = append(flagStrs, "3")
|
||||
}
|
||||
|
||||
if len(flagStrs) == 0 {
|
||||
return "-"
|
||||
|
||||
@@ -4,17 +4,18 @@ package systemops
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// filterRoutesByFlags returns true if the route message should be ignored based on its flags.
|
||||
func filterRoutesByFlags(routeMessageFlags int) bool {
|
||||
if routeMessageFlags&syscall.RTF_UP == 0 {
|
||||
if routeMessageFlags&unix.RTF_UP == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// NOTE: syscall.RTF_WASCLONED deprecated in FreeBSD 8.0
|
||||
if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE) != 0 {
|
||||
// NOTE: RTF_WASCLONED deprecated in FreeBSD 8.0
|
||||
if routeMessageFlags&(unix.RTF_REJECT|unix.RTF_BLACKHOLE) != 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -25,37 +26,46 @@ func filterRoutesByFlags(routeMessageFlags int) bool {
|
||||
func formatBSDFlags(flags int) string {
|
||||
var flagStrs []string
|
||||
|
||||
if flags&syscall.RTF_UP != 0 {
|
||||
if flags&unix.RTF_UP != 0 {
|
||||
flagStrs = append(flagStrs, "U")
|
||||
}
|
||||
if flags&syscall.RTF_GATEWAY != 0 {
|
||||
if flags&unix.RTF_GATEWAY != 0 {
|
||||
flagStrs = append(flagStrs, "G")
|
||||
}
|
||||
if flags&syscall.RTF_HOST != 0 {
|
||||
if flags&unix.RTF_HOST != 0 {
|
||||
flagStrs = append(flagStrs, "H")
|
||||
}
|
||||
if flags&syscall.RTF_REJECT != 0 {
|
||||
if flags&unix.RTF_REJECT != 0 {
|
||||
flagStrs = append(flagStrs, "R")
|
||||
}
|
||||
if flags&syscall.RTF_DYNAMIC != 0 {
|
||||
if flags&unix.RTF_DYNAMIC != 0 {
|
||||
flagStrs = append(flagStrs, "D")
|
||||
}
|
||||
if flags&syscall.RTF_MODIFIED != 0 {
|
||||
if flags&unix.RTF_MODIFIED != 0 {
|
||||
flagStrs = append(flagStrs, "M")
|
||||
}
|
||||
if flags&syscall.RTF_STATIC != 0 {
|
||||
if flags&unix.RTF_STATIC != 0 {
|
||||
flagStrs = append(flagStrs, "S")
|
||||
}
|
||||
if flags&syscall.RTF_LLINFO != 0 {
|
||||
if flags&unix.RTF_LLINFO != 0 {
|
||||
flagStrs = append(flagStrs, "L")
|
||||
}
|
||||
if flags&syscall.RTF_LOCAL != 0 {
|
||||
if flags&unix.RTF_LOCAL != 0 {
|
||||
flagStrs = append(flagStrs, "l")
|
||||
}
|
||||
if flags&syscall.RTF_BLACKHOLE != 0 {
|
||||
if flags&unix.RTF_BLACKHOLE != 0 {
|
||||
flagStrs = append(flagStrs, "B")
|
||||
}
|
||||
// Note: RTF_CLONING and RTF_WASCLONED deprecated in FreeBSD 8.0
|
||||
if flags&unix.RTF_PROTO1 != 0 {
|
||||
flagStrs = append(flagStrs, "1")
|
||||
}
|
||||
if flags&unix.RTF_PROTO2 != 0 {
|
||||
flagStrs = append(flagStrs, "2")
|
||||
}
|
||||
if flags&unix.RTF_PROTO3 != 0 {
|
||||
flagStrs = append(flagStrs, "3")
|
||||
}
|
||||
|
||||
if len(flagStrs) == 0 {
|
||||
return "-"
|
||||
|
||||
@@ -26,6 +26,13 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
|
||||
log.Warnf("failed to get latest sync response: %v", err)
|
||||
}
|
||||
|
||||
var clientMetrics debug.MetricsExporter
|
||||
if s.connectClient != nil {
|
||||
if engine := s.connectClient.Engine(); engine != nil {
|
||||
clientMetrics = engine.GetClientMetrics()
|
||||
}
|
||||
}
|
||||
|
||||
var cpuProfileData []byte
|
||||
if s.cpuProfileBuf != nil && !s.cpuProfiling {
|
||||
cpuProfileData = s.cpuProfileBuf.Bytes()
|
||||
@@ -54,6 +61,7 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
|
||||
LogPath: s.logFile,
|
||||
CPUProfile: cpuProfileData,
|
||||
RefreshStatus: refreshStatus,
|
||||
ClientMetrics: clientMetrics,
|
||||
},
|
||||
debug.BundleConfig{
|
||||
Anonymize: req.GetAnonymize(),
|
||||
|
||||
3
go.mod
3
go.mod
@@ -33,6 +33,7 @@ require (
|
||||
fyne.io/fyne/v2 v2.7.0
|
||||
fyne.io/systray v1.12.1-0.20260116214250-81f8e1a496f9
|
||||
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible
|
||||
github.com/VictoriaMetrics/metrics v1.40.2
|
||||
github.com/awnumar/memguard v0.23.0
|
||||
github.com/aws/aws-sdk-go-v2 v1.36.3
|
||||
github.com/aws/aws-sdk-go-v2/config v1.29.14
|
||||
@@ -263,6 +264,8 @@ require (
|
||||
github.com/stretchr/objx v0.5.2 // indirect
|
||||
github.com/tklauser/go-sysconf v0.3.14 // indirect
|
||||
github.com/tklauser/numcpus v0.8.0 // indirect
|
||||
github.com/valyala/fastrand v1.1.0 // indirect
|
||||
github.com/valyala/histogram v1.2.0 // indirect
|
||||
github.com/vishvananda/netns v0.0.5 // indirect
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
|
||||
github.com/wlynxg/anet v0.0.5 // indirect
|
||||
|
||||
6
go.sum
6
go.sum
@@ -38,6 +38,8 @@ github.com/Microsoft/hcsshim v0.12.3/go.mod h1:Iyl1WVpZzr+UkzjekHZbV8o5Z9ZkxNGx6
|
||||
github.com/RaveNoX/go-jsoncommentstrip v1.0.0/go.mod h1:78ihd09MekBnJnxpICcwzCMzGrKSKYe4AqU6PDYYpjk=
|
||||
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible h1:hqcTK6ZISdip65SR792lwYJTa/axESA0889D3UlZbLo=
|
||||
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible/go.mod h1:6B1nuc1MUs6c62ODZDl7hVE5Pv7O2XGSkgg2olnq34I=
|
||||
github.com/VictoriaMetrics/metrics v1.40.2 h1:OVSjKcQEx6JAwGeu8/KQm9Su5qJ72TMEW4xYn5vw3Ac=
|
||||
github.com/VictoriaMetrics/metrics v1.40.2/go.mod h1:XE4uudAAIRaJE614Tl5HMrtoEU6+GDZO4QTnNSsZRuA=
|
||||
github.com/alexbrainman/sspi v0.0.0-20250919150558-7d374ff0d59e h1:4dAU9FXIyQktpoUAgOJK3OTFc/xug0PCXYCqU0FgDKI=
|
||||
github.com/alexbrainman/sspi v0.0.0-20250919150558-7d374ff0d59e/go.mod h1:cEWa1LVoE5KvSD9ONXsZrj0z6KqySlCCNKHlLzbqAt4=
|
||||
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8=
|
||||
@@ -574,6 +576,10 @@ github.com/tklauser/numcpus v0.8.0 h1:Mx4Wwe/FjZLeQsK/6kt2EOepwwSl7SmJrK5bV/dXYg
|
||||
github.com/tklauser/numcpus v0.8.0/go.mod h1:ZJZlAY+dmR4eut8epnzf0u/VwodKmryxR8txiloSqBE=
|
||||
github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw=
|
||||
github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY=
|
||||
github.com/valyala/fastrand v1.1.0 h1:f+5HkLW4rsgzdNoleUOB69hyT9IlD2ZQh9GyDMfb5G8=
|
||||
github.com/valyala/fastrand v1.1.0/go.mod h1:HWqCzkrkg6QXT8V2EXWvXCoow7vLwOFN002oeRzjapQ=
|
||||
github.com/valyala/histogram v1.2.0 h1:wyYGAZZt3CpwUiIb9AU/Zbllg1llXyrtApRS815OLoQ=
|
||||
github.com/valyala/histogram v1.2.0/go.mod h1:Hb4kBwb4UxsaNbbbh+RRz8ZR6pdodR57tzWUS3BUzXY=
|
||||
github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0=
|
||||
github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4=
|
||||
github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY=
|
||||
|
||||
@@ -327,6 +327,60 @@ func ensureLocalConnector(ctx context.Context, stor storage.Storage) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// HasNonLocalConnectors checks if there are any connectors other than the local connector.
|
||||
func (p *Provider) HasNonLocalConnectors(ctx context.Context) (bool, error) {
|
||||
connectors, err := p.storage.ListConnectors(ctx)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to list connectors: %w", err)
|
||||
}
|
||||
|
||||
p.logger.Info("checking for non-local connectors", "total_connectors", len(connectors))
|
||||
for _, conn := range connectors {
|
||||
p.logger.Info("found connector in storage", "id", conn.ID, "type", conn.Type, "name", conn.Name)
|
||||
if conn.ID != "local" || conn.Type != "local" {
|
||||
p.logger.Info("found non-local connector", "id", conn.ID)
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
p.logger.Info("no non-local connectors found")
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// DisableLocalAuth removes the local (password) connector.
|
||||
// Returns an error if no other connectors are configured.
|
||||
func (p *Provider) DisableLocalAuth(ctx context.Context) error {
|
||||
hasOthers, err := p.HasNonLocalConnectors(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !hasOthers {
|
||||
return fmt.Errorf("cannot disable local authentication: no other identity providers configured")
|
||||
}
|
||||
|
||||
// Check if local connector exists
|
||||
_, err = p.storage.GetConnector(ctx, "local")
|
||||
if errors.Is(err, storage.ErrNotFound) {
|
||||
// Already disabled
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check local connector: %w", err)
|
||||
}
|
||||
|
||||
// Delete the local connector
|
||||
if err := p.storage.DeleteConnector(ctx, "local"); err != nil {
|
||||
return fmt.Errorf("failed to delete local connector: %w", err)
|
||||
}
|
||||
|
||||
p.logger.Info("local authentication disabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
// EnableLocalAuth creates the local (password) connector if it doesn't exist.
|
||||
func (p *Provider) EnableLocalAuth(ctx context.Context) error {
|
||||
return ensureLocalConnector(ctx, p.storage)
|
||||
}
|
||||
|
||||
// ensureStaticConnectors creates or updates static connectors in storage
|
||||
func ensureStaticConnectors(ctx context.Context, stor storage.Storage, connectors []Connector) error {
|
||||
for _, conn := range connectors {
|
||||
|
||||
@@ -247,7 +247,10 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
||||
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
|
||||
c.metrics.CountToSyncResponseDuration(time.Since(start))
|
||||
|
||||
c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{Update: update})
|
||||
c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{
|
||||
Update: update,
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
})
|
||||
}(peer)
|
||||
}
|
||||
|
||||
@@ -370,7 +373,10 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
||||
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
|
||||
|
||||
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
|
||||
c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{Update: update})
|
||||
c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{
|
||||
Update: update,
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -778,6 +784,7 @@ func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerI
|
||||
},
|
||||
},
|
||||
},
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
})
|
||||
c.peersUpdateManager.CloseChannel(ctx, peerID)
|
||||
|
||||
|
||||
@@ -25,11 +25,14 @@ func TestCreateChannel(t *testing.T) {
|
||||
func TestSendUpdate(t *testing.T) {
|
||||
peer := "test-sendupdate"
|
||||
peersUpdater := NewPeersUpdateManager(nil)
|
||||
update1 := &network_map.UpdateMessage{Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{
|
||||
Serial: 0,
|
||||
update1 := &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{
|
||||
Serial: 0,
|
||||
},
|
||||
},
|
||||
}}
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
}
|
||||
_ = peersUpdater.CreateChannel(context.Background(), peer)
|
||||
if _, ok := peersUpdater.peerChannels[peer]; !ok {
|
||||
t.Error("Error creating the channel")
|
||||
@@ -45,11 +48,14 @@ func TestSendUpdate(t *testing.T) {
|
||||
peersUpdater.SendUpdate(context.Background(), peer, update1)
|
||||
}
|
||||
|
||||
update2 := &network_map.UpdateMessage{Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{
|
||||
Serial: 10,
|
||||
update2 := &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{
|
||||
Serial: 10,
|
||||
},
|
||||
},
|
||||
}}
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
}
|
||||
|
||||
peersUpdater.SendUpdate(context.Background(), peer, update2)
|
||||
timeout := time.After(5 * time.Second)
|
||||
|
||||
@@ -4,6 +4,19 @@ import (
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// MessageType indicates the type of update message for debouncing strategy
|
||||
type MessageType int
|
||||
|
||||
const (
|
||||
// MessageTypeNetworkMap represents network map updates (peers, routes, DNS, firewall)
|
||||
// These updates can be safely debounced - only the latest state matters
|
||||
MessageTypeNetworkMap MessageType = iota
|
||||
// MessageTypeControlConfig represents control/config updates (tokens, peer expiration)
|
||||
// These updates should not be dropped as they contain time-sensitive information
|
||||
MessageTypeControlConfig
|
||||
)
|
||||
|
||||
type UpdateMessage struct {
|
||||
Update *proto.SyncResponse
|
||||
Update *proto.SyncResponse
|
||||
MessageType MessageType
|
||||
}
|
||||
|
||||
@@ -69,7 +69,14 @@ func (s *BaseServer) UsersManager() users.Manager {
|
||||
func (s *BaseServer) SettingsManager() settings.Manager {
|
||||
return Create(s, func() settings.Manager {
|
||||
extraSettingsManager := integrations.NewManager(s.EventStore())
|
||||
return settings.NewManager(s.Store(), s.UsersManager(), extraSettingsManager, s.PermissionsManager())
|
||||
|
||||
idpConfig := settings.IdpConfig{}
|
||||
if s.Config.EmbeddedIdP != nil && s.Config.EmbeddedIdP.Enabled {
|
||||
idpConfig.EmbeddedIdpEnabled = true
|
||||
idpConfig.LocalAuthDisabled = s.Config.EmbeddedIdP.LocalAuthDisabled
|
||||
}
|
||||
|
||||
return settings.NewManager(s.Store(), s.UsersManager(), extraSettingsManager, s.PermissionsManager(), idpConfig)
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -77,8 +77,9 @@ type Server struct {
|
||||
|
||||
oAuthConfigProvider idp.OAuthConfigProvider
|
||||
|
||||
syncSem atomic.Int32
|
||||
syncLim int32
|
||||
syncSem atomic.Int32
|
||||
syncLimEnabled bool
|
||||
syncLim int32
|
||||
}
|
||||
|
||||
// NewServer creates a new Management server
|
||||
@@ -108,6 +109,7 @@ func NewServer(
|
||||
blockPeersWithSameConfig := strings.ToLower(os.Getenv(envBlockPeers)) == "true"
|
||||
|
||||
syncLim := int32(defaultSyncLim)
|
||||
syncLimEnabled := true
|
||||
if syncLimStr := os.Getenv(envConcurrentSyncs); syncLimStr != "" {
|
||||
syncLimParsed, err := strconv.Atoi(syncLimStr)
|
||||
if err != nil {
|
||||
@@ -115,6 +117,9 @@ func NewServer(
|
||||
} else {
|
||||
//nolint:gosec
|
||||
syncLim = int32(syncLimParsed)
|
||||
if syncLim < 0 {
|
||||
syncLimEnabled = false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -134,7 +139,8 @@ func NewServer(
|
||||
|
||||
loginFilter: newLoginFilter(),
|
||||
|
||||
syncLim: syncLim,
|
||||
syncLim: syncLim,
|
||||
syncLimEnabled: syncLimEnabled,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -212,7 +218,7 @@ func (s *Server) Job(srv proto.ManagementService_JobServer) error {
|
||||
// Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and
|
||||
// notifies the connected peer of any updates (e.g. new peers under the same account)
|
||||
func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
|
||||
if s.syncSem.Load() >= s.syncLim {
|
||||
if s.syncLimEnabled && s.syncSem.Load() >= s.syncLim {
|
||||
return status.Errorf(codes.ResourceExhausted, "too many concurrent sync requests, please try again later")
|
||||
}
|
||||
s.syncSem.Add(1)
|
||||
@@ -294,7 +300,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
||||
metahash := metaHash(peerMeta, realIP.String())
|
||||
s.loginFilter.addLogin(peerKey.String(), metahash)
|
||||
|
||||
peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP)
|
||||
peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP, reqStart)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err)
|
||||
s.syncSem.Add(-1)
|
||||
@@ -305,7 +311,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
|
||||
s.syncSem.Add(-1)
|
||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
||||
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, reqStart)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -313,7 +319,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("error while notify peer connected for %s: %v", peerKey.String(), err)
|
||||
s.syncSem.Add(-1)
|
||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
||||
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, reqStart)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -330,7 +336,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
||||
|
||||
s.syncSem.Add(-1)
|
||||
|
||||
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv)
|
||||
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv, reqStart)
|
||||
}
|
||||
|
||||
func (s *Server) handleHandshake(ctx context.Context, srv proto.ManagementService_JobServer) (wgtypes.Key, error) {
|
||||
@@ -398,11 +404,20 @@ func (s *Server) sendJobsLoop(ctx context.Context, accountID string, peerKey wgt
|
||||
}
|
||||
|
||||
// handleUpdates sends updates to the connected peer until the updates channel is closed.
|
||||
func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
||||
// It implements a backpressure mechanism that sends the first update immediately,
|
||||
// then debounces subsequent rapid updates, ensuring only the latest update is sent
|
||||
// after a quiet period.
|
||||
func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *network_map.UpdateMessage, srv proto.ManagementService_SyncServer, streamStartTime time.Time) error {
|
||||
log.WithContext(ctx).Tracef("starting to handle updates for peer %s", peerKey.String())
|
||||
|
||||
// Create a debouncer for this peer connection
|
||||
debouncer := NewUpdateDebouncer(1000 * time.Millisecond)
|
||||
defer debouncer.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
// condition when there are some updates
|
||||
// todo set the updates channel size to 1
|
||||
case update, open := <-updates:
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().UpdateChannelQueueLength(len(updates) + 1)
|
||||
@@ -410,20 +425,38 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg
|
||||
|
||||
if !open {
|
||||
log.WithContext(ctx).Debugf("updates channel for peer %s was closed", peerKey.String())
|
||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
||||
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
|
||||
return nil
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String())
|
||||
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv); err != nil {
|
||||
log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err)
|
||||
return err
|
||||
if debouncer.ProcessUpdate(update) {
|
||||
// Send immediately (first update or after quiet period)
|
||||
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv, streamStartTime); err != nil {
|
||||
log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Timer expired - quiet period reached, send pending updates if any
|
||||
case <-debouncer.TimerChannel():
|
||||
pendingUpdates := debouncer.GetPendingUpdates()
|
||||
if len(pendingUpdates) == 0 {
|
||||
continue
|
||||
}
|
||||
log.WithContext(ctx).Debugf("sending %d debounced update(s) for peer %s", len(pendingUpdates), peerKey.String())
|
||||
for _, pendingUpdate := range pendingUpdates {
|
||||
if err := s.sendUpdate(ctx, accountID, peerKey, peer, pendingUpdate, srv, streamStartTime); err != nil {
|
||||
log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// condition when client <-> server connection has been terminated
|
||||
case <-srv.Context().Done():
|
||||
// happens when connection drops, e.g. client disconnects
|
||||
log.WithContext(ctx).Debugf("stream of peer %s has been closed", peerKey.String())
|
||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
||||
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
|
||||
return srv.Context().Err()
|
||||
}
|
||||
}
|
||||
@@ -431,16 +464,16 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg
|
||||
|
||||
// sendUpdate encrypts the update message using the peer key and the server's wireguard key,
|
||||
// then sends the encrypted message to the connected peer via the sync server.
|
||||
func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
||||
func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *network_map.UpdateMessage, srv proto.ManagementService_SyncServer, streamStartTime time.Time) error {
|
||||
key, err := s.secretsManager.GetWGKey()
|
||||
if err != nil {
|
||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
||||
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
|
||||
return status.Errorf(codes.Internal, "failed processing update message")
|
||||
}
|
||||
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, key, update.Update)
|
||||
if err != nil {
|
||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
||||
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
|
||||
return status.Errorf(codes.Internal, "failed processing update message")
|
||||
}
|
||||
err = srv.Send(&proto.EncryptedMessage{
|
||||
@@ -448,7 +481,7 @@ func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtyp
|
||||
Body: encryptedResp,
|
||||
})
|
||||
if err != nil {
|
||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
||||
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
|
||||
return status.Errorf(codes.Internal, "failed sending update message")
|
||||
}
|
||||
log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String())
|
||||
@@ -480,11 +513,15 @@ func (s *Server) sendJob(ctx context.Context, peerKey wgtypes.Key, job *job.Even
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) {
|
||||
func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer, streamStartTime time.Time) {
|
||||
unlock := s.acquirePeerLockByUID(ctx, peer.Key)
|
||||
defer unlock()
|
||||
|
||||
err := s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
|
||||
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, streamStartTime)
|
||||
}
|
||||
|
||||
func (s *Server) cancelPeerRoutinesWithoutLock(ctx context.Context, accountID string, peer *nbpeer.Peer, streamStartTime time.Time) {
|
||||
err := s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key, streamStartTime)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to disconnect peer %s properly: %v", peer.Key, err)
|
||||
}
|
||||
|
||||
@@ -242,7 +242,10 @@ func (m *TimeBasedAuthSecretsManager) pushNewTURNAndRelayTokens(ctx context.Cont
|
||||
m.extendNetbirdConfig(ctx, peerID, accountID, update)
|
||||
|
||||
log.WithContext(ctx).Debugf("sending new TURN credentials to peer %s", peerID)
|
||||
m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{Update: update})
|
||||
m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{
|
||||
Update: update,
|
||||
MessageType: network_map.MessageTypeControlConfig,
|
||||
})
|
||||
}
|
||||
|
||||
func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, accountID, peerID string) {
|
||||
@@ -266,7 +269,10 @@ func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, ac
|
||||
m.extendNetbirdConfig(ctx, peerID, accountID, update)
|
||||
|
||||
log.WithContext(ctx).Debugf("sending new relay credentials to peer %s", peerID)
|
||||
m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{Update: update})
|
||||
m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{
|
||||
Update: update,
|
||||
MessageType: network_map.MessageTypeControlConfig,
|
||||
})
|
||||
}
|
||||
|
||||
func (m *TimeBasedAuthSecretsManager) extendNetbirdConfig(ctx context.Context, peerID, accountID string, update *proto.SyncResponse) {
|
||||
|
||||
103
management/internals/shared/grpc/update_debouncer.go
Normal file
103
management/internals/shared/grpc/update_debouncer.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
)
|
||||
|
||||
// UpdateDebouncer implements a backpressure mechanism that:
|
||||
// - Sends the first update immediately
|
||||
// - Coalesces rapid subsequent network map updates (only latest matters)
|
||||
// - Queues control/config updates (all must be delivered)
|
||||
// - Preserves the order of messages (important for control configs between network maps)
|
||||
// - Ensures pending updates are sent after a quiet period
|
||||
type UpdateDebouncer struct {
|
||||
debounceInterval time.Duration
|
||||
timer *time.Timer
|
||||
pendingUpdates []*network_map.UpdateMessage // Queue that preserves order
|
||||
timerC <-chan time.Time
|
||||
}
|
||||
|
||||
// NewUpdateDebouncer creates a new debouncer with the specified interval
|
||||
func NewUpdateDebouncer(interval time.Duration) *UpdateDebouncer {
|
||||
return &UpdateDebouncer{
|
||||
debounceInterval: interval,
|
||||
}
|
||||
}
|
||||
|
||||
// ProcessUpdate handles an incoming update and returns whether it should be sent immediately
|
||||
func (d *UpdateDebouncer) ProcessUpdate(update *network_map.UpdateMessage) bool {
|
||||
if d.timer == nil {
|
||||
// No active debounce timer, signal to send immediately
|
||||
// and start the debounce period
|
||||
d.startTimer()
|
||||
return true
|
||||
}
|
||||
|
||||
// Already in debounce period, accumulate this update preserving order
|
||||
// Check if we should coalesce with the last pending update
|
||||
if len(d.pendingUpdates) > 0 &&
|
||||
update.MessageType == network_map.MessageTypeNetworkMap &&
|
||||
d.pendingUpdates[len(d.pendingUpdates)-1].MessageType == network_map.MessageTypeNetworkMap {
|
||||
// Replace the last network map with this one (coalesce consecutive network maps)
|
||||
d.pendingUpdates[len(d.pendingUpdates)-1] = update
|
||||
} else {
|
||||
// Append to the queue (preserves order for control configs and non-consecutive network maps)
|
||||
d.pendingUpdates = append(d.pendingUpdates, update)
|
||||
}
|
||||
d.resetTimer()
|
||||
return false
|
||||
}
|
||||
|
||||
// TimerChannel returns the timer channel for select statements
|
||||
func (d *UpdateDebouncer) TimerChannel() <-chan time.Time {
|
||||
if d.timer == nil {
|
||||
return nil
|
||||
}
|
||||
return d.timerC
|
||||
}
|
||||
|
||||
// GetPendingUpdates returns and clears all pending updates after timer expiration.
|
||||
// Updates are returned in the order they were received, with consecutive network maps
|
||||
// already coalesced to only the latest one.
|
||||
// If there were pending updates, it restarts the timer to continue debouncing.
|
||||
// If there were no pending updates, it clears the timer (true quiet period).
|
||||
func (d *UpdateDebouncer) GetPendingUpdates() []*network_map.UpdateMessage {
|
||||
updates := d.pendingUpdates
|
||||
d.pendingUpdates = nil
|
||||
|
||||
if len(updates) > 0 {
|
||||
// There were pending updates, so updates are still coming rapidly
|
||||
// Restart the timer to continue debouncing mode
|
||||
if d.timer != nil {
|
||||
d.timer.Reset(d.debounceInterval)
|
||||
}
|
||||
} else {
|
||||
// No pending updates means true quiet period - return to immediate mode
|
||||
d.timer = nil
|
||||
d.timerC = nil
|
||||
}
|
||||
|
||||
return updates
|
||||
}
|
||||
|
||||
// Stop stops the debouncer and cleans up resources
|
||||
func (d *UpdateDebouncer) Stop() {
|
||||
if d.timer != nil {
|
||||
d.timer.Stop()
|
||||
d.timer = nil
|
||||
d.timerC = nil
|
||||
}
|
||||
d.pendingUpdates = nil
|
||||
}
|
||||
|
||||
func (d *UpdateDebouncer) startTimer() {
|
||||
d.timer = time.NewTimer(d.debounceInterval)
|
||||
d.timerC = d.timer.C
|
||||
}
|
||||
|
||||
func (d *UpdateDebouncer) resetTimer() {
|
||||
d.timer.Stop()
|
||||
d.timer.Reset(d.debounceInterval)
|
||||
}
|
||||
587
management/internals/shared/grpc/update_debouncer_test.go
Normal file
587
management/internals/shared/grpc/update_debouncer_test.go
Normal file
@@ -0,0 +1,587 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
func TestUpdateDebouncer_FirstUpdateSentImmediately(t *testing.T) {
|
||||
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
||||
defer debouncer.Stop()
|
||||
|
||||
update := &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{},
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
}
|
||||
|
||||
shouldSend := debouncer.ProcessUpdate(update)
|
||||
|
||||
if !shouldSend {
|
||||
t.Error("First update should be sent immediately")
|
||||
}
|
||||
|
||||
if debouncer.TimerChannel() == nil {
|
||||
t.Error("Timer should be started after first update")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateDebouncer_RapidUpdatesCoalesced(t *testing.T) {
|
||||
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
||||
defer debouncer.Stop()
|
||||
|
||||
update1 := &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{},
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
}
|
||||
update2 := &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{},
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
}
|
||||
update3 := &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{},
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
}
|
||||
|
||||
// First update should be sent immediately
|
||||
if !debouncer.ProcessUpdate(update1) {
|
||||
t.Error("First update should be sent immediately")
|
||||
}
|
||||
|
||||
// Rapid subsequent updates should be coalesced
|
||||
if debouncer.ProcessUpdate(update2) {
|
||||
t.Error("Second rapid update should not be sent immediately")
|
||||
}
|
||||
|
||||
if debouncer.ProcessUpdate(update3) {
|
||||
t.Error("Third rapid update should not be sent immediately")
|
||||
}
|
||||
|
||||
// Wait for debounce period
|
||||
select {
|
||||
case <-debouncer.TimerChannel():
|
||||
pendingUpdates := debouncer.GetPendingUpdates()
|
||||
if len(pendingUpdates) != 1 {
|
||||
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
|
||||
}
|
||||
if pendingUpdates[0] != update3 {
|
||||
t.Error("Should get the last update (update3)")
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("Timer should have fired")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateDebouncer_LastUpdateAlwaysSent(t *testing.T) {
|
||||
debouncer := NewUpdateDebouncer(30 * time.Millisecond)
|
||||
defer debouncer.Stop()
|
||||
|
||||
update1 := &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{},
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
}
|
||||
update2 := &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{},
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
}
|
||||
|
||||
// Send first update
|
||||
debouncer.ProcessUpdate(update1)
|
||||
|
||||
// Send second update within debounce period
|
||||
debouncer.ProcessUpdate(update2)
|
||||
|
||||
// Wait for timer
|
||||
select {
|
||||
case <-debouncer.TimerChannel():
|
||||
pendingUpdates := debouncer.GetPendingUpdates()
|
||||
if len(pendingUpdates) != 1 {
|
||||
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
|
||||
}
|
||||
if pendingUpdates[0] != update2 {
|
||||
t.Error("Should get the last update")
|
||||
}
|
||||
if pendingUpdates[0] == update1 {
|
||||
t.Error("Should not get the first update")
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("Timer should have fired")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateDebouncer_TimerResetOnNewUpdate(t *testing.T) {
|
||||
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
||||
defer debouncer.Stop()
|
||||
|
||||
update1 := &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{},
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
}
|
||||
update2 := &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{},
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
}
|
||||
update3 := &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{},
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
}
|
||||
|
||||
// Send first update
|
||||
debouncer.ProcessUpdate(update1)
|
||||
|
||||
// Wait a bit, but not the full debounce period
|
||||
time.Sleep(30 * time.Millisecond)
|
||||
|
||||
// Send second update - should reset timer
|
||||
debouncer.ProcessUpdate(update2)
|
||||
|
||||
// Wait a bit more
|
||||
time.Sleep(30 * time.Millisecond)
|
||||
|
||||
// Send third update - should reset timer again
|
||||
debouncer.ProcessUpdate(update3)
|
||||
|
||||
// Now wait for the timer (should fire after last update's reset)
|
||||
select {
|
||||
case <-debouncer.TimerChannel():
|
||||
pendingUpdates := debouncer.GetPendingUpdates()
|
||||
if len(pendingUpdates) != 1 {
|
||||
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
|
||||
}
|
||||
if pendingUpdates[0] != update3 {
|
||||
t.Error("Should get the last update (update3)")
|
||||
}
|
||||
// Timer should be restarted since there was a pending update
|
||||
if debouncer.TimerChannel() == nil {
|
||||
t.Error("Timer should be restarted after sending pending update")
|
||||
}
|
||||
case <-time.After(150 * time.Millisecond):
|
||||
t.Error("Timer should have fired")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateDebouncer_TimerRestartsAfterPendingUpdateSent(t *testing.T) {
|
||||
debouncer := NewUpdateDebouncer(30 * time.Millisecond)
|
||||
defer debouncer.Stop()
|
||||
|
||||
update1 := &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{},
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
}
|
||||
update2 := &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{},
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
}
|
||||
update3 := &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{},
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
}
|
||||
|
||||
// First update sent immediately
|
||||
debouncer.ProcessUpdate(update1)
|
||||
|
||||
// Second update coalesced
|
||||
debouncer.ProcessUpdate(update2)
|
||||
|
||||
// Wait for timer to expire
|
||||
select {
|
||||
case <-debouncer.TimerChannel():
|
||||
pendingUpdates := debouncer.GetPendingUpdates()
|
||||
|
||||
if len(pendingUpdates) == 0 {
|
||||
t.Fatal("Should have pending update")
|
||||
}
|
||||
|
||||
// After sending pending update, timer is restarted, so next update is NOT immediate
|
||||
if debouncer.ProcessUpdate(update3) {
|
||||
t.Error("Update after debounced send should not be sent immediately (timer restarted)")
|
||||
}
|
||||
|
||||
// Wait for the restarted timer and verify update3 is pending
|
||||
select {
|
||||
case <-debouncer.TimerChannel():
|
||||
finalUpdates := debouncer.GetPendingUpdates()
|
||||
if len(finalUpdates) != 1 || finalUpdates[0] != update3 {
|
||||
t.Error("Should get update3 as pending")
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("Timer should have fired for restarted timer")
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("Timer should have fired")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateDebouncer_StopCleansUp(t *testing.T) {
|
||||
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
||||
|
||||
update := &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{},
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
}
|
||||
|
||||
// Send update to start timer
|
||||
debouncer.ProcessUpdate(update)
|
||||
|
||||
// Stop should clean up
|
||||
debouncer.Stop()
|
||||
|
||||
// Multiple stops should be safe
|
||||
debouncer.Stop()
|
||||
}
|
||||
|
||||
func TestUpdateDebouncer_HighFrequencyUpdates(t *testing.T) {
|
||||
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
||||
defer debouncer.Stop()
|
||||
|
||||
// Simulate high-frequency updates
|
||||
var lastUpdate *network_map.UpdateMessage
|
||||
sentImmediately := 0
|
||||
for i := 0; i < 100; i++ {
|
||||
update := &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{
|
||||
Serial: uint64(i),
|
||||
},
|
||||
},
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
}
|
||||
lastUpdate = update
|
||||
if debouncer.ProcessUpdate(update) {
|
||||
sentImmediately++
|
||||
}
|
||||
time.Sleep(1 * time.Millisecond) // Very rapid updates
|
||||
}
|
||||
|
||||
// Only first update should be sent immediately
|
||||
if sentImmediately != 1 {
|
||||
t.Errorf("Expected only 1 update sent immediately, got %d", sentImmediately)
|
||||
}
|
||||
|
||||
// Wait for debounce period
|
||||
select {
|
||||
case <-debouncer.TimerChannel():
|
||||
pendingUpdates := debouncer.GetPendingUpdates()
|
||||
if len(pendingUpdates) != 1 {
|
||||
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
|
||||
}
|
||||
if pendingUpdates[0] != lastUpdate {
|
||||
t.Error("Should get the very last update")
|
||||
}
|
||||
if pendingUpdates[0].Update.NetworkMap.Serial != 99 {
|
||||
t.Errorf("Expected serial 99, got %d", pendingUpdates[0].Update.NetworkMap.Serial)
|
||||
}
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Error("Timer should have fired")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateDebouncer_NoUpdatesAfterFirst(t *testing.T) {
|
||||
debouncer := NewUpdateDebouncer(30 * time.Millisecond)
|
||||
defer debouncer.Stop()
|
||||
|
||||
update := &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{},
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
}
|
||||
|
||||
// Send first update
|
||||
if !debouncer.ProcessUpdate(update) {
|
||||
t.Error("First update should be sent immediately")
|
||||
}
|
||||
|
||||
// Wait for timer to expire with no additional updates (true quiet period)
|
||||
select {
|
||||
case <-debouncer.TimerChannel():
|
||||
pendingUpdates := debouncer.GetPendingUpdates()
|
||||
if len(pendingUpdates) != 0 {
|
||||
t.Error("Should have no pending updates")
|
||||
}
|
||||
// After true quiet period, timer should be cleared
|
||||
if debouncer.TimerChannel() != nil {
|
||||
t.Error("Timer should be cleared after quiet period")
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("Timer should have fired")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateDebouncer_IntermediateUpdatesDropped(t *testing.T) {
|
||||
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
||||
defer debouncer.Stop()
|
||||
|
||||
updates := make([]*network_map.UpdateMessage, 5)
|
||||
for i := range updates {
|
||||
updates[i] = &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{
|
||||
Serial: uint64(i),
|
||||
},
|
||||
},
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
}
|
||||
}
|
||||
|
||||
// First update sent immediately
|
||||
debouncer.ProcessUpdate(updates[0])
|
||||
|
||||
// Send updates 1, 2, 3, 4 rapidly - only last one should remain pending
|
||||
debouncer.ProcessUpdate(updates[1])
|
||||
debouncer.ProcessUpdate(updates[2])
|
||||
debouncer.ProcessUpdate(updates[3])
|
||||
debouncer.ProcessUpdate(updates[4])
|
||||
|
||||
// Wait for debounce
|
||||
<-debouncer.TimerChannel()
|
||||
pendingUpdates := debouncer.GetPendingUpdates()
|
||||
|
||||
if len(pendingUpdates) != 1 {
|
||||
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
|
||||
}
|
||||
if pendingUpdates[0].Update.NetworkMap.Serial != 4 {
|
||||
t.Errorf("Expected only the last update (serial 4), got serial %d", pendingUpdates[0].Update.NetworkMap.Serial)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateDebouncer_TrueQuietPeriodResetsToImmediateMode(t *testing.T) {
|
||||
debouncer := NewUpdateDebouncer(30 * time.Millisecond)
|
||||
defer debouncer.Stop()
|
||||
|
||||
update1 := &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{},
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
}
|
||||
update2 := &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{},
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
}
|
||||
|
||||
// First update sent immediately
|
||||
if !debouncer.ProcessUpdate(update1) {
|
||||
t.Error("First update should be sent immediately")
|
||||
}
|
||||
|
||||
// Wait for timer without sending any more updates (true quiet period)
|
||||
<-debouncer.TimerChannel()
|
||||
pendingUpdates := debouncer.GetPendingUpdates()
|
||||
|
||||
if len(pendingUpdates) != 0 {
|
||||
t.Error("Should have no pending updates during quiet period")
|
||||
}
|
||||
|
||||
// After true quiet period, next update should be sent immediately
|
||||
if !debouncer.ProcessUpdate(update2) {
|
||||
t.Error("Update after true quiet period should be sent immediately")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateDebouncer_ContinuousHighFrequencyStaysInDebounceMode(t *testing.T) {
|
||||
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
||||
defer debouncer.Stop()
|
||||
|
||||
// Simulate continuous high-frequency updates
|
||||
for i := 0; i < 10; i++ {
|
||||
update := &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{
|
||||
Serial: uint64(i),
|
||||
},
|
||||
},
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
}
|
||||
|
||||
if i == 0 {
|
||||
// First one sent immediately
|
||||
if !debouncer.ProcessUpdate(update) {
|
||||
t.Error("First update should be sent immediately")
|
||||
}
|
||||
} else {
|
||||
// All others should be coalesced (not sent immediately)
|
||||
if debouncer.ProcessUpdate(update) {
|
||||
t.Errorf("Update %d should not be sent immediately", i)
|
||||
}
|
||||
}
|
||||
|
||||
// Wait a bit but send next update before debounce expires
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
|
||||
// Now wait for final debounce
|
||||
select {
|
||||
case <-debouncer.TimerChannel():
|
||||
pendingUpdates := debouncer.GetPendingUpdates()
|
||||
if len(pendingUpdates) == 0 {
|
||||
t.Fatal("Should have the last update pending")
|
||||
}
|
||||
if pendingUpdates[0].Update.NetworkMap.Serial != 9 {
|
||||
t.Errorf("Expected serial 9, got %d", pendingUpdates[0].Update.NetworkMap.Serial)
|
||||
}
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Error("Timer should have fired")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateDebouncer_ControlConfigMessagesQueued(t *testing.T) {
|
||||
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
||||
defer debouncer.Stop()
|
||||
|
||||
netmapUpdate := &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 1}},
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
}
|
||||
tokenUpdate1 := &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}},
|
||||
MessageType: network_map.MessageTypeControlConfig,
|
||||
}
|
||||
tokenUpdate2 := &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}},
|
||||
MessageType: network_map.MessageTypeControlConfig,
|
||||
}
|
||||
|
||||
// First update sent immediately
|
||||
debouncer.ProcessUpdate(netmapUpdate)
|
||||
|
||||
// Send multiple control config updates - they should all be queued
|
||||
debouncer.ProcessUpdate(tokenUpdate1)
|
||||
debouncer.ProcessUpdate(tokenUpdate2)
|
||||
|
||||
// Wait for debounce period
|
||||
select {
|
||||
case <-debouncer.TimerChannel():
|
||||
pendingUpdates := debouncer.GetPendingUpdates()
|
||||
// Should get both control config updates
|
||||
if len(pendingUpdates) != 2 {
|
||||
t.Errorf("Expected 2 control config updates, got %d", len(pendingUpdates))
|
||||
}
|
||||
// Control configs should come first
|
||||
if pendingUpdates[0] != tokenUpdate1 {
|
||||
t.Error("First pending update should be tokenUpdate1")
|
||||
}
|
||||
if pendingUpdates[1] != tokenUpdate2 {
|
||||
t.Error("Second pending update should be tokenUpdate2")
|
||||
}
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Error("Timer should have fired")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateDebouncer_MixedMessageTypes(t *testing.T) {
|
||||
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
||||
defer debouncer.Stop()
|
||||
|
||||
netmapUpdate1 := &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 1}},
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
}
|
||||
netmapUpdate2 := &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 2}},
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
}
|
||||
tokenUpdate := &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}},
|
||||
MessageType: network_map.MessageTypeControlConfig,
|
||||
}
|
||||
|
||||
// First update sent immediately
|
||||
debouncer.ProcessUpdate(netmapUpdate1)
|
||||
|
||||
// Send token update and network map update
|
||||
debouncer.ProcessUpdate(tokenUpdate)
|
||||
debouncer.ProcessUpdate(netmapUpdate2)
|
||||
|
||||
// Wait for debounce period
|
||||
select {
|
||||
case <-debouncer.TimerChannel():
|
||||
pendingUpdates := debouncer.GetPendingUpdates()
|
||||
// Should get 2 updates in order: token, then network map
|
||||
if len(pendingUpdates) != 2 {
|
||||
t.Errorf("Expected 2 pending updates, got %d", len(pendingUpdates))
|
||||
}
|
||||
// Token update should come first (preserves order)
|
||||
if pendingUpdates[0] != tokenUpdate {
|
||||
t.Error("First pending update should be tokenUpdate")
|
||||
}
|
||||
// Network map update should come second
|
||||
if pendingUpdates[1] != netmapUpdate2 {
|
||||
t.Error("Second pending update should be netmapUpdate2")
|
||||
}
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Error("Timer should have fired")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateDebouncer_OrderPreservation(t *testing.T) {
|
||||
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
||||
defer debouncer.Stop()
|
||||
|
||||
// Simulate: 50 network maps -> 1 control config -> 50 network maps
|
||||
// Expected result: 3 messages (netmap, controlConfig, netmap)
|
||||
|
||||
// Send first network map immediately
|
||||
firstNetmap := &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 0}},
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
}
|
||||
if !debouncer.ProcessUpdate(firstNetmap) {
|
||||
t.Error("First update should be sent immediately")
|
||||
}
|
||||
|
||||
// Send 49 more network maps (will be coalesced to last one)
|
||||
var lastNetmapBatch1 *network_map.UpdateMessage
|
||||
for i := 1; i < 50; i++ {
|
||||
lastNetmapBatch1 = &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: uint64(i)}},
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
}
|
||||
debouncer.ProcessUpdate(lastNetmapBatch1)
|
||||
}
|
||||
|
||||
// Send 1 control config
|
||||
controlConfig := &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}},
|
||||
MessageType: network_map.MessageTypeControlConfig,
|
||||
}
|
||||
debouncer.ProcessUpdate(controlConfig)
|
||||
|
||||
// Send 50 more network maps (will be coalesced to last one)
|
||||
var lastNetmapBatch2 *network_map.UpdateMessage
|
||||
for i := 50; i < 100; i++ {
|
||||
lastNetmapBatch2 = &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: uint64(i)}},
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
}
|
||||
debouncer.ProcessUpdate(lastNetmapBatch2)
|
||||
}
|
||||
|
||||
// Wait for debounce period
|
||||
select {
|
||||
case <-debouncer.TimerChannel():
|
||||
pendingUpdates := debouncer.GetPendingUpdates()
|
||||
// Should get exactly 3 updates: netmap, controlConfig, netmap
|
||||
if len(pendingUpdates) != 3 {
|
||||
t.Errorf("Expected 3 pending updates, got %d", len(pendingUpdates))
|
||||
}
|
||||
// First should be the last netmap from batch 1
|
||||
if pendingUpdates[0] != lastNetmapBatch1 {
|
||||
t.Error("First pending update should be last netmap from batch 1")
|
||||
}
|
||||
if pendingUpdates[0].Update.NetworkMap.Serial != 49 {
|
||||
t.Errorf("Expected serial 49, got %d", pendingUpdates[0].Update.NetworkMap.Serial)
|
||||
}
|
||||
// Second should be the control config
|
||||
if pendingUpdates[1] != controlConfig {
|
||||
t.Error("Second pending update should be control config")
|
||||
}
|
||||
// Third should be the last netmap from batch 2
|
||||
if pendingUpdates[2] != lastNetmapBatch2 {
|
||||
t.Error("Third pending update should be last netmap from batch 2")
|
||||
}
|
||||
if pendingUpdates[2].Update.NetworkMap.Serial != 99 {
|
||||
t.Errorf("Expected serial 99, got %d", pendingUpdates[2].Update.NetworkMap.Serial)
|
||||
}
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Error("Timer should have fired")
|
||||
}
|
||||
}
|
||||
@@ -26,7 +26,6 @@ import (
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
nbdomain "github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/formatter/hook"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
@@ -49,6 +48,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
nbdomain "github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
@@ -795,6 +795,19 @@ func IsEmbeddedIdp(i idp.Manager) bool {
|
||||
return ok
|
||||
}
|
||||
|
||||
// IsLocalAuthDisabled checks if local (email/password) authentication is disabled.
|
||||
// Returns true only when using embedded IDP with local auth disabled in config.
|
||||
func IsLocalAuthDisabled(ctx context.Context, i idp.Manager) bool {
|
||||
if isNil(i) {
|
||||
return false
|
||||
}
|
||||
embeddedIdp, ok := i.(*idp.EmbeddedIdPManager)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return embeddedIdp.IsLocalAuthDisabled()
|
||||
}
|
||||
|
||||
// addAccountIDToIDPAppMeta update user's app metadata in idp manager
|
||||
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error {
|
||||
if !isNil(am.idpManager) && !IsEmbeddedIdp(am.idpManager) {
|
||||
@@ -1657,13 +1670,13 @@ func domainIsUpToDate(domain string, domainCategory string, userAuth auth.UserAu
|
||||
return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||
peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err)
|
||||
}
|
||||
|
||||
err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, accountID)
|
||||
err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, accountID, syncTime)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
|
||||
}
|
||||
@@ -1671,8 +1684,20 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID
|
||||
return peer, netMap, postureChecks, dnsfwdPort, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error {
|
||||
err := am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID)
|
||||
func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error {
|
||||
peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerPubKey)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to get peer %s for disconnect check: %v", peerPubKey, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
if peer.Status.LastSeen.After(streamStartTime) {
|
||||
log.WithContext(ctx).Tracef("peer %s has newer activity (lastSeen=%s > streamStart=%s), skipping disconnect",
|
||||
peerPubKey, peer.Status.LastSeen.Format(time.RFC3339), streamStartTime.Format(time.RFC3339))
|
||||
return nil
|
||||
}
|
||||
|
||||
err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID, time.Now().UTC())
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err)
|
||||
}
|
||||
|
||||
@@ -58,7 +58,7 @@ type Manager interface {
|
||||
GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
|
||||
ListUsers(ctx context.Context, accountID string) ([]*types.User, error)
|
||||
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error
|
||||
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error
|
||||
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
|
||||
UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||
UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error
|
||||
@@ -114,8 +114,8 @@ type Manager interface {
|
||||
UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error
|
||||
GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error)
|
||||
GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error)
|
||||
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error
|
||||
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error
|
||||
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
||||
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||
GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error)
|
||||
|
||||
@@ -1881,7 +1881,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID)
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC())
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
||||
@@ -1952,7 +1952,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
// when we mark peer as connected, the peer login expiration routine should trigger
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID)
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC())
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
failed := waitTimeout(wg, time.Second)
|
||||
@@ -1961,6 +1961,82 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
key, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err, "unable to generate WireGuard key")
|
||||
peerPubKey := key.PublicKey().String()
|
||||
|
||||
_, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: peerPubKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
|
||||
}, false)
|
||||
require.NoError(t, err, "unable to add peer")
|
||||
|
||||
t.Run("disconnect peer when streamStartTime is after LastSeen", func(t *testing.T) {
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, time.Now().UTC())
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
require.NoError(t, err, "unable to get peer")
|
||||
require.True(t, peer.Status.Connected, "peer should be connected")
|
||||
|
||||
streamStartTime := time.Now().UTC()
|
||||
|
||||
err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, streamStartTime)
|
||||
require.NoError(t, err)
|
||||
|
||||
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
require.NoError(t, err)
|
||||
require.False(t, peer.Status.Connected, "peer should be disconnected")
|
||||
})
|
||||
|
||||
t.Run("skip disconnect when LastSeen is after streamStartTime (zombie stream protection)", func(t *testing.T) {
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, time.Now().UTC())
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
require.NoError(t, err)
|
||||
require.True(t, peer.Status.Connected, "peer should be connected")
|
||||
|
||||
streamStartTime := peer.Status.LastSeen.Add(-1 * time.Hour)
|
||||
|
||||
err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, streamStartTime)
|
||||
require.NoError(t, err)
|
||||
|
||||
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
require.NoError(t, err)
|
||||
require.True(t, peer.Status.Connected,
|
||||
"peer should remain connected because LastSeen > streamStartTime (zombie stream protection)")
|
||||
})
|
||||
|
||||
t.Run("skip stale connect when peer already has newer LastSeen (blocked goroutine protection)", func(t *testing.T) {
|
||||
node2SyncTime := time.Now().UTC()
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, node2SyncTime)
|
||||
require.NoError(t, err, "node 2 should connect peer")
|
||||
|
||||
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
require.NoError(t, err)
|
||||
require.True(t, peer.Status.Connected, "peer should be connected")
|
||||
require.Equal(t, node2SyncTime.Unix(), peer.Status.LastSeen.Unix(), "LastSeen should be node2SyncTime")
|
||||
|
||||
node1StaleSyncTime := node2SyncTime.Add(-1 * time.Minute)
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, node1StaleSyncTime)
|
||||
require.NoError(t, err, "stale connect should not return error")
|
||||
|
||||
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
require.NoError(t, err)
|
||||
require.True(t, peer.Status.Connected, "peer should still be connected")
|
||||
require.Equal(t, node2SyncTime.Unix(), peer.Status.LastSeen.Unix(),
|
||||
"LastSeen should NOT be overwritten by stale syncTime from blocked goroutine")
|
||||
})
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
@@ -1983,7 +2059,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID)
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC())
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
@@ -3176,7 +3252,7 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
start := time.Now()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1})
|
||||
_, _, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1}, time.Now().UTC())
|
||||
assert.NoError(b, err)
|
||||
}
|
||||
|
||||
|
||||
@@ -9,10 +9,11 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
idpmanager "github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/rs/cors"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
idpmanager "github.com/netbirdio/netbird/management/server/idp"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
@@ -129,15 +130,15 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
|
||||
return nil, fmt.Errorf("register integrations endpoints: %w", err)
|
||||
}
|
||||
|
||||
// Check if embedded IdP is enabled
|
||||
// Check if embedded IdP is enabled for instance manager
|
||||
embeddedIdP, embeddedIdpEnabled := idpManager.(*idpmanager.EmbeddedIdPManager)
|
||||
instanceManager, err := nbinstance.NewManager(ctx, accountManager.GetStore(), embeddedIdP)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create instance manager: %w", err)
|
||||
}
|
||||
|
||||
accounts.AddEndpoints(accountManager, settingsManager, embeddedIdpEnabled, router)
|
||||
peers.AddEndpoints(accountManager, router, networkMapController)
|
||||
accounts.AddEndpoints(accountManager, settingsManager, router)
|
||||
peers.AddEndpoints(accountManager, router, networkMapController, permissionsManager)
|
||||
users.AddEndpoints(accountManager, router)
|
||||
users.AddInvitesEndpoints(accountManager, router)
|
||||
users.AddPublicInvitesEndpoints(accountManager, router)
|
||||
|
||||
@@ -36,24 +36,22 @@ const (
|
||||
|
||||
// handler is a handler that handles the server.Account HTTP endpoints
|
||||
type handler struct {
|
||||
accountManager account.Manager
|
||||
settingsManager settings.Manager
|
||||
embeddedIdpEnabled bool
|
||||
accountManager account.Manager
|
||||
settingsManager settings.Manager
|
||||
}
|
||||
|
||||
func AddEndpoints(accountManager account.Manager, settingsManager settings.Manager, embeddedIdpEnabled bool, router *mux.Router) {
|
||||
accountsHandler := newHandler(accountManager, settingsManager, embeddedIdpEnabled)
|
||||
func AddEndpoints(accountManager account.Manager, settingsManager settings.Manager, router *mux.Router) {
|
||||
accountsHandler := newHandler(accountManager, settingsManager)
|
||||
router.HandleFunc("/accounts/{accountId}", accountsHandler.updateAccount).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/accounts/{accountId}", accountsHandler.deleteAccount).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/accounts", accountsHandler.getAllAccounts).Methods("GET", "OPTIONS")
|
||||
}
|
||||
|
||||
// newHandler creates a new handler HTTP handler
|
||||
func newHandler(accountManager account.Manager, settingsManager settings.Manager, embeddedIdpEnabled bool) *handler {
|
||||
func newHandler(accountManager account.Manager, settingsManager settings.Manager) *handler {
|
||||
return &handler{
|
||||
accountManager: accountManager,
|
||||
settingsManager: settingsManager,
|
||||
embeddedIdpEnabled: embeddedIdpEnabled,
|
||||
accountManager: accountManager,
|
||||
settingsManager: settingsManager,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -165,7 +163,7 @@ func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
resp := toAccountResponse(accountID, settings, meta, onboarding, h.embeddedIdpEnabled)
|
||||
resp := toAccountResponse(accountID, settings, meta, onboarding)
|
||||
util.WriteJSONObject(r.Context(), w, []*api.Account{resp})
|
||||
}
|
||||
|
||||
@@ -292,7 +290,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
resp := toAccountResponse(accountID, updatedSettings, meta, updatedOnboarding, h.embeddedIdpEnabled)
|
||||
resp := toAccountResponse(accountID, updatedSettings, meta, updatedOnboarding)
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, &resp)
|
||||
}
|
||||
@@ -321,7 +319,7 @@ func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
func toAccountResponse(accountID string, settings *types.Settings, meta *types.AccountMeta, onboarding *types.AccountOnboarding, embeddedIdpEnabled bool) *api.Account {
|
||||
func toAccountResponse(accountID string, settings *types.Settings, meta *types.AccountMeta, onboarding *types.AccountOnboarding) *api.Account {
|
||||
jwtAllowGroups := settings.JWTAllowGroups
|
||||
if jwtAllowGroups == nil {
|
||||
jwtAllowGroups = []string{}
|
||||
@@ -341,7 +339,8 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A
|
||||
LazyConnectionEnabled: &settings.LazyConnectionEnabled,
|
||||
DnsDomain: &settings.DNSDomain,
|
||||
AutoUpdateVersion: &settings.AutoUpdateVersion,
|
||||
EmbeddedIdpEnabled: &embeddedIdpEnabled,
|
||||
EmbeddedIdpEnabled: &settings.EmbeddedIdpEnabled,
|
||||
LocalAuthDisabled: &settings.LocalAuthDisabled,
|
||||
}
|
||||
|
||||
if settings.NetworkRange.IsValid() {
|
||||
|
||||
@@ -33,7 +33,6 @@ func initAccountsTestData(t *testing.T, account *types.Account) *handler {
|
||||
AnyTimes()
|
||||
|
||||
return &handler{
|
||||
embeddedIdpEnabled: false,
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*types.Settings, error) {
|
||||
return account.Settings, nil
|
||||
@@ -124,6 +123,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
DnsDomain: sr(""),
|
||||
AutoUpdateVersion: sr(""),
|
||||
EmbeddedIdpEnabled: br(false),
|
||||
LocalAuthDisabled: br(false),
|
||||
},
|
||||
expectedArray: true,
|
||||
expectedID: accountID,
|
||||
@@ -148,6 +148,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
DnsDomain: sr(""),
|
||||
AutoUpdateVersion: sr(""),
|
||||
EmbeddedIdpEnabled: br(false),
|
||||
LocalAuthDisabled: br(false),
|
||||
},
|
||||
expectedArray: false,
|
||||
expectedID: accountID,
|
||||
@@ -172,6 +173,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
DnsDomain: sr(""),
|
||||
AutoUpdateVersion: sr("latest"),
|
||||
EmbeddedIdpEnabled: br(false),
|
||||
LocalAuthDisabled: br(false),
|
||||
},
|
||||
expectedArray: false,
|
||||
expectedID: accountID,
|
||||
@@ -196,6 +198,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
DnsDomain: sr(""),
|
||||
AutoUpdateVersion: sr(""),
|
||||
EmbeddedIdpEnabled: br(false),
|
||||
LocalAuthDisabled: br(false),
|
||||
},
|
||||
expectedArray: false,
|
||||
expectedID: accountID,
|
||||
@@ -220,6 +223,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
DnsDomain: sr(""),
|
||||
AutoUpdateVersion: sr(""),
|
||||
EmbeddedIdpEnabled: br(false),
|
||||
LocalAuthDisabled: br(false),
|
||||
},
|
||||
expectedArray: false,
|
||||
expectedID: accountID,
|
||||
@@ -244,6 +248,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
DnsDomain: sr(""),
|
||||
AutoUpdateVersion: sr(""),
|
||||
EmbeddedIdpEnabled: br(false),
|
||||
LocalAuthDisabled: br(false),
|
||||
},
|
||||
expectedArray: false,
|
||||
expectedID: accountID,
|
||||
|
||||
@@ -46,7 +46,7 @@ func (h *handler) getInstanceStatus(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteErrorResponse("failed to check instance status", http.StatusInternalServerError, w)
|
||||
return
|
||||
}
|
||||
|
||||
log.WithContext(r.Context()).Infof("instance setup status: %v", setupRequired)
|
||||
util.WriteJSONObject(r.Context(), w, api.InstanceStatus{
|
||||
SetupRequired: setupRequired,
|
||||
})
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
@@ -26,11 +27,12 @@ import (
|
||||
// Handler is a handler that returns peers of the account
|
||||
type Handler struct {
|
||||
accountManager account.Manager
|
||||
permissionsManager permissions.Manager
|
||||
networkMapController network_map.Controller
|
||||
}
|
||||
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router, networkMapController network_map.Controller) {
|
||||
peersHandler := NewHandler(accountManager, networkMapController)
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router, networkMapController network_map.Controller, permissionsManager permissions.Manager) {
|
||||
peersHandler := NewHandler(accountManager, networkMapController, permissionsManager)
|
||||
router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer).
|
||||
Methods("GET", "PUT", "DELETE", "OPTIONS")
|
||||
@@ -42,10 +44,11 @@ func AddEndpoints(accountManager account.Manager, router *mux.Router, networkMap
|
||||
}
|
||||
|
||||
// NewHandler creates a new peers Handler
|
||||
func NewHandler(accountManager account.Manager, networkMapController network_map.Controller) *Handler {
|
||||
func NewHandler(accountManager account.Manager, networkMapController network_map.Controller, permissionsManager permissions.Manager) *Handler {
|
||||
return &Handler{
|
||||
accountManager: accountManager,
|
||||
networkMapController: networkMapController,
|
||||
permissionsManager: permissionsManager,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -359,13 +362,19 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, activity.SystemInitiator)
|
||||
user, err := h.accountManager.GetUserByID(r.Context(), userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.accountManager.GetUserByID(r.Context(), userID)
|
||||
err = h.permissionsManager.ValidateAccountAccess(r.Context(), accountID, user, false)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), status.NewPermissionDeniedError(), w)
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, activity.SystemInitiator)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
||||
@@ -13,13 +13,15 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/gorilla/mux"
|
||||
"go.uber.org/mock/gomock"
|
||||
ugomock "go.uber.org/mock/gomock"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
@@ -102,7 +104,7 @@ func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler {
|
||||
},
|
||||
}
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
ctrl := ugomock.NewController(t)
|
||||
|
||||
networkMapController := network_map.NewMockController(ctrl)
|
||||
networkMapController.EXPECT().
|
||||
@@ -110,6 +112,10 @@ func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler {
|
||||
Return("domain").
|
||||
AnyTimes()
|
||||
|
||||
ctrl2 := gomock.NewController(t)
|
||||
permissionsManager := permissions.NewMockManager(ctrl2)
|
||||
permissionsManager.EXPECT().ValidateAccountAccess(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||
|
||||
return &Handler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
UpdatePeerFunc: func(_ context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
|
||||
@@ -199,6 +205,7 @@ func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler {
|
||||
},
|
||||
},
|
||||
networkMapController: networkMapController,
|
||||
permissionsManager: permissionsManager,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -205,6 +205,14 @@ func TestCreateInvite(t *testing.T) {
|
||||
return nil, status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "local auth disabled",
|
||||
requestBody: `{"email":"test@example.com","name":"Test User","role":"user","auto_groups":[]}`,
|
||||
expectedStatus: http.StatusPreconditionFailed,
|
||||
mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
|
||||
return nil, status.Errorf(status.PreconditionFailed, "local user creation is disabled - use an external identity provider")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid JSON",
|
||||
requestBody: `{invalid json}`,
|
||||
@@ -376,6 +384,15 @@ func TestAcceptInvite(t *testing.T) {
|
||||
return status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "local auth disabled",
|
||||
token: testInviteToken,
|
||||
requestBody: `{"password":"SecurePass123!"}`,
|
||||
expectedStatus: http.StatusPreconditionFailed,
|
||||
mockFunc: func(ctx context.Context, token, password string) error {
|
||||
return status.Errorf(status.PreconditionFailed, "local user creation is disabled - use an external identity provider")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing token",
|
||||
token: "",
|
||||
|
||||
@@ -73,7 +73,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
||||
proxyController := integrations.NewController(store)
|
||||
userManager := users.NewManager(store)
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager)
|
||||
settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager, settings.IdpConfig{})
|
||||
peersManager := peers.NewManager(store, permissionsManager)
|
||||
|
||||
jobManager := job.NewJobManager(nil, store, peersManager)
|
||||
|
||||
@@ -43,6 +43,11 @@ type EmbeddedIdPConfig struct {
|
||||
Owner *OwnerConfig
|
||||
// SignKeyRefreshEnabled enables automatic key rotation for signing keys
|
||||
SignKeyRefreshEnabled bool
|
||||
// LocalAuthDisabled disables the local (email/password) authentication connector.
|
||||
// When true, users cannot authenticate via email/password, only via external identity providers.
|
||||
// Existing local users are preserved and will be able to login again if re-enabled.
|
||||
// Cannot be enabled if no external identity provider connectors are configured.
|
||||
LocalAuthDisabled bool
|
||||
}
|
||||
|
||||
// EmbeddedStorageConfig holds storage configuration for the embedded IdP.
|
||||
@@ -105,6 +110,8 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
|
||||
Issuer: "NetBird",
|
||||
Theme: "light",
|
||||
},
|
||||
// Always enable password DB initially - we disable the local connector after startup if needed.
|
||||
// This ensures Dex has at least one connector during initialization.
|
||||
EnablePasswordDB: true,
|
||||
StaticClients: []storage.Client{
|
||||
{
|
||||
@@ -192,11 +199,32 @@ func NewEmbeddedIdPManager(ctx context.Context, config *EmbeddedIdPConfig, appMe
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("initializing embedded Dex IDP with config: %+v", config)
|
||||
|
||||
provider, err := dex.NewProviderFromYAML(ctx, yamlConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create embedded IdP provider: %w", err)
|
||||
}
|
||||
|
||||
// If local auth is disabled, validate that other connectors exist
|
||||
if config.LocalAuthDisabled {
|
||||
hasOthers, err := provider.HasNonLocalConnectors(ctx)
|
||||
if err != nil {
|
||||
_ = provider.Stop(ctx)
|
||||
return nil, fmt.Errorf("failed to check connectors: %w", err)
|
||||
}
|
||||
if !hasOthers {
|
||||
_ = provider.Stop(ctx)
|
||||
return nil, fmt.Errorf("cannot disable local authentication: no other identity providers configured")
|
||||
}
|
||||
// Ensure local connector is removed (it might exist from a previous run)
|
||||
if err := provider.DisableLocalAuth(ctx); err != nil {
|
||||
_ = provider.Stop(ctx)
|
||||
return nil, fmt.Errorf("failed to disable local auth: %w", err)
|
||||
}
|
||||
log.WithContext(ctx).Info("local authentication disabled - only external identity providers can be used")
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Infof("embedded Dex IDP initialized with issuer: %s", yamlConfig.Issuer)
|
||||
|
||||
return &EmbeddedIdPManager{
|
||||
@@ -281,6 +309,8 @@ func (m *EmbeddedIdPManager) GetAllAccounts(ctx context.Context) (map[string][]*
|
||||
return nil, fmt.Errorf("failed to list users: %w", err)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("retrieved %d users from embedded IdP", len(users))
|
||||
|
||||
indexedUsers := make(map[string][]*UserData)
|
||||
for _, user := range users {
|
||||
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], &UserData{
|
||||
@@ -290,11 +320,17 @@ func (m *EmbeddedIdPManager) GetAllAccounts(ctx context.Context) (map[string][]*
|
||||
})
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("retrieved %d users from embedded IdP", len(indexedUsers[UnsetAccountID]))
|
||||
|
||||
return indexedUsers, nil
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in the embedded IdP.
|
||||
func (m *EmbeddedIdPManager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) {
|
||||
if m.config.LocalAuthDisabled {
|
||||
return nil, fmt.Errorf("local user creation is disabled")
|
||||
}
|
||||
|
||||
if m.appMetrics != nil {
|
||||
m.appMetrics.IDPMetrics().CountCreateUser()
|
||||
}
|
||||
@@ -364,6 +400,10 @@ func (m *EmbeddedIdPManager) GetUserByEmail(ctx context.Context, email string) (
|
||||
// Unlike CreateUser which auto-generates a password, this method uses the provided password.
|
||||
// This is useful for instance setup where the user provides their own password.
|
||||
func (m *EmbeddedIdPManager) CreateUserWithPassword(ctx context.Context, email, password, name string) (*UserData, error) {
|
||||
if m.config.LocalAuthDisabled {
|
||||
return nil, fmt.Errorf("local user creation is disabled")
|
||||
}
|
||||
|
||||
if m.appMetrics != nil {
|
||||
m.appMetrics.IDPMetrics().CountCreateUser()
|
||||
}
|
||||
@@ -553,3 +593,13 @@ func (m *EmbeddedIdPManager) GetClientIDs() []string {
|
||||
func (m *EmbeddedIdPManager) GetUserIDClaim() string {
|
||||
return defaultUserIDClaim
|
||||
}
|
||||
|
||||
// IsLocalAuthDisabled returns whether local authentication is disabled based on configuration.
|
||||
func (m *EmbeddedIdPManager) IsLocalAuthDisabled() bool {
|
||||
return m.config.LocalAuthDisabled
|
||||
}
|
||||
|
||||
// HasNonLocalConnectors checks if there are any identity provider connectors other than local.
|
||||
func (m *EmbeddedIdPManager) HasNonLocalConnectors(ctx context.Context) (bool, error) {
|
||||
return m.provider.HasNonLocalConnectors(ctx)
|
||||
}
|
||||
|
||||
@@ -370,3 +370,234 @@ func TestEmbeddedIdPManager_GetLocalKeysLocation(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmbeddedIdPManager_LocalAuthDisabled(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("cannot start with local auth disabled without other connectors", func(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
config := &EmbeddedIdPConfig{
|
||||
Enabled: true,
|
||||
Issuer: "http://localhost:5556/dex",
|
||||
LocalAuthDisabled: true,
|
||||
Storage: EmbeddedStorageConfig{
|
||||
Type: "sqlite3",
|
||||
Config: EmbeddedStorageTypeConfig{
|
||||
File: filepath.Join(tmpDir, "dex.db"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err = NewEmbeddedIdPManager(ctx, config, nil)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no other identity providers configured")
|
||||
})
|
||||
|
||||
t.Run("local auth enabled by default", func(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
config := &EmbeddedIdPConfig{
|
||||
Enabled: true,
|
||||
Issuer: "http://localhost:5556/dex",
|
||||
Storage: EmbeddedStorageConfig{
|
||||
Type: "sqlite3",
|
||||
Config: EmbeddedStorageTypeConfig{
|
||||
File: filepath.Join(tmpDir, "dex.db"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := NewEmbeddedIdPManager(ctx, config, nil)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = manager.Stop(ctx) }()
|
||||
|
||||
// Verify local auth is enabled by default
|
||||
assert.False(t, manager.IsLocalAuthDisabled())
|
||||
})
|
||||
|
||||
t.Run("start with local auth disabled when connector exists", func(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
dbFile := filepath.Join(tmpDir, "dex.db")
|
||||
|
||||
// First, create a manager with local auth enabled and add a connector
|
||||
config1 := &EmbeddedIdPConfig{
|
||||
Enabled: true,
|
||||
Issuer: "http://localhost:5556/dex",
|
||||
Storage: EmbeddedStorageConfig{
|
||||
Type: "sqlite3",
|
||||
Config: EmbeddedStorageTypeConfig{
|
||||
File: dbFile,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
manager1, err := NewEmbeddedIdPManager(ctx, config1, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a user
|
||||
userData, err := manager1.CreateUser(ctx, "preserved@example.com", "Preserved User", "account1", "admin@example.com")
|
||||
require.NoError(t, err)
|
||||
userID := userData.ID
|
||||
|
||||
// Add an external connector (Google doesn't require OIDC discovery)
|
||||
_, err = manager1.CreateConnector(ctx, &dex.ConnectorConfig{
|
||||
ID: "google-test",
|
||||
Name: "Google Test",
|
||||
Type: "google",
|
||||
ClientID: "test-client-id",
|
||||
ClientSecret: "test-client-secret",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Stop the first manager
|
||||
err = manager1.Stop(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Now create a new manager with local auth disabled
|
||||
config2 := &EmbeddedIdPConfig{
|
||||
Enabled: true,
|
||||
Issuer: "http://localhost:5556/dex",
|
||||
LocalAuthDisabled: true,
|
||||
Storage: EmbeddedStorageConfig{
|
||||
Type: "sqlite3",
|
||||
Config: EmbeddedStorageTypeConfig{
|
||||
File: dbFile,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
manager2, err := NewEmbeddedIdPManager(ctx, config2, nil)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = manager2.Stop(ctx) }()
|
||||
|
||||
// Verify local auth is disabled via config
|
||||
assert.True(t, manager2.IsLocalAuthDisabled())
|
||||
|
||||
// Verify the user still exists in storage (just can't login via local)
|
||||
lookedUp, err := manager2.GetUserDataByID(ctx, userID, AppMetadata{})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "preserved@example.com", lookedUp.Email)
|
||||
})
|
||||
|
||||
t.Run("CreateUser fails when local auth is disabled", func(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
dbFile := filepath.Join(tmpDir, "dex.db")
|
||||
|
||||
// First, create a manager and add an external connector
|
||||
config1 := &EmbeddedIdPConfig{
|
||||
Enabled: true,
|
||||
Issuer: "http://localhost:5556/dex",
|
||||
Storage: EmbeddedStorageConfig{
|
||||
Type: "sqlite3",
|
||||
Config: EmbeddedStorageTypeConfig{
|
||||
File: dbFile,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
manager1, err := NewEmbeddedIdPManager(ctx, config1, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = manager1.CreateConnector(ctx, &dex.ConnectorConfig{
|
||||
ID: "google-test",
|
||||
Name: "Google Test",
|
||||
Type: "google",
|
||||
ClientID: "test-client-id",
|
||||
ClientSecret: "test-client-secret",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = manager1.Stop(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create manager with local auth disabled
|
||||
config2 := &EmbeddedIdPConfig{
|
||||
Enabled: true,
|
||||
Issuer: "http://localhost:5556/dex",
|
||||
LocalAuthDisabled: true,
|
||||
Storage: EmbeddedStorageConfig{
|
||||
Type: "sqlite3",
|
||||
Config: EmbeddedStorageTypeConfig{
|
||||
File: dbFile,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
manager2, err := NewEmbeddedIdPManager(ctx, config2, nil)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = manager2.Stop(ctx) }()
|
||||
|
||||
// Try to create a user - should fail
|
||||
_, err = manager2.CreateUser(ctx, "newuser@example.com", "New User", "account1", "admin@example.com")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "local user creation is disabled")
|
||||
})
|
||||
|
||||
t.Run("CreateUserWithPassword fails when local auth is disabled", func(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
dbFile := filepath.Join(tmpDir, "dex.db")
|
||||
|
||||
// First, create a manager and add an external connector
|
||||
config1 := &EmbeddedIdPConfig{
|
||||
Enabled: true,
|
||||
Issuer: "http://localhost:5556/dex",
|
||||
Storage: EmbeddedStorageConfig{
|
||||
Type: "sqlite3",
|
||||
Config: EmbeddedStorageTypeConfig{
|
||||
File: dbFile,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
manager1, err := NewEmbeddedIdPManager(ctx, config1, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = manager1.CreateConnector(ctx, &dex.ConnectorConfig{
|
||||
ID: "google-test",
|
||||
Name: "Google Test",
|
||||
Type: "google",
|
||||
ClientID: "test-client-id",
|
||||
ClientSecret: "test-client-secret",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = manager1.Stop(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create manager with local auth disabled
|
||||
config2 := &EmbeddedIdPConfig{
|
||||
Enabled: true,
|
||||
Issuer: "http://localhost:5556/dex",
|
||||
LocalAuthDisabled: true,
|
||||
Storage: EmbeddedStorageConfig{
|
||||
Type: "sqlite3",
|
||||
Config: EmbeddedStorageTypeConfig{
|
||||
File: dbFile,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
manager2, err := NewEmbeddedIdPManager(ctx, config2, nil)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = manager2.Stop(ctx) }()
|
||||
|
||||
// Try to create a user with password - should fail
|
||||
_, err = manager2.CreateUserWithPassword(ctx, "newuser@example.com", "SecurePass123!", "New User")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "local user creation is disabled")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -104,13 +104,22 @@ func NewManager(ctx context.Context, store store.Store, idpManager idp.Manager)
|
||||
}
|
||||
|
||||
func (m *DefaultManager) loadSetupRequired(ctx context.Context) error {
|
||||
// Check if there are any accounts in the NetBird store
|
||||
numAccounts, err := m.store.GetAccountsCounter(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hasAccounts := numAccounts > 0
|
||||
|
||||
// Check if there are any users in the embedded IdP (Dex)
|
||||
users, err := m.embeddedIdpManager.GetAllAccounts(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hasLocalUsers := len(users) > 0
|
||||
|
||||
m.setupMu.Lock()
|
||||
m.setupRequired = len(users) == 0
|
||||
m.setupRequired = !(hasAccounts || hasLocalUsers)
|
||||
m.setupMu.Unlock()
|
||||
|
||||
return nil
|
||||
|
||||
@@ -610,6 +610,7 @@ func TestSync10PeersGetUpdates(t *testing.T) {
|
||||
|
||||
initialPeers := 10
|
||||
additionalPeers := 10
|
||||
expectedPeerCount := initialPeers + additionalPeers - 1 // -1 because peer doesn't see itself
|
||||
|
||||
var peers []wgtypes.Key
|
||||
for i := 0; i < initialPeers; i++ {
|
||||
@@ -618,8 +619,19 @@ func TestSync10PeersGetUpdates(t *testing.T) {
|
||||
peers = append(peers, key)
|
||||
}
|
||||
|
||||
// Track the maximum peer count each peer has seen
|
||||
type peerState struct {
|
||||
mu sync.Mutex
|
||||
maxPeerCount int
|
||||
done bool
|
||||
}
|
||||
peerStates := make(map[string]*peerState)
|
||||
for _, pk := range peers {
|
||||
peerStates[pk.PublicKey().String()] = &peerState{}
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(initialPeers + initialPeers*additionalPeers)
|
||||
wg.Add(initialPeers) // One completion per initial peer
|
||||
|
||||
var syncClients []mgmtProto.ManagementService_SyncClient
|
||||
for _, pk := range peers {
|
||||
@@ -643,6 +655,9 @@ func TestSync10PeersGetUpdates(t *testing.T) {
|
||||
syncClients = append(syncClients, s)
|
||||
|
||||
go func(pk wgtypes.Key, syncStream mgmtProto.ManagementService_SyncClient) {
|
||||
pubKey := pk.PublicKey().String()
|
||||
state := peerStates[pubKey]
|
||||
|
||||
for {
|
||||
encMsg := &mgmtProto.EncryptedMessage{}
|
||||
err := syncStream.RecvMsg(encMsg)
|
||||
@@ -651,19 +666,28 @@ func TestSync10PeersGetUpdates(t *testing.T) {
|
||||
}
|
||||
decryptedBytes, decErr := encryption.Decrypt(encMsg.Body, ts.serverPubKey, pk)
|
||||
if decErr != nil {
|
||||
t.Errorf("failed to decrypt SyncResponse for peer %s: %v", pk.PublicKey().String(), decErr)
|
||||
t.Errorf("failed to decrypt SyncResponse for peer %s: %v", pubKey, decErr)
|
||||
return
|
||||
}
|
||||
resp := &mgmtProto.SyncResponse{}
|
||||
umErr := pb.Unmarshal(decryptedBytes, resp)
|
||||
if umErr != nil {
|
||||
t.Errorf("failed to unmarshal SyncResponse for peer %s: %v", pk.PublicKey().String(), umErr)
|
||||
t.Errorf("failed to unmarshal SyncResponse for peer %s: %v", pubKey, umErr)
|
||||
return
|
||||
}
|
||||
// We only count if there's a new peer update
|
||||
if len(resp.GetRemotePeers()) > 0 {
|
||||
|
||||
// Track the maximum peer count seen (due to debouncing, updates are coalesced)
|
||||
peerCount := len(resp.GetRemotePeers())
|
||||
state.mu.Lock()
|
||||
if peerCount > state.maxPeerCount {
|
||||
state.maxPeerCount = peerCount
|
||||
}
|
||||
// Signal completion when this peer has seen all expected peers
|
||||
if !state.done && state.maxPeerCount >= expectedPeerCount {
|
||||
state.done = true
|
||||
wg.Done()
|
||||
}
|
||||
state.mu.Unlock()
|
||||
}
|
||||
}(pk, s)
|
||||
}
|
||||
@@ -677,7 +701,30 @@ func TestSync10PeersGetUpdates(t *testing.T) {
|
||||
time.Sleep(time.Duration(n) * time.Millisecond)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
// Wait for debouncer to flush final updates (debounce interval is 1000ms)
|
||||
time.Sleep(1500 * time.Millisecond)
|
||||
|
||||
// Wait with timeout
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// Success - all peers received expected peer count
|
||||
case <-time.After(5 * time.Second):
|
||||
// Timeout - report which peers didn't receive all updates
|
||||
t.Error("Timeout waiting for all peers to receive updates")
|
||||
for pubKey, state := range peerStates {
|
||||
state.mu.Lock()
|
||||
if state.maxPeerCount < expectedPeerCount {
|
||||
t.Errorf("Peer %s only saw %d peers, expected %d", pubKey, state.maxPeerCount, expectedPeerCount)
|
||||
}
|
||||
state.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
for _, sc := range syncClients {
|
||||
err := sc.CloseSend()
|
||||
|
||||
@@ -37,8 +37,8 @@ type MockAccountManager struct {
|
||||
GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
|
||||
ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error)
|
||||
GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error
|
||||
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP, syncTime time.Time) error
|
||||
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
|
||||
GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error)
|
||||
GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*types.Network, error)
|
||||
@@ -214,16 +214,15 @@ func (am *MockAccountManager) DeleteSetupKey(ctx context.Context, accountID, use
|
||||
return status.Errorf(codes.Unimplemented, "method DeleteSetupKey is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||
func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||
if am.SyncAndMarkPeerFunc != nil {
|
||||
return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP)
|
||||
return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP, syncTime)
|
||||
}
|
||||
return nil, nil, nil, 0, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID string, peerPubKey string) error {
|
||||
// TODO implement me
|
||||
panic("implement me")
|
||||
func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) {
|
||||
@@ -323,9 +322,9 @@ func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userAuth
|
||||
}
|
||||
|
||||
// MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface
|
||||
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error {
|
||||
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error {
|
||||
if am.MarkPeerConnectedFunc != nil {
|
||||
return am.MarkPeerConnectedFunc(ctx, peerKey, connected, realIP)
|
||||
return am.MarkPeerConnectedFunc(ctx, peerKey, connected, realIP, syncTime)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
|
||||
}
|
||||
|
||||
@@ -103,11 +103,13 @@ func (am *DefaultAccountManager) getUserAccessiblePeers(ctx context.Context, acc
|
||||
}
|
||||
|
||||
// MarkPeerConnected marks peer as connected (true) or disconnected (false)
|
||||
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string) error {
|
||||
// syncTime is used as the LastSeen timestamp and for stale request detection
|
||||
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error {
|
||||
var peer *nbpeer.Peer
|
||||
var settings *types.Settings
|
||||
var expired bool
|
||||
var err error
|
||||
var skipped bool
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
peer, err = transaction.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, peerPubKey)
|
||||
@@ -115,9 +117,19 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
|
||||
return err
|
||||
}
|
||||
|
||||
expired, err = updatePeerStatusAndLocation(ctx, am.geo, transaction, peer, connected, realIP, accountID)
|
||||
if connected && !syncTime.After(peer.Status.LastSeen) {
|
||||
log.WithContext(ctx).Tracef("peer %s has newer activity (lastSeen=%s >= syncTime=%s), skipping connect",
|
||||
peer.ID, peer.Status.LastSeen.Format(time.RFC3339), syncTime.Format(time.RFC3339))
|
||||
skipped = true
|
||||
return nil
|
||||
}
|
||||
|
||||
expired, err = updatePeerStatusAndLocation(ctx, am.geo, transaction, peer, connected, realIP, accountID, syncTime)
|
||||
return err
|
||||
})
|
||||
if skipped {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -147,10 +159,10 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
|
||||
return nil
|
||||
}
|
||||
|
||||
func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocation, transaction store.Store, peer *nbpeer.Peer, connected bool, realIP net.IP, accountID string) (bool, error) {
|
||||
func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocation, transaction store.Store, peer *nbpeer.Peer, connected bool, realIP net.IP, accountID string, syncTime time.Time) (bool, error) {
|
||||
oldStatus := peer.Status.Copy()
|
||||
newStatus := oldStatus
|
||||
newStatus.LastSeen = time.Now().UTC()
|
||||
newStatus.LastSeen = syncTime
|
||||
newStatus.Connected = connected
|
||||
// whenever peer got connected that means that it logged in successfully
|
||||
if newStatus.Connected {
|
||||
|
||||
@@ -24,19 +24,28 @@ type Manager interface {
|
||||
UpdateExtraSettings(ctx context.Context, accountID, userID string, extraSettings *types.ExtraSettings) (bool, error)
|
||||
}
|
||||
|
||||
// IdpConfig holds IdP-related configuration that is set at runtime
|
||||
// and not stored in the database.
|
||||
type IdpConfig struct {
|
||||
EmbeddedIdpEnabled bool
|
||||
LocalAuthDisabled bool
|
||||
}
|
||||
|
||||
type managerImpl struct {
|
||||
store store.Store
|
||||
extraSettingsManager extra_settings.Manager
|
||||
userManager users.Manager
|
||||
permissionsManager permissions.Manager
|
||||
idpConfig IdpConfig
|
||||
}
|
||||
|
||||
func NewManager(store store.Store, userManager users.Manager, extraSettingsManager extra_settings.Manager, permissionsManager permissions.Manager) Manager {
|
||||
func NewManager(store store.Store, userManager users.Manager, extraSettingsManager extra_settings.Manager, permissionsManager permissions.Manager, idpConfig IdpConfig) Manager {
|
||||
return &managerImpl{
|
||||
store: store,
|
||||
extraSettingsManager: extraSettingsManager,
|
||||
userManager: userManager,
|
||||
permissionsManager: permissionsManager,
|
||||
idpConfig: idpConfig,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,6 +83,10 @@ func (m *managerImpl) GetSettings(ctx context.Context, accountID, userID string)
|
||||
settings.Extra.FlowDnsCollectionEnabled = extraSettings.FlowDnsCollectionEnabled
|
||||
}
|
||||
|
||||
// Fill in IdP-related runtime settings
|
||||
settings.EmbeddedIdpEnabled = m.idpConfig.EmbeddedIdpEnabled
|
||||
settings.LocalAuthDisabled = m.idpConfig.LocalAuthDisabled
|
||||
|
||||
return settings, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -55,6 +55,14 @@ type Settings struct {
|
||||
|
||||
// AutoUpdateVersion client auto-update version
|
||||
AutoUpdateVersion string `gorm:"default:'disabled'"`
|
||||
|
||||
// EmbeddedIdpEnabled indicates if the embedded identity provider is enabled.
|
||||
// This is a runtime-only field, not stored in the database.
|
||||
EmbeddedIdpEnabled bool `gorm:"-"`
|
||||
|
||||
// LocalAuthDisabled indicates if local (email/password) authentication is disabled.
|
||||
// This is a runtime-only field, not stored in the database.
|
||||
LocalAuthDisabled bool `gorm:"-"`
|
||||
}
|
||||
|
||||
// Copy copies the Settings struct
|
||||
@@ -76,6 +84,8 @@ func (s *Settings) Copy() *Settings {
|
||||
DNSDomain: s.DNSDomain,
|
||||
NetworkRange: s.NetworkRange,
|
||||
AutoUpdateVersion: s.AutoUpdateVersion,
|
||||
EmbeddedIdpEnabled: s.EmbeddedIdpEnabled,
|
||||
LocalAuthDisabled: s.LocalAuthDisabled,
|
||||
}
|
||||
if s.Extra != nil {
|
||||
settings.Extra = s.Extra.Copy()
|
||||
|
||||
@@ -191,6 +191,10 @@ func (am *DefaultAccountManager) createNewIdpUser(ctx context.Context, accountID
|
||||
// Unlike createNewIdpUser, this method fetches user data directly from the database
|
||||
// since the embedded IdP usage ensures the username and email are stored locally in the User table.
|
||||
func (am *DefaultAccountManager) createEmbeddedIdpUser(ctx context.Context, accountID string, inviterID string, invite *types.UserInfo) (*idp.UserData, error) {
|
||||
if IsLocalAuthDisabled(ctx, am.idpManager) {
|
||||
return nil, status.Errorf(status.PreconditionFailed, "local user creation is disabled - use an external identity provider")
|
||||
}
|
||||
|
||||
inviter, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, inviterID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get inviter user: %w", err)
|
||||
@@ -1462,6 +1466,10 @@ func (am *DefaultAccountManager) CreateUserInvite(ctx context.Context, accountID
|
||||
return nil, status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
|
||||
}
|
||||
|
||||
if IsLocalAuthDisabled(ctx, am.idpManager) {
|
||||
return nil, status.Errorf(status.PreconditionFailed, "local user creation is disabled - use an external identity provider")
|
||||
}
|
||||
|
||||
if err := validateUserInvite(invite); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1621,6 +1629,10 @@ func (am *DefaultAccountManager) AcceptUserInvite(ctx context.Context, token, pa
|
||||
return status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
|
||||
}
|
||||
|
||||
if IsLocalAuthDisabled(ctx, am.idpManager) {
|
||||
return status.Errorf(status.PreconditionFailed, "local user creation is disabled - use an external identity provider")
|
||||
}
|
||||
|
||||
if password == "" {
|
||||
return status.Errorf(status.InvalidArgument, "password is required")
|
||||
}
|
||||
|
||||
@@ -21,6 +21,7 @@ type Client interface {
|
||||
GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error)
|
||||
GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error)
|
||||
GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, error)
|
||||
GetServerURL() string
|
||||
IsHealthy() bool
|
||||
SyncMeta(sysInfo *system.Info) error
|
||||
Logout() error
|
||||
|
||||
@@ -46,6 +46,7 @@ type GrpcClient struct {
|
||||
conn *grpc.ClientConn
|
||||
connStateCallback ConnStateNotifier
|
||||
connStateCallbackLock sync.RWMutex
|
||||
serverURL string
|
||||
}
|
||||
|
||||
// NewClient creates a new client to Management service
|
||||
@@ -75,9 +76,15 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE
|
||||
ctx: ctx,
|
||||
conn: conn,
|
||||
connStateCallbackLock: sync.RWMutex{},
|
||||
serverURL: addr,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetServerURL returns the management server URL
|
||||
func (c *GrpcClient) GetServerURL() string {
|
||||
return c.serverURL
|
||||
}
|
||||
|
||||
// Close closes connection to the Management Service
|
||||
func (c *GrpcClient) Close() error {
|
||||
return c.conn.Close()
|
||||
|
||||
@@ -18,6 +18,7 @@ type MockClient struct {
|
||||
LoginFunc func(serverKey wgtypes.Key, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
|
||||
GetDeviceAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error)
|
||||
GetPKCEAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error)
|
||||
GetServerURLFunc func() string
|
||||
SyncMetaFunc func(sysInfo *system.Info) error
|
||||
LogoutFunc func() error
|
||||
JobFunc func(ctx context.Context, msgHandler func(msg *proto.JobRequest) *proto.JobResponse) error
|
||||
@@ -88,6 +89,14 @@ func (m *MockClient) GetNetworkMap(_ *system.Info) (*proto.NetworkMap, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// GetServerURL mock implementation of GetServerURL from mgm.Client interface
|
||||
func (m *MockClient) GetServerURL() string {
|
||||
if m.GetServerURLFunc == nil {
|
||||
return ""
|
||||
}
|
||||
return m.GetServerURLFunc()
|
||||
}
|
||||
|
||||
func (m *MockClient) SyncMeta(sysInfo *system.Info) error {
|
||||
if m.SyncMetaFunc == nil {
|
||||
return nil
|
||||
|
||||
@@ -294,6 +294,11 @@ components:
|
||||
type: boolean
|
||||
readOnly: true
|
||||
example: false
|
||||
local_auth_disabled:
|
||||
description: Indicates whether local (email/password) authentication is disabled. When true, users can only authenticate via external identity providers. This is a read-only field.
|
||||
type: boolean
|
||||
readOnly: true
|
||||
example: false
|
||||
required:
|
||||
- peer_login_expiration_enabled
|
||||
- peer_login_expiration
|
||||
|
||||
@@ -415,6 +415,9 @@ type AccountSettings struct {
|
||||
// LazyConnectionEnabled Enables or disables experimental lazy connection
|
||||
LazyConnectionEnabled *bool `json:"lazy_connection_enabled,omitempty"`
|
||||
|
||||
// LocalAuthDisabled Indicates whether local (email/password) authentication is disabled. When true, users can only authenticate via external identity providers. This is a read-only field.
|
||||
LocalAuthDisabled *bool `json:"local_auth_disabled,omitempty"`
|
||||
|
||||
// NetworkRange Allows to define a custom network range for the account in CIDR format
|
||||
NetworkRange *string `json:"network_range,omitempty"`
|
||||
|
||||
|
||||
@@ -225,35 +225,42 @@ func (c *Client) OpenConn(ctx context.Context, dstPeerID string) (net.Conn, erro
|
||||
c.mu.Unlock()
|
||||
return nil, ErrConnAlreadyExists
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
if err := c.stateSubscription.WaitToBeOnlineAndSubscribe(ctx, peerID); err != nil {
|
||||
c.log.Errorf("peer not available: %s, %s", peerID, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.log.Infof("remote peer is available, prepare the relayed connection: %s", peerID)
|
||||
msgChannel := make(chan Msg, 100)
|
||||
|
||||
c.mu.Lock()
|
||||
if !c.serviceIsRunning {
|
||||
c.mu.Unlock()
|
||||
return nil, fmt.Errorf("relay connection is not established")
|
||||
}
|
||||
c.log.Infof("prepare the relayed connection, waiting for remote peer: %s", peerID)
|
||||
|
||||
c.muInstanceURL.Lock()
|
||||
instanceURL := c.instanceURL
|
||||
c.muInstanceURL.Unlock()
|
||||
conn := NewConn(c, peerID, msgChannel, instanceURL)
|
||||
|
||||
_, ok = c.conns[peerID]
|
||||
if ok {
|
||||
c.mu.Unlock()
|
||||
_ = conn.Close()
|
||||
return nil, ErrConnAlreadyExists
|
||||
}
|
||||
c.conns[peerID] = newConnContainer(c.log, conn, msgChannel)
|
||||
msgChannel := make(chan Msg, 100)
|
||||
conn := NewConn(c, peerID, msgChannel, instanceURL)
|
||||
container := newConnContainer(c.log, conn, msgChannel)
|
||||
c.conns[peerID] = container
|
||||
c.mu.Unlock()
|
||||
|
||||
if err := c.stateSubscription.WaitToBeOnlineAndSubscribe(ctx, peerID); err != nil {
|
||||
c.log.Errorf("peer not available: %s, %s", peerID, err)
|
||||
c.mu.Lock()
|
||||
if savedContainer, ok := c.conns[peerID]; ok && savedContainer == container {
|
||||
delete(c.conns, peerID)
|
||||
}
|
||||
c.mu.Unlock()
|
||||
container.close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
if !c.serviceIsRunning {
|
||||
if savedContainer, ok := c.conns[peerID]; ok && savedContainer == container {
|
||||
delete(c.conns, peerID)
|
||||
}
|
||||
c.mu.Unlock()
|
||||
container.close()
|
||||
return nil, fmt.Errorf("relay connection is not established")
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
c.log.Infof("remote peer is available: %s", peerID)
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user