mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-25 17:29:54 +00:00
Compare commits
13 Commits
main
...
fix/revert
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e3c66ced13 | ||
|
|
94284d5400 | ||
|
|
fa91c119c3 | ||
|
|
b5d2f054c2 | ||
|
|
50ff095d68 | ||
|
|
0afe52cfeb | ||
|
|
31c277b1df | ||
|
|
566d21c2c3 | ||
|
|
858e2d1c34 | ||
|
|
35ed69bfe7 | ||
|
|
8446713d28 | ||
|
|
12f2e69af2 | ||
|
|
4cb2c62f2a |
@@ -41,6 +41,7 @@ type ICEBind struct {
|
||||
*wgConn.StdNetBind
|
||||
|
||||
transportNet transport.Net
|
||||
filterFn udpmux.FilterFn
|
||||
address wgaddr.Address
|
||||
mtu uint16
|
||||
|
||||
@@ -60,11 +61,12 @@ type ICEBind struct {
|
||||
ipv6Conn *net.UDPConn
|
||||
}
|
||||
|
||||
func NewICEBind(transportNet transport.Net, address wgaddr.Address, mtu uint16) *ICEBind {
|
||||
func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wgaddr.Address, mtu uint16) *ICEBind {
|
||||
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
|
||||
ib := &ICEBind{
|
||||
StdNetBind: b,
|
||||
transportNet: transportNet,
|
||||
filterFn: filterFn,
|
||||
address: address,
|
||||
mtu: mtu,
|
||||
endpoints: make(map[netip.Addr]net.Conn),
|
||||
@@ -263,6 +265,7 @@ func (s *ICEBind) createOrUpdateMux() {
|
||||
udpmux.UniversalUDPMuxParams{
|
||||
UDPConn: muxConn,
|
||||
Net: s.transportNet,
|
||||
FilterFn: s.filterFn,
|
||||
WGAddress: s.address,
|
||||
MTU: s.mtu,
|
||||
},
|
||||
|
||||
@@ -289,7 +289,7 @@ func setupICEBind(t *testing.T) *ICEBind {
|
||||
IP: netip.MustParseAddr("100.64.0.1"),
|
||||
Network: netip.MustParsePrefix("100.64.0.0/10"),
|
||||
}
|
||||
return NewICEBind(transportNet, address, 1280)
|
||||
return NewICEBind(transportNet, nil, address, 1280)
|
||||
}
|
||||
|
||||
func createDualStackConns(t *testing.T) (*net.UDPConn, *net.UDPConn) {
|
||||
|
||||
@@ -32,6 +32,8 @@ type TunKernelDevice struct {
|
||||
link *wgLink
|
||||
udpMuxConn net.PacketConn
|
||||
udpMux *udpmux.UniversalUDPMuxDefault
|
||||
|
||||
filterFn udpmux.FilterFn
|
||||
}
|
||||
|
||||
func NewKernelDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, transportNet transport.Net) *TunKernelDevice {
|
||||
@@ -102,6 +104,7 @@ func (t *TunKernelDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||
bindParams := udpmux.UniversalUDPMuxParams{
|
||||
UDPConn: nbnet.WrapPacketConn(rawSock),
|
||||
Net: t.transportNet,
|
||||
FilterFn: t.filterFn,
|
||||
WGAddress: t.address,
|
||||
MTU: t.mtu,
|
||||
}
|
||||
|
||||
@@ -63,6 +63,7 @@ type WGIFaceOpts struct {
|
||||
MTU uint16
|
||||
MobileArgs *device.MobileIFaceArguments
|
||||
TransportNet transport.Net
|
||||
FilterFn udpmux.FilterFn
|
||||
DisableDNS bool
|
||||
}
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.Address, opts.MTU)
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
|
||||
|
||||
var tun WGTunDevice
|
||||
if netstack.IsEnabled() {
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.Address, opts.MTU)
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
|
||||
|
||||
if netstack.IsEnabled() {
|
||||
wgIFace := &WGIface{
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.Address, opts.MTU)
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
|
||||
|
||||
wgIFace := &WGIface{
|
||||
tun: device.NewTunDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunFd),
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
if netstack.IsEnabled() {
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.Address, opts.MTU)
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
|
||||
return &WGIface{
|
||||
tun: device.NewNetstackDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()),
|
||||
userspaceBind: true,
|
||||
@@ -30,7 +30,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
}
|
||||
|
||||
if device.ModuleTunIsLoaded() {
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.Address, opts.MTU)
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
|
||||
return &WGIface{
|
||||
tun: device.NewTunDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind),
|
||||
userspaceBind: true,
|
||||
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -20,6 +22,10 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
// FilterFn is a function that filters out candidates based on the address.
|
||||
// If it returns true, the address is to be filtered. It also returns the prefix of matching route.
|
||||
type FilterFn func(address netip.Addr) (bool, netip.Prefix, error)
|
||||
|
||||
// UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn
|
||||
// It then passes packets to the UDPMux that does the actual connection muxing.
|
||||
type UniversalUDPMuxDefault struct {
|
||||
@@ -37,6 +43,7 @@ type UniversalUDPMuxParams struct {
|
||||
UDPConn net.PacketConn
|
||||
XORMappedAddrCacheTTL time.Duration
|
||||
Net transport.Net
|
||||
FilterFn FilterFn
|
||||
WGAddress wgaddr.Address
|
||||
MTU uint16
|
||||
}
|
||||
@@ -61,6 +68,7 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
|
||||
PacketConn: params.UDPConn,
|
||||
mux: m,
|
||||
logger: params.Logger,
|
||||
filterFn: params.FilterFn,
|
||||
address: params.WGAddress,
|
||||
}
|
||||
|
||||
@@ -107,12 +115,15 @@ func (m *UniversalUDPMuxDefault) ReadFromConn(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// UDPConn is a wrapper around UDPMux conn that overrides WriteTo to drop packets destined for the overlay subnet.
|
||||
// UDPConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets
|
||||
type UDPConn struct {
|
||||
net.PacketConn
|
||||
mux *UniversalUDPMuxDefault
|
||||
logger logging.LeveledLogger
|
||||
address wgaddr.Address
|
||||
mux *UniversalUDPMuxDefault
|
||||
logger logging.LeveledLogger
|
||||
filterFn FilterFn
|
||||
// TODO: reset cache on route changes
|
||||
addrCache sync.Map
|
||||
address wgaddr.Address
|
||||
}
|
||||
|
||||
// GetPacketConn returns the underlying PacketConn
|
||||
@@ -121,18 +132,67 @@ func (u *UDPConn) GetPacketConn() net.PacketConn {
|
||||
}
|
||||
|
||||
func (u *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
|
||||
udpAddr, ok := addr.(*net.UDPAddr)
|
||||
if !ok {
|
||||
if u.filterFn == nil {
|
||||
return u.PacketConn.WriteTo(b, addr)
|
||||
}
|
||||
dst := udpAddr.AddrPort().Addr().Unmap()
|
||||
if (u.address.Network.IsValid() && u.address.Network.Contains(dst)) || (u.address.IPv6Net.IsValid() && u.address.IPv6Net.Contains(dst)) {
|
||||
log.Warnf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||
return 0, fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||
|
||||
if isRouted, found := u.addrCache.Load(addr.String()); found {
|
||||
return u.handleCachedAddress(isRouted.(bool), b, addr)
|
||||
}
|
||||
|
||||
return u.handleUncachedAddress(b, addr)
|
||||
}
|
||||
|
||||
func (u *UDPConn) handleCachedAddress(isRouted bool, b []byte, addr net.Addr) (int, error) {
|
||||
if isRouted {
|
||||
return 0, fmt.Errorf("address %s is part of a routed network, refusing to write", addr)
|
||||
}
|
||||
return u.PacketConn.WriteTo(b, addr)
|
||||
}
|
||||
|
||||
func (u *UDPConn) handleUncachedAddress(b []byte, addr net.Addr) (int, error) {
|
||||
if err := u.performFilterCheck(addr); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return u.PacketConn.WriteTo(b, addr)
|
||||
}
|
||||
|
||||
func (u *UDPConn) performFilterCheck(addr net.Addr) error {
|
||||
host, err := getHostFromAddr(addr)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get host from address %s: %v", addr, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
a, err := netip.ParseAddr(host)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to parse address %s: %v", addr, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
if u.address.Network.Contains(a) {
|
||||
log.Warnf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||
}
|
||||
|
||||
if isRouted, prefix, err := u.filterFn(a); err != nil {
|
||||
log.Errorf("Failed to check if address %s is routed: %v", addr, err)
|
||||
} else {
|
||||
u.addrCache.Store(addr.String(), isRouted)
|
||||
if isRouted {
|
||||
// Extra log, as the error only shows up with ICE logging enabled
|
||||
log.Infof("address %s is part of routed network %s, refusing to write", addr, prefix)
|
||||
return fmt.Errorf("address %s is part of routed network %s, refusing to write", addr, prefix)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getHostFromAddr(addr net.Addr) (string, error) {
|
||||
host, _, err := net.SplitHostPort(addr.String())
|
||||
return host, err
|
||||
}
|
||||
|
||||
// GetSharedConn returns the shared udp conn
|
||||
func (m *UniversalUDPMuxDefault) GetSharedConn() net.PacketConn {
|
||||
return m.params.UDPConn
|
||||
@@ -165,13 +225,6 @@ func (m *UniversalUDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.A
|
||||
return nil
|
||||
}
|
||||
|
||||
src := udpAddr.AddrPort().Addr().Unmap()
|
||||
wg := m.params.WGAddress
|
||||
if (wg.Network.IsValid() && wg.Network.Contains(src)) || (wg.IPv6Net.IsValid() && wg.IPv6Net.Contains(src)) {
|
||||
log.Debugf("dropping STUN message from overlay source %s", udpAddr)
|
||||
return nil
|
||||
}
|
||||
|
||||
if m.isXORMappedResponse(msg, udpAddr.String()) {
|
||||
err := m.handleXORMappedResponse(udpAddr, msg)
|
||||
if err != nil {
|
||||
|
||||
@@ -66,7 +66,7 @@ func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
iceBind := bind.NewICEBind(nil, wgAddress, 1280)
|
||||
iceBind := bind.NewICEBind(nil, nil, wgAddress, 1280)
|
||||
endpointAddress := &net.UDPAddr{
|
||||
IP: net.IPv4(10, 0, 0, 1),
|
||||
Port: 1234,
|
||||
|
||||
@@ -22,7 +22,7 @@ func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
iceBind := bind.NewICEBind(nil, wgAddress, 1280)
|
||||
iceBind := bind.NewICEBind(nil, nil, wgAddress, 1280)
|
||||
endpointAddress := &net.UDPAddr{
|
||||
IP: net.IPv4(10, 0, 0, 1),
|
||||
Port: 1234,
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
@@ -53,6 +54,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/relay"
|
||||
"github.com/netbirdio/netbird/client/internal/rosenpass"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/syncstore"
|
||||
"github.com/netbirdio/netbird/client/internal/updater"
|
||||
@@ -88,6 +90,13 @@ var ErrResetConnection = fmt.Errorf("reset connection")
|
||||
|
||||
var ErrEngineAlreadyStarted = errors.New("engine already started")
|
||||
|
||||
// engineRestartCount and engineLastRestart track client-restart cadence across
|
||||
// engine recreations so a restart loop is distinguishable from rare restarts.
|
||||
var (
|
||||
engineRestartCount atomic.Int64
|
||||
engineLastRestart atomic.Int64
|
||||
)
|
||||
|
||||
type EngineConfig struct {
|
||||
WgPort int
|
||||
WgIfaceName string
|
||||
@@ -899,7 +908,11 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
started := time.Now()
|
||||
defer func() {
|
||||
duration := time.Since(started)
|
||||
log.Infof("sync finished in %s", duration)
|
||||
if update.GetNetworkMap() != nil {
|
||||
log.Infof("sync finished in %s, %d", duration, update.GetNetworkMap().GetSerial())
|
||||
} else {
|
||||
log.Infof("sync finished in %s", duration)
|
||||
}
|
||||
e.clientMetrics.RecordSyncDuration(e.ctx, duration)
|
||||
}()
|
||||
e.syncMsgMux.Lock()
|
||||
@@ -909,14 +922,23 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
if e.ctx.Err() != nil {
|
||||
return e.ctx.Err()
|
||||
}
|
||||
serial := update.GetNetworkMap().GetSerial()
|
||||
if nm := update.GetNetworkMap(); nm != nil {
|
||||
log.Infof("sync update: serial=%d remotePeers=%d offlinePeers=%d routes=%d firewallRules=%d checks=%d configPresent=%v remotePeersEmpty=%v",
|
||||
nm.GetSerial(), len(nm.GetRemotePeers()), len(nm.GetOfflinePeers()), len(nm.GetRoutes()),
|
||||
len(nm.GetFirewallRules()), len(update.GetChecks()), update.GetNetbirdConfig() != nil, nm.GetRemotePeersIsEmpty())
|
||||
} else {
|
||||
log.Infof("sync update: config-only (no network map), configPresent=%v", update.GetNetbirdConfig() != nil)
|
||||
}
|
||||
|
||||
if update.NetworkMap != nil && update.NetworkMap.PeerConfig != nil {
|
||||
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate)
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
if err := e.updateNetbirdConfig(update.GetNetbirdConfig()); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Infof("netbird config updated in %s, serial=%d", time.Since(startTime), serial)
|
||||
|
||||
// Posture checks are bound to the network map presence:
|
||||
// NetworkMap != nil, checks present -> apply the received checks
|
||||
@@ -927,17 +949,21 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
if nm == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
startTime = time.Now()
|
||||
if err := e.updateChecksIfNew(update.Checks); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Infof("checks updated in %s, serial=%d", time.Since(startTime), serial)
|
||||
|
||||
startTime = time.Now()
|
||||
e.persistSyncResponse(update)
|
||||
|
||||
log.Infof("sync response persisted in %s, serial=%d", time.Since(startTime), serial)
|
||||
// only apply new changes and ignore old ones
|
||||
startTime = time.Now()
|
||||
if err := e.updateNetworkMap(nm); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Infof("network map updated in %s, serial=%d", time.Since(startTime), serial)
|
||||
|
||||
e.statusRecorder.PublishEvent(cProto.SystemEvent_INFO, cProto.SystemEvent_SYSTEM, "Network map updated", "", nil)
|
||||
|
||||
@@ -1357,44 +1383,56 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
|
||||
dnsConfig := toDNSConfig(protoDNSConfig, e.wgInterface.Address())
|
||||
|
||||
startTime := time.Now()
|
||||
if err := e.dnsServer.UpdateDNSServer(serial, dnsConfig); err != nil {
|
||||
log.Errorf("failed to update dns server, err: %v", err)
|
||||
}
|
||||
log.Infof("updated dns server in %v, serial=%d", time.Since(startTime), serial)
|
||||
|
||||
e.routeManager.SetDNSForwarderPort(dnsConfig.ForwarderPort)
|
||||
|
||||
// apply routes first, route related actions might depend on routing being enabled
|
||||
startTime = time.Now()
|
||||
routes := toRoutes(networkMap.GetRoutes())
|
||||
serverRoutes, clientRoutes := e.routeManager.ClassifyRoutes(routes)
|
||||
|
||||
log.Infof("updated routes in %v, serial=%d", time.Since(startTime), serial)
|
||||
// lazy mgr needs to be aware of which routes are available before they are applied
|
||||
if e.connMgr != nil {
|
||||
e.connMgr.UpdateRouteHAMap(clientRoutes)
|
||||
log.Debugf("updated lazy connection manager with %d HA groups", len(clientRoutes))
|
||||
}
|
||||
|
||||
startTime = time.Now()
|
||||
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
|
||||
if err := e.routeManager.UpdateRoutes(serial, serverRoutes, clientRoutes, dnsRouteFeatureFlag); err != nil {
|
||||
log.Errorf("failed to update routes: %v", err)
|
||||
}
|
||||
log.Infof("updated routes in %v, serial=%d", time.Since(startTime), serial)
|
||||
|
||||
startTime = time.Now()
|
||||
if e.acl != nil {
|
||||
e.acl.ApplyFiltering(networkMap, dnsRouteFeatureFlag)
|
||||
}
|
||||
log.Infof("updated filtering in %v, serial=%d", time.Since(startTime), serial)
|
||||
|
||||
startTime = time.Now()
|
||||
fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes)
|
||||
e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries)
|
||||
log.Infof("updated DNS forwarder in %v, serial=%d", time.Since(startTime), serial)
|
||||
|
||||
startTime = time.Now()
|
||||
// Ingress forward rules
|
||||
forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules())
|
||||
if err != nil {
|
||||
log.Errorf("failed to update forward rules, err: %v", err)
|
||||
}
|
||||
log.Infof("updated forward rules in %v, serial=%d", time.Since(startTime), serial)
|
||||
|
||||
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
|
||||
|
||||
startTime = time.Now()
|
||||
e.updateOfflinePeers(networkMap.GetOfflinePeers())
|
||||
|
||||
log.Infof("updated offline peers in %v, serial=%d", time.Since(startTime), serial)
|
||||
// Filter out own peer from the remote peers list
|
||||
localPubKey := e.config.WgPrivateKey.PublicKey().String()
|
||||
remotePeers := make([]*mgmProto.RemotePeerConfig, 0, len(networkMap.GetRemotePeers()))
|
||||
@@ -1412,20 +1450,24 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
startTime = time.Now()
|
||||
err := e.removePeers(remotePeers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Infof("removed peers in %v, serial=%d", time.Since(startTime), serial)
|
||||
startTime = time.Now()
|
||||
err = e.modifyPeers(remotePeers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Infof("modified peers in %v, serial=%d", time.Since(startTime), serial)
|
||||
startTime = time.Now()
|
||||
err = e.addNewPeers(remotePeers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Infof("added peers in %v, serial=%d", time.Since(startTime), serial)
|
||||
|
||||
e.statusRecorder.FinishPeerListModifications()
|
||||
|
||||
@@ -1438,9 +1480,11 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
e.updateSSHServerAuth(networkMap.GetSshAuth())
|
||||
}
|
||||
|
||||
startTime = time.Now()
|
||||
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
|
||||
excludedLazyPeers := e.toExcludedLazyPeers(forwardingRules, remotePeers)
|
||||
e.connMgr.SetExcludeList(e.ctx, excludedLazyPeers)
|
||||
log.Infof("updated lazy connection manager exclude list in %v, serial=%d", time.Since(startTime), serial)
|
||||
|
||||
e.networkSerial = serial
|
||||
|
||||
@@ -1955,6 +1999,7 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) {
|
||||
WGPrivKey: e.config.WgPrivateKey.String(),
|
||||
MTU: e.config.MTU,
|
||||
TransportNet: transportNet,
|
||||
FilterFn: e.addrViaRoutes,
|
||||
DisableDNS: e.config.DisableDNS,
|
||||
}
|
||||
|
||||
@@ -2171,7 +2216,14 @@ func (e *Engine) triggerClientRestart() {
|
||||
return
|
||||
}
|
||||
|
||||
log.Info("restarting engine")
|
||||
// Cadence survives engine recreation (package-level), so a restart loop shows
|
||||
// as a fast-climbing count with a short gap, distinct from rare intentional restarts.
|
||||
n := engineRestartCount.Add(1)
|
||||
var sinceLast time.Duration
|
||||
if prev := engineLastRestart.Swap(time.Now().UnixNano()); prev != 0 {
|
||||
sinceLast = time.Since(time.Unix(0, prev))
|
||||
}
|
||||
log.Infof("restarting engine (restart #%d, %s since previous)", n, sinceLast.Round(time.Second))
|
||||
CtxGetState(e.ctx).Set(StatusConnecting)
|
||||
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
|
||||
log.Infof("cancelling client context, engine will be recreated")
|
||||
@@ -2202,6 +2254,21 @@ func (e *Engine) startNetworkMonitor() {
|
||||
}()
|
||||
}
|
||||
|
||||
func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) {
|
||||
var vpnRoutes []netip.Prefix
|
||||
for _, routes := range e.routeManager.GetClientRoutes() {
|
||||
if len(routes) > 0 && routes[0] != nil {
|
||||
vpnRoutes = append(vpnRoutes, routes[0].Network)
|
||||
}
|
||||
}
|
||||
|
||||
if isVpn, prefix := systemops.IsAddrRouted(addr, vpnRoutes); isVpn {
|
||||
return true, prefix, nil
|
||||
}
|
||||
|
||||
return false, netip.Prefix{}, nil
|
||||
}
|
||||
|
||||
func (e *Engine) stopDNSServer() {
|
||||
if e.dnsServer == nil {
|
||||
return
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"runtime/debug"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -192,6 +193,7 @@ func (s *StatusChangeSubscription) Events() chan map[string]RouterState {
|
||||
// Pure read methods take RLock; anything that mutates state takes Lock.
|
||||
type Status struct {
|
||||
mux sync.RWMutex
|
||||
muxRelays sync.RWMutex
|
||||
peers map[string]State
|
||||
ipToKey map[string]string
|
||||
changeNotify map[string]map[string]*StatusChangeSubscription // map[peerID]map[subscriptionID]*StatusChangeSubscription
|
||||
@@ -226,6 +228,8 @@ type Status struct {
|
||||
|
||||
routeIDLookup routeIDLookup
|
||||
wgIface WGIfaceStatus
|
||||
|
||||
profile *StatusProfile
|
||||
}
|
||||
|
||||
// NewRecorder returns a new Status instance
|
||||
@@ -240,16 +244,23 @@ func NewRecorder(mgmAddress string) *Status {
|
||||
notifier: newNotifier(),
|
||||
mgmAddress: mgmAddress,
|
||||
resolvedDomainsStates: map[domain.Domain]ResolvedDomainInfo{},
|
||||
profile: NewStatusProfile(context.Background()),
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Status) StartProfile(ctx context.Context) {
|
||||
d.profile.Start(ctx)
|
||||
}
|
||||
|
||||
func (d *Status) SetRelayMgr(manager *relayClient.Manager) {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.profile.inc("SetRelayMgr")
|
||||
d.muxRelays.Lock()
|
||||
defer d.muxRelays.Unlock()
|
||||
d.relayMgr = manager
|
||||
}
|
||||
|
||||
func (d *Status) SetIngressGwMgr(ingressGwMgr *ingressgw.Manager) {
|
||||
d.profile.inc("SetIngressGwMgr")
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.ingressGwMgr = ingressGwMgr
|
||||
@@ -257,6 +268,7 @@ func (d *Status) SetIngressGwMgr(ingressGwMgr *ingressgw.Manager) {
|
||||
|
||||
// ReplaceOfflinePeers replaces
|
||||
func (d *Status) ReplaceOfflinePeers(replacement []State) {
|
||||
d.profile.inc("ReplaceOfflinePeers")
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.offlinePeers = make([]State, len(replacement))
|
||||
@@ -268,6 +280,7 @@ func (d *Status) ReplaceOfflinePeers(replacement []State) {
|
||||
|
||||
// AddPeer adds peer to Daemon status map
|
||||
func (d *Status) AddPeer(peerPubKey string, fqdn string, ip string, ipv6 string) error {
|
||||
d.profile.inc("AddPeer")
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
@@ -295,6 +308,7 @@ func (d *Status) AddPeer(peerPubKey string, fqdn string, ip string, ipv6 string)
|
||||
|
||||
// GetPeer adds peer to Daemon status map
|
||||
func (d *Status) GetPeer(peerPubKey string) (State, error) {
|
||||
d.profile.inc("GetPeer")
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
|
||||
@@ -306,6 +320,7 @@ func (d *Status) GetPeer(peerPubKey string) (State, error) {
|
||||
}
|
||||
|
||||
func (d *Status) PeerByIP(ip string) (string, bool) {
|
||||
d.profile.inc("PeerByIP")
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
|
||||
@@ -323,6 +338,7 @@ func (d *Status) PeerByIP(ip string) (string, bool) {
|
||||
// active peers are matched; peers moved into the offline slice by
|
||||
// ReplaceOfflinePeers are intentionally treated as unknown.
|
||||
func (d *Status) PeerStateByIP(ip string) (State, bool) {
|
||||
d.profile.inc("PeerStateByIP")
|
||||
if ip == "" {
|
||||
return State{}, false
|
||||
}
|
||||
@@ -341,6 +357,7 @@ func (d *Status) PeerStateByIP(ip string) (State, bool) {
|
||||
|
||||
// RemovePeer removes peer from Daemon status map
|
||||
func (d *Status) RemovePeer(peerPubKey string) error {
|
||||
d.profile.inc("RemovePeer")
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
@@ -362,6 +379,7 @@ func (d *Status) RemovePeer(peerPubKey string) error {
|
||||
|
||||
// UpdatePeerState updates peer status
|
||||
func (d *Status) UpdatePeerState(receivedState State) error {
|
||||
d.profile.inc("UpdatePeerState")
|
||||
d.mux.Lock()
|
||||
|
||||
peerState, ok := d.peers[receivedState.PubKey]
|
||||
@@ -404,6 +422,7 @@ func (d *Status) UpdatePeerState(receivedState State) error {
|
||||
}
|
||||
|
||||
func (d *Status) AddPeerStateRoute(peer string, route string, resourceId route.ResID) error {
|
||||
d.profile.inc("AddPeerStateRoute")
|
||||
d.mux.Lock()
|
||||
|
||||
peerState, ok := d.peers[peer]
|
||||
@@ -429,6 +448,7 @@ func (d *Status) AddPeerStateRoute(peer string, route string, resourceId route.R
|
||||
}
|
||||
|
||||
func (d *Status) RemovePeerStateRoute(peer string, route string) error {
|
||||
d.profile.inc("RemovePeerStateRoute")
|
||||
d.mux.Lock()
|
||||
|
||||
peerState, ok := d.peers[peer]
|
||||
@@ -459,11 +479,13 @@ func (d *Status) CheckRoutes(ip netip.Addr) ([]byte, bool) {
|
||||
if d == nil {
|
||||
return nil, false
|
||||
}
|
||||
d.profile.inc("CheckRoutes")
|
||||
resId, isExitNode := d.routeIDLookup.Lookup(ip)
|
||||
return []byte(resId), isExitNode
|
||||
}
|
||||
|
||||
func (d *Status) UpdatePeerICEState(receivedState State) error {
|
||||
d.profile.inc("UpdatePeerICEState")
|
||||
d.mux.Lock()
|
||||
|
||||
peerState, ok := d.peers[receivedState.PubKey]
|
||||
@@ -503,6 +525,7 @@ func (d *Status) UpdatePeerICEState(receivedState State) error {
|
||||
}
|
||||
|
||||
func (d *Status) UpdatePeerRelayedState(receivedState State) error {
|
||||
d.profile.inc("UpdatePeerRelayedState")
|
||||
d.mux.Lock()
|
||||
|
||||
peerState, ok := d.peers[receivedState.PubKey]
|
||||
@@ -539,6 +562,7 @@ func (d *Status) UpdatePeerRelayedState(receivedState State) error {
|
||||
}
|
||||
|
||||
func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error {
|
||||
d.profile.inc("UpdatePeerRelayedStateToDisconnected")
|
||||
d.mux.Lock()
|
||||
|
||||
peerState, ok := d.peers[receivedState.PubKey]
|
||||
@@ -574,6 +598,7 @@ func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error
|
||||
}
|
||||
|
||||
func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
|
||||
d.profile.inc("UpdatePeerICEStateToDisconnected")
|
||||
d.mux.Lock()
|
||||
|
||||
peerState, ok := d.peers[receivedState.PubKey]
|
||||
@@ -613,6 +638,7 @@ func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
|
||||
|
||||
// UpdateWireGuardPeerState updates the WireGuard bits of the peer state
|
||||
func (d *Status) UpdateWireGuardPeerState(pubKey string, wgStats configurer.WGStats) error {
|
||||
d.profile.inc("UpdateWireGuardPeerState")
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
@@ -640,6 +666,7 @@ func hasConnStatusChanged(oldStatus, newStatus ConnStatus) bool {
|
||||
|
||||
// UpdatePeerFQDN update peer's state fqdn only
|
||||
func (d *Status) UpdatePeerFQDN(peerPubKey, fqdn string) error {
|
||||
d.profile.inc("UpdatePeerFQDN")
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
@@ -656,6 +683,7 @@ func (d *Status) UpdatePeerFQDN(peerPubKey, fqdn string) error {
|
||||
|
||||
// UpdatePeerSSHHostKey updates peer's SSH host key
|
||||
func (d *Status) UpdatePeerSSHHostKey(peerPubKey string, sshHostKey []byte) error {
|
||||
d.profile.inc("UpdatePeerSSHHostKey")
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
@@ -672,6 +700,7 @@ func (d *Status) UpdatePeerSSHHostKey(peerPubKey string, sshHostKey []byte) erro
|
||||
|
||||
// FinishPeerListModifications this event invoke the notification
|
||||
func (d *Status) FinishPeerListModifications() {
|
||||
d.profile.inc("FinishPeerListModifications")
|
||||
d.mux.Lock()
|
||||
|
||||
if !d.peerListChangedForNotification {
|
||||
@@ -704,6 +733,7 @@ func (d *Status) FinishPeerListModifications() {
|
||||
}
|
||||
|
||||
func (d *Status) SubscribeToPeerStateChanges(ctx context.Context, peerID string) *StatusChangeSubscription {
|
||||
d.profile.inc("SubscribeToPeerStateChanges")
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
@@ -717,6 +747,7 @@ func (d *Status) SubscribeToPeerStateChanges(ctx context.Context, peerID string)
|
||||
}
|
||||
|
||||
func (d *Status) UnsubscribePeerStateChanges(subscription *StatusChangeSubscription) {
|
||||
d.profile.inc("UnsubscribePeerStateChanges")
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
@@ -742,6 +773,7 @@ func (d *Status) UnsubscribePeerStateChanges(subscription *StatusChangeSubscript
|
||||
|
||||
// GetLocalPeerState returns the local peer state
|
||||
func (d *Status) GetLocalPeerState() LocalPeerState {
|
||||
d.profile.inc("GetLocalPeerState")
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
return d.localPeer.Clone()
|
||||
@@ -749,6 +781,7 @@ func (d *Status) GetLocalPeerState() LocalPeerState {
|
||||
|
||||
// UpdateLocalPeerState updates local peer status
|
||||
func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) {
|
||||
d.profile.inc("UpdateLocalPeerState")
|
||||
d.mux.Lock()
|
||||
d.localPeer = localPeerState
|
||||
fqdn := d.localPeer.FQDN
|
||||
@@ -763,6 +796,7 @@ func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) {
|
||||
|
||||
// AddLocalPeerStateRoute adds a route to the local peer state
|
||||
func (d *Status) AddLocalPeerStateRoute(route string, resourceId route.ResID) {
|
||||
d.profile.inc("AddLocalPeerStateRoute")
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
@@ -780,6 +814,7 @@ func (d *Status) AddLocalPeerStateRoute(route string, resourceId route.ResID) {
|
||||
|
||||
// RemoveLocalPeerStateRoute removes a route from the local peer state
|
||||
func (d *Status) RemoveLocalPeerStateRoute(route string) {
|
||||
d.profile.inc("RemoveLocalPeerStateRoute")
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
@@ -793,6 +828,7 @@ func (d *Status) RemoveLocalPeerStateRoute(route string) {
|
||||
|
||||
// AddResolvedIPLookupEntry adds a resolved IP lookup entry
|
||||
func (d *Status) AddResolvedIPLookupEntry(prefix netip.Prefix, resourceId route.ResID) {
|
||||
d.profile.inc("AddResolvedIPLookupEntry")
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
@@ -801,6 +837,7 @@ func (d *Status) AddResolvedIPLookupEntry(prefix netip.Prefix, resourceId route.
|
||||
|
||||
// RemoveResolvedIPLookupEntry removes a resolved IP lookup entry
|
||||
func (d *Status) RemoveResolvedIPLookupEntry(route string) {
|
||||
d.profile.inc("RemoveResolvedIPLookupEntry")
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
@@ -812,6 +849,7 @@ func (d *Status) RemoveResolvedIPLookupEntry(route string) {
|
||||
|
||||
// CleanLocalPeerStateRoutes cleans all routes from the local peer state
|
||||
func (d *Status) CleanLocalPeerStateRoutes() {
|
||||
d.profile.inc("CleanLocalPeerStateRoutes")
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
@@ -820,6 +858,7 @@ func (d *Status) CleanLocalPeerStateRoutes() {
|
||||
|
||||
// CleanLocalPeerState cleans local peer status
|
||||
func (d *Status) CleanLocalPeerState() {
|
||||
d.profile.inc("CleanLocalPeerState")
|
||||
d.mux.Lock()
|
||||
d.localPeer = LocalPeerState{}
|
||||
fqdn := d.localPeer.FQDN
|
||||
@@ -831,6 +870,7 @@ func (d *Status) CleanLocalPeerState() {
|
||||
|
||||
// MarkManagementDisconnected sets ManagementState to disconnected
|
||||
func (d *Status) MarkManagementDisconnected(err error) {
|
||||
d.profile.inc("MarkManagementDisconnected")
|
||||
d.mux.Lock()
|
||||
d.managementState = false
|
||||
d.managementError = err
|
||||
@@ -843,6 +883,7 @@ func (d *Status) MarkManagementDisconnected(err error) {
|
||||
|
||||
// MarkManagementConnected sets ManagementState to connected
|
||||
func (d *Status) MarkManagementConnected() {
|
||||
d.profile.inc("MarkManagementConnected")
|
||||
d.mux.Lock()
|
||||
d.managementState = true
|
||||
d.managementError = nil
|
||||
@@ -855,6 +896,7 @@ func (d *Status) MarkManagementConnected() {
|
||||
|
||||
// UpdateSignalAddress update the address of the signal server
|
||||
func (d *Status) UpdateSignalAddress(signalURL string) {
|
||||
d.profile.inc("UpdateSignalAddress")
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.signalAddress = signalURL
|
||||
@@ -862,6 +904,7 @@ func (d *Status) UpdateSignalAddress(signalURL string) {
|
||||
|
||||
// UpdateManagementAddress update the address of the management server
|
||||
func (d *Status) UpdateManagementAddress(mgmAddress string) {
|
||||
d.profile.inc("UpdateManagementAddress")
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.mgmAddress = mgmAddress
|
||||
@@ -869,6 +912,7 @@ func (d *Status) UpdateManagementAddress(mgmAddress string) {
|
||||
|
||||
// UpdateRosenpass update the Rosenpass configuration
|
||||
func (d *Status) UpdateRosenpass(rosenpassEnabled, rosenpassPermissive bool) {
|
||||
d.profile.inc("UpdateRosenpass")
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.rosenpassPermissive = rosenpassPermissive
|
||||
@@ -876,6 +920,7 @@ func (d *Status) UpdateRosenpass(rosenpassEnabled, rosenpassPermissive bool) {
|
||||
}
|
||||
|
||||
func (d *Status) UpdateLazyConnection(enabled bool) {
|
||||
d.profile.inc("UpdateLazyConnection")
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.lazyConnectionEnabled = enabled
|
||||
@@ -883,6 +928,7 @@ func (d *Status) UpdateLazyConnection(enabled bool) {
|
||||
|
||||
// MarkSignalDisconnected sets SignalState to disconnected
|
||||
func (d *Status) MarkSignalDisconnected(err error) {
|
||||
d.profile.inc("MarkSignalDisconnected")
|
||||
d.mux.Lock()
|
||||
d.signalState = false
|
||||
d.signalError = err
|
||||
@@ -895,6 +941,7 @@ func (d *Status) MarkSignalDisconnected(err error) {
|
||||
|
||||
// MarkSignalConnected sets SignalState to connected
|
||||
func (d *Status) MarkSignalConnected() {
|
||||
d.profile.inc("MarkSignalConnected")
|
||||
d.mux.Lock()
|
||||
d.signalState = true
|
||||
d.signalError = nil
|
||||
@@ -906,18 +953,21 @@ func (d *Status) MarkSignalConnected() {
|
||||
}
|
||||
|
||||
func (d *Status) UpdateRelayStates(relayResults []relay.ProbeResult) {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.profile.inc("UpdateRelayStates")
|
||||
d.muxRelays.Lock()
|
||||
defer d.muxRelays.Unlock()
|
||||
d.relayStates = relayResults
|
||||
}
|
||||
|
||||
func (d *Status) UpdateDNSStates(dnsStates []NSGroupState) {
|
||||
d.profile.inc("UpdateDNSStates")
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.nsGroupStates = dnsStates
|
||||
}
|
||||
|
||||
func (d *Status) UpdateResolvedDomainsStates(originalDomain domain.Domain, resolvedDomain domain.Domain, prefixes []netip.Prefix, resourceId route.ResID) {
|
||||
d.profile.inc("UpdateResolvedDomainsStates")
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
@@ -933,6 +983,7 @@ func (d *Status) UpdateResolvedDomainsStates(originalDomain domain.Domain, resol
|
||||
}
|
||||
|
||||
func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) {
|
||||
d.profile.inc("DeleteResolvedDomainsStates")
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
@@ -949,6 +1000,7 @@ func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) {
|
||||
}
|
||||
|
||||
func (d *Status) GetRosenpassState() RosenpassState {
|
||||
d.profile.inc("GetRosenpassState")
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
return RosenpassState{
|
||||
@@ -958,12 +1010,14 @@ func (d *Status) GetRosenpassState() RosenpassState {
|
||||
}
|
||||
|
||||
func (d *Status) GetLazyConnection() bool {
|
||||
d.profile.inc("GetLazyConnection")
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
return d.lazyConnectionEnabled
|
||||
}
|
||||
|
||||
func (d *Status) GetManagementState() ManagementState {
|
||||
d.profile.inc("GetManagementState")
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
return ManagementState{
|
||||
@@ -974,6 +1028,7 @@ func (d *Status) GetManagementState() ManagementState {
|
||||
}
|
||||
|
||||
func (d *Status) UpdateLatency(pubKey string, latency time.Duration) error {
|
||||
d.profile.inc("UpdateLatency")
|
||||
if latency <= 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -991,6 +1046,7 @@ func (d *Status) UpdateLatency(pubKey string, latency time.Duration) error {
|
||||
|
||||
// IsLoginRequired determines if a peer's login has expired.
|
||||
func (d *Status) IsLoginRequired() bool {
|
||||
d.profile.inc("IsLoginRequired")
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
|
||||
@@ -1007,6 +1063,7 @@ func (d *Status) IsLoginRequired() bool {
|
||||
}
|
||||
|
||||
func (d *Status) GetSignalState() SignalState {
|
||||
d.profile.inc("GetSignalState")
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
return SignalState{
|
||||
@@ -1018,24 +1075,34 @@ func (d *Status) GetSignalState() SignalState {
|
||||
|
||||
// GetRelayStates returns the stun/turn/permanent relay states
|
||||
func (d *Status) GetRelayStates() []relay.ProbeResult {
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
d.profile.inc("GetRelayStates")
|
||||
d.muxRelays.RLock()
|
||||
|
||||
// debug lines
|
||||
started := time.Now()
|
||||
defer func() {
|
||||
debugElapsed("GetRelayStates", started)
|
||||
}()
|
||||
|
||||
if d.relayMgr == nil {
|
||||
defer d.muxRelays.RUnlock()
|
||||
return d.relayStates
|
||||
}
|
||||
|
||||
relayMgr := d.relayMgr
|
||||
// extend the list of stun, turn servers with the relay server connections
|
||||
relayStates := slices.Clone(d.relayStates)
|
||||
d.muxRelays.RUnlock()
|
||||
|
||||
states := d.relayMgr.RelayStates()
|
||||
states := relayMgr.RelayStates()
|
||||
if len(states) == 0 {
|
||||
// no relay connection tracked yet; surface configured servers as
|
||||
// unavailable with the real reconnect error when known
|
||||
err := relayClient.ErrRelayClientNotConnected
|
||||
if connErr := d.relayMgr.RelayConnectError(); connErr != nil {
|
||||
if connErr := relayMgr.RelayConnectError(); connErr != nil {
|
||||
err = connErr
|
||||
}
|
||||
for _, r := range d.relayMgr.ServerURLs() {
|
||||
for _, r := range relayMgr.ServerURLs() {
|
||||
relayStates = append(relayStates, relay.ProbeResult{
|
||||
URI: r,
|
||||
Err: err,
|
||||
@@ -1055,7 +1122,15 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
|
||||
}
|
||||
|
||||
func (d *Status) ForwardingRules() []firewall.ForwardRule {
|
||||
d.profile.inc("ForwardingRules")
|
||||
d.mux.RLock()
|
||||
|
||||
// debug lines
|
||||
started := time.Now()
|
||||
defer func() {
|
||||
debugElapsed("ForwardingRules", started)
|
||||
}()
|
||||
|
||||
defer d.mux.RUnlock()
|
||||
if d.ingressGwMgr == nil {
|
||||
return nil
|
||||
@@ -1065,6 +1140,7 @@ func (d *Status) ForwardingRules() []firewall.ForwardRule {
|
||||
}
|
||||
|
||||
func (d *Status) GetDNSStates() []NSGroupState {
|
||||
d.profile.inc("GetDNSStates")
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
|
||||
@@ -1073,6 +1149,7 @@ func (d *Status) GetDNSStates() []NSGroupState {
|
||||
}
|
||||
|
||||
func (d *Status) GetResolvedDomainsStates() map[domain.Domain]ResolvedDomainInfo {
|
||||
d.profile.inc("GetResolvedDomainsStates")
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
return maps.Clone(d.resolvedDomainsStates)
|
||||
@@ -1080,6 +1157,7 @@ func (d *Status) GetResolvedDomainsStates() map[domain.Domain]ResolvedDomainInfo
|
||||
|
||||
// GetFullStatus gets full status
|
||||
func (d *Status) GetFullStatus() FullStatus {
|
||||
d.profile.inc("GetFullStatus")
|
||||
fullStatus := FullStatus{
|
||||
ManagementState: d.GetManagementState(),
|
||||
SignalState: d.GetSignalState(),
|
||||
@@ -1106,26 +1184,31 @@ func (d *Status) GetFullStatus() FullStatus {
|
||||
|
||||
// ClientStart will notify all listeners about the new service state
|
||||
func (d *Status) ClientStart() {
|
||||
d.profile.inc("ClientStart")
|
||||
d.notifier.clientStart()
|
||||
}
|
||||
|
||||
// ClientStop will notify all listeners about the new service state
|
||||
func (d *Status) ClientStop() {
|
||||
d.profile.inc("ClientStop")
|
||||
d.notifier.clientStop()
|
||||
}
|
||||
|
||||
// ClientTeardown will notify all listeners about the service is under teardown
|
||||
func (d *Status) ClientTeardown() {
|
||||
d.profile.inc("ClientTeardown")
|
||||
d.notifier.clientTearDown()
|
||||
}
|
||||
|
||||
// SetConnectionListener set a listener to the notifier
|
||||
func (d *Status) SetConnectionListener(listener Listener) {
|
||||
d.profile.inc("SetConnectionListener")
|
||||
d.notifier.setListener(listener)
|
||||
}
|
||||
|
||||
// RemoveConnectionListener remove the listener from the notifier
|
||||
func (d *Status) RemoveConnectionListener() {
|
||||
d.profile.inc("RemoveConnectionListener")
|
||||
d.notifier.removeListener()
|
||||
}
|
||||
|
||||
@@ -1197,6 +1280,7 @@ func (d *Status) PublishEvent(
|
||||
userMsg string,
|
||||
metadata map[string]string,
|
||||
) {
|
||||
d.profile.inc("PublishEvent")
|
||||
event := &proto.SystemEvent{
|
||||
Id: uuid.New().String(),
|
||||
Severity: severity,
|
||||
@@ -1208,6 +1292,13 @@ func (d *Status) PublishEvent(
|
||||
}
|
||||
|
||||
d.eventMux.Lock()
|
||||
|
||||
// debug lines
|
||||
started := time.Now()
|
||||
defer func() {
|
||||
debugElapsed("PublishEvent", started)
|
||||
}()
|
||||
|
||||
defer d.eventMux.Unlock()
|
||||
|
||||
d.eventQueue.Add(event)
|
||||
@@ -1225,6 +1316,7 @@ func (d *Status) PublishEvent(
|
||||
|
||||
// SubscribeToEvents returns a new event subscription
|
||||
func (d *Status) SubscribeToEvents() *EventSubscription {
|
||||
d.profile.inc("SubscribeToEvents")
|
||||
d.eventMux.Lock()
|
||||
defer d.eventMux.Unlock()
|
||||
|
||||
@@ -1240,6 +1332,7 @@ func (d *Status) SubscribeToEvents() *EventSubscription {
|
||||
|
||||
// UnsubscribeFromEvents removes an event subscription
|
||||
func (d *Status) UnsubscribeFromEvents(sub *EventSubscription) {
|
||||
d.profile.inc("UnsubscribeFromEvents")
|
||||
if sub == nil {
|
||||
return
|
||||
}
|
||||
@@ -1255,10 +1348,12 @@ func (d *Status) UnsubscribeFromEvents(sub *EventSubscription) {
|
||||
|
||||
// GetEventHistory returns all events in the queue
|
||||
func (d *Status) GetEventHistory() []*proto.SystemEvent {
|
||||
d.profile.inc("GetEventHistory")
|
||||
return d.eventQueue.GetAll()
|
||||
}
|
||||
|
||||
func (d *Status) SetWgIface(wgInterface WGIfaceStatus) {
|
||||
d.profile.inc("SetWgIface")
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
@@ -1266,7 +1361,15 @@ func (d *Status) SetWgIface(wgInterface WGIfaceStatus) {
|
||||
}
|
||||
|
||||
func (d *Status) PeersStatus() (*configurer.Stats, error) {
|
||||
d.profile.inc("PeersStatus")
|
||||
d.mux.RLock()
|
||||
|
||||
// debug lines
|
||||
started := time.Now()
|
||||
defer func() {
|
||||
debugElapsed("PeersStatus", started)
|
||||
}()
|
||||
|
||||
defer d.mux.RUnlock()
|
||||
if d.wgIface == nil {
|
||||
return nil, fmt.Errorf("wgInterface is nil, cannot retrieve peers status")
|
||||
@@ -1279,7 +1382,15 @@ func (d *Status) PeersStatus() (*configurer.Stats, error) {
|
||||
// and updates the cached peer states. This ensures accurate handshake times and
|
||||
// transfer statistics in status reports without running full health probes.
|
||||
func (d *Status) RefreshWireGuardStats() error {
|
||||
d.profile.inc("RefreshWireGuardStats")
|
||||
d.mux.Lock()
|
||||
|
||||
// debug lines
|
||||
started := time.Now()
|
||||
defer func() {
|
||||
debugElapsed("RefreshWireGuardStats", started)
|
||||
}()
|
||||
|
||||
defer d.mux.Unlock()
|
||||
|
||||
if d.wgIface == nil {
|
||||
@@ -1444,3 +1555,10 @@ func (fs FullStatus) ToProto() *proto.FullStatus {
|
||||
|
||||
return &pbFullStatus
|
||||
}
|
||||
|
||||
func debugElapsed(msg string, startTime time.Time) {
|
||||
if elapsed := time.Since(startTime); elapsed > 1*time.Second {
|
||||
log.Infof("run %s took %s", msg, elapsed)
|
||||
debug.PrintStack()
|
||||
}
|
||||
}
|
||||
|
||||
97
client/internal/peer/status_profile.go
Normal file
97
client/internal/peer/status_profile.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type StatusProfile struct {
|
||||
counts sync.Map
|
||||
}
|
||||
|
||||
func NewStatusProfile(ctx context.Context) *StatusProfile {
|
||||
s := &StatusProfile{}
|
||||
go s.Start(ctx)
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *StatusProfile) Start(ctx context.Context) {
|
||||
ticker := time.NewTicker(10 * time.Second)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.logCounts()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StatusProfile) inc(method string) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
if v, ok := s.counts.Load(method); ok {
|
||||
v.(*atomic.Int64).Add(1)
|
||||
return
|
||||
}
|
||||
cnt := &atomic.Int64{}
|
||||
actual, _ := s.counts.LoadOrStore(method, cnt)
|
||||
actual.(*atomic.Int64).Add(1)
|
||||
}
|
||||
|
||||
func (s *StatusProfile) snapshot() map[string]int64 {
|
||||
out := make(map[string]int64)
|
||||
s.counts.Range(func(k, v any) bool {
|
||||
out[k.(string)] = v.(*atomic.Int64).Load()
|
||||
return true
|
||||
})
|
||||
return out
|
||||
}
|
||||
|
||||
func (s *StatusProfile) logCounts() {
|
||||
counts := s.snapshot()
|
||||
if len(counts) == 0 {
|
||||
log.Infof("status profile: no Status method calls so far")
|
||||
return
|
||||
}
|
||||
|
||||
type kv struct {
|
||||
method string
|
||||
count int64
|
||||
}
|
||||
sorted := make([]kv, 0, len(counts))
|
||||
var total int64
|
||||
for m, c := range counts {
|
||||
if c == 0 {
|
||||
continue
|
||||
}
|
||||
sorted = append(sorted, kv{m, c})
|
||||
total += c
|
||||
}
|
||||
sort.Slice(sorted, func(i, j int) bool {
|
||||
if sorted[i].count != sorted[j].count {
|
||||
return sorted[i].count > sorted[j].count
|
||||
}
|
||||
return sorted[i].method < sorted[j].method
|
||||
})
|
||||
|
||||
var b strings.Builder
|
||||
for i, e := range sorted {
|
||||
if i > 0 {
|
||||
b.WriteString(", ")
|
||||
}
|
||||
b.WriteString(e.method)
|
||||
b.WriteByte('=')
|
||||
b.WriteString(strconv.FormatInt(e.count, 10))
|
||||
}
|
||||
log.Infof("status profile (cumulative total=%d): %s", total, b.String())
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -164,6 +165,10 @@ func (w *WorkerICE) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HA
|
||||
return
|
||||
}
|
||||
|
||||
if candidateViaRoutes(candidate, haRoutes) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := w.agent.AddRemoteCandidate(candidate); err != nil {
|
||||
w.log.Errorf("error while handling remote candidate")
|
||||
return
|
||||
@@ -461,7 +466,7 @@ func (w *WorkerICE) createForwardedCandidate(srflxCandidate ice.Candidate, mappi
|
||||
}
|
||||
|
||||
func (w *WorkerICE) onICESelectedCandidatePair(agent *icemaker.ThreadSafeAgent, c1, c2 ice.Candidate) {
|
||||
w.log.Debugf("selected candidate pair [local <-> remote] -> [%s <-> %s], peer %s", c1.String(), c2.String(),
|
||||
w.log.Infof("selected candidate pair [local <-> remote] -> [%s <-> %s], peer %s", c1.String(), c2.String(),
|
||||
w.config.Key)
|
||||
|
||||
pairStat, ok := agent.GetSelectedCandidatePairStats()
|
||||
@@ -584,6 +589,34 @@ func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive
|
||||
return ec, nil
|
||||
}
|
||||
|
||||
func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool {
|
||||
addr, err := netip.ParseAddr(candidate.Address())
|
||||
if err != nil {
|
||||
log.Errorf("Failed to parse IP address %s: %v", candidate.Address(), err)
|
||||
return false
|
||||
}
|
||||
|
||||
var routePrefixes []netip.Prefix
|
||||
for _, routes := range clientRoutes {
|
||||
if len(routes) > 0 && routes[0] != nil {
|
||||
routePrefixes = append(routePrefixes, routes[0].Network)
|
||||
}
|
||||
}
|
||||
|
||||
for _, prefix := range routePrefixes {
|
||||
// default route is handled by route exclusion / ip rules
|
||||
if prefix.Bits() == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if prefix.Contains(addr) {
|
||||
log.Debugf("Ignoring candidate [%s], its address is part of routed network %s", candidate.String(), prefix)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isRelayCandidate(candidate ice.Candidate) bool {
|
||||
return candidate.Type() == ice.CandidateTypeRelay
|
||||
}
|
||||
|
||||
@@ -121,12 +121,9 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf wgIface, init
|
||||
return Nexthop{}, vars.ErrRouteNotAllowed
|
||||
}
|
||||
|
||||
// BSDs blackhole a /32 added inside a directly-connected subnet; Linux/Windows need it to beat the wt0 route.
|
||||
switch runtime.GOOS {
|
||||
case "darwin", "freebsd", "netbsd", "openbsd", "dragonfly":
|
||||
if isLocal, subnet := r.isPrefixInLocalSubnets(prefix); isLocal {
|
||||
return Nexthop{}, fmt.Errorf("prefix %s is part of local subnet %s: %w", prefix, subnet, vars.ErrRouteNotAllowed)
|
||||
}
|
||||
// Check if the prefix is part of any local subnets
|
||||
if isLocal, subnet := r.isPrefixInLocalSubnets(prefix); isLocal {
|
||||
return Nexthop{}, fmt.Errorf("prefix %s is part of local subnet %s: %w", prefix, subnet, vars.ErrRouteNotAllowed)
|
||||
}
|
||||
|
||||
// Determine the exit interface and next hop for the prefix, so we can add a specific route
|
||||
|
||||
@@ -136,6 +136,7 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable
|
||||
networksDisabled: networksDisabled,
|
||||
jwtCache: newJWTCache(),
|
||||
}
|
||||
go s.statusRecorder.StartProfile(ctx)
|
||||
agent := &serverAgent{s}
|
||||
s.sleepHandler = sleephandler.New(agent)
|
||||
s.startSleepDetector()
|
||||
|
||||
@@ -418,14 +418,7 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient {
|
||||
case args.showProfiles:
|
||||
s.showProfilesUI()
|
||||
case args.showQuickActions:
|
||||
// Suppress the on-boot Quick Actions popup when the daemon
|
||||
// reports DisableAutoConnect=true — that flag carries both the
|
||||
// user's "Connect on Startup = off" preference AND any MDM-
|
||||
// enforced override (applyMDMPolicy writes the policy value
|
||||
// into the same Config field). See netbirdio/netbird#5744.
|
||||
if !s.disableAutoConnectFromDaemon() {
|
||||
s.showQuickActionsUI()
|
||||
}
|
||||
s.showQuickActionsUI()
|
||||
case args.showUpdate:
|
||||
s.showUpdateProgress(ctx, args.showUpdateVersion)
|
||||
}
|
||||
@@ -1345,40 +1338,6 @@ func (s *serviceClient) getFeatures() (*proto.GetFeaturesResponse, error) {
|
||||
return features, nil
|
||||
}
|
||||
|
||||
// disableAutoConnectFromDaemon returns true when the daemon reports
|
||||
// the active profile has DisableAutoConnect=true. Used by the
|
||||
// --quick-actions startup path to suppress the on-boot popup when the
|
||||
// user (or an MDM admin) opted out of auto-connecting; both cases
|
||||
// converge on the same Config field because applyMDMPolicy writes the
|
||||
// policy value into it. Returns false on any RPC / lookup failure so a
|
||||
// daemon hiccup does not silently swallow the popup.
|
||||
func (s *serviceClient) disableAutoConnectFromDaemon() bool {
|
||||
activeProf, err := s.profileManager.GetActiveProfile()
|
||||
if err != nil {
|
||||
log.Warnf("disableAutoConnectFromDaemon: get active profile: %v", err)
|
||||
return false
|
||||
}
|
||||
currUser, err := user.Current()
|
||||
if err != nil {
|
||||
log.Warnf("disableAutoConnectFromDaemon: get current user: %v", err)
|
||||
return false
|
||||
}
|
||||
conn, err := s.getSrvClient(failFastTimeout)
|
||||
if err != nil {
|
||||
log.Warnf("disableAutoConnectFromDaemon: get daemon client: %v", err)
|
||||
return false
|
||||
}
|
||||
srvCfg, err := conn.GetConfig(s.ctx, &proto.GetConfigRequest{
|
||||
ProfileName: activeProf.ID.String(),
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
log.Warnf("disableAutoConnectFromDaemon: GetConfig RPC: %v", err)
|
||||
return false
|
||||
}
|
||||
return srvCfg.GetDisableAutoConnect()
|
||||
}
|
||||
|
||||
// getSrvConfig from the service to show it in the settings window.
|
||||
func (s *serviceClient) getSrvConfig() {
|
||||
s.managementURL = profilemanager.DefaultManagementURL
|
||||
|
||||
@@ -497,7 +497,7 @@ func (c *Controller) BufferUpdateAffectedPeers(ctx context.Context, accountID st
|
||||
c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation))
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Tracef("buffer updating %d affected peers for account %s from %s with reason %s/%s", len(peerIDs), accountID, util.GetCallerName(), reason.Operation, reason.Resource)
|
||||
log.WithContext(ctx).Tracef("buffer updating %d affected peers for account %s from %s", len(peerIDs), accountID, util.GetCallerName())
|
||||
|
||||
bufUpd, _ := c.affectedPeerUpdateLocks.LoadOrStore(accountID, &bufferAffectedUpdate{
|
||||
peerIDs: make(map[string]struct{}),
|
||||
@@ -610,10 +610,12 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
|
||||
startPosture := time.Now()
|
||||
postureChecks, err := c.getPeerPostureChecks(account, peerID)
|
||||
if err != nil {
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
log.WithContext(ctx).Debugf("getPeerPostureChecks took %s", time.Since(startPosture))
|
||||
|
||||
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
|
||||
if err != nil {
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
|
||||
const (
|
||||
reconnThreshold = 5 * time.Minute
|
||||
baseBlockDuration = 30 * time.Minute // Duration for which a peer is banned after exceeding the reconnection limit
|
||||
baseBlockDuration = 10 * time.Minute // Duration for which a peer is banned after exceeding the reconnection limit
|
||||
reconnLimitForBan = 30 // Number of reconnections within the reconnTreshold that triggers a ban
|
||||
metaChangeLimit = 3 // Number of reconnections with different metadata that triggers a ban of one peer
|
||||
)
|
||||
@@ -139,13 +139,22 @@ func (l *loginFilter) addLogin(wgPubKey string, metaHash uint64) {
|
||||
state.lastSeen = now
|
||||
}
|
||||
|
||||
func metaHash(meta nbpeer.PeerSystemMeta) uint64 {
|
||||
func metaHash(meta nbpeer.PeerSystemMeta, pubip string) uint64 {
|
||||
h := fnv.New64a()
|
||||
|
||||
h.Write([]byte(meta.WtVersion))
|
||||
h.Write([]byte(meta.OSVersion))
|
||||
h.Write([]byte(meta.KernelVersion))
|
||||
h.Write([]byte(meta.Hostname))
|
||||
h.Write([]byte(meta.SystemSerialNumber))
|
||||
h.Write([]byte(pubip))
|
||||
|
||||
return h.Sum64()
|
||||
macs := uint64(0)
|
||||
for _, na := range meta.NetworkAddresses {
|
||||
for _, r := range na.Mac {
|
||||
macs += uint64(r)
|
||||
}
|
||||
}
|
||||
|
||||
return h.Sum64() + macs
|
||||
}
|
||||
|
||||
@@ -164,7 +164,9 @@ func BenchmarkHashingMethods(b *testing.B) {
|
||||
KernelVersion: "5.15.0-76-generic",
|
||||
Hostname: "prod-server-database-01",
|
||||
SystemSerialNumber: "PC-1234567890",
|
||||
NetworkAddresses: []nbpeer.NetworkAddress{{Mac: "00:1B:44:11:3A:B7"}, {Mac: "00:1B:44:11:3A:B8"}},
|
||||
}
|
||||
pubip := "8.8.8.8"
|
||||
|
||||
var resultString string
|
||||
var resultUint uint64
|
||||
@@ -173,7 +175,7 @@ func BenchmarkHashingMethods(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
resultString = builderString(meta)
|
||||
resultString = builderString(meta, pubip)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -181,7 +183,7 @@ func BenchmarkHashingMethods(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
resultString = fnvHashToString(meta)
|
||||
resultString = fnvHashToString(meta, pubip)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -189,7 +191,7 @@ func BenchmarkHashingMethods(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
resultUint = metaHash(meta)
|
||||
resultUint = metaHash(meta, pubip)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -197,20 +199,29 @@ func BenchmarkHashingMethods(b *testing.B) {
|
||||
_ = resultUint
|
||||
}
|
||||
|
||||
func fnvHashToString(meta nbpeer.PeerSystemMeta) string {
|
||||
func fnvHashToString(meta nbpeer.PeerSystemMeta, pubip string) string {
|
||||
h := fnv.New64a()
|
||||
|
||||
if len(meta.NetworkAddresses) != 0 {
|
||||
for _, na := range meta.NetworkAddresses {
|
||||
h.Write([]byte(na.Mac))
|
||||
}
|
||||
}
|
||||
|
||||
h.Write([]byte(meta.WtVersion))
|
||||
h.Write([]byte(meta.OSVersion))
|
||||
h.Write([]byte(meta.KernelVersion))
|
||||
h.Write([]byte(meta.Hostname))
|
||||
h.Write([]byte(meta.SystemSerialNumber))
|
||||
h.Write([]byte(pubip))
|
||||
|
||||
return strconv.FormatUint(h.Sum64(), 16)
|
||||
}
|
||||
|
||||
func builderString(meta nbpeer.PeerSystemMeta) string {
|
||||
estimatedSize := len(meta.WtVersion) + len(meta.OSVersion) + len(meta.KernelVersion) + len(meta.Hostname) + len(meta.SystemSerialNumber) + 4
|
||||
func builderString(meta nbpeer.PeerSystemMeta, pubip string) string {
|
||||
mac := getMacAddress(meta.NetworkAddresses)
|
||||
estimatedSize := len(meta.WtVersion) + len(meta.OSVersion) + len(meta.KernelVersion) + len(meta.Hostname) + len(meta.SystemSerialNumber) +
|
||||
len(pubip) + len(mac) + 6
|
||||
|
||||
var b strings.Builder
|
||||
b.Grow(estimatedSize)
|
||||
@@ -224,10 +235,23 @@ func builderString(meta nbpeer.PeerSystemMeta) string {
|
||||
b.WriteString(meta.Hostname)
|
||||
b.WriteByte('|')
|
||||
b.WriteString(meta.SystemSerialNumber)
|
||||
b.WriteByte('|')
|
||||
b.WriteString(pubip)
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func getMacAddress(nas []nbpeer.NetworkAddress) string {
|
||||
if len(nas) == 0 {
|
||||
return ""
|
||||
}
|
||||
macs := make([]string, 0, len(nas))
|
||||
for _, na := range nas {
|
||||
macs = append(macs, na.Mac)
|
||||
}
|
||||
return strings.Join(macs, "/")
|
||||
}
|
||||
|
||||
func BenchmarkLoginFilter_ParallelLoad(b *testing.B) {
|
||||
filter := newLoginFilterWithCfg(testAdvancedCfg())
|
||||
numKeys := 100000
|
||||
|
||||
@@ -254,7 +254,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
||||
return mapError(ctx, err)
|
||||
}
|
||||
|
||||
metahashed := metaHash(peerMeta)
|
||||
metahashed := metaHash(peerMeta, sRealIP)
|
||||
if userID == "" && !s.loginFilter.allowLogin(peerKey.String(), metahashed) {
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().CountSyncRequestBlocked()
|
||||
@@ -306,7 +306,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
||||
log.WithContext(ctx).Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)
|
||||
}
|
||||
|
||||
metahash := metaHash(peerMeta)
|
||||
metahash := metaHash(peerMeta, realIP.String())
|
||||
s.loginFilter.addLogin(peerKey.String(), metahash)
|
||||
|
||||
peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP, syncStart)
|
||||
@@ -732,7 +732,7 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
|
||||
}
|
||||
|
||||
peerMeta := extractPeerMeta(ctx, loginReq.GetMeta())
|
||||
metahashed := metaHash(peerMeta)
|
||||
metahashed := metaHash(peerMeta, sRealIP)
|
||||
if !s.loginFilter.allowLogin(peerKey.String(), metahashed) {
|
||||
if s.logBlockedPeers {
|
||||
log.WithContext(ctx).Tracef("peer %s with meta hash %d is blocked from login", peerKey.String(), metahashed)
|
||||
@@ -788,11 +788,7 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
|
||||
ExtraDNSLabels: loginReq.GetDnsLabels(),
|
||||
})
|
||||
if err != nil {
|
||||
if errors.Is(err, internalStatus.ErrNoAuthMethodProvided) {
|
||||
log.WithContext(ctx).Tracef("failed logging in peer %s: %s", peerKey, err)
|
||||
} else {
|
||||
log.WithContext(ctx).Warnf("failed logging in peer %s: %s", peerKey, err)
|
||||
}
|
||||
log.WithContext(ctx).Warnf("failed logging in peer %s: %s", peerKey, err)
|
||||
return nil, mapError(ctx, err)
|
||||
}
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ func TestAffectedPeers_DependencyCoverageMatrix(t *testing.T) {
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
return affectedpeers.Change{ChangedPeerIDs: []string{s.routerPeerID}},
|
||||
[]string{s.sourcePeerID, s.routerPeerID}, []string{s.unrelatedPeerID}
|
||||
[]string{s.sourcePeerID}, []string{s.unrelatedPeerID}
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -106,9 +106,11 @@ func TestAffectedPeers_DependencyCoverageMatrix(t *testing.T) {
|
||||
change, mustContain, mustExclude := r.build(t, s, ctx)
|
||||
affected := resolveAffected(t, s.manager.Store, s.accountID, change)
|
||||
|
||||
assert.ElementsMatch(t, affected, mustContain, "expected peer to be affected")
|
||||
for _, peerID := range mustExclude {
|
||||
assert.NotContains(t, affected, peerID, "peer must not be affected")
|
||||
for _, id := range mustContain {
|
||||
assert.Contains(t, affected, id, "expected peer to be affected")
|
||||
}
|
||||
for _, id := range mustExclude {
|
||||
assert.NotContains(t, affected, id, "peer must not be affected")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -251,9 +251,7 @@ func TestAffectedPeers_E2E_UpdateResource_DestinationResourcePolicy_RefreshesSou
|
||||
}
|
||||
}
|
||||
|
||||
// A disabled sibling router routes to nobody, so updating a resource on its network
|
||||
// must NOT refresh its peer (the enabled router carries the bridge instead).
|
||||
func TestAffectedPeers_E2E_UpdateResource_DisabledSiblingRouterNotBridged(t *testing.T) {
|
||||
func TestAffectedPeers_E2E_UpdateResource_DisabledSiblingRouter_StillBridged(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
@@ -276,18 +274,13 @@ func TestAffectedPeers_E2E_UpdateResource_DisabledSiblingRouterNotBridged(t *tes
|
||||
require.NoError(t, err)
|
||||
|
||||
disabledCh := s.updateManager.CreateChannel(ctx, disabledRouterPeer.ID)
|
||||
enabledCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
|
||||
t.Cleanup(func() {
|
||||
s.updateManager.CloseChannel(ctx, disabledRouterPeer.ID)
|
||||
s.updateManager.CloseChannel(ctx, s.routerPeerID)
|
||||
})
|
||||
t.Cleanup(func() { s.updateManager.CloseChannel(ctx, disabledRouterPeer.ID) })
|
||||
|
||||
settleAffectedUpdates(disabledCh, enabledCh)
|
||||
settleAffectedUpdates(disabledCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, enabledCh)
|
||||
peerShouldNotReceiveUpdate(t, disabledCh)
|
||||
peerShouldReceiveUpdate(t, disabledCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -305,7 +298,7 @@ func TestAffectedPeers_E2E_UpdateResource_DisabledSiblingRouterNotBridged(t *tes
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout")
|
||||
t.Error("timeout: resource update did not refresh the disabled sibling router's peer")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -682,9 +682,6 @@ func TestAffectedPeers_AllRoutingPeers_Network(t *testing.T) {
|
||||
assert.Contains(t, affected, secondRouterPeer.ID, "second routing peer on the same network must also be affected")
|
||||
}
|
||||
|
||||
// A disabled router in the snapshot routes to nobody, so it is skipped when the
|
||||
// walk scans existing account data: a policy edit still folds the literal source
|
||||
// group, but not the disabled router's peer.
|
||||
func TestAffectedPeers_DisabledRouter(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
@@ -697,13 +694,11 @@ func TestAffectedPeers_DisabledRouter(t *testing.T) {
|
||||
|
||||
affected := s.resolvePolicyAffected(ctx, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID))
|
||||
|
||||
assert.Contains(t, affected, s.sourcePeerID, "source peer (literal policy source group) must be affected")
|
||||
assert.NotContains(t, affected, s.routerPeerID,
|
||||
"a disabled router routes to nobody, so its peer must not be folded from snapshot data")
|
||||
assert.Contains(t, affected, s.sourcePeerID, "source peer must be affected")
|
||||
assert.Contains(t, affected, s.routerPeerID,
|
||||
"disabled router's peer must still be affected: Enabled must not gate affected-peers")
|
||||
}
|
||||
|
||||
// A disabled resource in the snapshot is skipped: the policy edit still folds the
|
||||
// literal source group, but the resource no longer bridges to its network's router.
|
||||
func TestAffectedPeers_DisabledResource(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
@@ -715,9 +710,9 @@ func TestAffectedPeers_DisabledResource(t *testing.T) {
|
||||
|
||||
affected := s.resolvePolicyAffected(ctx, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID))
|
||||
|
||||
assert.Contains(t, affected, s.sourcePeerID, "source peer (literal policy source group) must be affected")
|
||||
assert.NotContains(t, affected, s.routerPeerID,
|
||||
"a disabled resource routes to nobody, so its network's router must not be folded from snapshot data")
|
||||
assert.Contains(t, affected, s.sourcePeerID, "source peer must be affected")
|
||||
assert.Contains(t, affected, s.routerPeerID,
|
||||
"disabled resource must still resolve the routing peer: Enabled must not gate affected-peers")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_DisabledRule(t *testing.T) {
|
||||
|
||||
@@ -96,54 +96,33 @@ func affectedGroupID(i int) string { return fmt.Sprintf("affected-grp-%d", i)
|
||||
func affectedGroupName(i int) string { return fmt.Sprintf("AffectedGroup%d", i) }
|
||||
|
||||
func TestCollectGroupChange_PolicyLinked(t *testing.T) {
|
||||
manager, s, accountID, peerIDs, groupIDs := setupAffectedPeersTest(t)
|
||||
manager, s, accountID, _, groupIDs := setupAffectedPeersTest(t)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := manager.SavePolicy(ctx, accountID, userID, &types.Policy{
|
||||
Enabled: true,
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{groupIDs[0]},
|
||||
Destinations: []string{groupIDs[1]},
|
||||
SourceResource: types.Resource{ID: peerIDs[0], Type: types.ResourceTypePeer},
|
||||
DestinationResource: types.Resource{ID: peerIDs[1], Type: types.ResourceTypePeer},
|
||||
Bidirectional: true,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{groupIDs[0]},
|
||||
Destinations: []string{groupIDs[1]},
|
||||
SourceResource: types.Resource{ID: peerIDs[2], Type: types.ResourceTypeHost},
|
||||
DestinationResource: types.Resource{ID: peerIDs[3], Type: types.ResourceTypeHost},
|
||||
Bidirectional: true,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{groupIDs[0]},
|
||||
Destinations: []string{groupIDs[1]},
|
||||
SourceResource: types.Resource{ID: "", Type: types.ResourceTypePeer},
|
||||
DestinationResource: types.Resource{ID: "", Type: types.ResourceTypePeer},
|
||||
Bidirectional: true,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
Enabled: true,
|
||||
Sources: []string{groupIDs[0]},
|
||||
Destinations: []string{groupIDs[1]},
|
||||
Bidirectional: true,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
groups, directPeers := collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[0]})
|
||||
assert.ElementsMatch(t, groups, []string{groupIDs[0], groupIDs[1]})
|
||||
assert.ElementsMatch(t, directPeers, []string{peerIDs[1]})
|
||||
groups, _ := collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[0]})
|
||||
assert.Contains(t, groups, groupIDs[0])
|
||||
assert.Contains(t, groups, groupIDs[1])
|
||||
|
||||
groups, directPeers = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[1]})
|
||||
assert.ElementsMatch(t, groups, []string{groupIDs[0], groupIDs[1]})
|
||||
assert.ElementsMatch(t, directPeers, []string{peerIDs[0]})
|
||||
groups, _ = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[1]})
|
||||
assert.Contains(t, groups, groupIDs[0])
|
||||
assert.Contains(t, groups, groupIDs[1])
|
||||
|
||||
groups, directPeers = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[2]})
|
||||
groups, _ = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[2]})
|
||||
assert.Empty(t, groups)
|
||||
assert.Empty(t, directPeers)
|
||||
}
|
||||
|
||||
func TestCollectGroupChange_PolicyWithDirectPeerResource(t *testing.T) {
|
||||
@@ -154,44 +133,20 @@ func TestCollectGroupChange_PolicyWithDirectPeerResource(t *testing.T) {
|
||||
Enabled: true,
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{groupIDs[0]},
|
||||
SourceResource: types.Resource{ID: peerIDs[3], Type: types.ResourceTypePeer},
|
||||
DestinationResource: types.Resource{ID: peerIDs[4], Type: types.ResourceTypePeer},
|
||||
Destinations: []string{groupIDs[1]},
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{groupIDs[0]},
|
||||
SourceResource: types.Resource{ID: peerIDs[1], Type: types.ResourceTypeHost},
|
||||
DestinationResource: types.Resource{ID: peerIDs[2], Type: types.ResourceTypeHost},
|
||||
Destinations: []string{groupIDs[1]},
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{groupIDs[0]},
|
||||
SourceResource: types.Resource{ID: "", Type: types.ResourceTypePeer},
|
||||
DestinationResource: types.Resource{ID: "", Type: types.ResourceTypePeer},
|
||||
Destinations: []string{groupIDs[1]},
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
Enabled: true,
|
||||
Sources: []string{groupIDs[0]},
|
||||
SourceResource: types.Resource{ID: peerIDs[3], Type: types.ResourceTypePeer},
|
||||
Destinations: []string{groupIDs[1]},
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
groups, directPeers := collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[0]})
|
||||
assert.ElementsMatch(t, groups, []string{groupIDs[0], groupIDs[1]})
|
||||
assert.ElementsMatch(t, directPeers, []string{peerIDs[4]})
|
||||
|
||||
groups, directPeers = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[1]})
|
||||
assert.ElementsMatch(t, groups, []string{groupIDs[0], groupIDs[1]})
|
||||
assert.ElementsMatch(t, directPeers, []string{peerIDs[3]})
|
||||
|
||||
groups, directPeers = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[2]})
|
||||
assert.Empty(t, groups)
|
||||
assert.Empty(t, directPeers)
|
||||
assert.Contains(t, groups, groupIDs[0])
|
||||
assert.Contains(t, groups, groupIDs[1])
|
||||
assert.Contains(t, directPeers, peerIDs[3])
|
||||
}
|
||||
|
||||
func TestCollectGroupChange_PolicyWithNonPeerResource_NoDirectPeers(t *testing.T) {
|
||||
@@ -213,7 +168,8 @@ func TestCollectGroupChange_PolicyWithNonPeerResource_NoDirectPeers(t *testing.T
|
||||
require.NoError(t, err)
|
||||
|
||||
groups, directPeers := collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[0]})
|
||||
assert.ElementsMatch(t, groups, []string{groupIDs[0], groupIDs[1]})
|
||||
assert.Contains(t, groups, groupIDs[0])
|
||||
assert.Contains(t, groups, groupIDs[1])
|
||||
assert.Empty(t, directPeers, "non-peer resources should not produce direct peer IDs")
|
||||
}
|
||||
|
||||
@@ -338,7 +294,6 @@ func TestCollectGroupChange_NetworkRouterLinked(t *testing.T) {
|
||||
AccountID: accountID,
|
||||
PeerGroups: []string{groupIDs[0]},
|
||||
Peer: peerIDs[3],
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -369,7 +324,6 @@ func TestCollectGroupChange_NetworkRouterPeerOnlyNoGroups(t *testing.T) {
|
||||
NetworkID: net1.ID,
|
||||
AccountID: accountID,
|
||||
Peer: peerIDs[4],
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -419,11 +373,17 @@ func TestCollectGroupChange_MultipleEntities(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
groups, directPeers := collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[0]})
|
||||
assert.ElementsMatch(t, groups, []string{groupIDs[0], groupIDs[1]})
|
||||
assert.Contains(t, groups, groupIDs[0])
|
||||
assert.Contains(t, groups, groupIDs[1])
|
||||
assert.NotContains(t, groups, groupIDs[2])
|
||||
assert.NotContains(t, groups, groupIDs[3])
|
||||
assert.Empty(t, directPeers)
|
||||
|
||||
groups, directPeers = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[3]})
|
||||
assert.ElementsMatch(t, groups, []string{groupIDs[2], groupIDs[3]})
|
||||
assert.Contains(t, groups, groupIDs[2])
|
||||
assert.Contains(t, groups, groupIDs[3])
|
||||
assert.NotContains(t, groups, groupIDs[0])
|
||||
assert.NotContains(t, groups, groupIDs[1])
|
||||
assert.Empty(t, directPeers)
|
||||
}
|
||||
|
||||
@@ -492,9 +452,8 @@ func TestResolveAffectedPeers_PolicyBetweenTwoGroups(t *testing.T) {
|
||||
result = manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[1]})
|
||||
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[1]}, result)
|
||||
|
||||
// peerIDs[2] is unrelated to the route; only its own map can change.
|
||||
result = manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[2]})
|
||||
assert.ElementsMatch(t, []string{peerIDs[2]}, result)
|
||||
assert.Empty(t, result)
|
||||
}
|
||||
|
||||
func TestResolveAffectedPeers_PolicyThreeGroups(t *testing.T) {
|
||||
@@ -515,7 +474,7 @@ func TestResolveAffectedPeers_PolicyThreeGroups(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
result := manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[0]})
|
||||
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[2]}, result)
|
||||
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[1], peerIDs[2]}, result)
|
||||
}
|
||||
|
||||
func TestResolveAffectedPeers_RoutePeerGroups(t *testing.T) {
|
||||
@@ -547,9 +506,8 @@ func TestResolveAffectedPeers_RoutePeerGroups(t *testing.T) {
|
||||
result = manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[1]})
|
||||
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[1]}, result)
|
||||
|
||||
// peerIDs[2] is in no policy; only its own map can change, so it refreshes itself.
|
||||
result = manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[2]})
|
||||
assert.ElementsMatch(t, []string{peerIDs[2]}, result)
|
||||
assert.Empty(t, result)
|
||||
}
|
||||
|
||||
func TestResolveAffectedPeers_RouteWithDirectPeer(t *testing.T) {
|
||||
@@ -606,9 +564,9 @@ func TestResolveAffectedPeers_RouteWithAccessControlGroups(t *testing.T) {
|
||||
result := manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[2]})
|
||||
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[1], peerIDs[2]}, result)
|
||||
|
||||
// peer3 is unrelated to the route; only its own map can change.
|
||||
// peer3 is unrelated
|
||||
result = manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[3]})
|
||||
assert.ElementsMatch(t, []string{peerIDs[3]}, result)
|
||||
assert.Empty(t, result)
|
||||
}
|
||||
|
||||
func TestResolveAffectedPeers_NetworkRouter(t *testing.T) {
|
||||
@@ -629,7 +587,6 @@ func TestResolveAffectedPeers_NetworkRouter(t *testing.T) {
|
||||
AccountID: accountID,
|
||||
PeerGroups: []string{groupIDs[0]},
|
||||
Peer: peerIDs[3],
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -702,13 +659,9 @@ func TestResolveAffectedPeers_PeerInMultipleGroups(t *testing.T) {
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
// peer0 is in group0 AND group1, so both policies apply. A peer change folds
|
||||
// only the changed peer plus the opposite side of each rule: group2 (peer2) via
|
||||
// the group0 policy and group3 (peer3) via the group1 policy. peer1, a co-member
|
||||
// of group1, is a sibling of the changed peer and must NOT refresh.
|
||||
// peer0 is in group0 AND group1, so both policies apply
|
||||
result := manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[0]})
|
||||
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[2], peerIDs[3]}, result)
|
||||
assert.NotContains(t, result, peerIDs[1], "co-member of the changed peer's group must not refresh")
|
||||
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[1], peerIDs[2], peerIDs[3]}, result)
|
||||
}
|
||||
|
||||
func TestResolveAffectedPeers_MultipleChangedPeers(t *testing.T) {
|
||||
@@ -744,7 +697,7 @@ func TestResolveAffectedPeers_MultipleChangedPeers(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
result := manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[0], peerIDs[2]})
|
||||
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[2], peerIDs[1], peerIDs[3]}, result)
|
||||
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[1], peerIDs[2], peerIDs[3]}, result)
|
||||
}
|
||||
|
||||
func TestResolveAffectedPeers_SharedGroupAcrossPolicyAndRoute(t *testing.T) {
|
||||
@@ -901,9 +854,8 @@ func TestAffectedPeers_IsolatedPolicies(t *testing.T) {
|
||||
assert.NotContains(t, result, peerIDs[0])
|
||||
assert.NotContains(t, result, peerIDs[1])
|
||||
|
||||
// peerIDs[4] is in neither isolated policy; only its own map can change.
|
||||
result = manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[4]})
|
||||
assert.ElementsMatch(t, []string{peerIDs[4]}, result)
|
||||
assert.Empty(t, result)
|
||||
}
|
||||
|
||||
func TestAffectedPeers_IsolatedRouteAndPolicy(t *testing.T) {
|
||||
@@ -1025,13 +977,12 @@ func TestAffectedPeers_GroupUpdateOnlyAffectsLinkedPeers(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
// A peer in no policy/route refreshes only itself — no other peer is affected.
|
||||
func TestAffectedPeers_UnlinkedPeerChange_RefreshesSelfOnly(t *testing.T) {
|
||||
func TestAffectedPeers_UnlinkedGroupChange_NoUpdates(t *testing.T) {
|
||||
manager, s, accountID, peerIDs, _ := setupAffectedPeersTest(t)
|
||||
ctx := context.Background()
|
||||
|
||||
result := manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[0]})
|
||||
assert.ElementsMatch(t, []string{peerIDs[0]}, result)
|
||||
assert.Empty(t, result)
|
||||
}
|
||||
|
||||
// TestAffectedPeers_PolicyChange_UnrelatedPeerNoUpdate verifies that creating/deleting a
|
||||
@@ -1381,7 +1332,6 @@ func TestAffectedPeers_NetworkRouterUnlinkedPeerNoUpdate(t *testing.T) {
|
||||
NetworkID: net1.ID,
|
||||
AccountID: accountID,
|
||||
PeerGroups: []string{"nr-grpA"},
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1805,9 +1755,7 @@ func TestCollectAffectedFromProxyServices_GroupContainingTargetPeerChanged(t *te
|
||||
assert.Contains(t, directPeers, peerIDs[1], "target peer must be refreshed")
|
||||
}
|
||||
|
||||
// A disabled service in the snapshot proxies nothing, so it is skipped: a changed
|
||||
// target peer does not pull in the service's proxy peer.
|
||||
func TestCollectAffectedFromProxyServices_DisabledServiceSkipped(t *testing.T) {
|
||||
func TestCollectAffectedFromProxyServices_DisabledServiceStillMatches(t *testing.T) {
|
||||
manager, s, accountID, peerIDs, _ := setupAffectedPeersTest(t)
|
||||
ctx := context.Background()
|
||||
|
||||
@@ -1833,7 +1781,8 @@ func TestCollectAffectedFromProxyServices_DisabledServiceSkipped(t *testing.T) {
|
||||
require.NoError(t, s.CreateService(ctx, svc))
|
||||
|
||||
_, directPeers := collectPeerChangeAffectedGroups(ctx, manager.Store, accountID, nil, []string{peerIDs[1]})
|
||||
assert.NotContains(t, directPeers, peerIDs[0], "a disabled service proxies nothing, so its proxy peer must not be folded")
|
||||
assert.Contains(t, directPeers, peerIDs[0], "disabled service should still trigger a refresh so peers are ready when re-enabled")
|
||||
assert.Contains(t, directPeers, peerIDs[1], "disabled target should still trigger a refresh")
|
||||
}
|
||||
|
||||
func TestCollectAffectedFromProxyServices_NonPeerTargetType(t *testing.T) {
|
||||
|
||||
@@ -6,12 +6,7 @@
|
||||
// and before a delete/removal severs the old state).
|
||||
// - Snapshot.Expand: in-memory walk, no store access. Run AFTER the tx commits.
|
||||
//
|
||||
// Enabled handling differs by source. Disabled objects in the SNAPSHOT (existing
|
||||
// account policies/resources/routers/routes/proxy services and their rules/targets)
|
||||
// route to nobody and are skipped — they cannot affect any peer's map. Objects in
|
||||
// the CHANGE itself are processed regardless of Enabled, so disabling one still
|
||||
// refreshes the peers that lose access (the toggle is the observable change, and the
|
||||
// update carries the old∪new state).
|
||||
// Enabled is never consulted: toggling it is itself an observable change.
|
||||
package affectedpeers
|
||||
|
||||
import (
|
||||
@@ -66,8 +61,7 @@ func Load(ctx context.Context, s store.Store, accountID string, c Change) (*Snap
|
||||
// loadCollections reads the policy/route/nameserver/dns/router/resource/proxy
|
||||
// collections a Change can touch, gated to what the walk needs.
|
||||
func (snap *Snapshot) loadCollections(ctx context.Context, s store.Store, accountID string, c Change) error {
|
||||
// LinkGroups drive the same policy/route/dns walk as a changed group or peer.
|
||||
hasGroupOrPeerChange := len(c.ChangedGroupIDs) > 0 || len(c.ChangedPeerIDs) > 0 || len(c.LinkGroups) > 0 || len(c.Resources) > 0
|
||||
hasGroupOrPeerChange := len(c.ChangedGroupIDs) > 0 || len(c.ChangedPeerIDs) > 0 || len(c.Resources) > 0
|
||||
hasNetworkObject := len(c.Routers) > 0 || len(c.Resources) > 0 || len(c.Networks) > 0
|
||||
// the resource<->router bridge can fire for any of these
|
||||
needsRoutersResources := hasGroupOrPeerChange || len(c.PostureCheckIDs) > 0 || len(c.Policies) > 0 || hasNetworkObject
|
||||
@@ -82,7 +76,7 @@ func (snap *Snapshot) loadCollections(ctx context.Context, s store.Store, accoun
|
||||
return err
|
||||
}
|
||||
}
|
||||
if len(c.ChangedGroupIDs) > 0 || len(c.ChangedPeerIDs) > 0 || len(c.LinkGroups) > 0 {
|
||||
if len(c.ChangedGroupIDs) > 0 || len(c.ChangedPeerIDs) > 0 {
|
||||
if err := snap.loadDNS(ctx, s, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -180,24 +174,6 @@ type Change struct {
|
||||
// folded in — but only when the group is linked (an unlinked group has no map
|
||||
// impact), matching how current members are handled.
|
||||
RemovedPeersByGroup map[string][]string
|
||||
|
||||
// OutputPeerIDs are peers folded straight into the result without seeding their
|
||||
// group memberships into the walk. Use for the peer whose group membership changed:
|
||||
// the peer itself must refresh, but its OTHER groups did not change, so they must
|
||||
// not be walked. Contrast ChangedPeerIDs, which seeds ALL of the peer's groups
|
||||
// (correct when the peer's own attributes changed, e.g. IP/status).
|
||||
OutputPeerIDs []string
|
||||
|
||||
// LinkGroups are groups used ONLY to match policies/routes/routers and walk to the
|
||||
// OPPOSITE side — they are never expanded to their own members. Use this when a
|
||||
// peer's group membership changed: pass the peer in ChangedPeerIDs and its
|
||||
// group(s) here. The opposite side of the policies the group participates in
|
||||
// refreshes, but the group's other members (siblings) do not — nothing changed for
|
||||
// them. For an intra-group policy (A→A) the opposite side IS the group, so its
|
||||
// members still refresh via the opposite-side fold, exactly when they genuinely
|
||||
// gain/lose the changed peer. Unlike ChangedGroupIDs, a LinkGroup is not added to
|
||||
// the output, so a one-sided membership change never wakes the whole group.
|
||||
LinkGroups []string
|
||||
}
|
||||
|
||||
func (c Change) isEmpty() bool {
|
||||
@@ -210,9 +186,7 @@ func (c Change) isEmpty() bool {
|
||||
len(c.Networks) == 0 &&
|
||||
len(c.PostureCheckIDs) == 0 &&
|
||||
len(c.DistributionGroupIDs) == 0 &&
|
||||
len(c.RemovedPeersByGroup) == 0 &&
|
||||
len(c.LinkGroups) == 0 &&
|
||||
len(c.OutputPeerIDs) == 0
|
||||
len(c.RemovedPeersByGroup) == 0
|
||||
}
|
||||
|
||||
// Expand returns the deduplicated affected peer IDs from the preloaded Snapshot,
|
||||
@@ -223,8 +197,8 @@ func (snap *Snapshot) Expand(ctx context.Context, accountID string, c Change) []
|
||||
return nil
|
||||
}
|
||||
r := newResolver(ctx, snap, accountID, c)
|
||||
log.WithContext(ctx).Tracef("affectedpeers expand start: account=%s changedGroups=%v changedPeers=%v linkGroups=%v policies=%d routes=%d routers=%d resources=%d networks=%d postureChecks=%v distributionGroups=%v",
|
||||
accountID, c.ChangedGroupIDs, c.ChangedPeerIDs, c.LinkGroups, len(c.Policies), len(c.Routes), len(c.Routers), len(c.Resources), len(c.Networks), c.PostureCheckIDs, c.DistributionGroupIDs)
|
||||
log.WithContext(ctx).Tracef("affectedpeers expand start: account=%s changedGroups=%v changedPeers=%v policies=%d routes=%d routers=%d resources=%d networks=%d postureChecks=%v distributionGroups=%v",
|
||||
accountID, c.ChangedGroupIDs, c.ChangedPeerIDs, len(c.Policies), len(c.Routes), len(c.Routers), len(c.Resources), len(c.Networks), c.PostureCheckIDs, c.DistributionGroupIDs)
|
||||
r.walk()
|
||||
return r.expand()
|
||||
}
|
||||
@@ -242,84 +216,57 @@ func Collect(ctx context.Context, s store.Store, accountID string, c Change) (gr
|
||||
}
|
||||
r := newResolver(ctx, snap, accountID, c)
|
||||
r.walk()
|
||||
return setToSlice(r.affectedGroups), setToSlice(r.affectedPeers)
|
||||
return setToSlice(r.groupSet), setToSlice(r.peerSet)
|
||||
}
|
||||
|
||||
func newResolver(ctx context.Context, snap *Snapshot, accountID string, c Change) *resolver {
|
||||
r := &resolver{
|
||||
ctx: ctx,
|
||||
snap: snap,
|
||||
accountID: accountID,
|
||||
change: c,
|
||||
linkGroups: toSet(c.ChangedGroupIDs),
|
||||
outputGroups: toSet(c.ChangedGroupIDs),
|
||||
changedPeers: toSet(c.ChangedPeerIDs),
|
||||
affectedGroups: make(map[string]struct{}),
|
||||
affectedPeers: make(map[string]struct{}),
|
||||
ctx: ctx,
|
||||
snap: snap,
|
||||
accountID: accountID,
|
||||
change: c,
|
||||
changedGroupSet: toSet(c.ChangedGroupIDs),
|
||||
changedPeerSet: toSet(c.ChangedPeerIDs),
|
||||
groupSet: make(map[string]struct{}),
|
||||
peerSet: make(map[string]struct{}),
|
||||
networkIDs: make(map[string]struct{}),
|
||||
}
|
||||
// LinkGroups match policies/routes to find the opposite side but are NOT output:
|
||||
// they go into linkGroups only, never outputGroups, so their members never fold in.
|
||||
addAll(r.linkGroups, c.LinkGroups)
|
||||
// Resolve each changed peer to its groups here so callers pass only ChangedPeerIDs.
|
||||
r.seedChangedGroupsFromPeers()
|
||||
r.matchedPolicies = append(r.matchedPolicies, c.Policies...)
|
||||
return r
|
||||
}
|
||||
|
||||
// seedChangedGroupsFromPeers adds each changed peer's groups to linkGroups so
|
||||
// seedChangedGroupsFromPeers adds each changed peer's groups to changedGroupSet so
|
||||
// the group-driven walkers fire for memberships, not just direct peer references.
|
||||
// These seeded groups are for MATCHING only — folding the changed entity's own
|
||||
// side is gated on outputGroups (the caller-reported groups), so a seeded group
|
||||
// never folds its whole membership; only the changed peer itself folds in.
|
||||
func (r *resolver) seedChangedGroupsFromPeers() {
|
||||
if len(r.changedPeers) == 0 {
|
||||
if len(r.changedPeerSet) == 0 {
|
||||
return
|
||||
}
|
||||
for groupID, members := range r.snap.groupPeers {
|
||||
for pID := range r.changedPeers {
|
||||
for pID := range r.changedPeerSet {
|
||||
if _, ok := members[pID]; ok {
|
||||
r.linkGroups[groupID] = struct{}{}
|
||||
r.changedGroupSet[groupID] = struct{}{}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// policySide selects which side of a policy rule to walk.
|
||||
type policySide int
|
||||
|
||||
const (
|
||||
sideSource policySide = iota
|
||||
sideDestination
|
||||
)
|
||||
|
||||
func (s policySide) opposite() policySide {
|
||||
if s == sideSource {
|
||||
return sideDestination
|
||||
}
|
||||
return sideSource
|
||||
}
|
||||
|
||||
// walk resolves affected peers in two buckets, by how far each change propagates.
|
||||
//
|
||||
// BOTH-SIDES — the rule itself changed (an explicit policy edit, or a policy whose
|
||||
// posture check changed). Source AND destination refresh, so each such policy is
|
||||
// walked on both sides.
|
||||
//
|
||||
// OPPOSITE-SIDE — an endpoint moved but no rule changed. For each policy the change
|
||||
// touches we fold only the side AWAY from the change:
|
||||
// - a changed peer/group sits ON a policy side -> fold the opposite side;
|
||||
// - a changed router/resource/network sits on a NETWORK -> fold the SOURCE side of
|
||||
// the policies whose destination reaches it (and the routers it implies).
|
||||
//
|
||||
// Routes, nameserver groups, DNS and embedded-proxy services distribute to their own
|
||||
// member peers, outside the policy graph, and are folded here too.
|
||||
func (r *resolver) walk() {
|
||||
for _, policy := range r.bothSidesPolicies() {
|
||||
r.foldPolicySide(policy, sideSource)
|
||||
r.foldPolicySide(policy, sideDestination)
|
||||
}
|
||||
r.collectFromExplicitPolicies()
|
||||
r.collectFromExplicitRoutes(r.change.Routes)
|
||||
r.collectFromExplicitRouters(r.change.Routers)
|
||||
r.collectFromExplicitResources(r.change.Resources)
|
||||
r.collectFromExplicitNetworks(r.change.Networks)
|
||||
r.collectFromPostureChecks(r.change.PostureCheckIDs)
|
||||
|
||||
if len(r.linkGroups) > 0 || len(r.changedPeers) > 0 {
|
||||
// Distribution groups (nameserver/DNS) affect only their member peers: fold them
|
||||
// straight into groupSet so expand() maps them to members, without the policy/
|
||||
// route walk that changedGroupSet would trigger.
|
||||
addAll(r.groupSet, r.change.DistributionGroupIDs)
|
||||
|
||||
if len(r.changedGroupSet) > 0 || len(r.changedPeerSet) > 0 {
|
||||
r.collectFromPolicies()
|
||||
r.collectFromRoutes()
|
||||
r.collectFromNameServers()
|
||||
@@ -328,31 +275,7 @@ func (r *resolver) walk() {
|
||||
r.collectFromProxyServices()
|
||||
}
|
||||
|
||||
r.collectFromChangedRoutes(r.change.Routes)
|
||||
r.collectFromChangedRouters(r.change.Routers)
|
||||
r.collectFromChangedResources(r.change.Resources)
|
||||
r.collectFromChangedNetworks(r.change.Networks)
|
||||
|
||||
// The explicitly changed peers always refresh their own maps. OnPeersUpdated only
|
||||
// refreshes the resolver's output (it ignores the separately-passed changed peers),
|
||||
// so the changed peer reaches its own new map only via here. An offline/deleted
|
||||
// peer in the set is filtered downstream (filterConnectedAffectedPeers).
|
||||
addAll(r.affectedPeers, setToSlice(r.changedPeers))
|
||||
// OutputPeerIDs refresh themselves too, but unlike changedPeers their group
|
||||
// memberships were not seeded into the walk (only the changed group was).
|
||||
addAll(r.affectedPeers, r.change.OutputPeerIDs)
|
||||
|
||||
// Distribution groups (nameserver/DNS) affect only their member peers: fold them
|
||||
// straight into affectedGroups so expand() maps them to members, without the
|
||||
// policy/route walk that linkGroups would trigger.
|
||||
addAll(r.affectedGroups, r.change.DistributionGroupIDs)
|
||||
}
|
||||
|
||||
// bothSidesPolicies are the policies whose rule changed: the explicitly edited ones
|
||||
// plus those gated by a changed posture check. walk folds both their sides.
|
||||
func (r *resolver) bothSidesPolicies() []*types.Policy {
|
||||
policies := append([]*types.Policy(nil), r.change.Policies...)
|
||||
return r.appendPoliciesForPostureChecks(policies, r.change.PostureCheckIDs)
|
||||
r.collectResourceRouterBridge()
|
||||
}
|
||||
|
||||
type resolver struct {
|
||||
@@ -361,71 +284,27 @@ type resolver struct {
|
||||
accountID string
|
||||
change Change
|
||||
|
||||
// Inputs — what changed. Set once at construction, read-only during the walk
|
||||
// (except linkGroups, which collectFromExplicitResources also seeds).
|
||||
//
|
||||
// linkGroups is the MATCH set: caller-changed groups ∪ the groups of changed
|
||||
// peers ∪ changed-resource groups. A rule/route/router matches the change when
|
||||
// one of its groups is here — used only to find the opposite side to fold.
|
||||
//
|
||||
// outputGroups is the FOLD-WHOLE-GROUP set: ONLY Change.ChangedGroupIDs. When a
|
||||
// matched group is here, its whole membership is affected. A peer-seeded group
|
||||
// is in linkGroups but NOT outputGroups, so it folds only the changed peer
|
||||
// (changedPeers), never its siblings.
|
||||
linkGroups map[string]struct{}
|
||||
outputGroups map[string]struct{}
|
||||
changedPeers map[string]struct{}
|
||||
changedGroupSet map[string]struct{}
|
||||
changedPeerSet map[string]struct{}
|
||||
|
||||
// Outputs — the answer. The only sets the walk accumulates into. affectedGroups
|
||||
// is expanded to its member peers in expand().
|
||||
affectedGroups map[string]struct{}
|
||||
affectedPeers map[string]struct{}
|
||||
groupSet map[string]struct{}
|
||||
peerSet map[string]struct{}
|
||||
|
||||
matchedPolicies []*types.Policy
|
||||
networkIDs map[string]struct{}
|
||||
}
|
||||
|
||||
// policies returns the account's ENABLED policies from the snapshot. Disabled
|
||||
// policies grant no access, so the walk skips them when scanning existing account
|
||||
// data. Explicitly changed policies (Change.Policies, via bothSidesPolicies) are
|
||||
// processed regardless of Enabled, so disabling one still refreshes its peers.
|
||||
func (r *resolver) policies() []*types.Policy {
|
||||
enabled := make([]*types.Policy, 0, len(r.snap.policies))
|
||||
for _, policy := range r.snap.policies {
|
||||
if policy != nil && policy.Enabled {
|
||||
enabled = append(enabled, policy)
|
||||
}
|
||||
}
|
||||
return enabled
|
||||
}
|
||||
func (r *resolver) policies() []*types.Policy { return r.snap.policies }
|
||||
|
||||
// networkResources / networkRouters return the account's ENABLED resources/routers
|
||||
// from the snapshot. Disabled objects route to nobody, so the walk skips them when
|
||||
// it scans existing account data. The explicitly changed objects in the Change are
|
||||
// processed regardless of Enabled (collectFromChanged*), so disabling one still
|
||||
// refreshes the peers that lose access.
|
||||
func (r *resolver) networkResources() []*resourceTypes.NetworkResource {
|
||||
enabled := make([]*resourceTypes.NetworkResource, 0, len(r.snap.resources))
|
||||
for _, resource := range r.snap.resources {
|
||||
if resource.Enabled {
|
||||
enabled = append(enabled, resource)
|
||||
}
|
||||
}
|
||||
return enabled
|
||||
}
|
||||
func (r *resolver) networkResources() []*resourceTypes.NetworkResource { return r.snap.resources }
|
||||
|
||||
func (r *resolver) networkRouters() []*routerTypes.NetworkRouter {
|
||||
enabled := make([]*routerTypes.NetworkRouter, 0, len(r.snap.routers))
|
||||
for _, router := range r.snap.routers {
|
||||
if router.Enabled {
|
||||
enabled = append(enabled, router)
|
||||
}
|
||||
}
|
||||
return enabled
|
||||
}
|
||||
func (r *resolver) networkRouters() []*routerTypes.NetworkRouter { return r.snap.routers }
|
||||
|
||||
// peerIDsForGroups maps a group set to its member peer IDs via the preloaded index.
|
||||
func (r *resolver) peerIDsForGroups(groups map[string]struct{}) []string {
|
||||
func (r *resolver) peerIDsForGroups(groupSet map[string]struct{}) []string {
|
||||
seen := make(map[string]struct{})
|
||||
var ids []string
|
||||
for gID := range groups {
|
||||
for gID := range groupSet {
|
||||
for pID := range r.snap.groupPeers[gID] {
|
||||
if _, ok := seen[pID]; ok {
|
||||
continue
|
||||
@@ -438,25 +317,25 @@ func (r *resolver) peerIDsForGroups(groups map[string]struct{}) []string {
|
||||
}
|
||||
|
||||
func (r *resolver) expand() []string {
|
||||
peerIDs := r.peerIDsForGroups(r.affectedGroups)
|
||||
peerIDs := r.peerIDsForGroups(r.groupSet)
|
||||
|
||||
log.WithContext(r.ctx).Tracef("affectedpeers expand: account=%s affectedGroups=%v -> %d group-member peers; direct peers=%v",
|
||||
r.accountID, setToSlice(r.affectedGroups), len(peerIDs), setToSlice(r.affectedPeers))
|
||||
r.accountID, setToSlice(r.groupSet), len(peerIDs), setToSlice(r.peerSet))
|
||||
|
||||
seen := make(map[string]struct{}, len(peerIDs))
|
||||
for _, id := range peerIDs {
|
||||
seen[id] = struct{}{}
|
||||
}
|
||||
for id := range r.affectedPeers {
|
||||
for id := range r.peerSet {
|
||||
if _, ok := seen[id]; !ok {
|
||||
peerIDs = append(peerIDs, id)
|
||||
seen[id] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// Fold in removed peers only when their group is linked (in affectedGroups).
|
||||
// Fold in removed peers only when their group is linked (in groupSet).
|
||||
for groupID, removed := range r.change.RemovedPeersByGroup {
|
||||
if _, linked := r.affectedGroups[groupID]; !linked {
|
||||
if _, linked := r.groupSet[groupID]; !linked {
|
||||
continue
|
||||
}
|
||||
for _, id := range removed {
|
||||
@@ -472,349 +351,169 @@ func (r *resolver) expand() []string {
|
||||
return peerIDs
|
||||
}
|
||||
|
||||
// ruleSideGroups / ruleSideResource return the groups and the resource on the given
|
||||
// side of a rule.
|
||||
func ruleSideGroups(rule *types.PolicyRule, side policySide) []string {
|
||||
if side == sideDestination {
|
||||
return rule.Destinations
|
||||
}
|
||||
return rule.Sources
|
||||
}
|
||||
|
||||
func ruleSideResource(rule *types.PolicyRule, side policySide) types.Resource {
|
||||
if side == sideDestination {
|
||||
return rule.DestinationResource
|
||||
}
|
||||
return rule.SourceResource
|
||||
}
|
||||
|
||||
// foldPolicySide folds one side of a policy down to affected peers: its groups
|
||||
// (resolved to members in expand) and its direct peer. When the side is the
|
||||
// DESTINATION and references a network resource (directly or via a destination
|
||||
// group's resources), it also folds the routers that serve that resource's network
|
||||
// — a destination resource is reached through its routers. A resource on the SOURCE
|
||||
// side routes to nobody (GetPoliciesForNetworkResource matches destinations only),
|
||||
// so the router hop is destination-only.
|
||||
func (r *resolver) foldPolicySide(policy *types.Policy, side policySide) {
|
||||
if policy == nil {
|
||||
return
|
||||
}
|
||||
for _, rule := range policy.Rules {
|
||||
addAll(r.affectedGroups, ruleSideGroups(rule, side))
|
||||
res := ruleSideResource(rule, side)
|
||||
if res.Type == types.ResourceTypePeer && res.ID != "" {
|
||||
r.affectedPeers[res.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
if side == sideDestination {
|
||||
r.foldRoutersForResources(r.policyDestinationResourceIDs(policy))
|
||||
}
|
||||
}
|
||||
|
||||
// appendPoliciesForPostureChecks appends every policy that references a changed
|
||||
// posture check (a rule change, so walk both sides).
|
||||
func (r *resolver) appendPoliciesForPostureChecks(policies []*types.Policy, postureCheckIDs []string) []*types.Policy {
|
||||
if len(postureCheckIDs) == 0 {
|
||||
return policies
|
||||
}
|
||||
ids := toSet(postureCheckIDs)
|
||||
for _, policy := range r.policies() {
|
||||
if !policyReferencesPostureChecks(policy, ids) || !policy.Enabled {
|
||||
func (r *resolver) collectFromExplicitPolicies() {
|
||||
for _, policy := range r.matchedPolicies {
|
||||
if policy == nil {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("appendPoliciesForPostureChecks: policy %s (%s) references changed posture checks %v -> both-sides policy",
|
||||
policy.ID, policy.Name, postureCheckIDs)
|
||||
policies = append(policies, policy)
|
||||
}
|
||||
return policies
|
||||
}
|
||||
|
||||
// collectFromPolicies folds, for every policy whose rule a changed group or peer
|
||||
// touches, only the OPPOSITE side (down to peers, incl. destination routers), plus
|
||||
// the changed entity's own side: the changed group's whole membership when the
|
||||
// group itself changed (outputGroups), or the changed peer alone when matched via a
|
||||
// peer-seeded group (never its co-members).
|
||||
func (r *resolver) collectFromPolicies() {
|
||||
for _, policy := range r.policies() {
|
||||
for _, rule := range policy.Rules {
|
||||
if !rule.Enabled {
|
||||
continue // a disabled rule grants no access
|
||||
}
|
||||
r.foldRuleSideIfChanged(policy, rule, sideSource)
|
||||
r.foldRuleSideIfChanged(policy, rule, sideDestination)
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromExplicitPolicies: changed policy %s (%s) -> folding rule groups %v + direct peers",
|
||||
policy.ID, policy.Name, policy.RuleGroups())
|
||||
addAll(r.groupSet, policy.RuleGroups())
|
||||
collectPolicyDirectPeers(policy, r.peerSet)
|
||||
}
|
||||
}
|
||||
|
||||
// foldRuleSideIfChanged: when a changed group or direct peer sits on `side` of the
|
||||
// rule, fold the opposite side fully (groups/peers + destination routers) and fold
|
||||
// the changed entity's own side (the whole changed group, or the changed peer alone).
|
||||
func (r *resolver) foldRuleSideIfChanged(policy *types.Policy, rule *types.PolicyRule, side policySide) {
|
||||
nearGroups := ruleSideGroups(rule, side)
|
||||
nearResource := ruleSideResource(rule, side)
|
||||
|
||||
matchedByGroup := anyInSet(nearGroups, r.linkGroups)
|
||||
matchedByPeer := isDirectPeerInSet(nearResource, r.changedPeers)
|
||||
if !matchedByGroup && !matchedByPeer {
|
||||
return
|
||||
}
|
||||
|
||||
// Opposite side, fully down to peers (a destination opposite also folds routers).
|
||||
r.foldPolicySideForRule(policy, rule, side.opposite())
|
||||
|
||||
// Own side: fold the whole changed group's members only when the group itself
|
||||
// changed (outputGroups). A peer-seeded or link-only group is not folded here —
|
||||
// its siblings never refresh. The changed peers themselves are folded once, after
|
||||
// the walk (see walk()).
|
||||
for _, gID := range nearGroups {
|
||||
if _, ok := r.outputGroups[gID]; ok {
|
||||
r.affectedGroups[gID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// When the changed side IS a destination, the resources it targets are reached
|
||||
// through their network's routers, so those routers refresh too (e.g. attaching a
|
||||
// resource to a destination group, or a changed destination group/resource).
|
||||
if side == sideDestination {
|
||||
r.foldRoutersForResources(r.ruleDestinationResourceIDs(rule))
|
||||
}
|
||||
}
|
||||
|
||||
// foldPolicySideForRule folds one side of a single rule (groups + direct peer), and
|
||||
// for a destination side the routers of that rule's destination resources.
|
||||
func (r *resolver) foldPolicySideForRule(policy *types.Policy, rule *types.PolicyRule, side policySide) {
|
||||
addAll(r.affectedGroups, ruleSideGroups(rule, side))
|
||||
res := ruleSideResource(rule, side)
|
||||
if res.Type == types.ResourceTypePeer && res.ID != "" {
|
||||
r.affectedPeers[res.ID] = struct{}{}
|
||||
}
|
||||
if side == sideDestination {
|
||||
r.foldRoutersForResources(r.ruleDestinationResourceIDs(rule))
|
||||
}
|
||||
}
|
||||
|
||||
// collectFromChangedRoutes folds an explicitly changed route's own groups and peer.
|
||||
func (r *resolver) collectFromChangedRoutes(routes []*route.Route) {
|
||||
func (r *resolver) collectFromExplicitRoutes(routes []*route.Route) {
|
||||
for _, rt := range routes {
|
||||
if rt == nil {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromChangedRoutes: changed route %s -> folding groups=%v peerGroups=%v accessControlGroups=%v peer=%q",
|
||||
log.WithContext(r.ctx).Tracef("collectFromExplicitRoutes: changed route %s -> folding groups=%v peerGroups=%v accessControlGroups=%v peer=%q",
|
||||
rt.ID, rt.Groups, rt.PeerGroups, rt.AccessControlGroups, rt.Peer)
|
||||
addAll(r.affectedGroups, rt.Groups, rt.PeerGroups, rt.AccessControlGroups)
|
||||
addAll(r.groupSet, rt.Groups, rt.PeerGroups, rt.AccessControlGroups)
|
||||
if rt.Peer != "" {
|
||||
r.affectedPeers[rt.Peer] = struct{}{}
|
||||
r.peerSet[rt.Peer] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// collectFromChangedRouters: a changed router refreshes its OWN backing peer/groups
|
||||
// (the changed entity) and the SOURCE side of every policy reaching a resource on
|
||||
// its network (the router serves the whole network). Sibling routers on the network
|
||||
// are independent and are NOT folded. Passing the old router state keeps a repointed
|
||||
// router's previous backing affected without a post-commit read.
|
||||
func (r *resolver) collectFromChangedRouters(routers []*routerTypes.NetworkRouter) {
|
||||
// collectFromExplicitRouters folds changed routers' peers and marks their networks
|
||||
// for the bridge. Passing the old router keeps a repointed router's previous peers
|
||||
// affected without a post-commit read.
|
||||
func (r *resolver) collectFromExplicitRouters(routers []*routerTypes.NetworkRouter) {
|
||||
for _, router := range routers {
|
||||
if router == nil {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromChangedRouters: changed router %s on network %s -> folding its own peerGroups=%v peer=%q + sources reaching network resources",
|
||||
log.WithContext(r.ctx).Tracef("collectFromExplicitRouters: changed router %s on network %s -> folding peerGroups=%v peer=%q and marking network for source bridge",
|
||||
router.ID, router.NetworkID, router.PeerGroups, router.Peer)
|
||||
addAll(r.affectedGroups, router.PeerGroups)
|
||||
addAll(r.groupSet, router.PeerGroups)
|
||||
if router.Peer != "" {
|
||||
r.affectedPeers[router.Peer] = struct{}{}
|
||||
r.peerSet[router.Peer] = struct{}{}
|
||||
}
|
||||
if router.NetworkID != "" {
|
||||
r.foldPolicySourcesForResources(r.networkResourceIDs(router.NetworkID))
|
||||
r.networkIDs[router.NetworkID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// collectFromChangedResources: a changed resource refreshes the SOURCE side of the
|
||||
// policies targeting EXACTLY that resource — directly, or via one of the resource's
|
||||
// own groups (old∪new across the change, so a now-detached group's sources still
|
||||
// refresh) — plus the routers serving its network (the resource is reached through
|
||||
// them). It does not touch sibling resources on the same network.
|
||||
func (r *resolver) collectFromChangedResources(resources []*resourceTypes.NetworkResource) {
|
||||
// collectFromExplicitResources marks changed resources' networks for the bridge and
|
||||
// treats their group IDs as changed, so policies targeting the resource via a
|
||||
// now-detached (old) group still refresh.
|
||||
func (r *resolver) collectFromExplicitResources(resources []*resourceTypes.NetworkResource) {
|
||||
for _, resource := range resources {
|
||||
if resource == nil {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromChangedResources: changed resource %s on network %s (groups %v) -> folding sources of policies targeting it + its network's routers",
|
||||
log.WithContext(r.ctx).Tracef("collectFromExplicitResources: changed resource %s on network %s -> marking network for bridge and treating groups %v as changed",
|
||||
resource.ID, resource.NetworkID, resource.GroupIDs)
|
||||
r.foldPolicySourcesForResource(resource.ID, resource.GroupIDs)
|
||||
addAll(r.changedGroupSet, resource.GroupIDs)
|
||||
if resource.NetworkID != "" {
|
||||
r.foldRoutersOnNetworks(map[string]struct{}{resource.NetworkID: {}})
|
||||
r.networkIDs[resource.NetworkID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// foldPolicySourcesForResource folds the source side of every policy whose
|
||||
// destination is the given resource — referenced directly, or via any of the given
|
||||
// groups (the resource's own old∪new groups, which captures a detached group).
|
||||
func (r *resolver) foldPolicySourcesForResource(resourceID string, groupIDs []string) {
|
||||
groups := toSet(groupIDs)
|
||||
for _, policy := range r.policies() {
|
||||
if !policyTargetsResourceOrGroups(policy, resourceID, groups) {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("foldPolicySourcesForResource: policy %s (%s) targets changed resource %s -> folding its source groups/peers", policy.ID, policy.Name, resourceID)
|
||||
collectPolicySources(policy, r.affectedGroups, r.affectedPeers)
|
||||
}
|
||||
}
|
||||
|
||||
// policyTargetsResourceOrGroups reports whether a policy's destination is the given
|
||||
// resource directly, or one of the given destination groups.
|
||||
func policyTargetsResourceOrGroups(policy *types.Policy, resourceID string, groups map[string]struct{}) bool {
|
||||
if policy == nil {
|
||||
return false
|
||||
}
|
||||
for _, rule := range policy.Rules {
|
||||
if !rule.Enabled {
|
||||
continue
|
||||
}
|
||||
if rule.DestinationResource.Type != types.ResourceTypePeer && rule.DestinationResource.ID == resourceID && resourceID != "" {
|
||||
return true
|
||||
}
|
||||
if anyInSet(rule.Destinations, groups) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// collectFromChangedNetworks: a changed network refreshes the SOURCE side of the
|
||||
// policies reaching any of its resources, plus its routers. A network has no
|
||||
// groups/peers of its own.
|
||||
func (r *resolver) collectFromChangedNetworks(networks []*networkTypes.Network) {
|
||||
// collectFromExplicitNetworks marks changed networks for the bridge. A network has
|
||||
// no groups/peers of its own.
|
||||
func (r *resolver) collectFromExplicitNetworks(networks []*networkTypes.Network) {
|
||||
for _, network := range networks {
|
||||
if network == nil || network.ID == "" {
|
||||
if network == nil {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromChangedNetworks: changed network %s -> folding sources reaching its resources + its routers", network.ID)
|
||||
resourceIDs := r.networkResourceIDs(network.ID)
|
||||
r.foldPolicySourcesForResources(resourceIDs)
|
||||
r.foldRoutersOnNetworks(map[string]struct{}{network.ID: {}})
|
||||
log.WithContext(r.ctx).Tracef("collectFromExplicitNetworks: changed network %s -> marking for bridge", network.ID)
|
||||
if network.ID != "" {
|
||||
r.networkIDs[network.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// foldPolicySourcesForResources folds the source groups/peers of every policy whose
|
||||
// destination targets one of resourceIDs (directly or via a destination group).
|
||||
func (r *resolver) foldPolicySourcesForResources(resourceIDs map[string]struct{}) {
|
||||
if len(resourceIDs) == 0 {
|
||||
func (r *resolver) collectFromPostureChecks(postureCheckIDs []string) {
|
||||
if len(postureCheckIDs) == 0 {
|
||||
return
|
||||
}
|
||||
ids := toSet(postureCheckIDs)
|
||||
for _, policy := range r.policies() {
|
||||
if r.policyTargetsResources(policy, resourceIDs) {
|
||||
log.WithContext(r.ctx).Tracef("foldPolicySourcesForResources: policy %s (%s) targets a changed resource -> folding its source groups/peers", policy.ID, policy.Name)
|
||||
collectPolicySources(policy, r.affectedGroups, r.affectedPeers)
|
||||
if !policyReferencesPostureChecks(policy, ids) {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromPostureChecks: policy %s (%s) references changed posture checks %v -> folding rule groups %v + direct peers",
|
||||
policy.ID, policy.Name, postureCheckIDs, policy.RuleGroups())
|
||||
addAll(r.groupSet, policy.RuleGroups())
|
||||
collectPolicyDirectPeers(policy, r.peerSet)
|
||||
r.matchedPolicies = append(r.matchedPolicies, policy)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromPolicies() {
|
||||
for _, policy := range r.policies() {
|
||||
matchedByGroup := policyReferencesGroups(policy, r.changedGroupSet)
|
||||
matchedByPeer := len(r.changedPeerSet) > 0 && policyReferencesDirectPeers(policy, r.changedPeerSet)
|
||||
if !matchedByGroup && !matchedByPeer {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromPolicies: policy %s (%s) matched (byGroup=%t byPeer=%t) -> folding rule groups %v + direct peers",
|
||||
policy.ID, policy.Name, matchedByGroup, matchedByPeer, policy.RuleGroups())
|
||||
addAll(r.groupSet, policy.RuleGroups())
|
||||
collectPolicyDirectPeers(policy, r.peerSet)
|
||||
r.matchedPolicies = append(r.matchedPolicies, policy)
|
||||
}
|
||||
}
|
||||
|
||||
// collectFromRoutes folds, per matched route, the OPPOSITE side(s) fully and the
|
||||
// matched side's own groups only on a whole-group change (outputGroups). A route has
|
||||
// three peer sides — routing (Peer/PeerGroups), consumer (Groups) and ACL
|
||||
// (AccessControlGroups) — that each refresh the others; the changed side's own group
|
||||
// folds its siblings only when the group itself changed, never on a one-peer move.
|
||||
func (r *resolver) collectFromRoutes() {
|
||||
for _, rt := range r.snap.routes {
|
||||
if !rt.Enabled {
|
||||
continue // disabled routes route to nobody; skip existing account data
|
||||
}
|
||||
routing := anyInSet(rt.PeerGroups, r.linkGroups) || (rt.Peer != "" && isInSet(rt.Peer, r.changedPeers))
|
||||
consumer := anyInSet(rt.Groups, r.linkGroups)
|
||||
acl := anyInSet(rt.AccessControlGroups, r.linkGroups)
|
||||
if !routing && !consumer && !acl {
|
||||
matchedByGroup := anyInSet(rt.Groups, r.changedGroupSet) || anyInSet(rt.PeerGroups, r.changedGroupSet) || anyInSet(rt.AccessControlGroups, r.changedGroupSet)
|
||||
matchedByPeer := rt.Peer != "" && len(r.changedPeerSet) > 0 && isInSet(rt.Peer, r.changedPeerSet)
|
||||
if !matchedByGroup && !matchedByPeer {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromRoutes: route %s matched (routing=%t consumer=%t acl=%t) -> folding opposite sides; own side gated on outputGroups",
|
||||
rt.ID, routing, consumer, acl)
|
||||
r.foldRouteSide(rt.PeerGroups, routing)
|
||||
r.foldRouteSide(rt.Groups, consumer)
|
||||
r.foldRouteSide(rt.AccessControlGroups, acl)
|
||||
// The single routing Peer folds when the routing side is the OPPOSITE of the
|
||||
// match (consumer/acl need it), or when that very peer is the change.
|
||||
if rt.Peer != "" && (consumer || acl || isInSet(rt.Peer, r.changedPeers)) {
|
||||
r.affectedPeers[rt.Peer] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// foldRouteSide folds a route side: when this side is the one that matched, fold its
|
||||
// groups only on a whole-group change (outputGroups) so siblings of a single moved
|
||||
// peer stay put; otherwise it is an opposite side and folds fully.
|
||||
func (r *resolver) foldRouteSide(groups []string, matchedHere bool) {
|
||||
if matchedHere {
|
||||
r.foldOutputGroups(groups)
|
||||
return
|
||||
}
|
||||
addAll(r.affectedGroups, groups)
|
||||
}
|
||||
|
||||
// foldOutputGroups folds only the groups that the caller reported as wholly changed
|
||||
// (outputGroups). Used for a matched object's OWN side, where a peer-seeded or
|
||||
// link-only group must not pull in its siblings.
|
||||
func (r *resolver) foldOutputGroups(groups ...[]string) {
|
||||
for _, gs := range groups {
|
||||
for _, gID := range gs {
|
||||
if _, ok := r.outputGroups[gID]; ok {
|
||||
r.affectedGroups[gID] = struct{}{}
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromRoutes: route %s matched (byGroup=%t byPeer=%t) -> folding groups=%v peerGroups=%v accessControlGroups=%v peer=%q",
|
||||
rt.ID, matchedByGroup, matchedByPeer, rt.Groups, rt.PeerGroups, rt.AccessControlGroups, rt.Peer)
|
||||
addAll(r.groupSet, rt.Groups, rt.PeerGroups, rt.AccessControlGroups)
|
||||
if rt.Peer != "" {
|
||||
r.peerSet[rt.Peer] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromNameServers() {
|
||||
if len(r.linkGroups) == 0 {
|
||||
if len(r.changedGroupSet) == 0 {
|
||||
return
|
||||
}
|
||||
for _, ns := range r.snap.nsGroups {
|
||||
if anyInSet(ns.Groups, r.linkGroups) {
|
||||
// A nameserver group has no opposite side: a peer's DNS config depends only
|
||||
// on its own membership, so a one-peer move refreshes that peer alone (folded
|
||||
// elsewhere). Fold the referenced groups only on a whole-group change.
|
||||
log.WithContext(r.ctx).Tracef("collectFromNameServers: nameserver group %s references a linked group -> folding its groups %v (outputGroups only)", ns.ID, ns.Groups)
|
||||
r.foldOutputGroups(ns.Groups)
|
||||
if anyInSet(ns.Groups, r.changedGroupSet) {
|
||||
log.WithContext(r.ctx).Tracef("collectFromNameServers: nameserver group %s references a changed group -> folding its groups %v", ns.ID, ns.Groups)
|
||||
addAll(r.groupSet, ns.Groups)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromDNSSettings() {
|
||||
if len(r.linkGroups) == 0 || r.snap.dnsSettings == nil {
|
||||
if len(r.changedGroupSet) == 0 || r.snap.dnsSettings == nil {
|
||||
return
|
||||
}
|
||||
for _, gID := range r.snap.dnsSettings.DisabledManagementGroups {
|
||||
if _, ok := r.linkGroups[gID]; ok {
|
||||
if _, ok := r.changedGroupSet[gID]; ok {
|
||||
log.WithContext(r.ctx).Tracef("collectFromDNSSettings: changed group %s is in DisabledManagementGroups -> folding it", gID)
|
||||
r.affectedGroups[gID] = struct{}{}
|
||||
r.groupSet[gID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// collectFromNetworkRouters handles a changed group/peer that BACKS a router (the
|
||||
// routing peer set moved): the router's own peers refresh and so do the sources of
|
||||
// the policies reaching its network's resources. Sibling routers on the network are
|
||||
// independent and are not folded.
|
||||
func (r *resolver) collectFromNetworkRouters() {
|
||||
for _, router := range r.networkRouters() {
|
||||
matchedByGroup := anyInSet(router.PeerGroups, r.linkGroups)
|
||||
matchedByPeer := router.Peer != "" && len(r.changedPeers) > 0 && isInSet(router.Peer, r.changedPeers)
|
||||
matchedByGroup := anyInSet(router.PeerGroups, r.changedGroupSet)
|
||||
matchedByPeer := router.Peer != "" && len(r.changedPeerSet) > 0 && isInSet(router.Peer, r.changedPeerSet)
|
||||
if !matchedByGroup && !matchedByPeer {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromNetworkRouters: router %s on network %s matched (byGroup=%t byPeer=%t) -> folding its peerGroups=%v peer=%q (own groups on outputGroups) + sources reaching network resources",
|
||||
log.WithContext(r.ctx).Tracef("collectFromNetworkRouters: router %s on network %s matched (byGroup=%t byPeer=%t) -> folding peerGroups=%v peer=%q and marking network for source bridge",
|
||||
router.ID, router.NetworkID, matchedByGroup, matchedByPeer, router.PeerGroups, router.Peer)
|
||||
// The backing PeerGroups are the matched (own) side: fold them only on a
|
||||
// whole-group change so a one-peer move does not wake sibling backing peers. The
|
||||
// opposite side (policy sources reaching the network) is folded below.
|
||||
r.foldOutputGroups(router.PeerGroups)
|
||||
addAll(r.groupSet, router.PeerGroups)
|
||||
if router.Peer != "" {
|
||||
r.affectedPeers[router.Peer] = struct{}{}
|
||||
}
|
||||
if router.NetworkID != "" {
|
||||
r.foldPolicySourcesForResources(r.networkResourceIDs(router.NetworkID))
|
||||
r.peerSet[router.Peer] = struct{}{}
|
||||
}
|
||||
r.networkIDs[router.NetworkID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -827,48 +526,42 @@ func (r *resolver) collectFromProxyServices() {
|
||||
expanded := r.expandChangedPeersWithGroups()
|
||||
|
||||
for _, svc := range services {
|
||||
if svc == nil || !svc.Enabled {
|
||||
continue // a disabled service proxies nothing; skip existing account data
|
||||
if svc == nil {
|
||||
continue
|
||||
}
|
||||
proxyPeers := proxyByCluster[svc.ProxyCluster]
|
||||
if len(proxyPeers) == 0 {
|
||||
continue
|
||||
}
|
||||
matchedByPeer := serviceMatchesChangedPeers(svc, proxyPeers, expanded)
|
||||
matchedByAccessGroup := anyInSet(svc.AccessGroups, r.linkGroups)
|
||||
matchedByAccessGroup := anyInSet(svc.AccessGroups, r.changedGroupSet)
|
||||
if !matchedByPeer && !matchedByAccessGroup {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromProxyServices: service %s (cluster=%s) matched (byProxyOrTargetPeer=%t byAccessGroup=%t) -> folding %d proxy peers, peer targets; access groups %v on outputGroups only",
|
||||
log.WithContext(r.ctx).Tracef("collectFromProxyServices: service %s (cluster=%s) matched (byProxyOrTargetPeer=%t byAccessGroup=%t) -> folding %d proxy peers, peer targets and access groups %v",
|
||||
svc.ID, svc.ProxyCluster, matchedByPeer, matchedByAccessGroup, len(proxyPeers), svc.AccessGroups)
|
||||
for _, pid := range proxyPeers {
|
||||
r.affectedPeers[pid] = struct{}{}
|
||||
r.peerSet[pid] = struct{}{}
|
||||
}
|
||||
for _, target := range svc.Targets {
|
||||
if !target.Enabled {
|
||||
continue // a disabled target forwards nothing
|
||||
}
|
||||
if target.TargetType == rpservice.TargetTypePeer && target.TargetId != "" {
|
||||
r.affectedPeers[target.TargetId] = struct{}{}
|
||||
r.peerSet[target.TargetId] = struct{}{}
|
||||
}
|
||||
}
|
||||
// AccessGroups are the matched (own) side with no opposite to fold: a member's
|
||||
// proxy access is self-contained, so a one-peer move refreshes that peer alone.
|
||||
// Fold the groups only on a whole-group change.
|
||||
r.foldOutputGroups(svc.AccessGroups)
|
||||
addAll(r.groupSet, svc.AccessGroups)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) expandChangedPeersWithGroups() map[string]struct{} {
|
||||
if len(r.linkGroups) == 0 {
|
||||
return r.changedPeers
|
||||
if len(r.changedGroupSet) == 0 {
|
||||
return r.changedPeerSet
|
||||
}
|
||||
ids := r.peerIDsForGroups(r.linkGroups)
|
||||
ids := r.peerIDsForGroups(r.changedGroupSet)
|
||||
if len(ids) == 0 {
|
||||
return r.changedPeers
|
||||
return r.changedPeerSet
|
||||
}
|
||||
merged := make(map[string]struct{}, len(r.changedPeers)+len(ids))
|
||||
for id := range r.changedPeers {
|
||||
merged := make(map[string]struct{}, len(r.changedPeerSet)+len(ids))
|
||||
for id := range r.changedPeerSet {
|
||||
merged[id] = struct{}{}
|
||||
}
|
||||
for _, id := range ids {
|
||||
@@ -877,36 +570,54 @@ func (r *resolver) expandChangedPeersWithGroups() map[string]struct{} {
|
||||
return merged
|
||||
}
|
||||
|
||||
// foldRoutersForResources folds the routers serving the networks of the given
|
||||
// resources (a destination resource is reached through its network's routers). It is
|
||||
// the resource -> network -> router hop used by foldPolicySide for a destination.
|
||||
func (r *resolver) foldRoutersForResources(resourceIDs map[string]struct{}) {
|
||||
// collectResourceRouterBridge crosses between source peers and routing peers, which
|
||||
// are reachable only via resource -> network -> router, not through the policy's own
|
||||
// groups: source -> router (targeted resources' networks), then router -> source.
|
||||
func (r *resolver) collectResourceRouterBridge() {
|
||||
r.bridgeSourceToRouters()
|
||||
r.bridgeRoutersToSources()
|
||||
}
|
||||
|
||||
func (r *resolver) bridgeSourceToRouters() {
|
||||
resourceIDs := r.policyDestinationResourceIDs(r.matchedPolicies...)
|
||||
if len(resourceIDs) == 0 {
|
||||
return
|
||||
}
|
||||
r.foldRoutersOnNetworks(r.resourceNetworkIDs(resourceIDs))
|
||||
}
|
||||
|
||||
// ruleDestinationResourceIDs returns the destination resource IDs of a single rule:
|
||||
// the direct DestinationResource plus the resources of its destination groups.
|
||||
func (r *resolver) ruleDestinationResourceIDs(rule *types.PolicyRule) map[string]struct{} {
|
||||
resourceIDs := make(map[string]struct{})
|
||||
if rule.DestinationResource.Type != types.ResourceTypePeer && rule.DestinationResource.ID != "" {
|
||||
resourceIDs[rule.DestinationResource.ID] = struct{}{}
|
||||
networkIDs := r.resourceNetworkIDs(resourceIDs)
|
||||
log.WithContext(r.ctx).Tracef("bridgeSourceToRouters: targeted resources %v -> networks %v (their routers become affected via the router->source pass)",
|
||||
setToSlice(resourceIDs), setToSlice(networkIDs))
|
||||
for id := range networkIDs {
|
||||
r.networkIDs[id] = struct{}{}
|
||||
}
|
||||
r.addGroupResourceIDs(toSet(rule.Destinations), resourceIDs)
|
||||
return resourceIDs
|
||||
}
|
||||
|
||||
// networkResourceIDs returns the IDs of all resources on the given network.
|
||||
func (r *resolver) networkResourceIDs(networkID string) map[string]struct{} {
|
||||
func (r *resolver) bridgeRoutersToSources() {
|
||||
if len(r.networkIDs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
log.WithContext(r.ctx).Tracef("bridgeRoutersToSources: affected networks %v -> folding their routing peers and the source peers of policies targeting their resources",
|
||||
setToSlice(r.networkIDs))
|
||||
|
||||
r.foldRoutersOnNetworks(r.networkIDs)
|
||||
|
||||
resourceIDs := make(map[string]struct{})
|
||||
for _, resource := range r.networkResources() {
|
||||
if resource.NetworkID == networkID {
|
||||
if _, ok := r.networkIDs[resource.NetworkID]; ok {
|
||||
resourceIDs[resource.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
return resourceIDs
|
||||
if len(resourceIDs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for _, policy := range r.policies() {
|
||||
if r.policyTargetsResources(policy, resourceIDs) {
|
||||
log.WithContext(r.ctx).Tracef("bridgeRoutersToSources: policy %s (%s) targets an affected-network resource -> folding its source groups/peers", policy.ID, policy.Name)
|
||||
collectPolicySources(policy, r.groupSet, r.peerSet)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) foldRoutersOnNetworks(networkIDs map[string]struct{}) {
|
||||
@@ -916,9 +627,9 @@ func (r *resolver) foldRoutersOnNetworks(networkIDs map[string]struct{}) {
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("bridgeRoutersToSources: router %s serves affected network %s -> folding peerGroups=%v peer=%q",
|
||||
router.ID, router.NetworkID, router.PeerGroups, router.Peer)
|
||||
addAll(r.affectedGroups, router.PeerGroups)
|
||||
addAll(r.groupSet, router.PeerGroups)
|
||||
if router.Peer != "" {
|
||||
r.affectedPeers[router.Peer] = struct{}{}
|
||||
r.peerSet[router.Peer] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -939,9 +650,6 @@ func (r *resolver) policyTargetsResources(policy *types.Policy, resourceIDs map[
|
||||
}
|
||||
destGroupSet := make(map[string]struct{})
|
||||
for _, rule := range policy.Rules {
|
||||
if !rule.Enabled {
|
||||
continue
|
||||
}
|
||||
if rule.DestinationResource.Type != types.ResourceTypePeer && isInSet(rule.DestinationResource.ID, resourceIDs) {
|
||||
return true
|
||||
}
|
||||
@@ -1006,20 +714,44 @@ func (r *resolver) addGroupResourceIDs(groupIDs map[string]struct{}, resourceIDs
|
||||
}
|
||||
}
|
||||
|
||||
// collectPolicySources folds the source groups/peers of a snapshot policy's enabled
|
||||
// rules (a disabled rule grants no access).
|
||||
func collectPolicySources(policy *types.Policy, groups, peers map[string]struct{}) {
|
||||
func collectPolicyDirectPeers(policy *types.Policy, peerSet map[string]struct{}) {
|
||||
for _, rule := range policy.Rules {
|
||||
if !rule.Enabled {
|
||||
continue
|
||||
}
|
||||
addAll(groups, rule.Sources)
|
||||
if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" {
|
||||
peers[rule.SourceResource.ID] = struct{}{}
|
||||
peerSet[rule.SourceResource.ID] = struct{}{}
|
||||
}
|
||||
if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" {
|
||||
peerSet[rule.DestinationResource.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func collectPolicySources(policy *types.Policy, groupSet, peerSet map[string]struct{}) {
|
||||
for _, rule := range policy.Rules {
|
||||
addAll(groupSet, rule.Sources)
|
||||
if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" {
|
||||
peerSet[rule.SourceResource.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func policyReferencesGroups(policy *types.Policy, groupSet map[string]struct{}) bool {
|
||||
for _, rule := range policy.Rules {
|
||||
if anyInSet(rule.Sources, groupSet) || anyInSet(rule.Destinations, groupSet) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func policyReferencesDirectPeers(policy *types.Policy, changedSet map[string]struct{}) bool {
|
||||
for _, rule := range policy.Rules {
|
||||
if isDirectPeerInSet(rule.SourceResource, changedSet) || isDirectPeerInSet(rule.DestinationResource, changedSet) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func policyReferencesPostureChecks(policy *types.Policy, ids map[string]struct{}) bool {
|
||||
for _, id := range policy.SourcePostureChecks {
|
||||
if _, ok := ids[id]; ok {
|
||||
@@ -1044,7 +776,7 @@ func serviceMatchesChangedPeers(svc *rpservice.Service, proxyPeers []string, cha
|
||||
}
|
||||
}
|
||||
for _, target := range svc.Targets {
|
||||
if !target.Enabled || target.TargetType != rpservice.TargetTypePeer || target.TargetId == "" {
|
||||
if target.TargetType != rpservice.TargetTypePeer || target.TargetId == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := changedPeers[target.TargetId]; ok {
|
||||
|
||||
@@ -10,8 +10,8 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// policyGroupsAndPeers mirrors the both-sides extraction (RuleGroups + direct peers)
|
||||
// the resolver folds in for a changed policy, for asserting the pure logic.
|
||||
// policyGroupsAndPeers mirrors the explicit-policy extraction (RuleGroups +
|
||||
// direct peers) the resolver folds in, for asserting the pure logic.
|
||||
func policyGroupsAndPeers(policies ...*types.Policy) (groups []string, peers []string) {
|
||||
peerSet := map[string]struct{}{}
|
||||
for _, p := range policies {
|
||||
@@ -19,14 +19,7 @@ func policyGroupsAndPeers(policies ...*types.Policy) (groups []string, peers []s
|
||||
continue
|
||||
}
|
||||
groups = append(groups, p.RuleGroups()...)
|
||||
for _, rule := range p.Rules {
|
||||
if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" {
|
||||
peerSet[rule.SourceResource.ID] = struct{}{}
|
||||
}
|
||||
if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" {
|
||||
peerSet[rule.DestinationResource.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
collectPolicyDirectPeers(p, peerSet)
|
||||
}
|
||||
for id := range peerSet {
|
||||
peers = append(peers, id)
|
||||
@@ -87,6 +80,26 @@ func TestChangeIsEmpty(t *testing.T) {
|
||||
assert.False(t, Change{PostureCheckIDs: []string{"pc"}}.isEmpty())
|
||||
}
|
||||
|
||||
func TestPolicyReferencesGroups(t *testing.T) {
|
||||
policy := &types.Policy{Rules: []*types.PolicyRule{{Sources: []string{"g1", "g2"}, Destinations: []string{"g3"}}}}
|
||||
|
||||
assert.True(t, policyReferencesGroups(policy, map[string]struct{}{"g1": {}}))
|
||||
assert.True(t, policyReferencesGroups(policy, map[string]struct{}{"g3": {}}))
|
||||
assert.False(t, policyReferencesGroups(policy, map[string]struct{}{"g4": {}}))
|
||||
assert.False(t, policyReferencesGroups(policy, map[string]struct{}{}))
|
||||
}
|
||||
|
||||
func TestPolicyReferencesDirectPeers(t *testing.T) {
|
||||
policy := &types.Policy{Rules: []*types.PolicyRule{{
|
||||
SourceResource: types.Resource{Type: types.ResourceTypePeer, ID: "p1"},
|
||||
DestinationResource: types.Resource{Type: types.ResourceTypeHost, ID: "r1"},
|
||||
}}}
|
||||
|
||||
assert.True(t, policyReferencesDirectPeers(policy, map[string]struct{}{"p1": {}}))
|
||||
assert.False(t, policyReferencesDirectPeers(policy, map[string]struct{}{"r1": {}}))
|
||||
assert.False(t, policyReferencesDirectPeers(policy, map[string]struct{}{"p2": {}}))
|
||||
}
|
||||
|
||||
func TestPolicyReferencesPostureChecks(t *testing.T) {
|
||||
policy := &types.Policy{SourcePostureChecks: []string{"pc1", "pc2"}}
|
||||
|
||||
@@ -94,9 +107,24 @@ func TestPolicyReferencesPostureChecks(t *testing.T) {
|
||||
assert.False(t, policyReferencesPostureChecks(policy, map[string]struct{}{"pc3": {}}))
|
||||
}
|
||||
|
||||
func TestCollectPolicyDirectPeers(t *testing.T) {
|
||||
policy := &types.Policy{Rules: []*types.PolicyRule{{
|
||||
SourceResource: types.Resource{Type: types.ResourceTypePeer, ID: "p1"},
|
||||
DestinationResource: types.Resource{Type: types.ResourceTypePeer, ID: "p2"},
|
||||
}, {
|
||||
DestinationResource: types.Resource{Type: types.ResourceTypeHost, ID: "r1"},
|
||||
}}}
|
||||
|
||||
peerSet := map[string]struct{}{}
|
||||
collectPolicyDirectPeers(policy, peerSet)
|
||||
|
||||
assert.Contains(t, peerSet, "p1")
|
||||
assert.Contains(t, peerSet, "p2")
|
||||
assert.NotContains(t, peerSet, "r1")
|
||||
}
|
||||
|
||||
func TestCollectPolicySources(t *testing.T) {
|
||||
policy := &types.Policy{Rules: []*types.PolicyRule{{
|
||||
Enabled: true,
|
||||
Sources: []string{"g1"},
|
||||
SourceResource: types.Resource{Type: types.ResourceTypePeer, ID: "p1"},
|
||||
Destinations: []string{"g2"},
|
||||
|
||||
@@ -520,12 +520,7 @@ func collectDeletableGroups(ctx context.Context, transaction store.Store, accoun
|
||||
// GroupAddPeer appends peer to the group
|
||||
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
|
||||
var snap *affectedpeers.Snapshot
|
||||
// A membership change affects only the peer itself and the opposite side of THIS
|
||||
// group's policies — not the group's other members, and not the peer's other
|
||||
// groups. LinkGroups walks only this group (matched, not expanded); OutputPeerIDs
|
||||
// refreshes the peer without seeding its other group memberships. For an
|
||||
// intra-group policy the opposite side is the group, so its members still refresh.
|
||||
change := affectedpeers.Change{OutputPeerIDs: []string{peerID}, LinkGroups: []string{groupID}}
|
||||
change := affectedpeers.Change{ChangedGroupIDs: []string{groupID}}
|
||||
|
||||
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err := transaction.AddPeerToGroup(ctx, accountID, peerID, groupID); err != nil {
|
||||
@@ -591,11 +586,10 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID
|
||||
// GroupDeletePeer removes peer from the group
|
||||
func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error {
|
||||
var snap *affectedpeers.Snapshot
|
||||
// Same as GroupAddPeer: the removed peer and the opposite side of THIS group's
|
||||
// policies refresh, not the group's other members or the peer's other groups. The
|
||||
// peer is no longer in the group's index, but LinkGroups still drives the
|
||||
// opposite-side walk, and OutputPeerIDs refreshes the removed peer itself.
|
||||
change := affectedpeers.Change{OutputPeerIDs: []string{peerID}, LinkGroups: []string{groupID}}
|
||||
change := affectedpeers.Change{
|
||||
ChangedGroupIDs: []string{groupID},
|
||||
RemovedPeersByGroup: map[string][]string{groupID: {peerID}},
|
||||
}
|
||||
|
||||
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err := transaction.RemovePeerFromGroup(ctx, peerID, groupID); err != nil {
|
||||
@@ -606,6 +600,8 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
|
||||
return err
|
||||
}
|
||||
|
||||
// The removed peer is carried in change.RemovedPeersByGroup and folded in
|
||||
// only when the group is linked, so loading post-removal is correct.
|
||||
var err error
|
||||
if snap, err = affectedpeers.Load(ctx, transaction, accountID, change); err != nil {
|
||||
return err
|
||||
|
||||
@@ -220,7 +220,7 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
includeServiceUser, err := strconv.ParseBool(serviceUser)
|
||||
log.WithContext(r.Context()).Tracef("Should include service user: %v", includeServiceUser)
|
||||
log.WithContext(r.Context()).Debugf("Should include service user: %v", includeServiceUser)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid service_user query parameter"), w)
|
||||
return
|
||||
|
||||
@@ -209,14 +209,14 @@ func (am *DefaultAccountManager) resolvePeerLocation(ctx context.Context, peer *
|
||||
if am.geo == nil || realIP == nil {
|
||||
return nil
|
||||
}
|
||||
if peer.Location.ConnectionIP != nil && peer.Location.ConnectionIP.Equal(realIP) {
|
||||
return nil
|
||||
}
|
||||
location, err := am.geo.Lookup(realIP)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to get location for peer %s realip: [%s]: %v", peer.ID, realIP.String(), err)
|
||||
return nil
|
||||
}
|
||||
if peer.Location.ConnectionIP != nil && peer.Location.ConnectionIP.Equal(realIP) && peer.Location.GeoNameID == location.City.GeonameID {
|
||||
return nil
|
||||
}
|
||||
return &nbpeer.Location{
|
||||
ConnectionIP: realIP,
|
||||
CountryCode: location.Country.ISOCode,
|
||||
@@ -730,7 +730,7 @@ func (am *DefaultAccountManager) handleSetupKeyAddedPeer(ctx context.Context, en
|
||||
func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKey, userID string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.Network, []*posture.Checks, bool, error) {
|
||||
if setupKey == "" && userID == "" && !peer.ProxyMeta.Embedded {
|
||||
// no auth method provided => reject access
|
||||
return nil, nil, nil, false, status.ErrNoAuthMethodProvided
|
||||
return nil, nil, nil, false, status.Errorf(status.Unauthenticated, "no peer auth method provided, please use a setup key or interactive SSO login")
|
||||
}
|
||||
|
||||
upperKey := strings.ToUpper(setupKey)
|
||||
@@ -1051,8 +1051,8 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
|
||||
return nil, nil, nil, 0, err
|
||||
}
|
||||
|
||||
metaDiffAffectsPosture := posture.AffectsPosture(ctx, &metaDiff, resPostureChecks)
|
||||
if requiresPeerUpdate(ctx, isStatusChanged, sync.UpdateAccountPeers, ipv6CapabilityChanged, metaDiffAffectsPosture, metaDiff.VersionChanged(), metaDiff.HostnameChanged()) {
|
||||
metaDiffAffectsPosture := posture.AffectsPosture(&metaDiff, resPostureChecks)
|
||||
if isStatusChanged || sync.UpdateAccountPeers || ipv6CapabilityChanged || metaDiffAffectsPosture || metaDiff.VersionChanged || metaDiff.Hostname {
|
||||
changedPeerIDs := []string{peer.ID}
|
||||
affectedPeerIDs := am.syncPeerAffectedPeers(ctx, accountID, peer.ID, nmap, peerNotValid, metaDiffAffectsPosture)
|
||||
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
|
||||
@@ -1063,29 +1063,6 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
|
||||
return peer, nmap, resPostureChecks, dnsFwdPort, nil
|
||||
}
|
||||
|
||||
func requiresPeerUpdate(ctx context.Context, isStatusChanged, updateAccountPeers, ipv6CapabilityChanged, metaDiffAffectsPosture, versionChanged, hostname bool) bool {
|
||||
var reason string
|
||||
switch {
|
||||
case isStatusChanged:
|
||||
reason = "status changed"
|
||||
case updateAccountPeers:
|
||||
reason = "update account peers"
|
||||
case ipv6CapabilityChanged:
|
||||
reason = "ipv6 capability changed"
|
||||
case metaDiffAffectsPosture:
|
||||
reason = "meta diff affects posture"
|
||||
case versionChanged:
|
||||
reason = "version changed"
|
||||
case hostname:
|
||||
reason = "hostname changed"
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Tracef("peer update required: %s", reason)
|
||||
return true
|
||||
}
|
||||
|
||||
// syncPeerAffectedPeers resolves the peers affected by a SyncPeer change. The
|
||||
// peer's own validated network map is bidirectional for policy and routing
|
||||
// reachability, so when the peer stays valid and no source-posture gate is in
|
||||
|
||||
@@ -107,15 +107,6 @@ type Location struct {
|
||||
GeoNameID uint // city level geoname id
|
||||
}
|
||||
|
||||
// equal reports whether two locations match. ConnectionIP is a net.IP slice, so it uses
|
||||
// IP.Equal, not ==.
|
||||
func (l Location) equal(other Location) bool {
|
||||
return l.CountryCode == other.CountryCode &&
|
||||
l.CityName == other.CityName &&
|
||||
l.GeoNameID == other.GeoNameID &&
|
||||
l.ConnectionIP.Equal(other.ConnectionIP)
|
||||
}
|
||||
|
||||
// NetworkAddress is the IP address with network and MAC address of a network interface
|
||||
type NetworkAddress struct {
|
||||
NetIP netip.Prefix `gorm:"serializer:json"`
|
||||
@@ -276,139 +267,183 @@ func (p *Peer) UpdateMetaIfNew(ctx context.Context, meta PeerSystemMeta, newLoca
|
||||
return MetaDiff{}
|
||||
}
|
||||
|
||||
versionChanged := p.Meta.WtVersion != meta.WtVersion
|
||||
|
||||
// Avoid overwriting UIVersion if the update was triggered sole by the CLI client
|
||||
if meta.UIVersion == "" {
|
||||
meta.UIVersion = p.Meta.UIVersion
|
||||
}
|
||||
|
||||
effectiveLocation := p.Location
|
||||
if newLocation != nil {
|
||||
effectiveLocation = *newLocation
|
||||
}
|
||||
oldVersion := p.Meta.WtVersion
|
||||
|
||||
diff := diffMeta(p.Meta, meta, p.Location, effectiveLocation)
|
||||
if diff.Updated() {
|
||||
diff := diffMeta(p.Meta, meta)
|
||||
if diff.Any() {
|
||||
p.Meta = meta
|
||||
}
|
||||
p.Location = effectiveLocation
|
||||
diff.VersionChanged = versionChanged
|
||||
|
||||
if diff.Updated() {
|
||||
log.WithContext(ctx).Debug(diff.LogSummary())
|
||||
locationInfo := ""
|
||||
if newLocation != nil {
|
||||
p.Location = *newLocation
|
||||
diff.LocationChanged = true
|
||||
locationInfo = fmt.Sprintf("location changed to %s, ", newLocation.ConnectionIP)
|
||||
}
|
||||
|
||||
versionInfo := ""
|
||||
if diff.VersionChanged {
|
||||
versionInfo = fmt.Sprintf("version changed: %s -> %s, ", oldVersion, meta.WtVersion)
|
||||
}
|
||||
|
||||
if diff.Any() || diff.VersionChanged || diff.LocationChanged {
|
||||
log.WithContext(ctx).
|
||||
Debugf("peer meta updated, %s%s%d field(s) changed: %s", versionInfo, locationInfo, len(diff.Changed), strings.Join(diff.Changed, ", "))
|
||||
}
|
||||
|
||||
return diff
|
||||
}
|
||||
|
||||
// MetaDiff holds a peer's full before/after state across a sync: both metas and both
|
||||
// connection locations (the location lives on Peer, not PeerSystemMeta, but posture
|
||||
// checks read it). Changed lists what moved, for logging and the persistence decision;
|
||||
// the snapshots let a posture check be replayed against old and new. Everything is derived
|
||||
// from these fields, so there are no parallel per-field flags to keep in sync.
|
||||
// MetaDiff records which PeerSystemMeta fields differ between two metas. Each bool
|
||||
// maps to a single struct field, except Environment, which is split into Cloud and
|
||||
// Platform. Changed holds the human-readable `field: <old> -> <new>` entries so the
|
||||
// existing log line and isEqual can be derived from the same comparison.
|
||||
//
|
||||
// VersionChanged and LocationChanged sit outside the per-meta-field set:
|
||||
// VersionChanged tracks the WireGuard client version specifically (compared before
|
||||
// the UIVersion fixup, to signal client upgrades) and LocationChanged tracks the
|
||||
// peer's connection geo location, which lives on Peer rather than PeerSystemMeta.
|
||||
// Neither contributes an entry to Changed, so the field-coverage accounting stays
|
||||
// driven purely by the PeerSystemMeta comparison.
|
||||
type MetaDiff struct {
|
||||
OldMeta PeerSystemMeta
|
||||
NewMeta PeerSystemMeta
|
||||
OldLocation Location
|
||||
NewLocation Location
|
||||
Hostname bool
|
||||
GoOS bool
|
||||
Kernel bool
|
||||
KernelVersion bool
|
||||
Core bool
|
||||
Platform bool
|
||||
OS bool
|
||||
OSVersion bool
|
||||
WtVersion bool
|
||||
UIVersion bool
|
||||
SystemSerialNumber bool
|
||||
SystemProductName bool
|
||||
SystemManufacturer bool
|
||||
EnvironmentCloud bool
|
||||
EnvironmentPlatform bool
|
||||
Flags bool
|
||||
Capabilities bool
|
||||
NetworkAddresses bool
|
||||
Files bool
|
||||
|
||||
VersionChanged bool
|
||||
LocationChanged bool
|
||||
|
||||
Changed []string
|
||||
}
|
||||
|
||||
// Updated reports whether anything changed and the peer must be persisted. diffMeta fills
|
||||
// Changed in the pass that builds the diff, so this is a length check, not a re-comparison.
|
||||
// Pointer receiver: MetaDiff embeds two metas, so copying it per call is wasteful.
|
||||
func (d *MetaDiff) Updated() bool {
|
||||
// Any reports whether any PeerSystemMeta field changed.
|
||||
func (d MetaDiff) Any() bool {
|
||||
return len(d.Changed) != 0
|
||||
}
|
||||
|
||||
// VersionChanged reports whether the WireGuard client version changed (a client upgrade).
|
||||
func (d *MetaDiff) VersionChanged() bool {
|
||||
return d.OldMeta.WtVersion != d.NewMeta.WtVersion
|
||||
}
|
||||
|
||||
// HostnameChanged reports whether the peer's hostname changed.
|
||||
func (d *MetaDiff) HostnameChanged() bool {
|
||||
return d.OldMeta.Hostname != d.NewMeta.Hostname
|
||||
}
|
||||
|
||||
// LogSummary renders the changed fields as a single human-readable line.
|
||||
func (d *MetaDiff) LogSummary() string {
|
||||
return fmt.Sprintf("peer meta updated, %d field(s) changed: %s",
|
||||
len(d.Changed), strings.Join(d.Changed, ", "))
|
||||
// Updated reports whether the peer needs to be persisted: any meta field changed
|
||||
// or the geo location changed. The version flag alone does not imply a write,
|
||||
// since a version change is also reflected in the WtVersion meta field.
|
||||
func (d MetaDiff) Updated() bool {
|
||||
return d.Any() || d.LocationChanged || d.VersionChanged
|
||||
}
|
||||
|
||||
func metaDiff(oldMeta, newMeta PeerSystemMeta) []string {
|
||||
return diffMeta(oldMeta, newMeta, Location{}, Location{}).Changed
|
||||
return diffMeta(oldMeta, newMeta).Changed
|
||||
}
|
||||
|
||||
// diffMeta snapshots a peer's old and new state and records a Changed entry per field that
|
||||
// moved. It is the single source of truth for the comparison: isEqual is an empty Changed
|
||||
// list, so the log line and the persistence decision can never disagree.
|
||||
func diffMeta(oldMeta, newMeta PeerSystemMeta, oldLocation, newLocation Location) MetaDiff {
|
||||
d := MetaDiff{OldMeta: oldMeta, NewMeta: newMeta, OldLocation: oldLocation, NewLocation: newLocation}
|
||||
// diffMeta compares two metas field by field, returning both a per-field flag set
|
||||
// (for callers that need to know exactly what changed, e.g. matching against
|
||||
// posture checks) and the human-readable Changed list. It is the single source of
|
||||
// truth for meta comparison: isEqual reports equality as an empty diff, so the log
|
||||
// line, the change decision, and the flags can never disagree.
|
||||
func diffMeta(oldMeta, newMeta PeerSystemMeta) MetaDiff {
|
||||
var d MetaDiff
|
||||
add := func(field string, oldVal, newVal any) {
|
||||
d.Changed = append(d.Changed, fmt.Sprintf("%s: %v -> %v", field, oldVal, newVal))
|
||||
}
|
||||
|
||||
if oldMeta.Hostname != newMeta.Hostname {
|
||||
d.Hostname = true
|
||||
add("hostname", oldMeta.Hostname, newMeta.Hostname)
|
||||
}
|
||||
if oldMeta.GoOS != newMeta.GoOS {
|
||||
d.GoOS = true
|
||||
add("goos", oldMeta.GoOS, newMeta.GoOS)
|
||||
}
|
||||
if oldMeta.Kernel != newMeta.Kernel {
|
||||
d.Kernel = true
|
||||
add("kernel", oldMeta.Kernel, newMeta.Kernel)
|
||||
}
|
||||
if oldMeta.KernelVersion != newMeta.KernelVersion {
|
||||
d.KernelVersion = true
|
||||
add("kernel_version", oldMeta.KernelVersion, newMeta.KernelVersion)
|
||||
}
|
||||
if oldMeta.Core != newMeta.Core {
|
||||
d.Core = true
|
||||
add("core", oldMeta.Core, newMeta.Core)
|
||||
}
|
||||
if oldMeta.Platform != newMeta.Platform {
|
||||
d.Platform = true
|
||||
add("platform", oldMeta.Platform, newMeta.Platform)
|
||||
}
|
||||
if oldMeta.OS != newMeta.OS {
|
||||
d.OS = true
|
||||
add("os", oldMeta.OS, newMeta.OS)
|
||||
}
|
||||
if oldMeta.OSVersion != newMeta.OSVersion {
|
||||
d.OSVersion = true
|
||||
add("os_version", oldMeta.OSVersion, newMeta.OSVersion)
|
||||
}
|
||||
if oldMeta.WtVersion != newMeta.WtVersion {
|
||||
d.WtVersion = true
|
||||
add("wt_version", oldMeta.WtVersion, newMeta.WtVersion)
|
||||
}
|
||||
if oldMeta.UIVersion != newMeta.UIVersion {
|
||||
d.UIVersion = true
|
||||
add("ui_version", oldMeta.UIVersion, newMeta.UIVersion)
|
||||
}
|
||||
if oldMeta.SystemSerialNumber != newMeta.SystemSerialNumber {
|
||||
d.SystemSerialNumber = true
|
||||
add("system_serial_number", oldMeta.SystemSerialNumber, newMeta.SystemSerialNumber)
|
||||
}
|
||||
if oldMeta.SystemProductName != newMeta.SystemProductName {
|
||||
d.SystemProductName = true
|
||||
add("system_product_name", oldMeta.SystemProductName, newMeta.SystemProductName)
|
||||
}
|
||||
if oldMeta.SystemManufacturer != newMeta.SystemManufacturer {
|
||||
d.SystemManufacturer = true
|
||||
add("system_manufacturer", oldMeta.SystemManufacturer, newMeta.SystemManufacturer)
|
||||
}
|
||||
if oldMeta.Environment.Cloud != newMeta.Environment.Cloud {
|
||||
d.EnvironmentCloud = true
|
||||
add("environment_cloud", oldMeta.Environment.Cloud, newMeta.Environment.Cloud)
|
||||
}
|
||||
if oldMeta.Environment.Platform != newMeta.Environment.Platform {
|
||||
d.EnvironmentPlatform = true
|
||||
add("environment_platform", oldMeta.Environment.Platform, newMeta.Environment.Platform)
|
||||
}
|
||||
if !oldMeta.Flags.isEqual(newMeta.Flags) {
|
||||
d.Flags = true
|
||||
add("flags", fmt.Sprintf("%+v", oldMeta.Flags), fmt.Sprintf("%+v", newMeta.Flags))
|
||||
}
|
||||
if !capabilitiesEqual(oldMeta.Capabilities, newMeta.Capabilities) {
|
||||
d.Capabilities = true
|
||||
add("capabilities", oldMeta.Capabilities, newMeta.Capabilities)
|
||||
}
|
||||
|
||||
if !sameMultiset(oldMeta.NetworkAddresses, newMeta.NetworkAddresses) {
|
||||
d.NetworkAddresses = true
|
||||
add("network_addresses", fmt.Sprintf("%v", oldMeta.NetworkAddresses), fmt.Sprintf("%v", newMeta.NetworkAddresses))
|
||||
}
|
||||
if !sameMultiset(oldMeta.Files, newMeta.Files) {
|
||||
add("files", fmt.Sprintf("%v", oldMeta.Files), fmt.Sprintf("%v", newMeta.Files))
|
||||
}
|
||||
|
||||
if !oldLocation.equal(newLocation) {
|
||||
add("connection_ip", oldLocation.ConnectionIP, newLocation.ConnectionIP)
|
||||
if !sameMultiset(oldMeta.Files, newMeta.Files) {
|
||||
d.Files = true
|
||||
add("files", fmt.Sprintf("%v", oldMeta.Files), fmt.Sprintf("%v", newMeta.Files))
|
||||
}
|
||||
|
||||
return d
|
||||
|
||||
@@ -49,7 +49,6 @@ import (
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
@@ -2894,141 +2893,3 @@ func TestUpdatePeer_DnsLabelUniqueName(t *testing.T) {
|
||||
require.NoError(t, err, "renaming to unique FQDN should succeed")
|
||||
assert.Equal(t, "api-server", updated.DNSLabel, "DNS label should be first label of FQDN")
|
||||
}
|
||||
|
||||
// fakeGeo is a configurable geolocation.Geolocation implementation for tests. It
|
||||
// returns a record built from the configured city geoname id, or an error when set.
|
||||
type fakeGeo struct {
|
||||
geoNameID uint
|
||||
isoCode string
|
||||
cityName string
|
||||
err error
|
||||
}
|
||||
|
||||
func (g *fakeGeo) Lookup(net.IP) (*geolocation.Record, error) {
|
||||
if g.err != nil {
|
||||
return nil, g.err
|
||||
}
|
||||
record := &geolocation.Record{}
|
||||
record.City.GeonameID = g.geoNameID
|
||||
record.City.Names.En = g.cityName
|
||||
record.Country.ISOCode = g.isoCode
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func (g *fakeGeo) GetAllCountries() ([]geolocation.Country, error) { return nil, nil }
|
||||
|
||||
func (g *fakeGeo) GetCitiesByCountry(string) ([]geolocation.City, error) { return nil, nil }
|
||||
|
||||
func (g *fakeGeo) Stop() error { return nil }
|
||||
|
||||
func TestResolvePeerLocation(t *testing.T) {
|
||||
realIP := net.ParseIP("203.0.113.10")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
geo geolocation.Geolocation
|
||||
peer *nbpeer.Peer
|
||||
realIP net.IP
|
||||
want *nbpeer.Location
|
||||
wantNil bool
|
||||
}{
|
||||
{
|
||||
name: "no geo configured returns nil",
|
||||
geo: nil,
|
||||
peer: &nbpeer.Peer{ID: "p1"},
|
||||
realIP: realIP,
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "nil real IP returns nil",
|
||||
geo: &fakeGeo{geoNameID: 100},
|
||||
peer: &nbpeer.Peer{ID: "p1"},
|
||||
realIP: nil,
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "lookup error returns nil",
|
||||
geo: &fakeGeo{err: fmt.Errorf("lookup boom")},
|
||||
peer: &nbpeer.Peer{ID: "p1"},
|
||||
realIP: realIP,
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "same IP and same geoname returns nil",
|
||||
geo: &fakeGeo{geoNameID: 100, isoCode: "US", cityName: "City A"},
|
||||
peer: &nbpeer.Peer{
|
||||
ID: "p1",
|
||||
Location: nbpeer.Location{
|
||||
ConnectionIP: realIP,
|
||||
GeoNameID: 100,
|
||||
},
|
||||
},
|
||||
realIP: realIP,
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "same IP but changed geoname returns location",
|
||||
geo: &fakeGeo{geoNameID: 200, isoCode: "US", cityName: "City B"},
|
||||
peer: &nbpeer.Peer{
|
||||
ID: "p1",
|
||||
Location: nbpeer.Location{
|
||||
ConnectionIP: realIP,
|
||||
GeoNameID: 100,
|
||||
},
|
||||
},
|
||||
realIP: realIP,
|
||||
want: &nbpeer.Location{
|
||||
ConnectionIP: realIP,
|
||||
CountryCode: "US",
|
||||
CityName: "City B",
|
||||
GeoNameID: 200,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "different IP returns location",
|
||||
geo: &fakeGeo{geoNameID: 100, isoCode: "US", cityName: "City A"},
|
||||
peer: &nbpeer.Peer{
|
||||
ID: "p1",
|
||||
Location: nbpeer.Location{
|
||||
ConnectionIP: net.ParseIP("198.51.100.7"),
|
||||
GeoNameID: 100,
|
||||
},
|
||||
},
|
||||
realIP: realIP,
|
||||
want: &nbpeer.Location{
|
||||
ConnectionIP: realIP,
|
||||
CountryCode: "US",
|
||||
CityName: "City A",
|
||||
GeoNameID: 100,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no prior location returns location",
|
||||
geo: &fakeGeo{geoNameID: 100, isoCode: "US", cityName: "City A"},
|
||||
peer: &nbpeer.Peer{ID: "p1"},
|
||||
realIP: realIP,
|
||||
want: &nbpeer.Location{
|
||||
ConnectionIP: realIP,
|
||||
CountryCode: "US",
|
||||
CityName: "City A",
|
||||
GeoNameID: 100,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
am := &DefaultAccountManager{geo: tt.geo}
|
||||
got := am.resolvePeerLocation(context.Background(), tt.peer, tt.realIP)
|
||||
if tt.wantNil {
|
||||
assert.Nil(t, got, "resolved location should be nil")
|
||||
return
|
||||
}
|
||||
require.NotNil(t, got, "resolved location should not be nil")
|
||||
assert.True(t, tt.want.ConnectionIP.Equal(got.ConnectionIP), "connection IP should match")
|
||||
assert.Equal(t, tt.want.CountryCode, got.CountryCode, "country code should match")
|
||||
assert.Equal(t, tt.want.CityName, got.CityName, "city name should match")
|
||||
assert.Equal(t, tt.want.GeoNameID, got.GeoNameID, "geoname id should match")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,202 +0,0 @@
|
||||
package posture
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
)
|
||||
|
||||
// diffFrom builds a MetaDiff from the old/new snapshots AffectsPosture replays against.
|
||||
func diffFrom(oldMeta, newMeta nbpeer.PeerSystemMeta, oldLoc, newLoc nbpeer.Location) *nbpeer.MetaDiff {
|
||||
return &nbpeer.MetaDiff{
|
||||
OldMeta: oldMeta,
|
||||
NewMeta: newMeta,
|
||||
OldLocation: oldLoc,
|
||||
NewLocation: newLoc,
|
||||
}
|
||||
}
|
||||
|
||||
func checks(def ChecksDefinition) []*Checks {
|
||||
return []*Checks{{Checks: def}}
|
||||
}
|
||||
|
||||
func TestAffectsPosture_NilDiff(t *testing.T) {
|
||||
assert.False(t, AffectsPosture(context.Background(), nil, checks(ChecksDefinition{
|
||||
NBVersionCheck: &NBVersionCheck{MinVersion: "1.0.0"},
|
||||
})))
|
||||
}
|
||||
|
||||
func TestAffectsPosture_NBVersion(t *testing.T) {
|
||||
c := checks(ChecksDefinition{NBVersionCheck: &NBVersionCheck{MinVersion: "1.2.0"}})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
oldVer, newVer string
|
||||
want bool
|
||||
}{
|
||||
{"both above min, no flip", "1.3.0", "1.4.0", false},
|
||||
{"both below min, no flip", "1.0.0", "1.1.0", false},
|
||||
{"crosses up below->above", "1.1.0", "1.3.0", true},
|
||||
{"crosses down above->below", "1.3.0", "1.1.0", true},
|
||||
{"unparsable old only -> flip", "garbage", "1.3.0", true},
|
||||
{"unparsable both -> no flip", "garbage", "junk", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
diff := diffFrom(
|
||||
nbpeer.PeerSystemMeta{WtVersion: tt.oldVer},
|
||||
nbpeer.PeerSystemMeta{WtVersion: tt.newVer},
|
||||
nbpeer.Location{}, nbpeer.Location{},
|
||||
)
|
||||
assert.Equal(t, tt.want, AffectsPosture(context.Background(), diff, c))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAffectsPosture_OSVersion_KernelBumpWithinMin(t *testing.T) {
|
||||
c := checks(ChecksDefinition{OSVersionCheck: &OSVersionCheck{
|
||||
Linux: &MinKernelVersionCheck{MinKernelVersion: "5.0.0"},
|
||||
}})
|
||||
|
||||
// Kernel moves but stays above the minimum: verdict stays pass -> not affected.
|
||||
withinMin := diffFrom(
|
||||
nbpeer.PeerSystemMeta{GoOS: "linux", KernelVersion: "5.10.0-arch1"},
|
||||
nbpeer.PeerSystemMeta{GoOS: "linux", KernelVersion: "5.15.0-arch2"},
|
||||
nbpeer.Location{}, nbpeer.Location{},
|
||||
)
|
||||
assert.False(t, AffectsPosture(context.Background(), withinMin, c))
|
||||
|
||||
// Kernel drops below the minimum: verdict flips pass -> fail -> affected.
|
||||
crossesDown := diffFrom(
|
||||
nbpeer.PeerSystemMeta{GoOS: "linux", KernelVersion: "5.10.0-arch1"},
|
||||
nbpeer.PeerSystemMeta{GoOS: "linux", KernelVersion: "4.19.0-arch1"},
|
||||
nbpeer.Location{}, nbpeer.Location{},
|
||||
)
|
||||
assert.True(t, AffectsPosture(context.Background(), crossesDown, c))
|
||||
}
|
||||
|
||||
func TestAffectsPosture_OSVersion_GoOSSwitchFlipsVerdict(t *testing.T) {
|
||||
// Only Linux is constrained. An OS outside the switch (freebsd) passes; switching to a
|
||||
// failing linux kernel flips the verdict pass -> fail.
|
||||
c := checks(ChecksDefinition{OSVersionCheck: &OSVersionCheck{
|
||||
Linux: &MinKernelVersionCheck{MinKernelVersion: "6.0.0"},
|
||||
}})
|
||||
|
||||
diff := diffFrom(
|
||||
nbpeer.PeerSystemMeta{GoOS: "freebsd"},
|
||||
nbpeer.PeerSystemMeta{GoOS: "linux", KernelVersion: "4.19.0"},
|
||||
nbpeer.Location{}, nbpeer.Location{},
|
||||
)
|
||||
assert.True(t, AffectsPosture(context.Background(), diff, c))
|
||||
}
|
||||
|
||||
func TestAffectsPosture_Process_GoOSSwitchFlipsVerdict(t *testing.T) {
|
||||
// Process runs at a linux path. Switching GoOS to windows (no WindowsPath configured)
|
||||
// flips the verdict.
|
||||
c := checks(ChecksDefinition{ProcessCheck: &ProcessCheck{
|
||||
Processes: []Process{{LinuxPath: "/usr/bin/foo"}},
|
||||
}})
|
||||
|
||||
files := []nbpeer.File{{Path: "/usr/bin/foo", ProcessIsRunning: true}}
|
||||
diff := diffFrom(
|
||||
nbpeer.PeerSystemMeta{GoOS: "linux", Files: files},
|
||||
nbpeer.PeerSystemMeta{GoOS: "windows", Files: files},
|
||||
nbpeer.Location{}, nbpeer.Location{},
|
||||
)
|
||||
assert.True(t, AffectsPosture(context.Background(), diff, c))
|
||||
}
|
||||
|
||||
func TestAffectsPosture_Process_UnrelatedFileChange(t *testing.T) {
|
||||
// A tracked process stays running while an unrelated file is added: the verdict does
|
||||
// not move, so posture is not affected.
|
||||
c := checks(ChecksDefinition{ProcessCheck: &ProcessCheck{
|
||||
Processes: []Process{{LinuxPath: "/usr/bin/foo"}},
|
||||
}})
|
||||
|
||||
diff := diffFrom(
|
||||
nbpeer.PeerSystemMeta{GoOS: "linux", Files: []nbpeer.File{
|
||||
{Path: "/usr/bin/foo", ProcessIsRunning: true},
|
||||
}},
|
||||
nbpeer.PeerSystemMeta{GoOS: "linux", Files: []nbpeer.File{
|
||||
{Path: "/usr/bin/foo", ProcessIsRunning: true},
|
||||
{Path: "/usr/bin/bar", ProcessIsRunning: true},
|
||||
}},
|
||||
nbpeer.Location{}, nbpeer.Location{},
|
||||
)
|
||||
assert.False(t, AffectsPosture(context.Background(), diff, c))
|
||||
}
|
||||
|
||||
func TestAffectsPosture_GeoLocation(t *testing.T) {
|
||||
c := checks(ChecksDefinition{GeoLocationCheck: &GeoLocationCheck{
|
||||
Action: CheckActionAllow,
|
||||
Locations: []Location{{CountryCode: "DE"}},
|
||||
}})
|
||||
|
||||
// Moving within allowed countries keeps the verdict; moving out flips it.
|
||||
stayAllowed := diffFrom(
|
||||
nbpeer.PeerSystemMeta{}, nbpeer.PeerSystemMeta{},
|
||||
nbpeer.Location{CountryCode: "DE", CityName: "Berlin"},
|
||||
nbpeer.Location{CountryCode: "DE", CityName: "Munich"},
|
||||
)
|
||||
assert.False(t, AffectsPosture(context.Background(), stayAllowed, c))
|
||||
|
||||
moveOut := diffFrom(
|
||||
nbpeer.PeerSystemMeta{}, nbpeer.PeerSystemMeta{},
|
||||
nbpeer.Location{CountryCode: "DE"},
|
||||
nbpeer.Location{CountryCode: "FR"},
|
||||
)
|
||||
assert.True(t, AffectsPosture(context.Background(), moveOut, c))
|
||||
}
|
||||
|
||||
func TestAffectsPosture_PeerNetworkRange_ConnectionIP(t *testing.T) {
|
||||
// The check reads the connection IP. Moving out of the allowed range flips the verdict;
|
||||
// moving within it does not.
|
||||
_, allowed, _ := net.ParseCIDR("10.0.0.0/8")
|
||||
c := checks(ChecksDefinition{PeerNetworkRangeCheck: &PeerNetworkRangeCheck{
|
||||
Action: CheckActionAllow,
|
||||
Ranges: []netip.Prefix{netip.MustParsePrefix(allowed.String())},
|
||||
}})
|
||||
|
||||
movesOutOfRange := diffFrom(
|
||||
nbpeer.PeerSystemMeta{}, nbpeer.PeerSystemMeta{},
|
||||
nbpeer.Location{ConnectionIP: net.ParseIP("10.1.2.3")},
|
||||
nbpeer.Location{ConnectionIP: net.ParseIP("8.8.8.8")},
|
||||
)
|
||||
assert.True(t, AffectsPosture(context.Background(), movesOutOfRange, c))
|
||||
|
||||
staysInRange := diffFrom(
|
||||
nbpeer.PeerSystemMeta{}, nbpeer.PeerSystemMeta{},
|
||||
nbpeer.Location{ConnectionIP: net.ParseIP("10.1.2.3")},
|
||||
nbpeer.Location{ConnectionIP: net.ParseIP("10.9.9.9")},
|
||||
)
|
||||
assert.False(t, AffectsPosture(context.Background(), staysInRange, c))
|
||||
}
|
||||
|
||||
func TestAffectsPosture_IrrelevantFieldChange(t *testing.T) {
|
||||
// Hostname changes but no check reads it: not affected even with checks present.
|
||||
c := checks(ChecksDefinition{
|
||||
NBVersionCheck: &NBVersionCheck{MinVersion: "1.0.0"},
|
||||
GeoLocationCheck: &GeoLocationCheck{Action: CheckActionAllow, Locations: []Location{{CountryCode: "DE"}}},
|
||||
})
|
||||
|
||||
diff := diffFrom(
|
||||
nbpeer.PeerSystemMeta{Hostname: "old", WtVersion: "1.5.0"},
|
||||
nbpeer.PeerSystemMeta{Hostname: "new", WtVersion: "1.5.0"},
|
||||
nbpeer.Location{CountryCode: "DE"}, nbpeer.Location{CountryCode: "DE"},
|
||||
)
|
||||
assert.False(t, AffectsPosture(context.Background(), diff, c))
|
||||
}
|
||||
|
||||
func TestAffectsPosture_NoChecks(t *testing.T) {
|
||||
diff := diffFrom(
|
||||
nbpeer.PeerSystemMeta{WtVersion: "1.0.0"},
|
||||
nbpeer.PeerSystemMeta{WtVersion: "2.0.0"},
|
||||
nbpeer.Location{}, nbpeer.Location{},
|
||||
)
|
||||
assert.False(t, AffectsPosture(context.Background(), diff, nil))
|
||||
}
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"regexp"
|
||||
|
||||
"github.com/hashicorp/go-version"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
@@ -53,46 +52,34 @@ type Checks struct {
|
||||
Checks ChecksDefinition `gorm:"serializer:json"`
|
||||
}
|
||||
|
||||
// AffectsPosture reports whether the change in diff flips the verdict of any check. It
|
||||
// replays each check against the peer's old and new state and compares verdicts, so a
|
||||
// change that moves a field but stays the right side of a threshold (e.g. a kernel bump
|
||||
// still above the minimum) does not force a re-evaluation. See verdictChanged for how an
|
||||
// evaluation error counts.
|
||||
func AffectsPosture(ctx context.Context, diff *nbpeer.MetaDiff, checks []*Checks) bool {
|
||||
// AffectsPosture reports whether the peer metadata changes described by diff can
|
||||
// alter the outcome of any of the given posture checks. It maps each check kind to
|
||||
// the metadata fields it inspects, so an unrelated change (e.g. a hostname update)
|
||||
// does not force a posture re-evaluation.
|
||||
func AffectsPosture(diff *nbpeer.MetaDiff, checks []*Checks) bool {
|
||||
if diff == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
oldPeer := nbpeer.Peer{Meta: diff.OldMeta, Location: diff.OldLocation}
|
||||
newPeer := nbpeer.Peer{Meta: diff.NewMeta, Location: diff.NewLocation}
|
||||
|
||||
for _, c := range checks {
|
||||
for _, check := range c.GetChecks() {
|
||||
if verdictChanged(ctx, check, oldPeer, newPeer) {
|
||||
return true
|
||||
}
|
||||
if c.Checks.ProcessCheck != nil && diff.Files {
|
||||
return true
|
||||
}
|
||||
if c.Checks.OSVersionCheck != nil && (diff.OSVersion || diff.OS || diff.KernelVersion) {
|
||||
return true
|
||||
}
|
||||
if c.Checks.NBVersionCheck != nil && diff.WtVersion {
|
||||
return true
|
||||
}
|
||||
if c.Checks.GeoLocationCheck != nil && diff.LocationChanged {
|
||||
return true
|
||||
}
|
||||
if c.Checks.PeerNetworkRangeCheck != nil && diff.NetworkAddresses {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// verdictChanged replays check against old and new state and reports whether the verdict
|
||||
// differs. Like callers, it treats an evaluation error as deny: two errors are the same
|
||||
// verdict (no change), an error on one side only is a flip.
|
||||
func verdictChanged(ctx context.Context, check Check, oldPeer, newPeer nbpeer.Peer) bool {
|
||||
oldPass, oldErr := check.Check(ctx, oldPeer)
|
||||
newPass, newErr := check.Check(ctx, newPeer)
|
||||
|
||||
oldVerdict := oldPass && (oldErr == nil)
|
||||
newVerdict := newPass && (newErr == nil)
|
||||
changed := oldVerdict != newVerdict
|
||||
|
||||
log.WithContext(ctx).Tracef("posture check %s replay: verdict %t -> %t (changed=%t), errs: %v -> %v",
|
||||
check.Name(), oldVerdict, newVerdict, changed, oldErr, newErr)
|
||||
|
||||
return changed
|
||||
}
|
||||
|
||||
// ChecksDefinition contains definition of actual check
|
||||
type ChecksDefinition struct {
|
||||
NBVersionCheck *NBVersionCheck `json:",omitempty"`
|
||||
|
||||
@@ -489,7 +489,6 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
|
||||
|
||||
policy := &types.Policy{
|
||||
AccountID: account.Id,
|
||||
Enabled: true,
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
|
||||
@@ -1059,8 +1059,8 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.WithContext(ctx).Tracef("Got %d users from ExternalCache for account %s", len(usersFromIntegration), accountID)
|
||||
log.WithContext(ctx).Tracef("Got %d users from InternalCache for account %s", len(queriedUsers), accountID)
|
||||
log.WithContext(ctx).Debugf("Got %d users from ExternalCache for account %s", len(usersFromIntegration), accountID)
|
||||
log.WithContext(ctx).Debugf("Got %d users from InternalCache for account %s", len(queriedUsers), accountID)
|
||||
queriedUsers = append(queriedUsers, usersFromIntegration...)
|
||||
}
|
||||
|
||||
|
||||
@@ -48,10 +48,6 @@ type Type int32
|
||||
var (
|
||||
ErrExtraSettingsNotFound = errors.New("extra settings not found")
|
||||
ErrPeerAlreadyLoggedIn = errors.New("peer with the same public key is already logged in")
|
||||
|
||||
// ErrNoAuthMethodProvided is returned when a peer login attempt carries neither a
|
||||
// setup key nor an SSO token. Match it with errors.Is.
|
||||
ErrNoAuthMethodProvided = Errorf(Unauthenticated, "no peer auth method provided, please use a setup key or interactive SSO login")
|
||||
)
|
||||
|
||||
// Error is an internal error
|
||||
@@ -70,16 +66,6 @@ func (e *Error) Error() string {
|
||||
return e.Message
|
||||
}
|
||||
|
||||
// Is reports whether target is an *Error with the same type and message,
|
||||
// enabling matching with errors.Is against sentinel errors.
|
||||
func (e *Error) Is(target error) bool {
|
||||
var t *Error
|
||||
if !errors.As(target, &t) {
|
||||
return false
|
||||
}
|
||||
return e.ErrorType == t.ErrorType && e.Message == t.Message
|
||||
}
|
||||
|
||||
// Errorf returns Error(ErrorType, fmt.Sprintf(format, a...)).
|
||||
func Errorf(errorType Type, format string, a ...interface{}) error {
|
||||
return &Error{
|
||||
|
||||
@@ -243,7 +243,11 @@ func NewClientWithServerIP(serverURL string, serverIP netip.Addr, authTokenStore
|
||||
|
||||
// Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs.
|
||||
func (c *Client) Connect(ctx context.Context) error {
|
||||
c.log.Infof("connecting to relay server")
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
c.log.Infof("connect elapsed time: %v", time.Since(start))
|
||||
}()
|
||||
|
||||
c.readLoopMutex.Lock()
|
||||
defer c.readLoopMutex.Unlock()
|
||||
|
||||
@@ -287,6 +291,11 @@ func (c *Client) Connect(ctx context.Context) error {
|
||||
func (c *Client) OpenConn(ctx context.Context, dstPeerID string) (net.Conn, error) {
|
||||
peerID := messages.HashID(dstPeerID)
|
||||
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
c.log.Infof("connect elapsed time: %v", time.Since(start))
|
||||
}()
|
||||
|
||||
c.mu.Lock()
|
||||
if !c.serviceIsRunning {
|
||||
c.mu.Unlock()
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"go.opentelemetry.io/otel"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
@@ -252,7 +253,7 @@ func TestClient_ConnectedIPParsesRemoteAddr(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Client{relayConn: stubConn{remote: staticAddr{s: tt.s}}}
|
||||
c := &Client{log: log.WithField("relay", tt.name), relayConn: stubConn{remote: staticAddr{s: tt.s}}}
|
||||
got := c.ConnectedIP()
|
||||
var gotStr string
|
||||
if got.IsValid() {
|
||||
|
||||
@@ -78,6 +78,23 @@ type GrpcClient struct {
|
||||
// transport-alive but no longer delivering messages. It is the source of
|
||||
// truth IsHealthy reads, and is cleared once any frame is received again.
|
||||
receiveStalled atomic.Bool
|
||||
// receiveHandoffBlocked is set while the receive loop is parked handing a
|
||||
// message to a busy decryption worker. The loop stops calling Recv (and
|
||||
// markReceived) in that window, so the stream looks silent though it is
|
||||
// healthy. The watchdog reads this to avoid misreading self-inflicted
|
||||
// receive backpressure as a dead stream: reconnecting cannot help, since the
|
||||
// new stream feeds the same worker, and only triggers a reconnect storm.
|
||||
receiveHandoffBlocked atomic.Bool
|
||||
// lastDecrypt holds the Unix-nano timestamp of the last message the decryption
|
||||
// worker pulled off its queue. Diagnostic only: it lets a stall log show
|
||||
// whether the worker was draining (busy) or idle when the stream went silent.
|
||||
lastDecrypt atomic.Int64
|
||||
// handoffWaitTotal, handoffWaitMax (nanos) and handoffWaitCount accumulate the
|
||||
// time the receive loop spent blocked handing messages to the worker. This is
|
||||
// time not spent reading the stream, so it quantifies receive backpressure.
|
||||
handoffWaitTotal atomic.Int64
|
||||
handoffWaitMax atomic.Int64
|
||||
handoffWaitCount atomic.Int64
|
||||
}
|
||||
|
||||
// NewClient creates a new Signal client
|
||||
@@ -353,6 +370,8 @@ func (c *GrpcClient) SendToStream(msg *proto.EncryptedMessage) error {
|
||||
|
||||
// decryptMessage decrypts the body of the msg using Wireguard private key and Remote peer's public key
|
||||
func (c *GrpcClient) decryptMessage(msg *proto.EncryptedMessage) (*proto.Message, error) {
|
||||
c.lastDecrypt.Store(time.Now().UnixNano())
|
||||
|
||||
remoteKey, err := wgtypes.ParseKey(msg.GetKey())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -439,6 +458,22 @@ func (c *GrpcClient) idleSinceReceive() time.Duration {
|
||||
return time.Since(time.Unix(0, c.lastReceived.Load()))
|
||||
}
|
||||
|
||||
// idleSinceDecrypt returns how long since the worker last pulled a message.
|
||||
// Diagnostic only: distinguishes a busy/wedged worker from an idle one.
|
||||
func (c *GrpcClient) idleSinceDecrypt() time.Duration {
|
||||
return time.Since(time.Unix(0, c.lastDecrypt.Load()))
|
||||
}
|
||||
|
||||
// receiveAlive reports whether the receive stream shows liveness: it delivered a
|
||||
// frame within the inactivity threshold, or the receive loop is currently parked
|
||||
// handing a message to a busy decryption worker. In the latter case the loop has
|
||||
// stopped calling Recv, so the stream looks silent while being healthy, and
|
||||
// reconnecting would not help, so the watchdog must treat it as alive.
|
||||
func (c *GrpcClient) receiveAlive() bool {
|
||||
return c.idleSinceReceive() < receiveInactivityThreshold ||
|
||||
c.receiveHandoffBlocked.Load()
|
||||
}
|
||||
|
||||
// watchReceiveStream guards against a receive stream that is transport-alive but
|
||||
// no longer delivering messages. While the stream is idle past
|
||||
// receiveInactivityThreshold it sends a self-addressed probe that the Signal
|
||||
@@ -450,18 +485,55 @@ func (c *GrpcClient) watchReceiveStream(ctx context.Context, cancelStream contex
|
||||
defer ticker.Stop()
|
||||
|
||||
var probeSentAt time.Time
|
||||
var holdLogged bool
|
||||
var statTicks int
|
||||
var lastStatTotal int64
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if c.idleSinceReceive() < receiveInactivityThreshold {
|
||||
// Periodic backpressure summary so time lost to the worker handoff is
|
||||
// visible even when no stall fires. Emitted ~once a minute and only
|
||||
// when the wait grew, to stay quiet on a healthy stream.
|
||||
if statTicks++; statTicks >= int(time.Minute/receiveWatchdogInterval) {
|
||||
statTicks = 0
|
||||
if total, max, count := c.handoffWaitStats(); int64(total) > lastStatTotal {
|
||||
log.Infof("signal receive backpressure: handoffWaitTotal=%s (+%s last min) handoffWaitMax=%s handoffMsgs=%d",
|
||||
total.Round(time.Second), (total - time.Duration(lastStatTotal)).Round(time.Millisecond),
|
||||
max.Round(time.Millisecond), count)
|
||||
lastStatTotal = int64(total)
|
||||
}
|
||||
}
|
||||
|
||||
if c.receiveAlive() {
|
||||
// Attribute the case that matters in the field: silent past the
|
||||
// threshold but held because the receive loop is parked on the
|
||||
// worker handoff (backpressure), not a dead stream. Log once per
|
||||
// hold episode so a persistent worker stall is visible at info.
|
||||
if c.idleSinceReceive() >= receiveInactivityThreshold && c.receiveHandoffBlocked.Load() {
|
||||
if !holdLogged {
|
||||
total, max, count := c.handoffWaitStats()
|
||||
log.Infof("signal receive idle %s, loop blocked on worker handoff (idleDecrypt=%s queueDepth=%d connState=%s handoffWaitTotal=%s handoffWaitMax=%s handoffMsgs=%d); holding stream",
|
||||
c.idleSinceReceive().Round(time.Second), c.idleSinceDecrypt().Round(time.Second),
|
||||
c.decryptionWorker.QueueLen(), c.signalConn.GetState(),
|
||||
total.Round(time.Second), max.Round(time.Millisecond), count)
|
||||
holdLogged = true
|
||||
}
|
||||
} else {
|
||||
holdLogged = false
|
||||
}
|
||||
probeSentAt = time.Time{}
|
||||
continue
|
||||
}
|
||||
holdLogged = false
|
||||
|
||||
if !probeSentAt.IsZero() && time.Since(probeSentAt) >= receiveProbeTimeout {
|
||||
log.Warnf("signal receive stream stalled: no messages for %s and probe did not return, reconnecting", c.idleSinceReceive().Round(time.Second))
|
||||
total, max, count := c.handoffWaitStats()
|
||||
log.Warnf("signal receive stream stalled, reconnecting: idleRecv=%s idleDecrypt=%s handoffBlocked=%v queueDepth=%d connState=%s handoffWaitTotal=%s handoffWaitMax=%s handoffMsgs=%d probe did not return",
|
||||
c.idleSinceReceive().Round(time.Second), c.idleSinceDecrypt().Round(time.Second),
|
||||
c.receiveHandoffBlocked.Load(), c.decryptionWorker.QueueLen(), c.signalConn.GetState(),
|
||||
total.Round(time.Second), max.Round(time.Millisecond), count)
|
||||
c.receiveStalled.Store(true)
|
||||
c.notifyDisconnected(errReceiveStreamStalled)
|
||||
cancelStream()
|
||||
@@ -517,12 +589,37 @@ func (c *GrpcClient) receive(stream proto.SignalExchange_ConnectStreamClient) er
|
||||
continue
|
||||
}
|
||||
|
||||
// The handoff blocks while the worker is busy, which parks this loop and
|
||||
// stops Recv. Flag it so the watchdog does not read the resulting silence
|
||||
// as a dead stream, and account the wait as receive backpressure.
|
||||
handoffStart := time.Now()
|
||||
c.receiveHandoffBlocked.Store(true)
|
||||
if err := c.decryptionWorker.AddMsg(c.ctx, msg); err != nil {
|
||||
log.Errorf("failed to add message to decryption worker: %v", err)
|
||||
}
|
||||
c.receiveHandoffBlocked.Store(false)
|
||||
c.recordHandoffWait(time.Since(handoffStart))
|
||||
}
|
||||
}
|
||||
|
||||
// recordHandoffWait accumulates the time the receive loop was blocked handing a
|
||||
// message to the worker.
|
||||
func (c *GrpcClient) recordHandoffWait(d time.Duration) {
|
||||
c.handoffWaitTotal.Add(int64(d))
|
||||
c.handoffWaitCount.Add(1)
|
||||
for {
|
||||
cur := c.handoffWaitMax.Load()
|
||||
if int64(d) <= cur || c.handoffWaitMax.CompareAndSwap(cur, int64(d)) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handoffWaitStats returns cumulative receive-loop handoff backpressure.
|
||||
func (c *GrpcClient) handoffWaitStats() (total, max time.Duration, count int64) {
|
||||
return time.Duration(c.handoffWaitTotal.Load()), time.Duration(c.handoffWaitMax.Load()), c.handoffWaitCount.Load()
|
||||
}
|
||||
|
||||
func (c *GrpcClient) startEncryptionWorker(handler func(msg *proto.Message) error) {
|
||||
if c.decryptionWorker != nil {
|
||||
return
|
||||
|
||||
@@ -82,3 +82,27 @@ func TestReceiveProbeRoundTrips(t *testing.T) {
|
||||
t.Fatal("self-addressed heartbeat did not round-trip back through the signal server")
|
||||
}
|
||||
}
|
||||
|
||||
// TestReceiveAliveTreatsHandoffBlockAsLiveness reproduces the false positive
|
||||
// where a busy decryption worker parks the receive loop on the worker handoff,
|
||||
// so Recv (and markReceived) stops firing even though the stream is healthy.
|
||||
// With the receive stream silent past the inactivity threshold but the loop
|
||||
// blocked on handoff, the watchdog must consider the stream alive rather than
|
||||
// tear it down (reconnecting feeds the same worker and would not help).
|
||||
func TestReceiveAliveTreatsHandoffBlockAsLiveness(t *testing.T) {
|
||||
c := &GrpcClient{}
|
||||
|
||||
// Receive stream silent and the loop not blocked on handoff: genuinely stalled.
|
||||
c.lastReceived.Store(time.Now().Add(-2 * receiveInactivityThreshold).UnixNano())
|
||||
require.False(t, c.receiveAlive(), "silent stream with the receive loop idle must be treated as stalled")
|
||||
|
||||
// Receive stream silent but the loop is parked handing a message to a busy
|
||||
// worker: self-inflicted backpressure, not a dead stream, must not tear down.
|
||||
c.receiveHandoffBlocked.Store(true)
|
||||
require.True(t, c.receiveAlive(), "a receive loop blocked on worker handoff must keep the stream alive")
|
||||
|
||||
// Handoff drained, loop back to reading, a frame just arrived: alive via the receive path.
|
||||
c.receiveHandoffBlocked.Store(false)
|
||||
c.markReceived()
|
||||
require.True(t, c.receiveAlive(), "a freshly received frame must keep the stream alive")
|
||||
}
|
||||
|
||||
@@ -32,6 +32,13 @@ func (w *Worker) AddMsg(ctx context.Context, msg *proto.EncryptedMessage) error
|
||||
return nil
|
||||
}
|
||||
|
||||
// QueueLen returns the number of messages buffered for decryption. Diagnostic
|
||||
// only: a non-empty queue while the receive stream is silent indicates the
|
||||
// receive loop is parked on the handoff rather than the stream being dead.
|
||||
func (w *Worker) QueueLen() int {
|
||||
return len(w.encryptedMsgPool)
|
||||
}
|
||||
|
||||
func (w *Worker) Work(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
|
||||
Reference in New Issue
Block a user