mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-30 03:39:55 +00:00
Compare commits
6 Commits
fix/revert
...
fix/mgmt-c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
df6e422e10 | ||
|
|
3236a4c7fd | ||
|
|
08ac4855f6 | ||
|
|
b6c79f1f71 | ||
|
|
37be8811a3 | ||
|
|
a7d85ff3ab |
@@ -130,7 +130,7 @@ func debugConfigDump(cmd *cobra.Command, _ []string) error {
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
resp, err := client.GetConfig(cmd.Context(), &proto.GetConfigRequest{
|
||||
ProfileName: string(activeProf.ID),
|
||||
ProfileName: activeProf.Name,
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
@@ -41,7 +41,6 @@ type ICEBind struct {
|
||||
*wgConn.StdNetBind
|
||||
|
||||
transportNet transport.Net
|
||||
filterFn udpmux.FilterFn
|
||||
address wgaddr.Address
|
||||
mtu uint16
|
||||
|
||||
@@ -61,12 +60,11 @@ type ICEBind struct {
|
||||
ipv6Conn *net.UDPConn
|
||||
}
|
||||
|
||||
func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wgaddr.Address, mtu uint16) *ICEBind {
|
||||
func NewICEBind(transportNet transport.Net, address wgaddr.Address, mtu uint16) *ICEBind {
|
||||
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
|
||||
ib := &ICEBind{
|
||||
StdNetBind: b,
|
||||
transportNet: transportNet,
|
||||
filterFn: filterFn,
|
||||
address: address,
|
||||
mtu: mtu,
|
||||
endpoints: make(map[netip.Addr]net.Conn),
|
||||
@@ -265,7 +263,6 @@ func (s *ICEBind) createOrUpdateMux() {
|
||||
udpmux.UniversalUDPMuxParams{
|
||||
UDPConn: muxConn,
|
||||
Net: s.transportNet,
|
||||
FilterFn: s.filterFn,
|
||||
WGAddress: s.address,
|
||||
MTU: s.mtu,
|
||||
},
|
||||
|
||||
@@ -289,7 +289,7 @@ func setupICEBind(t *testing.T) *ICEBind {
|
||||
IP: netip.MustParseAddr("100.64.0.1"),
|
||||
Network: netip.MustParsePrefix("100.64.0.0/10"),
|
||||
}
|
||||
return NewICEBind(transportNet, nil, address, 1280)
|
||||
return NewICEBind(transportNet, address, 1280)
|
||||
}
|
||||
|
||||
func createDualStackConns(t *testing.T) (*net.UDPConn, *net.UDPConn) {
|
||||
|
||||
@@ -32,8 +32,6 @@ type TunKernelDevice struct {
|
||||
link *wgLink
|
||||
udpMuxConn net.PacketConn
|
||||
udpMux *udpmux.UniversalUDPMuxDefault
|
||||
|
||||
filterFn udpmux.FilterFn
|
||||
}
|
||||
|
||||
func NewKernelDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, transportNet transport.Net) *TunKernelDevice {
|
||||
@@ -104,7 +102,6 @@ func (t *TunKernelDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||
bindParams := udpmux.UniversalUDPMuxParams{
|
||||
UDPConn: nbnet.WrapPacketConn(rawSock),
|
||||
Net: t.transportNet,
|
||||
FilterFn: t.filterFn,
|
||||
WGAddress: t.address,
|
||||
MTU: t.mtu,
|
||||
}
|
||||
|
||||
@@ -63,7 +63,6 @@ type WGIFaceOpts struct {
|
||||
MTU uint16
|
||||
MobileArgs *device.MobileIFaceArguments
|
||||
TransportNet transport.Net
|
||||
FilterFn udpmux.FilterFn
|
||||
DisableDNS bool
|
||||
}
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.Address, opts.MTU)
|
||||
|
||||
var tun WGTunDevice
|
||||
if netstack.IsEnabled() {
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.Address, opts.MTU)
|
||||
|
||||
if netstack.IsEnabled() {
|
||||
wgIFace := &WGIface{
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.Address, opts.MTU)
|
||||
|
||||
wgIFace := &WGIface{
|
||||
tun: device.NewTunDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunFd),
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
if netstack.IsEnabled() {
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.Address, opts.MTU)
|
||||
return &WGIface{
|
||||
tun: device.NewNetstackDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()),
|
||||
userspaceBind: true,
|
||||
@@ -30,7 +30,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
}
|
||||
|
||||
if device.ModuleTunIsLoaded() {
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.Address, opts.MTU)
|
||||
return &WGIface{
|
||||
tun: device.NewTunDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind),
|
||||
userspaceBind: true,
|
||||
|
||||
@@ -8,8 +8,6 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -22,10 +20,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
// FilterFn is a function that filters out candidates based on the address.
|
||||
// If it returns true, the address is to be filtered. It also returns the prefix of matching route.
|
||||
type FilterFn func(address netip.Addr) (bool, netip.Prefix, error)
|
||||
|
||||
// UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn
|
||||
// It then passes packets to the UDPMux that does the actual connection muxing.
|
||||
type UniversalUDPMuxDefault struct {
|
||||
@@ -43,7 +37,6 @@ type UniversalUDPMuxParams struct {
|
||||
UDPConn net.PacketConn
|
||||
XORMappedAddrCacheTTL time.Duration
|
||||
Net transport.Net
|
||||
FilterFn FilterFn
|
||||
WGAddress wgaddr.Address
|
||||
MTU uint16
|
||||
}
|
||||
@@ -68,7 +61,6 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
|
||||
PacketConn: params.UDPConn,
|
||||
mux: m,
|
||||
logger: params.Logger,
|
||||
filterFn: params.FilterFn,
|
||||
address: params.WGAddress,
|
||||
}
|
||||
|
||||
@@ -115,15 +107,12 @@ func (m *UniversalUDPMuxDefault) ReadFromConn(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// UDPConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets
|
||||
// UDPConn is a wrapper around UDPMux conn that overrides WriteTo to drop packets destined for the overlay subnet.
|
||||
type UDPConn struct {
|
||||
net.PacketConn
|
||||
mux *UniversalUDPMuxDefault
|
||||
logger logging.LeveledLogger
|
||||
filterFn FilterFn
|
||||
// TODO: reset cache on route changes
|
||||
addrCache sync.Map
|
||||
address wgaddr.Address
|
||||
mux *UniversalUDPMuxDefault
|
||||
logger logging.LeveledLogger
|
||||
address wgaddr.Address
|
||||
}
|
||||
|
||||
// GetPacketConn returns the underlying PacketConn
|
||||
@@ -132,65 +121,16 @@ func (u *UDPConn) GetPacketConn() net.PacketConn {
|
||||
}
|
||||
|
||||
func (u *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
|
||||
if u.filterFn == nil {
|
||||
udpAddr, ok := addr.(*net.UDPAddr)
|
||||
if !ok {
|
||||
return u.PacketConn.WriteTo(b, addr)
|
||||
}
|
||||
|
||||
if isRouted, found := u.addrCache.Load(addr.String()); found {
|
||||
return u.handleCachedAddress(isRouted.(bool), b, addr)
|
||||
}
|
||||
|
||||
return u.handleUncachedAddress(b, addr)
|
||||
}
|
||||
|
||||
func (u *UDPConn) handleCachedAddress(isRouted bool, b []byte, addr net.Addr) (int, error) {
|
||||
if isRouted {
|
||||
return 0, fmt.Errorf("address %s is part of a routed network, refusing to write", addr)
|
||||
}
|
||||
return u.PacketConn.WriteTo(b, addr)
|
||||
}
|
||||
|
||||
func (u *UDPConn) handleUncachedAddress(b []byte, addr net.Addr) (int, error) {
|
||||
if err := u.performFilterCheck(addr); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return u.PacketConn.WriteTo(b, addr)
|
||||
}
|
||||
|
||||
func (u *UDPConn) performFilterCheck(addr net.Addr) error {
|
||||
host, err := getHostFromAddr(addr)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get host from address %s: %v", addr, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
a, err := netip.ParseAddr(host)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to parse address %s: %v", addr, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
if u.address.Network.Contains(a) {
|
||||
dst := udpAddr.AddrPort().Addr().Unmap()
|
||||
if (u.address.Network.IsValid() && u.address.Network.Contains(dst)) || (u.address.IPv6Net.IsValid() && u.address.IPv6Net.Contains(dst)) {
|
||||
log.Warnf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||
return 0, fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||
}
|
||||
|
||||
if isRouted, prefix, err := u.filterFn(a); err != nil {
|
||||
log.Errorf("Failed to check if address %s is routed: %v", addr, err)
|
||||
} else {
|
||||
u.addrCache.Store(addr.String(), isRouted)
|
||||
if isRouted {
|
||||
// Extra log, as the error only shows up with ICE logging enabled
|
||||
log.Infof("address %s is part of routed network %s, refusing to write", addr, prefix)
|
||||
return fmt.Errorf("address %s is part of routed network %s, refusing to write", addr, prefix)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getHostFromAddr(addr net.Addr) (string, error) {
|
||||
host, _, err := net.SplitHostPort(addr.String())
|
||||
return host, err
|
||||
return u.PacketConn.WriteTo(b, addr)
|
||||
}
|
||||
|
||||
// GetSharedConn returns the shared udp conn
|
||||
@@ -225,6 +165,13 @@ func (m *UniversalUDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.A
|
||||
return nil
|
||||
}
|
||||
|
||||
src := udpAddr.AddrPort().Addr().Unmap()
|
||||
wg := m.params.WGAddress
|
||||
if (wg.Network.IsValid() && wg.Network.Contains(src)) || (wg.IPv6Net.IsValid() && wg.IPv6Net.Contains(src)) {
|
||||
log.Debugf("dropping STUN message from overlay source %s", udpAddr)
|
||||
return nil
|
||||
}
|
||||
|
||||
if m.isXORMappedResponse(msg, udpAddr.String()) {
|
||||
err := m.handleXORMappedResponse(udpAddr, msg)
|
||||
if err != nil {
|
||||
|
||||
@@ -66,7 +66,7 @@ func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
iceBind := bind.NewICEBind(nil, nil, wgAddress, 1280)
|
||||
iceBind := bind.NewICEBind(nil, wgAddress, 1280)
|
||||
endpointAddress := &net.UDPAddr{
|
||||
IP: net.IPv4(10, 0, 0, 1),
|
||||
Port: 1234,
|
||||
|
||||
@@ -22,7 +22,7 @@ func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
iceBind := bind.NewICEBind(nil, nil, wgAddress, 1280)
|
||||
iceBind := bind.NewICEBind(nil, wgAddress, 1280)
|
||||
endpointAddress := &net.UDPAddr{
|
||||
IP: net.IPv4(10, 0, 0, 1),
|
||||
Port: 1234,
|
||||
|
||||
23
client/internal/dns/mgmt/export_test.go
Normal file
23
client/internal/dns/mgmt/export_test.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package mgmt
|
||||
|
||||
import "time"
|
||||
|
||||
// pendingCount returns how many initial resolves are still in flight. Test-only.
|
||||
func (m *Resolver) pendingCount() int {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
return len(m.pending)
|
||||
}
|
||||
|
||||
// waitForPendingResolves blocks until all pending resolves settle or the
|
||||
// timeout elapses, returning true if all settled. Test-only.
|
||||
func (m *Resolver) waitForPendingResolves(timeout time.Duration) bool {
|
||||
deadline := time.Now().Add(timeout)
|
||||
for m.pendingCount() > 0 {
|
||||
if time.Now().After(deadline) {
|
||||
return false
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -50,24 +50,31 @@ type cachedRecord struct {
|
||||
consecFailures int
|
||||
}
|
||||
|
||||
// pendingEntry marks a domain whose initial resolve is in flight, so ServeDNS
|
||||
// can wait on it instead of falling through to upstream.
|
||||
type pendingEntry struct{}
|
||||
|
||||
// Resolver caches critical NetBird infrastructure domains.
|
||||
// records, refreshing, failedResolves, mgmtDomain and serverDomains are all
|
||||
// guarded by mutex.
|
||||
// records, refreshing, pending, mgmtDomain and serverDomains are all guarded by mutex.
|
||||
type Resolver struct {
|
||||
// ctx is the server-lifetime context for background resolves.
|
||||
ctx context.Context
|
||||
|
||||
records map[dns.Question]*cachedRecord
|
||||
mgmtDomain *domain.Domain
|
||||
serverDomains *dnsconfig.ServerDomains
|
||||
mutex sync.RWMutex
|
||||
|
||||
// failedResolves records the last failed initial resolve per domain so a
|
||||
// domain that never resolves isn't retried on every server-domains update
|
||||
// until refreshBackoff elapses. Entries are cleared on success and pruned
|
||||
// to the current server-domains set.
|
||||
failedResolves map[domain.Domain]time.Time
|
||||
// pending holds domains whose initial resolve is in flight, keyed by
|
||||
// punycode FQDN (trailing dot).
|
||||
pending map[string]pendingEntry
|
||||
|
||||
chain ChainResolver
|
||||
chainMaxPriority int
|
||||
refreshGroup singleflight.Group
|
||||
// resolveGroup dedups initial (cold-cache) resolves; kept separate from
|
||||
// refreshGroup so initial and stale-refresh flights don't collapse.
|
||||
resolveGroup singleflight.Group
|
||||
|
||||
// refreshing tracks questions whose refresh is running via the OS
|
||||
// fallback path. A ServeDNS hit for a question in this map indicates
|
||||
@@ -81,12 +88,13 @@ type Resolver struct {
|
||||
}
|
||||
|
||||
// NewResolver creates a new management domains cache resolver.
|
||||
func NewResolver() *Resolver {
|
||||
func NewResolver(ctx context.Context) *Resolver {
|
||||
return &Resolver{
|
||||
records: make(map[dns.Question]*cachedRecord),
|
||||
refreshing: make(map[dns.Question]*atomic.Bool),
|
||||
failedResolves: make(map[domain.Domain]time.Time),
|
||||
cacheTTL: resolveCacheTTL(),
|
||||
ctx: ctx,
|
||||
records: make(map[dns.Question]*cachedRecord),
|
||||
refreshing: make(map[dns.Question]*atomic.Bool),
|
||||
pending: make(map[string]pendingEntry),
|
||||
cacheTTL: resolveCacheTTL(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -125,6 +133,7 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
m.mutex.RLock()
|
||||
cached, found := m.records[question]
|
||||
inflight := m.refreshing[question]
|
||||
_, isPending := m.pending[question.Name]
|
||||
var shouldRefresh bool
|
||||
if found {
|
||||
stale := time.Since(cached.cachedAt) > m.cacheTTL
|
||||
@@ -134,8 +143,17 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
m.mutex.RUnlock()
|
||||
|
||||
if !found {
|
||||
m.continueToNext(w, r)
|
||||
return
|
||||
// Registered but not resolved yet: wait on the in-flight resolve
|
||||
// rather than falling through to (possibly dead) upstream.
|
||||
if isPending && m.awaitPendingResolve(question.Name) {
|
||||
m.mutex.RLock()
|
||||
cached, found = m.records[question]
|
||||
m.mutex.RUnlock()
|
||||
}
|
||||
if !found {
|
||||
m.continueToNext(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if inflight != nil && inflight.CompareAndSwap(false, true) {
|
||||
@@ -181,9 +199,7 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
|
||||
|
||||
// AddDomain resolves a domain and stores its A/AAAA records in the cache.
|
||||
// A family that resolves NODATA (nil err, zero records) evicts any stale
|
||||
// entry for that qtype. When one family hard-errors while the other succeeds,
|
||||
// the resolved family is still cached but AddDomain returns an error so the
|
||||
// caller retries the incomplete resolve rather than treating it as complete.
|
||||
// entry for that qtype.
|
||||
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
||||
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
|
||||
|
||||
@@ -213,10 +229,6 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
||||
log.Debugf("added/updated domain=%s with %d A records and %d AAAA records",
|
||||
d.SafeString(), len(aRecords), len(aaaaRecords))
|
||||
|
||||
if errA != nil || errAAAA != nil {
|
||||
return fmt.Errorf("resolve %s: incomplete, a family failed: %w", d.SafeString(), errors.Join(errA, errAAAA))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -476,12 +488,18 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error {
|
||||
delete(m.records, qAAAA)
|
||||
delete(m.refreshing, qA)
|
||||
delete(m.refreshing, qAAAA)
|
||||
delete(m.failedResolves, d)
|
||||
|
||||
log.Debugf("removed domain=%s from cache", d.SafeString())
|
||||
return nil
|
||||
}
|
||||
|
||||
// RequestedDomains returns the cacheable infrastructure domains (signal, relay,
|
||||
// STUN, TURN; flow excluded) so the cache handler can be registered for them
|
||||
// before resolution completes.
|
||||
func (m *Resolver) RequestedDomains(serverDomains dnsconfig.ServerDomains) domain.List {
|
||||
return m.extractDomainsFromServerDomains(serverDomains)
|
||||
}
|
||||
|
||||
// GetCachedDomains returns a list of all cached domains.
|
||||
func (m *Resolver) GetCachedDomains() domain.List {
|
||||
m.mutex.RLock()
|
||||
@@ -501,10 +519,12 @@ func (m *Resolver) GetCachedDomains() domain.List {
|
||||
return domains
|
||||
}
|
||||
|
||||
// UpdateFromServerDomains updates the cache with server domains from network configuration.
|
||||
// It merges new domains with existing ones, replacing entire domain types when updated.
|
||||
// Empty updates are ignored to prevent clearing infrastructure domains during partial updates.
|
||||
func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dnsconfig.ServerDomains) (domain.List, error) {
|
||||
// UpdateFromServerDomains merges server domains into the cache and resolves
|
||||
// them. New types replace whole types; empty updates are ignored. Resolution is
|
||||
// async (off the caller's sync lock) except for cold domains when dnsWillBeServed
|
||||
// and takeover is pending, which kickoffResolve primes synchronously. ctx is the
|
||||
// server lifetime, so a fast sync won't cancel resolves but Stop will.
|
||||
func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dnsconfig.ServerDomains, dnsWillBeServed bool) (domain.List, error) {
|
||||
newDomains := m.extractDomainsFromServerDomains(serverDomains)
|
||||
var removedDomains domain.List
|
||||
|
||||
@@ -520,14 +540,138 @@ func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dn
|
||||
allDomains := m.extractDomainsFromServerDomains(updatedServerDomains)
|
||||
currentDomains := m.GetCachedDomains()
|
||||
removedDomains = m.removeStaleDomains(currentDomains, allDomains)
|
||||
m.pruneFailedResolves(allDomains)
|
||||
}
|
||||
|
||||
m.addNewDomains(ctx, newDomains)
|
||||
m.kickoffResolve(ctx, newDomains, dnsWillBeServed)
|
||||
|
||||
return removedDomains, nil
|
||||
}
|
||||
|
||||
// kickoffResolve resolves each unresolved domain, skipping fresh/in-flight ones.
|
||||
// Cold domains resolve synchronously only before takeover (no upstream root
|
||||
// handler) and when dnsWillBeServed, to prime the cache via the working OS
|
||||
// resolver before OS DNS routes through the tunnel; otherwise async.
|
||||
func (m *Resolver) kickoffResolve(ctx context.Context, domains domain.List, dnsWillBeServed bool) {
|
||||
m.mutex.RLock()
|
||||
chain := m.chain
|
||||
maxPriority := m.chainMaxPriority
|
||||
m.mutex.RUnlock()
|
||||
preTakeover := chain == nil || !chain.HasRootHandlerAtOrBelow(maxPriority)
|
||||
|
||||
for _, d := range domains {
|
||||
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
|
||||
|
||||
m.mutex.Lock()
|
||||
_, hasPending := m.pending[dnsName]
|
||||
fresh := m.hasFreshRecordLocked(dnsName)
|
||||
cold := !m.hasAnyRecordLocked(dnsName)
|
||||
if !hasPending && !fresh {
|
||||
m.pending[dnsName] = pendingEntry{}
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
|
||||
if hasPending || fresh {
|
||||
continue
|
||||
}
|
||||
|
||||
if cold && preTakeover && dnsWillBeServed {
|
||||
m.resolveInitial(ctx, d, dnsName)
|
||||
continue
|
||||
}
|
||||
|
||||
m.scheduleInitialResolve(ctx, d, dnsName)
|
||||
}
|
||||
}
|
||||
|
||||
// resolveInitial resolves a cold domain synchronously, deduped via resolveGroup
|
||||
// so a concurrent ServeDNS await joins the same flight. Clears pending when done.
|
||||
func (m *Resolver) resolveInitial(ctx context.Context, d domain.Domain, dnsName string) {
|
||||
key := "initial|" + dnsName
|
||||
_, _, _ = m.resolveGroup.Do(key, func() (any, error) {
|
||||
defer m.clearPending(dnsName)
|
||||
if err := m.AddDomain(ctx, d); err != nil {
|
||||
log.Warnf("initial resolve mgmt domain=%s: %v", d.SafeString(), err)
|
||||
return struct{}{}, err
|
||||
}
|
||||
log.Debugf("added/updated management cache domain=%s", d.SafeString())
|
||||
return struct{}{}, nil
|
||||
})
|
||||
}
|
||||
|
||||
// scheduleInitialResolve runs AddDomain in the background, deduped per domain
|
||||
// by resolveGroup, clearing the pending marker when it finishes. ctx is the
|
||||
// server-lifetime context so a Stop cancels in-flight resolves.
|
||||
func (m *Resolver) scheduleInitialResolve(ctx context.Context, d domain.Domain, dnsName string) {
|
||||
key := "initial|" + dnsName
|
||||
_ = m.resolveGroup.DoChan(key, func() (any, error) {
|
||||
defer m.clearPending(dnsName)
|
||||
if err := m.AddDomain(ctx, d); err != nil {
|
||||
log.Warnf("failed to add/update domain=%s: %v", d.SafeString(), err)
|
||||
return struct{}{}, err
|
||||
}
|
||||
log.Debugf("added/updated management cache domain=%s", d.SafeString())
|
||||
return struct{}{}, nil
|
||||
})
|
||||
}
|
||||
|
||||
// hasFreshRecordLocked reports whether a non-stale A or AAAA record exists for
|
||||
// the name. Caller holds m.mutex.
|
||||
func (m *Resolver) hasFreshRecordLocked(dnsName string) bool {
|
||||
for _, qtype := range []uint16{dns.TypeA, dns.TypeAAAA} {
|
||||
q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET}
|
||||
if c, ok := m.records[q]; ok && time.Since(c.cachedAt) <= m.cacheTTL {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// hasAnyRecordLocked reports whether any A or AAAA record exists for the name,
|
||||
// fresh or stale. Caller holds m.mutex.
|
||||
func (m *Resolver) hasAnyRecordLocked(dnsName string) bool {
|
||||
for _, qtype := range []uint16{dns.TypeA, dns.TypeAAAA} {
|
||||
q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET}
|
||||
if _, ok := m.records[q]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Resolver) clearPending(dnsName string) {
|
||||
m.mutex.Lock()
|
||||
delete(m.pending, dnsName)
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
|
||||
// awaitPendingResolve joins the in-flight resolve for dnsName (bounded by
|
||||
// dnsTimeout) and reports whether a record became available.
|
||||
func (m *Resolver) awaitPendingResolve(dnsName string) bool {
|
||||
key := "initial|" + dnsName
|
||||
d, err := domain.FromString(strings.TrimSuffix(dnsName, "."))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
ch := m.resolveGroup.DoChan(key, func() (any, error) {
|
||||
defer m.clearPending(dnsName)
|
||||
if err := m.AddDomain(m.ctx, d); err != nil {
|
||||
return struct{}{}, err
|
||||
}
|
||||
return struct{}{}, nil
|
||||
})
|
||||
|
||||
select {
|
||||
case <-ch:
|
||||
case <-time.After(dnsTimeout):
|
||||
return false
|
||||
}
|
||||
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
return m.hasFreshRecordLocked(dnsName)
|
||||
}
|
||||
|
||||
// removeStaleDomains removes cached domains not present in the target domain list.
|
||||
// Management domains are preserved and never removed during server domain updates.
|
||||
func (m *Resolver) removeStaleDomains(currentDomains, newDomains domain.List) domain.List {
|
||||
@@ -593,89 +737,6 @@ func (m *Resolver) isManagementDomain(domain domain.Domain) bool {
|
||||
return m.mgmtDomain != nil && domain == *m.mgmtDomain
|
||||
}
|
||||
|
||||
// addNewDomains resolves and caches domains that are not yet in the cache,
|
||||
// running the lookups concurrently. Domains already cached are skipped and left
|
||||
// to the stale-while-revalidate refresh path, so a sync never re-resolves them
|
||||
// synchronously: once NetBird owns the OS resolver the resolve runs through the
|
||||
// handler chain and would otherwise dial the managed upstreams under the engine
|
||||
// sync lock on every update.
|
||||
func (m *Resolver) addNewDomains(ctx context.Context, newDomains domain.List) {
|
||||
var wg sync.WaitGroup
|
||||
seen := make(map[domain.Domain]struct{}, len(newDomains))
|
||||
for _, newDomain := range newDomains {
|
||||
if _, dup := seen[newDomain]; dup {
|
||||
continue
|
||||
}
|
||||
seen[newDomain] = struct{}{}
|
||||
|
||||
if !m.needsResolve(newDomain) {
|
||||
continue
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func(d domain.Domain) {
|
||||
defer wg.Done()
|
||||
if err := m.AddDomain(ctx, d); err != nil {
|
||||
m.markResolveFailed(d)
|
||||
log.Warnf("failed to add/update domain=%s: %v", d.SafeString(), err)
|
||||
return
|
||||
}
|
||||
m.clearResolveFailed(d)
|
||||
log.Debugf("added/updated management cache domain=%s", d.SafeString())
|
||||
}(newDomain)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// needsResolve reports whether d should be resolved now. A recent failed or
|
||||
// incomplete resolve gates retries on the backoff even when one family is
|
||||
// already cached, so a transiently-failed family is retried instead of being
|
||||
// treated as fully resolved. Otherwise a domain with any cached record is left
|
||||
// to the stale-while-revalidate refresh path.
|
||||
func (m *Resolver) needsResolve(d domain.Domain) bool {
|
||||
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
|
||||
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
if failedAt, ok := m.failedResolves[d]; ok {
|
||||
return time.Since(failedAt) >= refreshBackoff
|
||||
}
|
||||
|
||||
for _, qtype := range []uint16{dns.TypeA, dns.TypeAAAA} {
|
||||
q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET}
|
||||
if _, ok := m.records[q]; ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Resolver) markResolveFailed(d domain.Domain) {
|
||||
m.mutex.Lock()
|
||||
m.failedResolves[d] = time.Now()
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (m *Resolver) clearResolveFailed(d domain.Domain) {
|
||||
m.mutex.Lock()
|
||||
delete(m.failedResolves, d)
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
|
||||
// pruneFailedResolves drops failure markers for domains no longer present in
|
||||
// the server-domains set, keeping the map bounded to the current set (a
|
||||
// failed-only domain has no cached record, so RemoveDomain never sees it).
|
||||
func (m *Resolver) pruneFailedResolves(domains domain.List) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
for d := range m.failedResolves {
|
||||
if !slices.Contains(domains, d) {
|
||||
delete(m.failedResolves, d)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.ServerDomains) domain.List {
|
||||
var domains domain.List
|
||||
|
||||
|
||||
@@ -21,7 +21,6 @@ type fakeChain struct {
|
||||
mu sync.Mutex
|
||||
calls map[string]int
|
||||
answers map[string][]dns.RR
|
||||
qErr map[string]error
|
||||
err error
|
||||
hasRoot bool
|
||||
onLookup func()
|
||||
@@ -31,7 +30,6 @@ func newFakeChain() *fakeChain {
|
||||
return &fakeChain{
|
||||
calls: map[string]int{},
|
||||
answers: map[string][]dns.RR{},
|
||||
qErr: map[string]error{},
|
||||
hasRoot: true,
|
||||
}
|
||||
}
|
||||
@@ -49,9 +47,6 @@ func (f *fakeChain) ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriori
|
||||
f.calls[key]++
|
||||
answers := f.answers[key]
|
||||
err := f.err
|
||||
if err == nil {
|
||||
err = f.qErr[key]
|
||||
}
|
||||
onLookup := f.onLookup
|
||||
f.mu.Unlock()
|
||||
|
||||
@@ -80,12 +75,6 @@ func (f *fakeChain) setAnswer(name string, qtype uint16, ip string) {
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeChain) setErr(name string, qtype uint16, err error) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.qErr[name+"|"+dns.TypeToString[qtype]] = err
|
||||
}
|
||||
|
||||
func (f *fakeChain) callCount(name string, qtype uint16) int {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
@@ -141,7 +130,7 @@ func TestResolver_CacheTTLGatesRefresh(t *testing.T) {
|
||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
|
||||
t.Run("short TTL treats entry as stale and refreshes", func(t *testing.T) {
|
||||
r := NewResolver()
|
||||
r := NewResolver(context.Background())
|
||||
r.cacheTTL = 10 * time.Millisecond
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2")
|
||||
@@ -157,7 +146,7 @@ func TestResolver_CacheTTLGatesRefresh(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("long TTL keeps entry fresh and skips refresh", func(t *testing.T) {
|
||||
r := NewResolver()
|
||||
r := NewResolver(context.Background())
|
||||
r.cacheTTL = time.Hour
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2")
|
||||
@@ -173,7 +162,7 @@ func TestResolver_CacheTTLGatesRefresh(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResolver_ServeFresh_NoRefresh(t *testing.T) {
|
||||
r := NewResolver()
|
||||
r := NewResolver(context.Background())
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
@@ -194,7 +183,7 @@ func TestResolver_ServeFresh_NoRefresh(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResolver_StaleTriggersAsyncRefresh(t *testing.T) {
|
||||
r := NewResolver()
|
||||
r := NewResolver(context.Background())
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
@@ -224,7 +213,7 @@ func TestResolver_StaleTriggersAsyncRefresh(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResolver_ConcurrentStaleHitsCollapseRefresh(t *testing.T) {
|
||||
r := NewResolver()
|
||||
r := NewResolver(context.Background())
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
||||
|
||||
@@ -273,7 +262,7 @@ func TestResolver_ConcurrentStaleHitsCollapseRefresh(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResolver_RefreshFailureArmsBackoff(t *testing.T) {
|
||||
r := NewResolver()
|
||||
r := NewResolver(context.Background())
|
||||
chain := newFakeChain()
|
||||
chain.err = errors.New("boom")
|
||||
r.SetChainResolver(chain, 50)
|
||||
@@ -310,7 +299,7 @@ func TestResolver_RefreshFailureArmsBackoff(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResolver_NoRootHandler_SkipsChain(t *testing.T) {
|
||||
r := NewResolver()
|
||||
r := NewResolver(context.Background())
|
||||
chain := newFakeChain()
|
||||
chain.hasRoot = false
|
||||
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
||||
@@ -331,7 +320,7 @@ func TestResolver_ServeDuringRefreshSetsLoopFlag(t *testing.T) {
|
||||
// ServeDNS being invoked for a question while a refresh for that question
|
||||
// is inflight indicates a resolver loop (OS resolver sent the recursive
|
||||
// query back to us). The inflightRefresh.loopLoggedOnce flag must be set.
|
||||
r := NewResolver()
|
||||
r := NewResolver(context.Background())
|
||||
|
||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
r.records[q] = &cachedRecord{
|
||||
@@ -357,7 +346,7 @@ func TestResolver_ServeDuringRefreshSetsLoopFlag(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResolver_LoopFlagOnlyTrippedOncePerRefresh(t *testing.T) {
|
||||
r := NewResolver()
|
||||
r := NewResolver(context.Background())
|
||||
|
||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
r.records[q] = &cachedRecord{
|
||||
@@ -384,7 +373,7 @@ func TestResolver_LoopFlagOnlyTrippedOncePerRefresh(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResolver_NoLoopFlagWhenNotRefreshing(t *testing.T) {
|
||||
r := NewResolver()
|
||||
r := NewResolver(context.Background())
|
||||
|
||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
r.records[q] = &cachedRecord{
|
||||
@@ -404,7 +393,7 @@ func TestResolver_NoLoopFlagWhenNotRefreshing(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResolver_AddDomain_UsesChainWhenRootRegistered(t *testing.T) {
|
||||
r := NewResolver()
|
||||
r := NewResolver(context.Background())
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
||||
chain.setAnswer("mgmt.example.com.", dns.TypeAAAA, "fd00::2")
|
||||
|
||||
@@ -1,183 +0,0 @@
|
||||
package mgmt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
// A domain already in the cache must not be re-resolved on a subsequent server
|
||||
// domains update; it is left to the stale-while-revalidate refresh path.
|
||||
func TestResolver_UpdateFromServerDomains_SkipsCached(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("signal.example.com.", dns.TypeA, "10.0.0.2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
sd := dnsconfig.ServerDomains{Signal: domain.Domain("signal.example.com")}
|
||||
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), sd)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
|
||||
"first update must resolve the domain")
|
||||
|
||||
_, err = r.UpdateFromServerDomains(context.Background(), sd)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
|
||||
"cached domain must not be re-resolved on a subsequent update")
|
||||
}
|
||||
|
||||
// New domains in a single update must resolve concurrently rather than serially.
|
||||
func TestResolver_AddNewDomains_ResolvesConcurrently(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
|
||||
var inflight, maxInflight atomic.Int32
|
||||
chain.onLookup = func() {
|
||||
n := inflight.Add(1)
|
||||
for {
|
||||
old := maxInflight.Load()
|
||||
if n <= old || maxInflight.CompareAndSwap(old, n) {
|
||||
break
|
||||
}
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
inflight.Add(-1)
|
||||
}
|
||||
|
||||
relays := []domain.Domain{"a.example.com", "b.example.com", "c.example.com", "d.example.com"}
|
||||
for _, d := range relays {
|
||||
chain.setAnswer(dns.Fqdn(string(d)), dns.TypeA, "10.0.0.2")
|
||||
}
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
start := time.Now()
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: relays})
|
||||
require.NoError(t, err)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
assert.GreaterOrEqual(t, int(maxInflight.Load()), 2, "domains must resolve concurrently")
|
||||
// Serial resolution of 4 domains would take at least 4*50ms; concurrent is far less.
|
||||
assert.Less(t, elapsed, 300*time.Millisecond, "resolution should not be serial")
|
||||
}
|
||||
|
||||
// A domain that fails to resolve must not be retried on every update; the
|
||||
// failure backoff suppresses re-resolution until it expires.
|
||||
func TestResolver_UpdateFromServerDomains_BacksOffFailures(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.err = errors.New("resolve boom")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
sd := dnsconfig.ServerDomains{Signal: domain.Domain("signal.example.com")}
|
||||
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), sd)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
|
||||
"first update must attempt the resolve")
|
||||
|
||||
_, err = r.UpdateFromServerDomains(context.Background(), sd)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
|
||||
"failed resolve must back off and not retry on the next update")
|
||||
}
|
||||
|
||||
// A domain listed under more than one server-domain type (e.g. STUN and TURN on
|
||||
// the same host) must be resolved once per update, not once per occurrence.
|
||||
func TestResolver_AddNewDomains_DedupesDuplicateDomains(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("dup.example.com.", dns.TypeA, "10.0.0.9")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
sd := dnsconfig.ServerDomains{
|
||||
Stuns: []domain.Domain{"dup.example.com"},
|
||||
Turns: []domain.Domain{"dup.example.com"},
|
||||
}
|
||||
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), sd)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, chain.callCount("dup.example.com.", dns.TypeA),
|
||||
"a domain appearing under multiple server-domain types must resolve once")
|
||||
}
|
||||
|
||||
// A failure marker must be dropped once its domain leaves the server-domains set
|
||||
// so the map stays bounded to the current set.
|
||||
func TestResolver_UpdateFromServerDomains_PrunesFailedResolves(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.err = errors.New("resolve boom")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Signal: domain.Domain("gone.example.com")})
|
||||
require.NoError(t, err)
|
||||
r.mutex.RLock()
|
||||
_, marked := r.failedResolves[domain.Domain("gone.example.com")]
|
||||
r.mutex.RUnlock()
|
||||
require.True(t, marked, "failed resolve must be recorded")
|
||||
|
||||
_, err = r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Signal: domain.Domain("other.example.com")})
|
||||
require.NoError(t, err)
|
||||
r.mutex.RLock()
|
||||
_, stillMarked := r.failedResolves[domain.Domain("gone.example.com")]
|
||||
r.mutex.RUnlock()
|
||||
assert.False(t, stillMarked, "failure marker for a domain no longer in the set must be pruned")
|
||||
}
|
||||
|
||||
// When one family hard-errors while the other resolves, the domain is cached
|
||||
// for the working family but recorded as incomplete so the failed family is
|
||||
// retried under backoff instead of being treated as fully resolved forever.
|
||||
func TestResolver_AddNewDomains_RetriesPartialFamilyFailure(t *testing.T) {
|
||||
d := domain.Domain("relay.example.com")
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("relay.example.com.", dns.TypeA, "10.0.0.2")
|
||||
chain.setErr("relay.example.com.", dns.TypeAAAA, errors.New("servfail"))
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: []domain.Domain{d}})
|
||||
require.NoError(t, err)
|
||||
|
||||
r.mutex.RLock()
|
||||
_, aCached := r.records[dns.Question{Name: "relay.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}]
|
||||
_, marked := r.failedResolves[d]
|
||||
r.mutex.RUnlock()
|
||||
require.True(t, aCached, "the working family must still be cached")
|
||||
require.True(t, marked, "a partial failure must be recorded so the failed family is retried")
|
||||
|
||||
assert.False(t, r.needsResolve(d), "within the backoff window the domain is not retried")
|
||||
|
||||
r.mutex.Lock()
|
||||
r.failedResolves[d] = time.Now().Add(-2 * refreshBackoff)
|
||||
r.mutex.Unlock()
|
||||
assert.True(t, r.needsResolve(d), "after the backoff elapses the domain is retried to pick up the missing family")
|
||||
}
|
||||
|
||||
// A family that returns NODATA (legitimately absent, e.g. an IPv4-only host) is
|
||||
// not a failure: the domain must not be marked for retry, otherwise it would be
|
||||
// re-resolved on every sync.
|
||||
func TestResolver_AddNewDomains_NodataIsNotFailure(t *testing.T) {
|
||||
d := domain.Domain("v4only.example.com")
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("v4only.example.com.", dns.TypeA, "10.0.0.2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: []domain.Domain{d}})
|
||||
require.NoError(t, err)
|
||||
|
||||
r.mutex.RLock()
|
||||
_, marked := r.failedResolves[d]
|
||||
r.mutex.RUnlock()
|
||||
assert.False(t, marked, "a NODATA family must not be recorded as a failure")
|
||||
assert.False(t, r.needsResolve(d), "an IPv4-only host must not be re-resolved on later syncs")
|
||||
}
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
)
|
||||
|
||||
func TestResolver_NewResolver(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver := NewResolver(context.Background())
|
||||
|
||||
assert.NotNil(t, resolver)
|
||||
assert.NotNil(t, resolver.records)
|
||||
@@ -49,7 +49,7 @@ func TestResolveCacheTTL(t *testing.T) {
|
||||
|
||||
func TestNewResolver_CacheTTLFromEnv(t *testing.T) {
|
||||
t.Setenv(envMgmtCacheTTL, "7s")
|
||||
r := NewResolver()
|
||||
r := NewResolver(context.Background())
|
||||
assert.Equal(t, 7*time.Second, r.cacheTTL, "NewResolver should evaluate cacheTTL once from env")
|
||||
}
|
||||
|
||||
@@ -169,7 +169,7 @@ func TestResolver_PopulateFromConfig(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
resolver := NewResolver()
|
||||
resolver := NewResolver(context.Background())
|
||||
|
||||
// Test with IP address - should return error since IP addresses are rejected
|
||||
mgmtURL, _ := url.Parse("https://127.0.0.1")
|
||||
@@ -184,7 +184,7 @@ func TestResolver_PopulateFromConfig(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResolver_ServeDNS(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver := NewResolver(context.Background())
|
||||
ctx := context.Background()
|
||||
|
||||
// Add a test domain to the cache - use example.org which is reserved for testing
|
||||
@@ -284,7 +284,7 @@ func TestResolver_ServeDNS(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResolver_GetCachedDomains(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver := NewResolver(context.Background())
|
||||
ctx := context.Background()
|
||||
|
||||
testDomain, err := domain.FromString("example.org")
|
||||
@@ -304,7 +304,7 @@ func TestResolver_GetCachedDomains(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResolver_ManagementDomainProtection(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver := NewResolver(context.Background())
|
||||
ctx := context.Background()
|
||||
|
||||
mgmtURL, _ := url.Parse("https://example.org")
|
||||
@@ -325,10 +325,11 @@ func TestResolver_ManagementDomainProtection(t *testing.T) {
|
||||
Relay: []domain.Domain{"cloudflare.com"},
|
||||
}
|
||||
|
||||
_, err = resolver.UpdateFromServerDomains(ctx, serverDomains)
|
||||
_, err = resolver.UpdateFromServerDomains(ctx, serverDomains, true)
|
||||
if err != nil {
|
||||
t.Logf("Server domains update failed: %v", err)
|
||||
}
|
||||
resolver.waitForPendingResolves(10 * time.Second)
|
||||
|
||||
finalDomains := resolver.GetCachedDomains()
|
||||
|
||||
@@ -351,7 +352,7 @@ func extractDomainFromURL(u *url.URL) (domain.Domain, error) {
|
||||
}
|
||||
|
||||
func TestResolver_EmptyUpdateDoesNotRemoveDomains(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver := NewResolver(context.Background())
|
||||
ctx := context.Background()
|
||||
|
||||
// Set up initial domains using resolvable domains
|
||||
@@ -362,10 +363,11 @@ func TestResolver_EmptyUpdateDoesNotRemoveDomains(t *testing.T) {
|
||||
}
|
||||
|
||||
// Add initial domains
|
||||
_, err := resolver.UpdateFromServerDomains(ctx, initialDomains)
|
||||
_, err := resolver.UpdateFromServerDomains(ctx, initialDomains, true)
|
||||
if err != nil {
|
||||
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
|
||||
}
|
||||
resolver.waitForPendingResolves(10 * time.Second)
|
||||
|
||||
// Verify domains were added
|
||||
cachedDomains := resolver.GetCachedDomains()
|
||||
@@ -373,7 +375,7 @@ func TestResolver_EmptyUpdateDoesNotRemoveDomains(t *testing.T) {
|
||||
|
||||
// Update with empty ServerDomains (simulating partial network map update)
|
||||
emptyDomains := dnsconfig.ServerDomains{}
|
||||
removedDomains, err := resolver.UpdateFromServerDomains(ctx, emptyDomains)
|
||||
removedDomains, err := resolver.UpdateFromServerDomains(ctx, emptyDomains, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify no domains were removed
|
||||
@@ -385,7 +387,7 @@ func TestResolver_EmptyUpdateDoesNotRemoveDomains(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResolver_PartialUpdateReplacesOnlyUpdatedTypes(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver := NewResolver(context.Background())
|
||||
ctx := context.Background()
|
||||
|
||||
// Set up initial complete domains using resolvable domains
|
||||
@@ -396,20 +398,22 @@ func TestResolver_PartialUpdateReplacesOnlyUpdatedTypes(t *testing.T) {
|
||||
}
|
||||
|
||||
// Add initial domains
|
||||
_, err := resolver.UpdateFromServerDomains(ctx, initialDomains)
|
||||
_, err := resolver.UpdateFromServerDomains(ctx, initialDomains, true)
|
||||
if err != nil {
|
||||
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
|
||||
}
|
||||
resolver.waitForPendingResolves(10 * time.Second)
|
||||
assert.Len(t, resolver.GetCachedDomains(), 3)
|
||||
|
||||
// Update with partial ServerDomains (only signal domain - this should replace signal but preserve stun/turn)
|
||||
partialDomains := dnsconfig.ServerDomains{
|
||||
Signal: "github.com",
|
||||
}
|
||||
removedDomains, err := resolver.UpdateFromServerDomains(ctx, partialDomains)
|
||||
removedDomains, err := resolver.UpdateFromServerDomains(ctx, partialDomains, true)
|
||||
if err != nil {
|
||||
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
|
||||
}
|
||||
resolver.waitForPendingResolves(10 * time.Second)
|
||||
|
||||
// Should remove only the old signal domain
|
||||
assert.Len(t, removedDomains, 1, "Should remove only the old signal domain")
|
||||
@@ -429,7 +433,7 @@ func TestResolver_PartialUpdateReplacesOnlyUpdatedTypes(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver := NewResolver(context.Background())
|
||||
ctx := context.Background()
|
||||
|
||||
// Set up initial complete domains using resolvable domains
|
||||
@@ -440,10 +444,11 @@ func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
|
||||
}
|
||||
|
||||
// Add initial domains
|
||||
_, err := resolver.UpdateFromServerDomains(ctx, initialDomains)
|
||||
_, err := resolver.UpdateFromServerDomains(ctx, initialDomains, true)
|
||||
if err != nil {
|
||||
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
|
||||
}
|
||||
resolver.waitForPendingResolves(10 * time.Second)
|
||||
assert.Len(t, resolver.GetCachedDomains(), 3)
|
||||
|
||||
// Update with partial ServerDomains (only flow domain - flow is intentionally excluded from
|
||||
@@ -451,10 +456,11 @@ func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
|
||||
partialDomains := dnsconfig.ServerDomains{
|
||||
Flow: "github.com",
|
||||
}
|
||||
removedDomains, err := resolver.UpdateFromServerDomains(ctx, partialDomains)
|
||||
removedDomains, err := resolver.UpdateFromServerDomains(ctx, partialDomains, true)
|
||||
if err != nil {
|
||||
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
|
||||
}
|
||||
resolver.waitForPendingResolves(10 * time.Second)
|
||||
|
||||
assert.Len(t, removedDomains, 0, "Should not remove any domains when only flow domain is provided")
|
||||
|
||||
|
||||
@@ -282,7 +282,7 @@ func newDefaultServer(
|
||||
handlerChain := NewHandlerChain()
|
||||
ctx, stop := context.WithCancel(ctx)
|
||||
|
||||
mgmtCacheResolver := mgmt.NewResolver()
|
||||
mgmtCacheResolver := mgmt.NewResolver(ctx)
|
||||
mgmtCacheResolver.SetChainResolver(handlerChain, PriorityUpstream)
|
||||
|
||||
defaultServer := &DefaultServer{
|
||||
@@ -613,7 +613,11 @@ func (s *DefaultServer) UpdateServerConfig(domains dnsconfig.ServerDomains) erro
|
||||
defer s.mux.Unlock()
|
||||
|
||||
if s.mgmtCacheResolver != nil {
|
||||
removedDomains, err := s.mgmtCacheResolver.UpdateFromServerDomains(s.ctx, domains)
|
||||
// Mirrors the Initialize guard: without it NetBird never becomes the
|
||||
// system resolver, so the mgmt cache is never queried and need not be
|
||||
// primed synchronously.
|
||||
dnsWillBeServed := !s.disableSys && !netstack.IsEnabled()
|
||||
removedDomains, err := s.mgmtCacheResolver.UpdateFromServerDomains(s.ctx, domains, dnsWillBeServed)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update management cache resolver: %w", err)
|
||||
}
|
||||
@@ -622,7 +626,9 @@ func (s *DefaultServer) UpdateServerConfig(domains dnsconfig.ServerDomains) erro
|
||||
s.deregisterHandler(removedDomains.ToPunycodeList(), PriorityMgmtCache)
|
||||
}
|
||||
|
||||
newDomains := s.mgmtCacheResolver.GetCachedDomains()
|
||||
// Register for the requested domains, not just resolved ones: resolution
|
||||
// now runs in the background, so the cache may still be empty here.
|
||||
newDomains := s.mgmtCacheResolver.RequestedDomains(domains)
|
||||
if len(newDomains) > 0 {
|
||||
s.registerHandler(newDomains.ToPunycodeList(), s.mgmtCacheResolver, PriorityMgmtCache)
|
||||
}
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
@@ -54,7 +53,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/relay"
|
||||
"github.com/netbirdio/netbird/client/internal/rosenpass"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/syncstore"
|
||||
"github.com/netbirdio/netbird/client/internal/updater"
|
||||
@@ -90,13 +88,6 @@ var ErrResetConnection = fmt.Errorf("reset connection")
|
||||
|
||||
var ErrEngineAlreadyStarted = errors.New("engine already started")
|
||||
|
||||
// engineRestartCount and engineLastRestart track client-restart cadence across
|
||||
// engine recreations so a restart loop is distinguishable from rare restarts.
|
||||
var (
|
||||
engineRestartCount atomic.Int64
|
||||
engineLastRestart atomic.Int64
|
||||
)
|
||||
|
||||
type EngineConfig struct {
|
||||
WgPort int
|
||||
WgIfaceName string
|
||||
@@ -908,11 +899,7 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
started := time.Now()
|
||||
defer func() {
|
||||
duration := time.Since(started)
|
||||
if update.GetNetworkMap() != nil {
|
||||
log.Infof("sync finished in %s, %d", duration, update.GetNetworkMap().GetSerial())
|
||||
} else {
|
||||
log.Infof("sync finished in %s", duration)
|
||||
}
|
||||
log.Infof("sync finished in %s", duration)
|
||||
e.clientMetrics.RecordSyncDuration(e.ctx, duration)
|
||||
}()
|
||||
e.syncMsgMux.Lock()
|
||||
@@ -922,23 +909,14 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
if e.ctx.Err() != nil {
|
||||
return e.ctx.Err()
|
||||
}
|
||||
serial := update.GetNetworkMap().GetSerial()
|
||||
if nm := update.GetNetworkMap(); nm != nil {
|
||||
log.Infof("sync update: serial=%d remotePeers=%d offlinePeers=%d routes=%d firewallRules=%d checks=%d configPresent=%v remotePeersEmpty=%v",
|
||||
nm.GetSerial(), len(nm.GetRemotePeers()), len(nm.GetOfflinePeers()), len(nm.GetRoutes()),
|
||||
len(nm.GetFirewallRules()), len(update.GetChecks()), update.GetNetbirdConfig() != nil, nm.GetRemotePeersIsEmpty())
|
||||
} else {
|
||||
log.Infof("sync update: config-only (no network map), configPresent=%v", update.GetNetbirdConfig() != nil)
|
||||
}
|
||||
|
||||
if update.NetworkMap != nil && update.NetworkMap.PeerConfig != nil {
|
||||
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate)
|
||||
}
|
||||
startTime := time.Now()
|
||||
|
||||
if err := e.updateNetbirdConfig(update.GetNetbirdConfig()); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Infof("netbird config updated in %s, serial=%d", time.Since(startTime), serial)
|
||||
|
||||
// Posture checks are bound to the network map presence:
|
||||
// NetworkMap != nil, checks present -> apply the received checks
|
||||
@@ -949,21 +927,17 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
if nm == nil {
|
||||
return nil
|
||||
}
|
||||
startTime = time.Now()
|
||||
|
||||
if err := e.updateChecksIfNew(update.Checks); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Infof("checks updated in %s, serial=%d", time.Since(startTime), serial)
|
||||
|
||||
startTime = time.Now()
|
||||
e.persistSyncResponse(update)
|
||||
log.Infof("sync response persisted in %s, serial=%d", time.Since(startTime), serial)
|
||||
|
||||
// only apply new changes and ignore old ones
|
||||
startTime = time.Now()
|
||||
if err := e.updateNetworkMap(nm); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Infof("network map updated in %s, serial=%d", time.Since(startTime), serial)
|
||||
|
||||
e.statusRecorder.PublishEvent(cProto.SystemEvent_INFO, cProto.SystemEvent_SYSTEM, "Network map updated", "", nil)
|
||||
|
||||
@@ -1383,56 +1357,44 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
|
||||
dnsConfig := toDNSConfig(protoDNSConfig, e.wgInterface.Address())
|
||||
|
||||
startTime := time.Now()
|
||||
if err := e.dnsServer.UpdateDNSServer(serial, dnsConfig); err != nil {
|
||||
log.Errorf("failed to update dns server, err: %v", err)
|
||||
}
|
||||
log.Infof("updated dns server in %v, serial=%d", time.Since(startTime), serial)
|
||||
|
||||
e.routeManager.SetDNSForwarderPort(dnsConfig.ForwarderPort)
|
||||
|
||||
// apply routes first, route related actions might depend on routing being enabled
|
||||
startTime = time.Now()
|
||||
routes := toRoutes(networkMap.GetRoutes())
|
||||
serverRoutes, clientRoutes := e.routeManager.ClassifyRoutes(routes)
|
||||
log.Infof("updated routes in %v, serial=%d", time.Since(startTime), serial)
|
||||
|
||||
// lazy mgr needs to be aware of which routes are available before they are applied
|
||||
if e.connMgr != nil {
|
||||
e.connMgr.UpdateRouteHAMap(clientRoutes)
|
||||
log.Debugf("updated lazy connection manager with %d HA groups", len(clientRoutes))
|
||||
}
|
||||
|
||||
startTime = time.Now()
|
||||
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
|
||||
if err := e.routeManager.UpdateRoutes(serial, serverRoutes, clientRoutes, dnsRouteFeatureFlag); err != nil {
|
||||
log.Errorf("failed to update routes: %v", err)
|
||||
}
|
||||
log.Infof("updated routes in %v, serial=%d", time.Since(startTime), serial)
|
||||
|
||||
startTime = time.Now()
|
||||
if e.acl != nil {
|
||||
e.acl.ApplyFiltering(networkMap, dnsRouteFeatureFlag)
|
||||
}
|
||||
log.Infof("updated filtering in %v, serial=%d", time.Since(startTime), serial)
|
||||
|
||||
startTime = time.Now()
|
||||
fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes)
|
||||
e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries)
|
||||
log.Infof("updated DNS forwarder in %v, serial=%d", time.Since(startTime), serial)
|
||||
|
||||
startTime = time.Now()
|
||||
// Ingress forward rules
|
||||
forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules())
|
||||
if err != nil {
|
||||
log.Errorf("failed to update forward rules, err: %v", err)
|
||||
}
|
||||
log.Infof("updated forward rules in %v, serial=%d", time.Since(startTime), serial)
|
||||
|
||||
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
|
||||
|
||||
startTime = time.Now()
|
||||
e.updateOfflinePeers(networkMap.GetOfflinePeers())
|
||||
log.Infof("updated offline peers in %v, serial=%d", time.Since(startTime), serial)
|
||||
|
||||
// Filter out own peer from the remote peers list
|
||||
localPubKey := e.config.WgPrivateKey.PublicKey().String()
|
||||
remotePeers := make([]*mgmProto.RemotePeerConfig, 0, len(networkMap.GetRemotePeers()))
|
||||
@@ -1450,24 +1412,20 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
startTime = time.Now()
|
||||
err := e.removePeers(remotePeers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Infof("removed peers in %v, serial=%d", time.Since(startTime), serial)
|
||||
startTime = time.Now()
|
||||
|
||||
err = e.modifyPeers(remotePeers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Infof("modified peers in %v, serial=%d", time.Since(startTime), serial)
|
||||
startTime = time.Now()
|
||||
|
||||
err = e.addNewPeers(remotePeers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Infof("added peers in %v, serial=%d", time.Since(startTime), serial)
|
||||
|
||||
e.statusRecorder.FinishPeerListModifications()
|
||||
|
||||
@@ -1480,11 +1438,9 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
e.updateSSHServerAuth(networkMap.GetSshAuth())
|
||||
}
|
||||
|
||||
startTime = time.Now()
|
||||
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
|
||||
excludedLazyPeers := e.toExcludedLazyPeers(forwardingRules, remotePeers)
|
||||
e.connMgr.SetExcludeList(e.ctx, excludedLazyPeers)
|
||||
log.Infof("updated lazy connection manager exclude list in %v, serial=%d", time.Since(startTime), serial)
|
||||
|
||||
e.networkSerial = serial
|
||||
|
||||
@@ -1999,7 +1955,6 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) {
|
||||
WGPrivKey: e.config.WgPrivateKey.String(),
|
||||
MTU: e.config.MTU,
|
||||
TransportNet: transportNet,
|
||||
FilterFn: e.addrViaRoutes,
|
||||
DisableDNS: e.config.DisableDNS,
|
||||
}
|
||||
|
||||
@@ -2216,14 +2171,7 @@ func (e *Engine) triggerClientRestart() {
|
||||
return
|
||||
}
|
||||
|
||||
// Cadence survives engine recreation (package-level), so a restart loop shows
|
||||
// as a fast-climbing count with a short gap, distinct from rare intentional restarts.
|
||||
n := engineRestartCount.Add(1)
|
||||
var sinceLast time.Duration
|
||||
if prev := engineLastRestart.Swap(time.Now().UnixNano()); prev != 0 {
|
||||
sinceLast = time.Since(time.Unix(0, prev))
|
||||
}
|
||||
log.Infof("restarting engine (restart #%d, %s since previous)", n, sinceLast.Round(time.Second))
|
||||
log.Info("restarting engine")
|
||||
CtxGetState(e.ctx).Set(StatusConnecting)
|
||||
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
|
||||
log.Infof("cancelling client context, engine will be recreated")
|
||||
@@ -2254,21 +2202,6 @@ func (e *Engine) startNetworkMonitor() {
|
||||
}()
|
||||
}
|
||||
|
||||
func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) {
|
||||
var vpnRoutes []netip.Prefix
|
||||
for _, routes := range e.routeManager.GetClientRoutes() {
|
||||
if len(routes) > 0 && routes[0] != nil {
|
||||
vpnRoutes = append(vpnRoutes, routes[0].Network)
|
||||
}
|
||||
}
|
||||
|
||||
if isVpn, prefix := systemops.IsAddrRouted(addr, vpnRoutes); isVpn {
|
||||
return true, prefix, nil
|
||||
}
|
||||
|
||||
return false, netip.Prefix{}, nil
|
||||
}
|
||||
|
||||
func (e *Engine) stopDNSServer() {
|
||||
if e.dnsServer == nil {
|
||||
return
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"runtime/debug"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -193,7 +192,6 @@ func (s *StatusChangeSubscription) Events() chan map[string]RouterState {
|
||||
// Pure read methods take RLock; anything that mutates state takes Lock.
|
||||
type Status struct {
|
||||
mux sync.RWMutex
|
||||
muxRelays sync.RWMutex
|
||||
peers map[string]State
|
||||
ipToKey map[string]string
|
||||
changeNotify map[string]map[string]*StatusChangeSubscription // map[peerID]map[subscriptionID]*StatusChangeSubscription
|
||||
@@ -228,8 +226,6 @@ type Status struct {
|
||||
|
||||
routeIDLookup routeIDLookup
|
||||
wgIface WGIfaceStatus
|
||||
|
||||
profile *StatusProfile
|
||||
}
|
||||
|
||||
// NewRecorder returns a new Status instance
|
||||
@@ -244,23 +240,16 @@ func NewRecorder(mgmAddress string) *Status {
|
||||
notifier: newNotifier(),
|
||||
mgmAddress: mgmAddress,
|
||||
resolvedDomainsStates: map[domain.Domain]ResolvedDomainInfo{},
|
||||
profile: NewStatusProfile(context.Background()),
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Status) StartProfile(ctx context.Context) {
|
||||
d.profile.Start(ctx)
|
||||
}
|
||||
|
||||
func (d *Status) SetRelayMgr(manager *relayClient.Manager) {
|
||||
d.profile.inc("SetRelayMgr")
|
||||
d.muxRelays.Lock()
|
||||
defer d.muxRelays.Unlock()
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.relayMgr = manager
|
||||
}
|
||||
|
||||
func (d *Status) SetIngressGwMgr(ingressGwMgr *ingressgw.Manager) {
|
||||
d.profile.inc("SetIngressGwMgr")
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.ingressGwMgr = ingressGwMgr
|
||||
@@ -268,7 +257,6 @@ func (d *Status) SetIngressGwMgr(ingressGwMgr *ingressgw.Manager) {
|
||||
|
||||
// ReplaceOfflinePeers replaces
|
||||
func (d *Status) ReplaceOfflinePeers(replacement []State) {
|
||||
d.profile.inc("ReplaceOfflinePeers")
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.offlinePeers = make([]State, len(replacement))
|
||||
@@ -280,7 +268,6 @@ func (d *Status) ReplaceOfflinePeers(replacement []State) {
|
||||
|
||||
// AddPeer adds peer to Daemon status map
|
||||
func (d *Status) AddPeer(peerPubKey string, fqdn string, ip string, ipv6 string) error {
|
||||
d.profile.inc("AddPeer")
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
@@ -308,7 +295,6 @@ func (d *Status) AddPeer(peerPubKey string, fqdn string, ip string, ipv6 string)
|
||||
|
||||
// GetPeer adds peer to Daemon status map
|
||||
func (d *Status) GetPeer(peerPubKey string) (State, error) {
|
||||
d.profile.inc("GetPeer")
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
|
||||
@@ -320,7 +306,6 @@ func (d *Status) GetPeer(peerPubKey string) (State, error) {
|
||||
}
|
||||
|
||||
func (d *Status) PeerByIP(ip string) (string, bool) {
|
||||
d.profile.inc("PeerByIP")
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
|
||||
@@ -338,7 +323,6 @@ func (d *Status) PeerByIP(ip string) (string, bool) {
|
||||
// active peers are matched; peers moved into the offline slice by
|
||||
// ReplaceOfflinePeers are intentionally treated as unknown.
|
||||
func (d *Status) PeerStateByIP(ip string) (State, bool) {
|
||||
d.profile.inc("PeerStateByIP")
|
||||
if ip == "" {
|
||||
return State{}, false
|
||||
}
|
||||
@@ -357,7 +341,6 @@ func (d *Status) PeerStateByIP(ip string) (State, bool) {
|
||||
|
||||
// RemovePeer removes peer from Daemon status map
|
||||
func (d *Status) RemovePeer(peerPubKey string) error {
|
||||
d.profile.inc("RemovePeer")
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
@@ -379,7 +362,6 @@ func (d *Status) RemovePeer(peerPubKey string) error {
|
||||
|
||||
// UpdatePeerState updates peer status
|
||||
func (d *Status) UpdatePeerState(receivedState State) error {
|
||||
d.profile.inc("UpdatePeerState")
|
||||
d.mux.Lock()
|
||||
|
||||
peerState, ok := d.peers[receivedState.PubKey]
|
||||
@@ -422,7 +404,6 @@ func (d *Status) UpdatePeerState(receivedState State) error {
|
||||
}
|
||||
|
||||
func (d *Status) AddPeerStateRoute(peer string, route string, resourceId route.ResID) error {
|
||||
d.profile.inc("AddPeerStateRoute")
|
||||
d.mux.Lock()
|
||||
|
||||
peerState, ok := d.peers[peer]
|
||||
@@ -448,7 +429,6 @@ func (d *Status) AddPeerStateRoute(peer string, route string, resourceId route.R
|
||||
}
|
||||
|
||||
func (d *Status) RemovePeerStateRoute(peer string, route string) error {
|
||||
d.profile.inc("RemovePeerStateRoute")
|
||||
d.mux.Lock()
|
||||
|
||||
peerState, ok := d.peers[peer]
|
||||
@@ -479,13 +459,11 @@ func (d *Status) CheckRoutes(ip netip.Addr) ([]byte, bool) {
|
||||
if d == nil {
|
||||
return nil, false
|
||||
}
|
||||
d.profile.inc("CheckRoutes")
|
||||
resId, isExitNode := d.routeIDLookup.Lookup(ip)
|
||||
return []byte(resId), isExitNode
|
||||
}
|
||||
|
||||
func (d *Status) UpdatePeerICEState(receivedState State) error {
|
||||
d.profile.inc("UpdatePeerICEState")
|
||||
d.mux.Lock()
|
||||
|
||||
peerState, ok := d.peers[receivedState.PubKey]
|
||||
@@ -525,7 +503,6 @@ func (d *Status) UpdatePeerICEState(receivedState State) error {
|
||||
}
|
||||
|
||||
func (d *Status) UpdatePeerRelayedState(receivedState State) error {
|
||||
d.profile.inc("UpdatePeerRelayedState")
|
||||
d.mux.Lock()
|
||||
|
||||
peerState, ok := d.peers[receivedState.PubKey]
|
||||
@@ -562,7 +539,6 @@ func (d *Status) UpdatePeerRelayedState(receivedState State) error {
|
||||
}
|
||||
|
||||
func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error {
|
||||
d.profile.inc("UpdatePeerRelayedStateToDisconnected")
|
||||
d.mux.Lock()
|
||||
|
||||
peerState, ok := d.peers[receivedState.PubKey]
|
||||
@@ -598,7 +574,6 @@ func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error
|
||||
}
|
||||
|
||||
func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
|
||||
d.profile.inc("UpdatePeerICEStateToDisconnected")
|
||||
d.mux.Lock()
|
||||
|
||||
peerState, ok := d.peers[receivedState.PubKey]
|
||||
@@ -638,7 +613,6 @@ func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
|
||||
|
||||
// UpdateWireGuardPeerState updates the WireGuard bits of the peer state
|
||||
func (d *Status) UpdateWireGuardPeerState(pubKey string, wgStats configurer.WGStats) error {
|
||||
d.profile.inc("UpdateWireGuardPeerState")
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
@@ -666,7 +640,6 @@ func hasConnStatusChanged(oldStatus, newStatus ConnStatus) bool {
|
||||
|
||||
// UpdatePeerFQDN update peer's state fqdn only
|
||||
func (d *Status) UpdatePeerFQDN(peerPubKey, fqdn string) error {
|
||||
d.profile.inc("UpdatePeerFQDN")
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
@@ -683,7 +656,6 @@ func (d *Status) UpdatePeerFQDN(peerPubKey, fqdn string) error {
|
||||
|
||||
// UpdatePeerSSHHostKey updates peer's SSH host key
|
||||
func (d *Status) UpdatePeerSSHHostKey(peerPubKey string, sshHostKey []byte) error {
|
||||
d.profile.inc("UpdatePeerSSHHostKey")
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
@@ -700,7 +672,6 @@ func (d *Status) UpdatePeerSSHHostKey(peerPubKey string, sshHostKey []byte) erro
|
||||
|
||||
// FinishPeerListModifications this event invoke the notification
|
||||
func (d *Status) FinishPeerListModifications() {
|
||||
d.profile.inc("FinishPeerListModifications")
|
||||
d.mux.Lock()
|
||||
|
||||
if !d.peerListChangedForNotification {
|
||||
@@ -733,7 +704,6 @@ func (d *Status) FinishPeerListModifications() {
|
||||
}
|
||||
|
||||
func (d *Status) SubscribeToPeerStateChanges(ctx context.Context, peerID string) *StatusChangeSubscription {
|
||||
d.profile.inc("SubscribeToPeerStateChanges")
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
@@ -747,7 +717,6 @@ func (d *Status) SubscribeToPeerStateChanges(ctx context.Context, peerID string)
|
||||
}
|
||||
|
||||
func (d *Status) UnsubscribePeerStateChanges(subscription *StatusChangeSubscription) {
|
||||
d.profile.inc("UnsubscribePeerStateChanges")
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
@@ -773,7 +742,6 @@ func (d *Status) UnsubscribePeerStateChanges(subscription *StatusChangeSubscript
|
||||
|
||||
// GetLocalPeerState returns the local peer state
|
||||
func (d *Status) GetLocalPeerState() LocalPeerState {
|
||||
d.profile.inc("GetLocalPeerState")
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
return d.localPeer.Clone()
|
||||
@@ -781,7 +749,6 @@ func (d *Status) GetLocalPeerState() LocalPeerState {
|
||||
|
||||
// UpdateLocalPeerState updates local peer status
|
||||
func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) {
|
||||
d.profile.inc("UpdateLocalPeerState")
|
||||
d.mux.Lock()
|
||||
d.localPeer = localPeerState
|
||||
fqdn := d.localPeer.FQDN
|
||||
@@ -796,7 +763,6 @@ func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) {
|
||||
|
||||
// AddLocalPeerStateRoute adds a route to the local peer state
|
||||
func (d *Status) AddLocalPeerStateRoute(route string, resourceId route.ResID) {
|
||||
d.profile.inc("AddLocalPeerStateRoute")
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
@@ -814,7 +780,6 @@ func (d *Status) AddLocalPeerStateRoute(route string, resourceId route.ResID) {
|
||||
|
||||
// RemoveLocalPeerStateRoute removes a route from the local peer state
|
||||
func (d *Status) RemoveLocalPeerStateRoute(route string) {
|
||||
d.profile.inc("RemoveLocalPeerStateRoute")
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
@@ -828,7 +793,6 @@ func (d *Status) RemoveLocalPeerStateRoute(route string) {
|
||||
|
||||
// AddResolvedIPLookupEntry adds a resolved IP lookup entry
|
||||
func (d *Status) AddResolvedIPLookupEntry(prefix netip.Prefix, resourceId route.ResID) {
|
||||
d.profile.inc("AddResolvedIPLookupEntry")
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
@@ -837,7 +801,6 @@ func (d *Status) AddResolvedIPLookupEntry(prefix netip.Prefix, resourceId route.
|
||||
|
||||
// RemoveResolvedIPLookupEntry removes a resolved IP lookup entry
|
||||
func (d *Status) RemoveResolvedIPLookupEntry(route string) {
|
||||
d.profile.inc("RemoveResolvedIPLookupEntry")
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
@@ -849,7 +812,6 @@ func (d *Status) RemoveResolvedIPLookupEntry(route string) {
|
||||
|
||||
// CleanLocalPeerStateRoutes cleans all routes from the local peer state
|
||||
func (d *Status) CleanLocalPeerStateRoutes() {
|
||||
d.profile.inc("CleanLocalPeerStateRoutes")
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
@@ -858,7 +820,6 @@ func (d *Status) CleanLocalPeerStateRoutes() {
|
||||
|
||||
// CleanLocalPeerState cleans local peer status
|
||||
func (d *Status) CleanLocalPeerState() {
|
||||
d.profile.inc("CleanLocalPeerState")
|
||||
d.mux.Lock()
|
||||
d.localPeer = LocalPeerState{}
|
||||
fqdn := d.localPeer.FQDN
|
||||
@@ -870,7 +831,6 @@ func (d *Status) CleanLocalPeerState() {
|
||||
|
||||
// MarkManagementDisconnected sets ManagementState to disconnected
|
||||
func (d *Status) MarkManagementDisconnected(err error) {
|
||||
d.profile.inc("MarkManagementDisconnected")
|
||||
d.mux.Lock()
|
||||
d.managementState = false
|
||||
d.managementError = err
|
||||
@@ -883,7 +843,6 @@ func (d *Status) MarkManagementDisconnected(err error) {
|
||||
|
||||
// MarkManagementConnected sets ManagementState to connected
|
||||
func (d *Status) MarkManagementConnected() {
|
||||
d.profile.inc("MarkManagementConnected")
|
||||
d.mux.Lock()
|
||||
d.managementState = true
|
||||
d.managementError = nil
|
||||
@@ -896,7 +855,6 @@ func (d *Status) MarkManagementConnected() {
|
||||
|
||||
// UpdateSignalAddress update the address of the signal server
|
||||
func (d *Status) UpdateSignalAddress(signalURL string) {
|
||||
d.profile.inc("UpdateSignalAddress")
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.signalAddress = signalURL
|
||||
@@ -904,7 +862,6 @@ func (d *Status) UpdateSignalAddress(signalURL string) {
|
||||
|
||||
// UpdateManagementAddress update the address of the management server
|
||||
func (d *Status) UpdateManagementAddress(mgmAddress string) {
|
||||
d.profile.inc("UpdateManagementAddress")
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.mgmAddress = mgmAddress
|
||||
@@ -912,7 +869,6 @@ func (d *Status) UpdateManagementAddress(mgmAddress string) {
|
||||
|
||||
// UpdateRosenpass update the Rosenpass configuration
|
||||
func (d *Status) UpdateRosenpass(rosenpassEnabled, rosenpassPermissive bool) {
|
||||
d.profile.inc("UpdateRosenpass")
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.rosenpassPermissive = rosenpassPermissive
|
||||
@@ -920,7 +876,6 @@ func (d *Status) UpdateRosenpass(rosenpassEnabled, rosenpassPermissive bool) {
|
||||
}
|
||||
|
||||
func (d *Status) UpdateLazyConnection(enabled bool) {
|
||||
d.profile.inc("UpdateLazyConnection")
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.lazyConnectionEnabled = enabled
|
||||
@@ -928,7 +883,6 @@ func (d *Status) UpdateLazyConnection(enabled bool) {
|
||||
|
||||
// MarkSignalDisconnected sets SignalState to disconnected
|
||||
func (d *Status) MarkSignalDisconnected(err error) {
|
||||
d.profile.inc("MarkSignalDisconnected")
|
||||
d.mux.Lock()
|
||||
d.signalState = false
|
||||
d.signalError = err
|
||||
@@ -941,7 +895,6 @@ func (d *Status) MarkSignalDisconnected(err error) {
|
||||
|
||||
// MarkSignalConnected sets SignalState to connected
|
||||
func (d *Status) MarkSignalConnected() {
|
||||
d.profile.inc("MarkSignalConnected")
|
||||
d.mux.Lock()
|
||||
d.signalState = true
|
||||
d.signalError = nil
|
||||
@@ -953,21 +906,18 @@ func (d *Status) MarkSignalConnected() {
|
||||
}
|
||||
|
||||
func (d *Status) UpdateRelayStates(relayResults []relay.ProbeResult) {
|
||||
d.profile.inc("UpdateRelayStates")
|
||||
d.muxRelays.Lock()
|
||||
defer d.muxRelays.Unlock()
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.relayStates = relayResults
|
||||
}
|
||||
|
||||
func (d *Status) UpdateDNSStates(dnsStates []NSGroupState) {
|
||||
d.profile.inc("UpdateDNSStates")
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.nsGroupStates = dnsStates
|
||||
}
|
||||
|
||||
func (d *Status) UpdateResolvedDomainsStates(originalDomain domain.Domain, resolvedDomain domain.Domain, prefixes []netip.Prefix, resourceId route.ResID) {
|
||||
d.profile.inc("UpdateResolvedDomainsStates")
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
@@ -983,7 +933,6 @@ func (d *Status) UpdateResolvedDomainsStates(originalDomain domain.Domain, resol
|
||||
}
|
||||
|
||||
func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) {
|
||||
d.profile.inc("DeleteResolvedDomainsStates")
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
@@ -1000,7 +949,6 @@ func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) {
|
||||
}
|
||||
|
||||
func (d *Status) GetRosenpassState() RosenpassState {
|
||||
d.profile.inc("GetRosenpassState")
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
return RosenpassState{
|
||||
@@ -1010,14 +958,12 @@ func (d *Status) GetRosenpassState() RosenpassState {
|
||||
}
|
||||
|
||||
func (d *Status) GetLazyConnection() bool {
|
||||
d.profile.inc("GetLazyConnection")
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
return d.lazyConnectionEnabled
|
||||
}
|
||||
|
||||
func (d *Status) GetManagementState() ManagementState {
|
||||
d.profile.inc("GetManagementState")
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
return ManagementState{
|
||||
@@ -1028,7 +974,6 @@ func (d *Status) GetManagementState() ManagementState {
|
||||
}
|
||||
|
||||
func (d *Status) UpdateLatency(pubKey string, latency time.Duration) error {
|
||||
d.profile.inc("UpdateLatency")
|
||||
if latency <= 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -1046,7 +991,6 @@ func (d *Status) UpdateLatency(pubKey string, latency time.Duration) error {
|
||||
|
||||
// IsLoginRequired determines if a peer's login has expired.
|
||||
func (d *Status) IsLoginRequired() bool {
|
||||
d.profile.inc("IsLoginRequired")
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
|
||||
@@ -1063,7 +1007,6 @@ func (d *Status) IsLoginRequired() bool {
|
||||
}
|
||||
|
||||
func (d *Status) GetSignalState() SignalState {
|
||||
d.profile.inc("GetSignalState")
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
return SignalState{
|
||||
@@ -1075,34 +1018,24 @@ func (d *Status) GetSignalState() SignalState {
|
||||
|
||||
// GetRelayStates returns the stun/turn/permanent relay states
|
||||
func (d *Status) GetRelayStates() []relay.ProbeResult {
|
||||
d.profile.inc("GetRelayStates")
|
||||
d.muxRelays.RLock()
|
||||
|
||||
// debug lines
|
||||
started := time.Now()
|
||||
defer func() {
|
||||
debugElapsed("GetRelayStates", started)
|
||||
}()
|
||||
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
if d.relayMgr == nil {
|
||||
defer d.muxRelays.RUnlock()
|
||||
return d.relayStates
|
||||
}
|
||||
|
||||
relayMgr := d.relayMgr
|
||||
// extend the list of stun, turn servers with the relay server connections
|
||||
relayStates := slices.Clone(d.relayStates)
|
||||
d.muxRelays.RUnlock()
|
||||
|
||||
states := relayMgr.RelayStates()
|
||||
states := d.relayMgr.RelayStates()
|
||||
if len(states) == 0 {
|
||||
// no relay connection tracked yet; surface configured servers as
|
||||
// unavailable with the real reconnect error when known
|
||||
err := relayClient.ErrRelayClientNotConnected
|
||||
if connErr := relayMgr.RelayConnectError(); connErr != nil {
|
||||
if connErr := d.relayMgr.RelayConnectError(); connErr != nil {
|
||||
err = connErr
|
||||
}
|
||||
for _, r := range relayMgr.ServerURLs() {
|
||||
for _, r := range d.relayMgr.ServerURLs() {
|
||||
relayStates = append(relayStates, relay.ProbeResult{
|
||||
URI: r,
|
||||
Err: err,
|
||||
@@ -1122,15 +1055,7 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
|
||||
}
|
||||
|
||||
func (d *Status) ForwardingRules() []firewall.ForwardRule {
|
||||
d.profile.inc("ForwardingRules")
|
||||
d.mux.RLock()
|
||||
|
||||
// debug lines
|
||||
started := time.Now()
|
||||
defer func() {
|
||||
debugElapsed("ForwardingRules", started)
|
||||
}()
|
||||
|
||||
defer d.mux.RUnlock()
|
||||
if d.ingressGwMgr == nil {
|
||||
return nil
|
||||
@@ -1140,7 +1065,6 @@ func (d *Status) ForwardingRules() []firewall.ForwardRule {
|
||||
}
|
||||
|
||||
func (d *Status) GetDNSStates() []NSGroupState {
|
||||
d.profile.inc("GetDNSStates")
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
|
||||
@@ -1149,7 +1073,6 @@ func (d *Status) GetDNSStates() []NSGroupState {
|
||||
}
|
||||
|
||||
func (d *Status) GetResolvedDomainsStates() map[domain.Domain]ResolvedDomainInfo {
|
||||
d.profile.inc("GetResolvedDomainsStates")
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
return maps.Clone(d.resolvedDomainsStates)
|
||||
@@ -1157,7 +1080,6 @@ func (d *Status) GetResolvedDomainsStates() map[domain.Domain]ResolvedDomainInfo
|
||||
|
||||
// GetFullStatus gets full status
|
||||
func (d *Status) GetFullStatus() FullStatus {
|
||||
d.profile.inc("GetFullStatus")
|
||||
fullStatus := FullStatus{
|
||||
ManagementState: d.GetManagementState(),
|
||||
SignalState: d.GetSignalState(),
|
||||
@@ -1184,31 +1106,26 @@ func (d *Status) GetFullStatus() FullStatus {
|
||||
|
||||
// ClientStart will notify all listeners about the new service state
|
||||
func (d *Status) ClientStart() {
|
||||
d.profile.inc("ClientStart")
|
||||
d.notifier.clientStart()
|
||||
}
|
||||
|
||||
// ClientStop will notify all listeners about the new service state
|
||||
func (d *Status) ClientStop() {
|
||||
d.profile.inc("ClientStop")
|
||||
d.notifier.clientStop()
|
||||
}
|
||||
|
||||
// ClientTeardown will notify all listeners about the service is under teardown
|
||||
func (d *Status) ClientTeardown() {
|
||||
d.profile.inc("ClientTeardown")
|
||||
d.notifier.clientTearDown()
|
||||
}
|
||||
|
||||
// SetConnectionListener set a listener to the notifier
|
||||
func (d *Status) SetConnectionListener(listener Listener) {
|
||||
d.profile.inc("SetConnectionListener")
|
||||
d.notifier.setListener(listener)
|
||||
}
|
||||
|
||||
// RemoveConnectionListener remove the listener from the notifier
|
||||
func (d *Status) RemoveConnectionListener() {
|
||||
d.profile.inc("RemoveConnectionListener")
|
||||
d.notifier.removeListener()
|
||||
}
|
||||
|
||||
@@ -1280,7 +1197,6 @@ func (d *Status) PublishEvent(
|
||||
userMsg string,
|
||||
metadata map[string]string,
|
||||
) {
|
||||
d.profile.inc("PublishEvent")
|
||||
event := &proto.SystemEvent{
|
||||
Id: uuid.New().String(),
|
||||
Severity: severity,
|
||||
@@ -1292,13 +1208,6 @@ func (d *Status) PublishEvent(
|
||||
}
|
||||
|
||||
d.eventMux.Lock()
|
||||
|
||||
// debug lines
|
||||
started := time.Now()
|
||||
defer func() {
|
||||
debugElapsed("PublishEvent", started)
|
||||
}()
|
||||
|
||||
defer d.eventMux.Unlock()
|
||||
|
||||
d.eventQueue.Add(event)
|
||||
@@ -1316,7 +1225,6 @@ func (d *Status) PublishEvent(
|
||||
|
||||
// SubscribeToEvents returns a new event subscription
|
||||
func (d *Status) SubscribeToEvents() *EventSubscription {
|
||||
d.profile.inc("SubscribeToEvents")
|
||||
d.eventMux.Lock()
|
||||
defer d.eventMux.Unlock()
|
||||
|
||||
@@ -1332,7 +1240,6 @@ func (d *Status) SubscribeToEvents() *EventSubscription {
|
||||
|
||||
// UnsubscribeFromEvents removes an event subscription
|
||||
func (d *Status) UnsubscribeFromEvents(sub *EventSubscription) {
|
||||
d.profile.inc("UnsubscribeFromEvents")
|
||||
if sub == nil {
|
||||
return
|
||||
}
|
||||
@@ -1348,12 +1255,10 @@ func (d *Status) UnsubscribeFromEvents(sub *EventSubscription) {
|
||||
|
||||
// GetEventHistory returns all events in the queue
|
||||
func (d *Status) GetEventHistory() []*proto.SystemEvent {
|
||||
d.profile.inc("GetEventHistory")
|
||||
return d.eventQueue.GetAll()
|
||||
}
|
||||
|
||||
func (d *Status) SetWgIface(wgInterface WGIfaceStatus) {
|
||||
d.profile.inc("SetWgIface")
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
@@ -1361,15 +1266,7 @@ func (d *Status) SetWgIface(wgInterface WGIfaceStatus) {
|
||||
}
|
||||
|
||||
func (d *Status) PeersStatus() (*configurer.Stats, error) {
|
||||
d.profile.inc("PeersStatus")
|
||||
d.mux.RLock()
|
||||
|
||||
// debug lines
|
||||
started := time.Now()
|
||||
defer func() {
|
||||
debugElapsed("PeersStatus", started)
|
||||
}()
|
||||
|
||||
defer d.mux.RUnlock()
|
||||
if d.wgIface == nil {
|
||||
return nil, fmt.Errorf("wgInterface is nil, cannot retrieve peers status")
|
||||
@@ -1382,15 +1279,7 @@ func (d *Status) PeersStatus() (*configurer.Stats, error) {
|
||||
// and updates the cached peer states. This ensures accurate handshake times and
|
||||
// transfer statistics in status reports without running full health probes.
|
||||
func (d *Status) RefreshWireGuardStats() error {
|
||||
d.profile.inc("RefreshWireGuardStats")
|
||||
d.mux.Lock()
|
||||
|
||||
// debug lines
|
||||
started := time.Now()
|
||||
defer func() {
|
||||
debugElapsed("RefreshWireGuardStats", started)
|
||||
}()
|
||||
|
||||
defer d.mux.Unlock()
|
||||
|
||||
if d.wgIface == nil {
|
||||
@@ -1555,10 +1444,3 @@ func (fs FullStatus) ToProto() *proto.FullStatus {
|
||||
|
||||
return &pbFullStatus
|
||||
}
|
||||
|
||||
func debugElapsed(msg string, startTime time.Time) {
|
||||
if elapsed := time.Since(startTime); elapsed > 1*time.Second {
|
||||
log.Infof("run %s took %s", msg, elapsed)
|
||||
debug.PrintStack()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,97 +0,0 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type StatusProfile struct {
|
||||
counts sync.Map
|
||||
}
|
||||
|
||||
func NewStatusProfile(ctx context.Context) *StatusProfile {
|
||||
s := &StatusProfile{}
|
||||
go s.Start(ctx)
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *StatusProfile) Start(ctx context.Context) {
|
||||
ticker := time.NewTicker(10 * time.Second)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.logCounts()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StatusProfile) inc(method string) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
if v, ok := s.counts.Load(method); ok {
|
||||
v.(*atomic.Int64).Add(1)
|
||||
return
|
||||
}
|
||||
cnt := &atomic.Int64{}
|
||||
actual, _ := s.counts.LoadOrStore(method, cnt)
|
||||
actual.(*atomic.Int64).Add(1)
|
||||
}
|
||||
|
||||
func (s *StatusProfile) snapshot() map[string]int64 {
|
||||
out := make(map[string]int64)
|
||||
s.counts.Range(func(k, v any) bool {
|
||||
out[k.(string)] = v.(*atomic.Int64).Load()
|
||||
return true
|
||||
})
|
||||
return out
|
||||
}
|
||||
|
||||
func (s *StatusProfile) logCounts() {
|
||||
counts := s.snapshot()
|
||||
if len(counts) == 0 {
|
||||
log.Infof("status profile: no Status method calls so far")
|
||||
return
|
||||
}
|
||||
|
||||
type kv struct {
|
||||
method string
|
||||
count int64
|
||||
}
|
||||
sorted := make([]kv, 0, len(counts))
|
||||
var total int64
|
||||
for m, c := range counts {
|
||||
if c == 0 {
|
||||
continue
|
||||
}
|
||||
sorted = append(sorted, kv{m, c})
|
||||
total += c
|
||||
}
|
||||
sort.Slice(sorted, func(i, j int) bool {
|
||||
if sorted[i].count != sorted[j].count {
|
||||
return sorted[i].count > sorted[j].count
|
||||
}
|
||||
return sorted[i].method < sorted[j].method
|
||||
})
|
||||
|
||||
var b strings.Builder
|
||||
for i, e := range sorted {
|
||||
if i > 0 {
|
||||
b.WriteString(", ")
|
||||
}
|
||||
b.WriteString(e.method)
|
||||
b.WriteByte('=')
|
||||
b.WriteString(strconv.FormatInt(e.count, 10))
|
||||
}
|
||||
log.Infof("status profile (cumulative total=%d): %s", total, b.String())
|
||||
}
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -165,10 +164,6 @@ func (w *WorkerICE) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HA
|
||||
return
|
||||
}
|
||||
|
||||
if candidateViaRoutes(candidate, haRoutes) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := w.agent.AddRemoteCandidate(candidate); err != nil {
|
||||
w.log.Errorf("error while handling remote candidate")
|
||||
return
|
||||
@@ -466,7 +461,7 @@ func (w *WorkerICE) createForwardedCandidate(srflxCandidate ice.Candidate, mappi
|
||||
}
|
||||
|
||||
func (w *WorkerICE) onICESelectedCandidatePair(agent *icemaker.ThreadSafeAgent, c1, c2 ice.Candidate) {
|
||||
w.log.Infof("selected candidate pair [local <-> remote] -> [%s <-> %s], peer %s", c1.String(), c2.String(),
|
||||
w.log.Debugf("selected candidate pair [local <-> remote] -> [%s <-> %s], peer %s", c1.String(), c2.String(),
|
||||
w.config.Key)
|
||||
|
||||
pairStat, ok := agent.GetSelectedCandidatePairStats()
|
||||
@@ -589,34 +584,6 @@ func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive
|
||||
return ec, nil
|
||||
}
|
||||
|
||||
func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool {
|
||||
addr, err := netip.ParseAddr(candidate.Address())
|
||||
if err != nil {
|
||||
log.Errorf("Failed to parse IP address %s: %v", candidate.Address(), err)
|
||||
return false
|
||||
}
|
||||
|
||||
var routePrefixes []netip.Prefix
|
||||
for _, routes := range clientRoutes {
|
||||
if len(routes) > 0 && routes[0] != nil {
|
||||
routePrefixes = append(routePrefixes, routes[0].Network)
|
||||
}
|
||||
}
|
||||
|
||||
for _, prefix := range routePrefixes {
|
||||
// default route is handled by route exclusion / ip rules
|
||||
if prefix.Bits() == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if prefix.Contains(addr) {
|
||||
log.Debugf("Ignoring candidate [%s], its address is part of routed network %s", candidate.String(), prefix)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isRelayCandidate(candidate ice.Candidate) bool {
|
||||
return candidate.Type() == ice.CandidateTypeRelay
|
||||
}
|
||||
|
||||
@@ -121,9 +121,12 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf wgIface, init
|
||||
return Nexthop{}, vars.ErrRouteNotAllowed
|
||||
}
|
||||
|
||||
// Check if the prefix is part of any local subnets
|
||||
if isLocal, subnet := r.isPrefixInLocalSubnets(prefix); isLocal {
|
||||
return Nexthop{}, fmt.Errorf("prefix %s is part of local subnet %s: %w", prefix, subnet, vars.ErrRouteNotAllowed)
|
||||
// BSDs blackhole a /32 added inside a directly-connected subnet; Linux/Windows need it to beat the wt0 route.
|
||||
switch runtime.GOOS {
|
||||
case "darwin", "freebsd", "netbsd", "openbsd", "dragonfly":
|
||||
if isLocal, subnet := r.isPrefixInLocalSubnets(prefix); isLocal {
|
||||
return Nexthop{}, fmt.Errorf("prefix %s is part of local subnet %s: %w", prefix, subnet, vars.ErrRouteNotAllowed)
|
||||
}
|
||||
}
|
||||
|
||||
// Determine the exit interface and next hop for the prefix, so we can add a specific route
|
||||
|
||||
@@ -136,7 +136,6 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable
|
||||
networksDisabled: networksDisabled,
|
||||
jwtCache: newJWTCache(),
|
||||
}
|
||||
go s.statusRecorder.StartProfile(ctx)
|
||||
agent := &serverAgent{s}
|
||||
s.sleepHandler = sleephandler.New(agent)
|
||||
s.startSleepDetector()
|
||||
|
||||
@@ -1205,7 +1205,7 @@ func (s *Server) SyncMeta(ctx context.Context, req *proto.EncryptedMessage) (*pr
|
||||
return nil, msg
|
||||
}
|
||||
|
||||
err = s.accountManager.SyncPeerMeta(ctx, peerKey.String(), extractPeerMeta(ctx, syncMetaReq.GetMeta()), realIP)
|
||||
err = s.accountManager.SyncPeerMeta(ctx, peerKey.String(), extractPeerMeta(ctx, syncMetaReq.GetMeta()))
|
||||
if err != nil {
|
||||
return nil, mapError(ctx, err)
|
||||
}
|
||||
@@ -1254,10 +1254,7 @@ func (s *Server) Logout(ctx context.Context, req *proto.EncryptedMessage) (*prot
|
||||
func toProtocolChecks(ctx context.Context, postureChecks []*posture.Checks) []*proto.Checks {
|
||||
protoChecks := make([]*proto.Checks, 0, len(postureChecks))
|
||||
for _, postureCheck := range postureChecks {
|
||||
check := toProtocolCheck(postureCheck)
|
||||
if check != nil {
|
||||
protoChecks = append(protoChecks, check)
|
||||
}
|
||||
protoChecks = append(protoChecks, toProtocolCheck(postureCheck))
|
||||
}
|
||||
|
||||
return protoChecks
|
||||
@@ -1281,9 +1278,5 @@ func toProtocolCheck(postureCheck *posture.Checks) *proto.Checks {
|
||||
}
|
||||
}
|
||||
|
||||
if len(protoCheck.Files) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return protoCheck
|
||||
}
|
||||
|
||||
@@ -1889,12 +1889,12 @@ func domainIsUpToDate(domain string, domainCategory string, userAuth auth.UserAu
|
||||
// concurrent stream that started earlier loses the optimistic-lock race
|
||||
// in MarkPeerConnected and bails without writing.
|
||||
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||
peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, RealIP: realIP}, accountID)
|
||||
peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err)
|
||||
}
|
||||
|
||||
if err := am.MarkPeerConnected(ctx, peerPubKey, accountID, syncTime.UnixNano(), netMap); err != nil {
|
||||
if err := am.MarkPeerConnected(ctx, peerPubKey, realIP, accountID, syncTime.UnixNano(), netMap); err != nil {
|
||||
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
|
||||
}
|
||||
|
||||
@@ -1914,13 +1914,13 @@ func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, account
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) error {
|
||||
func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error {
|
||||
accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, peerPubKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, _, _, _, err = am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, RealIP: realIP, UpdateAccountPeers: true}, accountID)
|
||||
_, _, _, _, err = am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -62,7 +62,7 @@ type Manager interface {
|
||||
GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
|
||||
ListUsers(ctx context.Context, accountID string) ([]*types.User, error)
|
||||
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error
|
||||
MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error
|
||||
MarkPeerDisconnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error
|
||||
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
|
||||
UpdatePeer(ctx context.Context, accountID, userID string, p *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||
@@ -123,7 +123,7 @@ type Manager interface {
|
||||
GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error)
|
||||
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error
|
||||
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) error
|
||||
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
||||
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||
GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error)
|
||||
GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error)
|
||||
|
||||
@@ -1323,17 +1323,17 @@ func (mr *MockManagerMockRecorder) ExtendPeerSession(ctx, peerPubKey, userID int
|
||||
}
|
||||
|
||||
// MarkPeerConnected mocks base method.
|
||||
func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
|
||||
func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "MarkPeerConnected", ctx, peerKey, accountID, sessionStartedAt, nmap)
|
||||
ret := m.ctrl.Call(m, "MarkPeerConnected", ctx, peerKey, realIP, accountID, sessionStartedAt, nmap)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// MarkPeerConnected indicates an expected call of MarkPeerConnected.
|
||||
func (mr *MockManagerMockRecorder) MarkPeerConnected(ctx, peerKey, accountID, sessionStartedAt, nmap interface{}) *gomock.Call {
|
||||
func (mr *MockManagerMockRecorder) MarkPeerConnected(ctx, peerKey, realIP, accountID, sessionStartedAt, nmap interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnected", reflect.TypeOf((*MockManager)(nil).MarkPeerConnected), ctx, peerKey, accountID, sessionStartedAt, nmap)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnected", reflect.TypeOf((*MockManager)(nil).MarkPeerConnected), ctx, peerKey, realIP, accountID, sessionStartedAt, nmap)
|
||||
}
|
||||
|
||||
// MarkPeerDisconnected mocks base method.
|
||||
@@ -1586,17 +1586,17 @@ func (mr *MockManagerMockRecorder) SyncPeer(ctx, sync, accountID interface{}) *g
|
||||
}
|
||||
|
||||
// SyncPeerMeta mocks base method.
|
||||
func (m *MockManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta peer.PeerSystemMeta, realIP net.IP) error {
|
||||
func (m *MockManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta peer.PeerSystemMeta) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SyncPeerMeta", ctx, peerPubKey, meta, realIP)
|
||||
ret := m.ctrl.Call(m, "SyncPeerMeta", ctx, peerPubKey, meta)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SyncPeerMeta indicates an expected call of SyncPeerMeta.
|
||||
func (mr *MockManagerMockRecorder) SyncPeerMeta(ctx, peerPubKey, meta, realIP interface{}) *gomock.Call {
|
||||
func (mr *MockManagerMockRecorder) SyncPeerMeta(ctx, peerPubKey, meta interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SyncPeerMeta", reflect.TypeOf((*MockManager)(nil).SyncPeerMeta), ctx, peerPubKey, meta, realIP)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SyncPeerMeta", reflect.TypeOf((*MockManager)(nil).SyncPeerMeta), ctx, peerPubKey, meta)
|
||||
}
|
||||
|
||||
// SyncUserJWTGroups mocks base method.
|
||||
|
||||
@@ -1836,7 +1836,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), accountID, time.Now().UTC().UnixNano(), nil)
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano(), nil)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
||||
@@ -1907,7 +1907,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
// when we mark peer as connected, the peer login expiration routine should trigger
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), accountID, time.Now().UTC().UnixNano(), nil)
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano(), nil)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
failed := waitTimeout(wg, time.Second)
|
||||
@@ -1916,117 +1916,6 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_MarkPeerDisconnected_SchedulesInactivityExpiration(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
key, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err, "unable to generate WireGuard key")
|
||||
peerPubKey := key.PublicKey().String()
|
||||
|
||||
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: peerPubKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
|
||||
InactivityExpirationEnabled: true,
|
||||
}, false)
|
||||
require.NoError(t, err, "unable to add peer")
|
||||
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: true,
|
||||
PeerInactivityExpiration: time.Hour,
|
||||
PeerInactivityExpirationEnabled: true,
|
||||
Extra: &types.ExtraSettings{},
|
||||
})
|
||||
require.NoError(t, err, "expecting to update account settings successfully but got error")
|
||||
|
||||
// Establish a session so the matching-token disconnect is actually applied.
|
||||
streamStartTime := time.Now().UTC()
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano(), nil)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
// Install the mock only now, so the assertion observes the disconnect, not
|
||||
// the earlier connect.
|
||||
scheduled := make(chan struct{}, 1)
|
||||
manager.peerInactivityExpiry = &MockScheduler{
|
||||
CancelFunc: func(ctx context.Context, IDs []string) {},
|
||||
ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
|
||||
select {
|
||||
case scheduled <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
err = manager.MarkPeerDisconnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano())
|
||||
require.NoError(t, err, "unable to mark peer disconnected")
|
||||
|
||||
select {
|
||||
case <-scheduled:
|
||||
// expected: disconnect re-armed the inactivity expiry timer
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected inactivity expiration to be rescheduled when an eligible peer disconnects")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_MarkPeerDisconnected_SkipsInactivityExpirationWhenDisabled(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
key, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err, "unable to generate WireGuard key")
|
||||
peerPubKey := key.PublicKey().String()
|
||||
|
||||
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: peerPubKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
|
||||
InactivityExpirationEnabled: true,
|
||||
}, false)
|
||||
require.NoError(t, err, "unable to add peer")
|
||||
|
||||
// Peer is eligible (SSO + inactivity enabled) but the account-level setting
|
||||
// stays disabled, so disconnect must not schedule anything.
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: true,
|
||||
PeerInactivityExpiration: time.Hour,
|
||||
PeerInactivityExpirationEnabled: false,
|
||||
Extra: &types.ExtraSettings{},
|
||||
})
|
||||
require.NoError(t, err, "expecting to update account settings successfully but got error")
|
||||
|
||||
streamStartTime := time.Now().UTC()
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano(), nil)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
scheduled := make(chan struct{}, 1)
|
||||
manager.peerInactivityExpiry = &MockScheduler{
|
||||
CancelFunc: func(ctx context.Context, IDs []string) {},
|
||||
ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
|
||||
select {
|
||||
case scheduled <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
err = manager.MarkPeerDisconnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano())
|
||||
require.NoError(t, err, "unable to mark peer disconnected")
|
||||
|
||||
select {
|
||||
case <-scheduled:
|
||||
t.Fatal("inactivity expiration must not be scheduled while the account-level setting is disabled")
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
// expected: nothing scheduled
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
@@ -2046,7 +1935,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
|
||||
|
||||
t.Run("disconnect peer when session token matches", func(t *testing.T) {
|
||||
streamStartTime := time.Now().UTC()
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano(), nil)
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano(), nil)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
@@ -2067,7 +1956,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
|
||||
t.Run("skip disconnect when stored session is newer (zombie stream protection)", func(t *testing.T) {
|
||||
// Newer stream wins on connect (sets SessionStartedAt = now ns).
|
||||
streamStartTime := time.Now().UTC()
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano(), nil)
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano(), nil)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
@@ -2091,7 +1980,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
|
||||
|
||||
t.Run("skip stale connect when stored session is newer (blocked goroutine protection)", func(t *testing.T) {
|
||||
node2SyncTime := time.Now().UTC()
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, node2SyncTime.UnixNano(), nil)
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node2SyncTime.UnixNano(), nil)
|
||||
require.NoError(t, err, "node 2 should connect peer")
|
||||
|
||||
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
@@ -2101,7 +1990,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
|
||||
"SessionStartedAt should equal node2SyncTime token")
|
||||
|
||||
node1StaleSyncTime := node2SyncTime.Add(-1 * time.Minute)
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, node1StaleSyncTime.UnixNano(), nil)
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node1StaleSyncTime.UnixNano(), nil)
|
||||
require.NoError(t, err, "stale connect should not return error")
|
||||
|
||||
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
@@ -2163,7 +2052,7 @@ func TestDefaultAccountManager_MarkPeerConnected_ConcurrentRace(t *testing.T) {
|
||||
defer done.Done()
|
||||
ready.Done()
|
||||
start.Wait()
|
||||
errs <- manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, token, nil)
|
||||
errs <- manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, token, nil)
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -2204,7 +2093,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), accountID, time.Now().UTC().UnixNano(), nil)
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano(), nil)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
|
||||
@@ -39,7 +39,7 @@ type MockAccountManager struct {
|
||||
GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
|
||||
ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error)
|
||||
GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error
|
||||
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error
|
||||
MarkPeerDisconnectedFunc func(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error
|
||||
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
|
||||
@@ -114,7 +114,7 @@ type MockAccountManager struct {
|
||||
GetIdpManagerFunc func() idp.Manager
|
||||
UpdateIntegratedValidatorFunc func(ctx context.Context, accountID, userID, validator string, groups []string) error
|
||||
GroupValidationFunc func(ctx context.Context, accountId string, groups []string) (bool, error)
|
||||
SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) error
|
||||
SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
||||
FindExistingPostureCheckFunc func(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||
GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error)
|
||||
GetAccountByIDFunc func(ctx context.Context, accountID string, userID string) (*types.Account, error)
|
||||
@@ -345,9 +345,9 @@ func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userAuth
|
||||
}
|
||||
|
||||
// MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface
|
||||
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
|
||||
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
|
||||
if am.MarkPeerConnectedFunc != nil {
|
||||
return am.MarkPeerConnectedFunc(ctx, peerKey, accountID, sessionStartedAt, nmap)
|
||||
return am.MarkPeerConnectedFunc(ctx, peerKey, realIP, accountID, sessionStartedAt, nmap)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
|
||||
}
|
||||
@@ -975,9 +975,9 @@ func (am *MockAccountManager) GroupValidation(ctx context.Context, accountId str
|
||||
}
|
||||
|
||||
// SyncPeerMeta mocks SyncPeerMeta of the AccountManager interface
|
||||
func (am *MockAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) error {
|
||||
func (am *MockAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error {
|
||||
if am.SyncPeerMetaFunc != nil {
|
||||
return am.SyncPeerMetaFunc(ctx, peerPubKey, meta, realIP)
|
||||
return am.SyncPeerMetaFunc(ctx, peerPubKey, meta)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method SyncPeerMeta is not implemented")
|
||||
}
|
||||
|
||||
@@ -74,7 +74,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
|
||||
//
|
||||
// Disconnects use MarkPeerDisconnected and require the session to match
|
||||
// exactly; see PeerStatus.SessionStartedAt for the protocol.
|
||||
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
|
||||
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
am.metrics.AccountManagerMetrics().RecordPeerStatusUpdateDuration(telemetry.PeerStatusConnect, time.Since(start))
|
||||
@@ -102,6 +102,10 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
|
||||
}
|
||||
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusConnect, telemetry.PeerStatusApplied)
|
||||
|
||||
if am.geo != nil && realIP != nil {
|
||||
am.updatePeerLocationIfChanged(ctx, accountID, peer, realIP)
|
||||
}
|
||||
|
||||
if err = am.schedulePeerExpirations(ctx, accountID, peer); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -188,40 +192,27 @@ func (am *DefaultAccountManager) MarkPeerDisconnected(ctx context.Context, peerP
|
||||
}
|
||||
}
|
||||
|
||||
if peer.AddedWithSSOLogin() && peer.InactivityExpirationEnabled {
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed getting account settings to schedule inactivity expiration for peer %s: %v", peer.ID, err)
|
||||
} else if settings.PeerInactivityExpirationEnabled {
|
||||
am.checkAndSchedulePeerInactivityExpiration(ctx, accountID)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// resolvePeerLocation looks up the geo location for realIP, returning nil when
|
||||
// there is nothing to apply: geo disabled, no real IP, the IP is unchanged from
|
||||
// what the peer already has, or the lookup failed. Geo lookups are skipped on
|
||||
// same-IP reconnects since they are comparatively expensive. The returned value
|
||||
// is applied by Peer.UpdateMetaIfNew so the change is persisted by its peer save.
|
||||
func (am *DefaultAccountManager) resolvePeerLocation(ctx context.Context, peer *nbpeer.Peer, realIP net.IP) *nbpeer.Location {
|
||||
if am.geo == nil || realIP == nil {
|
||||
return nil
|
||||
}
|
||||
// updatePeerLocationIfChanged refreshes the geolocation on a separate
|
||||
// row update, only when the connection IP actually changed. Geo lookups
|
||||
// are expensive so we skip same-IP reconnects.
|
||||
func (am *DefaultAccountManager) updatePeerLocationIfChanged(ctx context.Context, accountID string, peer *nbpeer.Peer, realIP net.IP) {
|
||||
if peer.Location.ConnectionIP != nil && peer.Location.ConnectionIP.Equal(realIP) {
|
||||
return nil
|
||||
return
|
||||
}
|
||||
location, err := am.geo.Lookup(realIP)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to get location for peer %s realip: [%s]: %v", peer.ID, realIP.String(), err)
|
||||
return nil
|
||||
return
|
||||
}
|
||||
return &nbpeer.Location{
|
||||
ConnectionIP: realIP,
|
||||
CountryCode: location.Country.ISOCode,
|
||||
CityName: location.City.Names.En,
|
||||
GeoNameID: location.City.GeonameID,
|
||||
peer.Location.ConnectionIP = realIP
|
||||
peer.Location.CountryCode = location.Country.ISOCode
|
||||
peer.Location.CityName = location.City.Names.En
|
||||
peer.Location.GeoNameID = location.City.GeonameID
|
||||
if err := am.Store.SavePeerLocation(ctx, accountID, peer); err != nil {
|
||||
log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -989,8 +980,7 @@ func getPeerIPDNSLabel(ip netip.Addr, peerHostName string) (string, error) {
|
||||
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
|
||||
func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||
var peer *nbpeer.Peer
|
||||
var ipv6CapabilityChanged bool
|
||||
var metaDiff nbpeer.MetaDiff
|
||||
var updated, versionChanged, ipv6CapabilityChanged bool
|
||||
var err error
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
@@ -1020,10 +1010,9 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
|
||||
}
|
||||
|
||||
oldHasIPv6Cap := peer.HasCapability(nbpeer.PeerCapabilityIPv6Overlay)
|
||||
newLocation := am.resolvePeerLocation(ctx, peer, sync.RealIP)
|
||||
metaDiff = peer.UpdateMetaIfNew(ctx, sync.Meta, newLocation)
|
||||
updated, versionChanged = peer.UpdateMetaIfNew(ctx, sync.Meta)
|
||||
ipv6CapabilityChanged = oldHasIPv6Cap != peer.HasCapability(nbpeer.PeerCapabilityIPv6Overlay)
|
||||
if metaDiff.Updated() {
|
||||
if updated {
|
||||
am.metrics.AccountManagerMetrics().CountPeerMetUpdate()
|
||||
log.WithContext(ctx).Tracef("peer %s metadata updated", peer.ID)
|
||||
if err = transaction.SavePeer(ctx, accountID, peer); err != nil {
|
||||
@@ -1051,10 +1040,9 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
|
||||
return nil, nil, nil, 0, err
|
||||
}
|
||||
|
||||
metaDiffAffectsPosture := posture.AffectsPosture(&metaDiff, resPostureChecks)
|
||||
if isStatusChanged || sync.UpdateAccountPeers || ipv6CapabilityChanged || metaDiffAffectsPosture || metaDiff.VersionChanged || metaDiff.Hostname {
|
||||
if isStatusChanged || sync.UpdateAccountPeers || ipv6CapabilityChanged || (updated && (len(resPostureChecks) > 0 || versionChanged)) {
|
||||
changedPeerIDs := []string{peer.ID}
|
||||
affectedPeerIDs := am.syncPeerAffectedPeers(ctx, accountID, peer.ID, nmap, peerNotValid, metaDiffAffectsPosture)
|
||||
affectedPeerIDs := am.syncPeerAffectedPeers(ctx, accountID, peer.ID, nmap, peerNotValid, updated, len(resPostureChecks) > 0)
|
||||
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
|
||||
return nil, nil, nil, 0, fmt.Errorf("notify network map controller of peer update: %w", err)
|
||||
}
|
||||
@@ -1071,8 +1059,8 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
|
||||
// metadata change that flips a posture result removes this peer from others'
|
||||
// maps asymmetrically; that case (and an invalid peer, whose map is empty) falls
|
||||
// back to the resolver.
|
||||
func (am *DefaultAccountManager) syncPeerAffectedPeers(ctx context.Context, accountID, peerID string, nmap *types.NetworkMap, peerNotValid, metaChangeAffectedPosture bool) []string {
|
||||
if peerNotValid || metaChangeAffectedPosture {
|
||||
func (am *DefaultAccountManager) syncPeerAffectedPeers(ctx context.Context, accountID, peerID string, nmap *types.NetworkMap, peerNotValid, metaUpdated, hasPostureChecks bool) []string {
|
||||
if peerNotValid || (metaUpdated && hasPostureChecks) {
|
||||
return am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, []string{peerID})
|
||||
}
|
||||
return affectedPeerIDsFromNetworkMap(nmap, peerID)
|
||||
@@ -1182,7 +1170,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
|
||||
}
|
||||
|
||||
// This is needed to keep in memory for the peer config. Otherwise browser client will end in a retry loop
|
||||
peer.Meta = login.Meta
|
||||
peer.UpdateMetaIfNew(ctx, login.Meta)
|
||||
|
||||
peerGroupIDs, err = getPeerGroupIDs(ctx, am.Store, accountID, peer.ID)
|
||||
if err != nil {
|
||||
|
||||
@@ -256,18 +256,14 @@ func (p *Peer) Copy() *Peer {
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateMetaIfNew updates peer's system metadata and connection geo location if
|
||||
// new information is provided. newLocation is the geo location resolved from the
|
||||
// peer's current connection IP, or nil when there is nothing to apply (geo
|
||||
// disabled, no real IP, or the IP is unchanged); the caller owns the expensive
|
||||
// lookup and the same-IP guard. It returns a MetaDiff describing what changed;
|
||||
// diff.Updated() reports whether the peer needs to be persisted.
|
||||
func (p *Peer) UpdateMetaIfNew(ctx context.Context, meta PeerSystemMeta, newLocation *Location) MetaDiff {
|
||||
// UpdateMetaIfNew updates peer's system metadata if new information is provided
|
||||
// returns true if meta was updated, false otherwise
|
||||
func (p *Peer) UpdateMetaIfNew(ctx context.Context, meta PeerSystemMeta) (updated, versionChanged bool) {
|
||||
if meta.isEmpty() {
|
||||
return MetaDiff{}
|
||||
return updated, versionChanged
|
||||
}
|
||||
|
||||
versionChanged := p.Meta.WtVersion != meta.WtVersion
|
||||
versionChanged = p.Meta.WtVersion != meta.WtVersion
|
||||
|
||||
// Avoid overwriting UIVersion if the update was triggered sole by the CLI client
|
||||
if meta.UIVersion == "" {
|
||||
@@ -276,177 +272,97 @@ func (p *Peer) UpdateMetaIfNew(ctx context.Context, meta PeerSystemMeta, newLoca
|
||||
|
||||
oldVersion := p.Meta.WtVersion
|
||||
|
||||
diff := diffMeta(p.Meta, meta)
|
||||
if diff.Any() {
|
||||
diff := metaDiff(p.Meta, meta)
|
||||
if len(diff) != 0 {
|
||||
p.Meta = meta
|
||||
}
|
||||
diff.VersionChanged = versionChanged
|
||||
|
||||
locationInfo := ""
|
||||
if newLocation != nil {
|
||||
p.Location = *newLocation
|
||||
diff.LocationChanged = true
|
||||
locationInfo = fmt.Sprintf("location changed to %s, ", newLocation.ConnectionIP)
|
||||
updated = true
|
||||
}
|
||||
|
||||
versionInfo := ""
|
||||
if diff.VersionChanged {
|
||||
if versionChanged {
|
||||
versionInfo = fmt.Sprintf("version changed: %s -> %s, ", oldVersion, meta.WtVersion)
|
||||
}
|
||||
|
||||
if diff.Any() || diff.VersionChanged || diff.LocationChanged {
|
||||
if len(diff) > 0 || versionChanged {
|
||||
log.WithContext(ctx).
|
||||
Debugf("peer meta updated, %s%s%d field(s) changed: %s", versionInfo, locationInfo, len(diff.Changed), strings.Join(diff.Changed, ", "))
|
||||
Debugf("peer meta updated, %s%d field(s) changed: %s", versionInfo, len(diff), strings.Join(diff, ", "))
|
||||
}
|
||||
|
||||
return diff
|
||||
}
|
||||
|
||||
// MetaDiff records which PeerSystemMeta fields differ between two metas. Each bool
|
||||
// maps to a single struct field, except Environment, which is split into Cloud and
|
||||
// Platform. Changed holds the human-readable `field: <old> -> <new>` entries so the
|
||||
// existing log line and isEqual can be derived from the same comparison.
|
||||
//
|
||||
// VersionChanged and LocationChanged sit outside the per-meta-field set:
|
||||
// VersionChanged tracks the WireGuard client version specifically (compared before
|
||||
// the UIVersion fixup, to signal client upgrades) and LocationChanged tracks the
|
||||
// peer's connection geo location, which lives on Peer rather than PeerSystemMeta.
|
||||
// Neither contributes an entry to Changed, so the field-coverage accounting stays
|
||||
// driven purely by the PeerSystemMeta comparison.
|
||||
type MetaDiff struct {
|
||||
Hostname bool
|
||||
GoOS bool
|
||||
Kernel bool
|
||||
KernelVersion bool
|
||||
Core bool
|
||||
Platform bool
|
||||
OS bool
|
||||
OSVersion bool
|
||||
WtVersion bool
|
||||
UIVersion bool
|
||||
SystemSerialNumber bool
|
||||
SystemProductName bool
|
||||
SystemManufacturer bool
|
||||
EnvironmentCloud bool
|
||||
EnvironmentPlatform bool
|
||||
Flags bool
|
||||
Capabilities bool
|
||||
NetworkAddresses bool
|
||||
Files bool
|
||||
|
||||
VersionChanged bool
|
||||
LocationChanged bool
|
||||
|
||||
Changed []string
|
||||
}
|
||||
|
||||
// Any reports whether any PeerSystemMeta field changed.
|
||||
func (d MetaDiff) Any() bool {
|
||||
return len(d.Changed) != 0
|
||||
}
|
||||
|
||||
// Updated reports whether the peer needs to be persisted: any meta field changed
|
||||
// or the geo location changed. The version flag alone does not imply a write,
|
||||
// since a version change is also reflected in the WtVersion meta field.
|
||||
func (d MetaDiff) Updated() bool {
|
||||
return d.Any() || d.LocationChanged || d.VersionChanged
|
||||
return updated, versionChanged
|
||||
}
|
||||
|
||||
// metaDiff returns a human-readable list of the fields that differ between the
|
||||
// old and new meta, each formatted as `field: <old> -> <new>`. It is the single
|
||||
// source of truth for meta comparison: isEqual reports equality as an empty
|
||||
// diff, so the log line can never disagree with the change decision. Slices are
|
||||
// cloned before sorting, so callers' meta is not mutated.
|
||||
func metaDiff(oldMeta, newMeta PeerSystemMeta) []string {
|
||||
return diffMeta(oldMeta, newMeta).Changed
|
||||
}
|
||||
|
||||
// diffMeta compares two metas field by field, returning both a per-field flag set
|
||||
// (for callers that need to know exactly what changed, e.g. matching against
|
||||
// posture checks) and the human-readable Changed list. It is the single source of
|
||||
// truth for meta comparison: isEqual reports equality as an empty diff, so the log
|
||||
// line, the change decision, and the flags can never disagree.
|
||||
func diffMeta(oldMeta, newMeta PeerSystemMeta) MetaDiff {
|
||||
var d MetaDiff
|
||||
var diff []string
|
||||
add := func(field string, oldVal, newVal any) {
|
||||
d.Changed = append(d.Changed, fmt.Sprintf("%s: %v -> %v", field, oldVal, newVal))
|
||||
diff = append(diff, fmt.Sprintf("%s: %v -> %v", field, oldVal, newVal))
|
||||
}
|
||||
|
||||
if oldMeta.Hostname != newMeta.Hostname {
|
||||
d.Hostname = true
|
||||
add("hostname", oldMeta.Hostname, newMeta.Hostname)
|
||||
}
|
||||
if oldMeta.GoOS != newMeta.GoOS {
|
||||
d.GoOS = true
|
||||
add("goos", oldMeta.GoOS, newMeta.GoOS)
|
||||
}
|
||||
if oldMeta.Kernel != newMeta.Kernel {
|
||||
d.Kernel = true
|
||||
add("kernel", oldMeta.Kernel, newMeta.Kernel)
|
||||
}
|
||||
if oldMeta.KernelVersion != newMeta.KernelVersion {
|
||||
d.KernelVersion = true
|
||||
add("kernel_version", oldMeta.KernelVersion, newMeta.KernelVersion)
|
||||
}
|
||||
if oldMeta.Core != newMeta.Core {
|
||||
d.Core = true
|
||||
add("core", oldMeta.Core, newMeta.Core)
|
||||
}
|
||||
if oldMeta.Platform != newMeta.Platform {
|
||||
d.Platform = true
|
||||
add("platform", oldMeta.Platform, newMeta.Platform)
|
||||
}
|
||||
if oldMeta.OS != newMeta.OS {
|
||||
d.OS = true
|
||||
add("os", oldMeta.OS, newMeta.OS)
|
||||
}
|
||||
if oldMeta.OSVersion != newMeta.OSVersion {
|
||||
d.OSVersion = true
|
||||
add("os_version", oldMeta.OSVersion, newMeta.OSVersion)
|
||||
}
|
||||
if oldMeta.WtVersion != newMeta.WtVersion {
|
||||
d.WtVersion = true
|
||||
add("wt_version", oldMeta.WtVersion, newMeta.WtVersion)
|
||||
}
|
||||
if oldMeta.UIVersion != newMeta.UIVersion {
|
||||
d.UIVersion = true
|
||||
add("ui_version", oldMeta.UIVersion, newMeta.UIVersion)
|
||||
}
|
||||
if oldMeta.SystemSerialNumber != newMeta.SystemSerialNumber {
|
||||
d.SystemSerialNumber = true
|
||||
add("system_serial_number", oldMeta.SystemSerialNumber, newMeta.SystemSerialNumber)
|
||||
}
|
||||
if oldMeta.SystemProductName != newMeta.SystemProductName {
|
||||
d.SystemProductName = true
|
||||
add("system_product_name", oldMeta.SystemProductName, newMeta.SystemProductName)
|
||||
}
|
||||
if oldMeta.SystemManufacturer != newMeta.SystemManufacturer {
|
||||
d.SystemManufacturer = true
|
||||
add("system_manufacturer", oldMeta.SystemManufacturer, newMeta.SystemManufacturer)
|
||||
}
|
||||
if oldMeta.Environment.Cloud != newMeta.Environment.Cloud {
|
||||
d.EnvironmentCloud = true
|
||||
add("environment_cloud", oldMeta.Environment.Cloud, newMeta.Environment.Cloud)
|
||||
}
|
||||
if oldMeta.Environment.Platform != newMeta.Environment.Platform {
|
||||
d.EnvironmentPlatform = true
|
||||
add("environment_platform", oldMeta.Environment.Platform, newMeta.Environment.Platform)
|
||||
}
|
||||
if !oldMeta.Flags.isEqual(newMeta.Flags) {
|
||||
d.Flags = true
|
||||
add("flags", fmt.Sprintf("%+v", oldMeta.Flags), fmt.Sprintf("%+v", newMeta.Flags))
|
||||
}
|
||||
if !capabilitiesEqual(oldMeta.Capabilities, newMeta.Capabilities) {
|
||||
d.Capabilities = true
|
||||
add("capabilities", oldMeta.Capabilities, newMeta.Capabilities)
|
||||
}
|
||||
|
||||
if !sameMultiset(oldMeta.NetworkAddresses, newMeta.NetworkAddresses) {
|
||||
d.NetworkAddresses = true
|
||||
add("network_addresses", fmt.Sprintf("%v", oldMeta.NetworkAddresses), fmt.Sprintf("%v", newMeta.NetworkAddresses))
|
||||
}
|
||||
|
||||
if !sameMultiset(oldMeta.Files, newMeta.Files) {
|
||||
d.Files = true
|
||||
add("files", fmt.Sprintf("%v", oldMeta.Files), fmt.Sprintf("%v", newMeta.Files))
|
||||
}
|
||||
|
||||
return d
|
||||
return diff
|
||||
}
|
||||
|
||||
// sameMultiset reports whether two slices contain the same elements with the
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"regexp"
|
||||
|
||||
"github.com/hashicorp/go-version"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -52,34 +51,6 @@ type Checks struct {
|
||||
Checks ChecksDefinition `gorm:"serializer:json"`
|
||||
}
|
||||
|
||||
// AffectsPosture reports whether the peer metadata changes described by diff can
|
||||
// alter the outcome of any of the given posture checks. It maps each check kind to
|
||||
// the metadata fields it inspects, so an unrelated change (e.g. a hostname update)
|
||||
// does not force a posture re-evaluation.
|
||||
func AffectsPosture(diff *nbpeer.MetaDiff, checks []*Checks) bool {
|
||||
if diff == nil {
|
||||
return false
|
||||
}
|
||||
for _, c := range checks {
|
||||
if c.Checks.ProcessCheck != nil && diff.Files {
|
||||
return true
|
||||
}
|
||||
if c.Checks.OSVersionCheck != nil && (diff.OSVersion || diff.OS || diff.KernelVersion) {
|
||||
return true
|
||||
}
|
||||
if c.Checks.NBVersionCheck != nil && diff.WtVersion {
|
||||
return true
|
||||
}
|
||||
if c.Checks.GeoLocationCheck != nil && diff.LocationChanged {
|
||||
return true
|
||||
}
|
||||
if c.Checks.PeerNetworkRangeCheck != nil && diff.NetworkAddresses {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ChecksDefinition contains definition of actual check
|
||||
type ChecksDefinition struct {
|
||||
NBVersionCheck *NBVersionCheck `json:",omitempty"`
|
||||
|
||||
@@ -581,6 +581,28 @@ func (s *SqlStore) MarkPeerDisconnectedIfSameSession(ctx context.Context, accoun
|
||||
return result.RowsAffected > 0, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) SavePeerLocation(ctx context.Context, accountID string, peerWithLocation *nbpeer.Peer) error {
|
||||
// To maintain data integrity, we create a copy of the peer's location to prevent unintended updates to other fields.
|
||||
var peerCopy nbpeer.Peer
|
||||
// Since the location field has been migrated to JSON serialization,
|
||||
// updating the struct ensures the correct data format is inserted into the database.
|
||||
peerCopy.Location = peerWithLocation.Location
|
||||
|
||||
result := s.db.Model(&nbpeer.Peer{}).
|
||||
Where(accountAndIDQueryCondition, accountID, peerWithLocation.ID).
|
||||
Updates(peerCopy)
|
||||
|
||||
if result.Error != nil {
|
||||
return status.Errorf(status.Internal, "failed to save peer locations to store: %v", result.Error)
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.Errorf(status.NotFound, peerNotFoundFMT, peerWithLocation.ID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ApproveAccountPeers marks all peers that currently require approval in the given account as approved.
|
||||
func (s *SqlStore) ApproveAccountPeers(ctx context.Context, accountID string) (int, error) {
|
||||
result := s.db.Model(&nbpeer.Peer{}).
|
||||
|
||||
@@ -618,6 +618,56 @@ func TestSqlStore_SavePeerStatus(t *testing.T) {
|
||||
assert.WithinDurationf(t, newStatus.LastSeen, actual.LastSeen.UTC(), time.Millisecond, "LastSeen should be equal")
|
||||
}
|
||||
|
||||
func TestSqlStore_SavePeerLocation(t *testing.T) {
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b")
|
||||
require.NoError(t, err)
|
||||
|
||||
peer := &nbpeer.Peer{
|
||||
AccountID: account.Id,
|
||||
ID: "testpeer",
|
||||
Location: nbpeer.Location{
|
||||
ConnectionIP: net.ParseIP("0.0.0.0"),
|
||||
CountryCode: "YY",
|
||||
CityName: "City",
|
||||
GeoNameID: 1,
|
||||
},
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Meta: nbpeer.PeerSystemMeta{},
|
||||
}
|
||||
// error is expected as peer is not in store yet
|
||||
err = store.SavePeerLocation(context.Background(), account.Id, peer)
|
||||
assert.Error(t, err)
|
||||
|
||||
account.Peers[peer.ID] = peer
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
|
||||
peer.Location.ConnectionIP = net.ParseIP("35.1.1.1")
|
||||
peer.Location.CountryCode = "DE"
|
||||
peer.Location.CityName = "Berlin"
|
||||
peer.Location.GeoNameID = 2950159
|
||||
|
||||
err = store.SavePeerLocation(context.Background(), account.Id, account.Peers[peer.ID])
|
||||
assert.NoError(t, err)
|
||||
|
||||
account, err = store.GetAccount(context.Background(), account.Id)
|
||||
require.NoError(t, err)
|
||||
|
||||
actual := account.Peers[peer.ID].Location
|
||||
assert.Equal(t, peer.Location, actual)
|
||||
|
||||
peer.ID = "non-existing-peer"
|
||||
err = store.SavePeerLocation(context.Background(), account.Id, peer)
|
||||
assert.Error(t, err)
|
||||
parsedErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")
|
||||
}
|
||||
|
||||
func Test_TestGetAccountByPrivateDomain(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
|
||||
@@ -185,6 +185,7 @@ type Store interface {
|
||||
// recorded by the database. Returns true when the update happened,
|
||||
// false when a newer session has taken over.
|
||||
MarkPeerDisconnectedIfSameSession(ctx context.Context, accountID, peerID string, sessionStartedAt int64) (bool, error)
|
||||
SavePeerLocation(ctx context.Context, accountID string, peer *nbpeer.Peer) error
|
||||
ApproveAccountPeers(ctx context.Context, accountID string) (int, error)
|
||||
DeletePeer(ctx context.Context, accountID string, peerID string) error
|
||||
|
||||
|
||||
@@ -2968,6 +2968,20 @@ func (mr *MockStoreMockRecorder) SavePeer(ctx, accountID, peer interface{}) *gom
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SavePeer", reflect.TypeOf((*MockStore)(nil).SavePeer), ctx, accountID, peer)
|
||||
}
|
||||
|
||||
// SavePeerLocation mocks base method.
|
||||
func (m *MockStore) SavePeerLocation(ctx context.Context, accountID string, peer *peer.Peer) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SavePeerLocation", ctx, accountID, peer)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SavePeerLocation indicates an expected call of SavePeerLocation.
|
||||
func (mr *MockStoreMockRecorder) SavePeerLocation(ctx, accountID, peer interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SavePeerLocation", reflect.TypeOf((*MockStore)(nil).SavePeerLocation), ctx, accountID, peer)
|
||||
}
|
||||
|
||||
// SavePeerStatus mocks base method.
|
||||
func (m *MockStore) SavePeerStatus(ctx context.Context, accountID, peerID string, status peer.PeerStatus) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -12,9 +12,6 @@ type PeerSync struct {
|
||||
WireGuardPubKey string
|
||||
// Meta is the system information passed by peer, must be always present
|
||||
Meta nbpeer.PeerSystemMeta
|
||||
// RealIP is the peer's connection IP, used to refresh its geo location.
|
||||
// May be nil when the request has no associated connection IP.
|
||||
RealIP net.IP
|
||||
// UpdateAccountPeers indicate updating account peers,
|
||||
// which occurs when the peer's metadata is updated
|
||||
UpdateAccountPeers bool
|
||||
|
||||
@@ -243,11 +243,7 @@ func NewClientWithServerIP(serverURL string, serverIP netip.Addr, authTokenStore
|
||||
|
||||
// Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs.
|
||||
func (c *Client) Connect(ctx context.Context) error {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
c.log.Infof("connect elapsed time: %v", time.Since(start))
|
||||
}()
|
||||
|
||||
c.log.Infof("connecting to relay server")
|
||||
c.readLoopMutex.Lock()
|
||||
defer c.readLoopMutex.Unlock()
|
||||
|
||||
@@ -291,11 +287,6 @@ func (c *Client) Connect(ctx context.Context) error {
|
||||
func (c *Client) OpenConn(ctx context.Context, dstPeerID string) (net.Conn, error) {
|
||||
peerID := messages.HashID(dstPeerID)
|
||||
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
c.log.Infof("connect elapsed time: %v", time.Since(start))
|
||||
}()
|
||||
|
||||
c.mu.Lock()
|
||||
if !c.serviceIsRunning {
|
||||
c.mu.Unlock()
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"go.opentelemetry.io/otel"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
@@ -253,7 +252,7 @@ func TestClient_ConnectedIPParsesRemoteAddr(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Client{log: log.WithField("relay", tt.name), relayConn: stubConn{remote: staticAddr{s: tt.s}}}
|
||||
c := &Client{relayConn: stubConn{remote: staticAddr{s: tt.s}}}
|
||||
got := c.ConnectedIP()
|
||||
var gotStr string
|
||||
if got.IsValid() {
|
||||
|
||||
@@ -78,23 +78,6 @@ type GrpcClient struct {
|
||||
// transport-alive but no longer delivering messages. It is the source of
|
||||
// truth IsHealthy reads, and is cleared once any frame is received again.
|
||||
receiveStalled atomic.Bool
|
||||
// receiveHandoffBlocked is set while the receive loop is parked handing a
|
||||
// message to a busy decryption worker. The loop stops calling Recv (and
|
||||
// markReceived) in that window, so the stream looks silent though it is
|
||||
// healthy. The watchdog reads this to avoid misreading self-inflicted
|
||||
// receive backpressure as a dead stream: reconnecting cannot help, since the
|
||||
// new stream feeds the same worker, and only triggers a reconnect storm.
|
||||
receiveHandoffBlocked atomic.Bool
|
||||
// lastDecrypt holds the Unix-nano timestamp of the last message the decryption
|
||||
// worker pulled off its queue. Diagnostic only: it lets a stall log show
|
||||
// whether the worker was draining (busy) or idle when the stream went silent.
|
||||
lastDecrypt atomic.Int64
|
||||
// handoffWaitTotal, handoffWaitMax (nanos) and handoffWaitCount accumulate the
|
||||
// time the receive loop spent blocked handing messages to the worker. This is
|
||||
// time not spent reading the stream, so it quantifies receive backpressure.
|
||||
handoffWaitTotal atomic.Int64
|
||||
handoffWaitMax atomic.Int64
|
||||
handoffWaitCount atomic.Int64
|
||||
}
|
||||
|
||||
// NewClient creates a new Signal client
|
||||
@@ -370,8 +353,6 @@ func (c *GrpcClient) SendToStream(msg *proto.EncryptedMessage) error {
|
||||
|
||||
// decryptMessage decrypts the body of the msg using Wireguard private key and Remote peer's public key
|
||||
func (c *GrpcClient) decryptMessage(msg *proto.EncryptedMessage) (*proto.Message, error) {
|
||||
c.lastDecrypt.Store(time.Now().UnixNano())
|
||||
|
||||
remoteKey, err := wgtypes.ParseKey(msg.GetKey())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -458,22 +439,6 @@ func (c *GrpcClient) idleSinceReceive() time.Duration {
|
||||
return time.Since(time.Unix(0, c.lastReceived.Load()))
|
||||
}
|
||||
|
||||
// idleSinceDecrypt returns how long since the worker last pulled a message.
|
||||
// Diagnostic only: distinguishes a busy/wedged worker from an idle one.
|
||||
func (c *GrpcClient) idleSinceDecrypt() time.Duration {
|
||||
return time.Since(time.Unix(0, c.lastDecrypt.Load()))
|
||||
}
|
||||
|
||||
// receiveAlive reports whether the receive stream shows liveness: it delivered a
|
||||
// frame within the inactivity threshold, or the receive loop is currently parked
|
||||
// handing a message to a busy decryption worker. In the latter case the loop has
|
||||
// stopped calling Recv, so the stream looks silent while being healthy, and
|
||||
// reconnecting would not help, so the watchdog must treat it as alive.
|
||||
func (c *GrpcClient) receiveAlive() bool {
|
||||
return c.idleSinceReceive() < receiveInactivityThreshold ||
|
||||
c.receiveHandoffBlocked.Load()
|
||||
}
|
||||
|
||||
// watchReceiveStream guards against a receive stream that is transport-alive but
|
||||
// no longer delivering messages. While the stream is idle past
|
||||
// receiveInactivityThreshold it sends a self-addressed probe that the Signal
|
||||
@@ -485,55 +450,18 @@ func (c *GrpcClient) watchReceiveStream(ctx context.Context, cancelStream contex
|
||||
defer ticker.Stop()
|
||||
|
||||
var probeSentAt time.Time
|
||||
var holdLogged bool
|
||||
var statTicks int
|
||||
var lastStatTotal int64
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
// Periodic backpressure summary so time lost to the worker handoff is
|
||||
// visible even when no stall fires. Emitted ~once a minute and only
|
||||
// when the wait grew, to stay quiet on a healthy stream.
|
||||
if statTicks++; statTicks >= int(time.Minute/receiveWatchdogInterval) {
|
||||
statTicks = 0
|
||||
if total, max, count := c.handoffWaitStats(); int64(total) > lastStatTotal {
|
||||
log.Infof("signal receive backpressure: handoffWaitTotal=%s (+%s last min) handoffWaitMax=%s handoffMsgs=%d",
|
||||
total.Round(time.Second), (total - time.Duration(lastStatTotal)).Round(time.Millisecond),
|
||||
max.Round(time.Millisecond), count)
|
||||
lastStatTotal = int64(total)
|
||||
}
|
||||
}
|
||||
|
||||
if c.receiveAlive() {
|
||||
// Attribute the case that matters in the field: silent past the
|
||||
// threshold but held because the receive loop is parked on the
|
||||
// worker handoff (backpressure), not a dead stream. Log once per
|
||||
// hold episode so a persistent worker stall is visible at info.
|
||||
if c.idleSinceReceive() >= receiveInactivityThreshold && c.receiveHandoffBlocked.Load() {
|
||||
if !holdLogged {
|
||||
total, max, count := c.handoffWaitStats()
|
||||
log.Infof("signal receive idle %s, loop blocked on worker handoff (idleDecrypt=%s queueDepth=%d connState=%s handoffWaitTotal=%s handoffWaitMax=%s handoffMsgs=%d); holding stream",
|
||||
c.idleSinceReceive().Round(time.Second), c.idleSinceDecrypt().Round(time.Second),
|
||||
c.decryptionWorker.QueueLen(), c.signalConn.GetState(),
|
||||
total.Round(time.Second), max.Round(time.Millisecond), count)
|
||||
holdLogged = true
|
||||
}
|
||||
} else {
|
||||
holdLogged = false
|
||||
}
|
||||
if c.idleSinceReceive() < receiveInactivityThreshold {
|
||||
probeSentAt = time.Time{}
|
||||
continue
|
||||
}
|
||||
holdLogged = false
|
||||
|
||||
if !probeSentAt.IsZero() && time.Since(probeSentAt) >= receiveProbeTimeout {
|
||||
total, max, count := c.handoffWaitStats()
|
||||
log.Warnf("signal receive stream stalled, reconnecting: idleRecv=%s idleDecrypt=%s handoffBlocked=%v queueDepth=%d connState=%s handoffWaitTotal=%s handoffWaitMax=%s handoffMsgs=%d probe did not return",
|
||||
c.idleSinceReceive().Round(time.Second), c.idleSinceDecrypt().Round(time.Second),
|
||||
c.receiveHandoffBlocked.Load(), c.decryptionWorker.QueueLen(), c.signalConn.GetState(),
|
||||
total.Round(time.Second), max.Round(time.Millisecond), count)
|
||||
log.Warnf("signal receive stream stalled: no messages for %s and probe did not return, reconnecting", c.idleSinceReceive().Round(time.Second))
|
||||
c.receiveStalled.Store(true)
|
||||
c.notifyDisconnected(errReceiveStreamStalled)
|
||||
cancelStream()
|
||||
@@ -589,37 +517,12 @@ func (c *GrpcClient) receive(stream proto.SignalExchange_ConnectStreamClient) er
|
||||
continue
|
||||
}
|
||||
|
||||
// The handoff blocks while the worker is busy, which parks this loop and
|
||||
// stops Recv. Flag it so the watchdog does not read the resulting silence
|
||||
// as a dead stream, and account the wait as receive backpressure.
|
||||
handoffStart := time.Now()
|
||||
c.receiveHandoffBlocked.Store(true)
|
||||
if err := c.decryptionWorker.AddMsg(c.ctx, msg); err != nil {
|
||||
log.Errorf("failed to add message to decryption worker: %v", err)
|
||||
}
|
||||
c.receiveHandoffBlocked.Store(false)
|
||||
c.recordHandoffWait(time.Since(handoffStart))
|
||||
}
|
||||
}
|
||||
|
||||
// recordHandoffWait accumulates the time the receive loop was blocked handing a
|
||||
// message to the worker.
|
||||
func (c *GrpcClient) recordHandoffWait(d time.Duration) {
|
||||
c.handoffWaitTotal.Add(int64(d))
|
||||
c.handoffWaitCount.Add(1)
|
||||
for {
|
||||
cur := c.handoffWaitMax.Load()
|
||||
if int64(d) <= cur || c.handoffWaitMax.CompareAndSwap(cur, int64(d)) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handoffWaitStats returns cumulative receive-loop handoff backpressure.
|
||||
func (c *GrpcClient) handoffWaitStats() (total, max time.Duration, count int64) {
|
||||
return time.Duration(c.handoffWaitTotal.Load()), time.Duration(c.handoffWaitMax.Load()), c.handoffWaitCount.Load()
|
||||
}
|
||||
|
||||
func (c *GrpcClient) startEncryptionWorker(handler func(msg *proto.Message) error) {
|
||||
if c.decryptionWorker != nil {
|
||||
return
|
||||
|
||||
@@ -82,27 +82,3 @@ func TestReceiveProbeRoundTrips(t *testing.T) {
|
||||
t.Fatal("self-addressed heartbeat did not round-trip back through the signal server")
|
||||
}
|
||||
}
|
||||
|
||||
// TestReceiveAliveTreatsHandoffBlockAsLiveness reproduces the false positive
|
||||
// where a busy decryption worker parks the receive loop on the worker handoff,
|
||||
// so Recv (and markReceived) stops firing even though the stream is healthy.
|
||||
// With the receive stream silent past the inactivity threshold but the loop
|
||||
// blocked on handoff, the watchdog must consider the stream alive rather than
|
||||
// tear it down (reconnecting feeds the same worker and would not help).
|
||||
func TestReceiveAliveTreatsHandoffBlockAsLiveness(t *testing.T) {
|
||||
c := &GrpcClient{}
|
||||
|
||||
// Receive stream silent and the loop not blocked on handoff: genuinely stalled.
|
||||
c.lastReceived.Store(time.Now().Add(-2 * receiveInactivityThreshold).UnixNano())
|
||||
require.False(t, c.receiveAlive(), "silent stream with the receive loop idle must be treated as stalled")
|
||||
|
||||
// Receive stream silent but the loop is parked handing a message to a busy
|
||||
// worker: self-inflicted backpressure, not a dead stream, must not tear down.
|
||||
c.receiveHandoffBlocked.Store(true)
|
||||
require.True(t, c.receiveAlive(), "a receive loop blocked on worker handoff must keep the stream alive")
|
||||
|
||||
// Handoff drained, loop back to reading, a frame just arrived: alive via the receive path.
|
||||
c.receiveHandoffBlocked.Store(false)
|
||||
c.markReceived()
|
||||
require.True(t, c.receiveAlive(), "a freshly received frame must keep the stream alive")
|
||||
}
|
||||
|
||||
@@ -32,13 +32,6 @@ func (w *Worker) AddMsg(ctx context.Context, msg *proto.EncryptedMessage) error
|
||||
return nil
|
||||
}
|
||||
|
||||
// QueueLen returns the number of messages buffered for decryption. Diagnostic
|
||||
// only: a non-empty queue while the receive stream is silent indicates the
|
||||
// receive loop is parked on the handoff rather than the stream being dead.
|
||||
func (w *Worker) QueueLen() int {
|
||||
return len(w.encryptedMsgPool)
|
||||
}
|
||||
|
||||
func (w *Worker) Work(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
|
||||
Reference in New Issue
Block a user