Compare commits

..

30 Commits

Author SHA1 Message Date
pascal
e048f221f0 Merge branch 'feature/less-strict-metaHash' into test/affected-logic 2026-06-25 14:37:08 +02:00
pascal
4e8780e0c3 Merge branch 'chore/change-log-level' into test/affected-logic 2026-06-25 13:43:59 +02:00
pascal
a1933f792a Merge branch 'feature/meta-posture-relation' into test/affected-logic 2026-06-25 13:43:48 +02:00
pascal
ccf1e965c3 Merge branch 'refactor/simplify-affected-peers-ignore-disabled' into test/affected-logic 2026-06-25 13:43:37 +02:00
pascal
38bd53eb71 handle login error 2026-06-25 13:09:58 +02:00
pascal
845dd0c9bb add brackets 2026-06-25 13:08:36 +02:00
pascal
2f0093a7e1 move some logs to trace 2026-06-25 12:20:59 +02:00
pascal
c5d26106f2 make test policies enabled to not fali tests 2026-06-25 09:51:20 +02:00
pascal
9746bc2a08 Merge branch 'refactor/simplify-affected-peers' into refactor/simplify-affected-peers-ignore-disabled 2026-06-25 01:48:57 +02:00
pascal
42867c7a59 ignore siblings when add/remove peer from group 2026-06-25 01:48:39 +02:00
pascal
d7740f9868 respect disabled 2026-06-25 00:48:33 +02:00
pascal
c2db940a8c simplify affected peers walk 2026-06-25 00:01:17 +02:00
pascal
62ffa08744 split networkIDs to check 2026-06-24 22:39:35 +02:00
Dmitri Dolguikh
d8e7f2e9e6 a couple of fixes
Signed-off-by: Dmitri Dolguikh <dmitri.external@netbird.io>
2026-06-24 15:40:44 +02:00
Dmitri Dolguikh
1205641b44 fixed test
Signed-off-by: Dmitri Dolguikh <dmitri.external@netbird.io>
2026-06-24 15:17:41 +02:00
pascal
de0cf0fc7a update callsites and benchmark 2026-06-24 14:36:57 +02:00
pascal
57f475c5a9 less strict metaHash when blocking peers 2026-06-24 14:33:15 +02:00
Dmitri Dolguikh
56e8215ebe updated 'resource-routing-bridge/router-peer-change refreshes policy sources' test to expect router peer among changed peer ids
Signed-off-by: Dmitri Dolguikh <dmitri.external@netbird.io>
2026-06-24 14:29:23 +02:00
pascal
5ae36cd260 remove posture check took timing logs 2026-06-24 14:13:55 +02:00
Dmitri Dolguikh
9b768d1773 fixed a bug in collectFromPolicies
Signed-off-by: Dmitri Dolguikh <dmitri.external@netbird.io>
2026-06-24 14:10:37 +02:00
Dmitri Dolguikh
33954ea15e fixing tests + adding tests
Signed-off-by: Dmitri Dolguikh <dmitri.external@netbird.io>
2026-06-24 13:07:53 +02:00
Dmitri Dolguikh
4c4434a871 fixed a few tests
Signed-off-by: Dmitri Dolguikh <dmitri.external@netbird.io>
2026-06-24 09:48:45 +02:00
pascal
9eadf50f4c fix typo 2026-06-23 23:02:48 +02:00
pascal
be06016ad2 validate posture checks on meta change before account update 2026-06-23 22:58:32 +02:00
Dmitri Dolguikh
7873f337df when collecting group and peer IDs from policies, do so directionally
Signed-off-by: Dmitri Dolguikh <dmitri.external@netbird.io>
2026-06-23 18:39:53 +02:00
Viktor Liu
17b2044596 [client] Skip re-resolving cached management cache domains (#6518) 2026-06-23 17:55:57 +02:00
Bethuel Mmbaga
07101c59ac [management] Reschedule inactivity expiration when a peer disconnects (#6523) 2026-06-23 17:44:32 +03:00
Riccardo Manfrin
51b6f6291b Fixup debug config (#6514) 2026-06-22 22:01:49 +02:00
Pascal Fischer
2ebf26006a [management] empty file check in nmap on other posturechecks (#6511) 2026-06-22 19:54:38 +02:00
Pascal Fischer
211a26019a [management] validate meta change against posture checks (#6510) 2026-06-22 19:42:04 +02:00
42 changed files with 1492 additions and 711 deletions

View File

@@ -130,7 +130,7 @@ func debugConfigDump(cmd *cobra.Command, _ []string) error {
client := proto.NewDaemonServiceClient(conn)
resp, err := client.GetConfig(cmd.Context(), &proto.GetConfigRequest{
ProfileName: activeProf.Name,
ProfileName: string(activeProf.ID),
Username: currUser.Username,
})
if err != nil {

View File

@@ -51,13 +51,20 @@ type cachedRecord struct {
}
// Resolver caches critical NetBird infrastructure domains.
// records, refreshing, mgmtDomain and serverDomains are all guarded by mutex.
// records, refreshing, failedResolves, mgmtDomain and serverDomains are all
// guarded by mutex.
type Resolver struct {
records map[dns.Question]*cachedRecord
mgmtDomain *domain.Domain
serverDomains *dnsconfig.ServerDomains
mutex sync.RWMutex
// failedResolves records the last failed initial resolve per domain so a
// domain that never resolves isn't retried on every server-domains update
// until refreshBackoff elapses. Entries are cleared on success and pruned
// to the current server-domains set.
failedResolves map[domain.Domain]time.Time
chain ChainResolver
chainMaxPriority int
refreshGroup singleflight.Group
@@ -76,9 +83,10 @@ type Resolver struct {
// NewResolver creates a new management domains cache resolver.
func NewResolver() *Resolver {
return &Resolver{
records: make(map[dns.Question]*cachedRecord),
refreshing: make(map[dns.Question]*atomic.Bool),
cacheTTL: resolveCacheTTL(),
records: make(map[dns.Question]*cachedRecord),
refreshing: make(map[dns.Question]*atomic.Bool),
failedResolves: make(map[domain.Domain]time.Time),
cacheTTL: resolveCacheTTL(),
}
}
@@ -173,7 +181,9 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
// AddDomain resolves a domain and stores its A/AAAA records in the cache.
// A family that resolves NODATA (nil err, zero records) evicts any stale
// entry for that qtype.
// entry for that qtype. When one family hard-errors while the other succeeds,
// the resolved family is still cached but AddDomain returns an error so the
// caller retries the incomplete resolve rather than treating it as complete.
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
@@ -203,6 +213,10 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
log.Debugf("added/updated domain=%s with %d A records and %d AAAA records",
d.SafeString(), len(aRecords), len(aaaaRecords))
if errA != nil || errAAAA != nil {
return fmt.Errorf("resolve %s: incomplete, a family failed: %w", d.SafeString(), errors.Join(errA, errAAAA))
}
return nil
}
@@ -462,6 +476,7 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error {
delete(m.records, qAAAA)
delete(m.refreshing, qA)
delete(m.refreshing, qAAAA)
delete(m.failedResolves, d)
log.Debugf("removed domain=%s from cache", d.SafeString())
return nil
@@ -505,6 +520,7 @@ func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dn
allDomains := m.extractDomainsFromServerDomains(updatedServerDomains)
currentDomains := m.GetCachedDomains()
removedDomains = m.removeStaleDomains(currentDomains, allDomains)
m.pruneFailedResolves(allDomains)
}
m.addNewDomains(ctx, newDomains)
@@ -577,13 +593,85 @@ func (m *Resolver) isManagementDomain(domain domain.Domain) bool {
return m.mgmtDomain != nil && domain == *m.mgmtDomain
}
// addNewDomains resolves and caches all domains from the update
// addNewDomains resolves and caches domains that are not yet in the cache,
// running the lookups concurrently. Domains already cached are skipped and left
// to the stale-while-revalidate refresh path, so a sync never re-resolves them
// synchronously: once NetBird owns the OS resolver the resolve runs through the
// handler chain and would otherwise dial the managed upstreams under the engine
// sync lock on every update.
func (m *Resolver) addNewDomains(ctx context.Context, newDomains domain.List) {
var wg sync.WaitGroup
seen := make(map[domain.Domain]struct{}, len(newDomains))
for _, newDomain := range newDomains {
if err := m.AddDomain(ctx, newDomain); err != nil {
log.Warnf("failed to add/update domain=%s: %v", newDomain.SafeString(), err)
} else {
log.Debugf("added/updated management cache domain=%s", newDomain.SafeString())
if _, dup := seen[newDomain]; dup {
continue
}
seen[newDomain] = struct{}{}
if !m.needsResolve(newDomain) {
continue
}
wg.Add(1)
go func(d domain.Domain) {
defer wg.Done()
if err := m.AddDomain(ctx, d); err != nil {
m.markResolveFailed(d)
log.Warnf("failed to add/update domain=%s: %v", d.SafeString(), err)
return
}
m.clearResolveFailed(d)
log.Debugf("added/updated management cache domain=%s", d.SafeString())
}(newDomain)
}
wg.Wait()
}
// needsResolve reports whether d should be resolved now. A recent failed or
// incomplete resolve gates retries on the backoff even when one family is
// already cached, so a transiently-failed family is retried instead of being
// treated as fully resolved. Otherwise a domain with any cached record is left
// to the stale-while-revalidate refresh path.
func (m *Resolver) needsResolve(d domain.Domain) bool {
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
m.mutex.RLock()
defer m.mutex.RUnlock()
if failedAt, ok := m.failedResolves[d]; ok {
return time.Since(failedAt) >= refreshBackoff
}
for _, qtype := range []uint16{dns.TypeA, dns.TypeAAAA} {
q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET}
if _, ok := m.records[q]; ok {
return false
}
}
return true
}
func (m *Resolver) markResolveFailed(d domain.Domain) {
m.mutex.Lock()
m.failedResolves[d] = time.Now()
m.mutex.Unlock()
}
func (m *Resolver) clearResolveFailed(d domain.Domain) {
m.mutex.Lock()
delete(m.failedResolves, d)
m.mutex.Unlock()
}
// pruneFailedResolves drops failure markers for domains no longer present in
// the server-domains set, keeping the map bounded to the current set (a
// failed-only domain has no cached record, so RemoveDomain never sees it).
func (m *Resolver) pruneFailedResolves(domains domain.List) {
m.mutex.Lock()
defer m.mutex.Unlock()
for d := range m.failedResolves {
if !slices.Contains(domains, d) {
delete(m.failedResolves, d)
}
}
}

View File

@@ -21,6 +21,7 @@ type fakeChain struct {
mu sync.Mutex
calls map[string]int
answers map[string][]dns.RR
qErr map[string]error
err error
hasRoot bool
onLookup func()
@@ -30,6 +31,7 @@ func newFakeChain() *fakeChain {
return &fakeChain{
calls: map[string]int{},
answers: map[string][]dns.RR{},
qErr: map[string]error{},
hasRoot: true,
}
}
@@ -47,6 +49,9 @@ func (f *fakeChain) ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriori
f.calls[key]++
answers := f.answers[key]
err := f.err
if err == nil {
err = f.qErr[key]
}
onLookup := f.onLookup
f.mu.Unlock()
@@ -75,6 +80,12 @@ func (f *fakeChain) setAnswer(name string, qtype uint16, ip string) {
}
}
func (f *fakeChain) setErr(name string, qtype uint16, err error) {
f.mu.Lock()
defer f.mu.Unlock()
f.qErr[name+"|"+dns.TypeToString[qtype]] = err
}
func (f *fakeChain) callCount(name string, qtype uint16) int {
f.mu.Lock()
defer f.mu.Unlock()

View File

@@ -0,0 +1,183 @@
package mgmt
import (
"context"
"errors"
"sync/atomic"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/shared/management/domain"
)
// A domain already in the cache must not be re-resolved on a subsequent server
// domains update; it is left to the stale-while-revalidate refresh path.
func TestResolver_UpdateFromServerDomains_SkipsCached(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("signal.example.com.", dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
sd := dnsconfig.ServerDomains{Signal: domain.Domain("signal.example.com")}
_, err := r.UpdateFromServerDomains(context.Background(), sd)
require.NoError(t, err)
require.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
"first update must resolve the domain")
_, err = r.UpdateFromServerDomains(context.Background(), sd)
require.NoError(t, err)
assert.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
"cached domain must not be re-resolved on a subsequent update")
}
// New domains in a single update must resolve concurrently rather than serially.
func TestResolver_AddNewDomains_ResolvesConcurrently(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
var inflight, maxInflight atomic.Int32
chain.onLookup = func() {
n := inflight.Add(1)
for {
old := maxInflight.Load()
if n <= old || maxInflight.CompareAndSwap(old, n) {
break
}
}
time.Sleep(50 * time.Millisecond)
inflight.Add(-1)
}
relays := []domain.Domain{"a.example.com", "b.example.com", "c.example.com", "d.example.com"}
for _, d := range relays {
chain.setAnswer(dns.Fqdn(string(d)), dns.TypeA, "10.0.0.2")
}
r.SetChainResolver(chain, 50)
start := time.Now()
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: relays})
require.NoError(t, err)
elapsed := time.Since(start)
assert.GreaterOrEqual(t, int(maxInflight.Load()), 2, "domains must resolve concurrently")
// Serial resolution of 4 domains would take at least 4*50ms; concurrent is far less.
assert.Less(t, elapsed, 300*time.Millisecond, "resolution should not be serial")
}
// A domain that fails to resolve must not be retried on every update; the
// failure backoff suppresses re-resolution until it expires.
func TestResolver_UpdateFromServerDomains_BacksOffFailures(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.err = errors.New("resolve boom")
r.SetChainResolver(chain, 50)
sd := dnsconfig.ServerDomains{Signal: domain.Domain("signal.example.com")}
_, err := r.UpdateFromServerDomains(context.Background(), sd)
require.NoError(t, err)
require.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
"first update must attempt the resolve")
_, err = r.UpdateFromServerDomains(context.Background(), sd)
require.NoError(t, err)
assert.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
"failed resolve must back off and not retry on the next update")
}
// A domain listed under more than one server-domain type (e.g. STUN and TURN on
// the same host) must be resolved once per update, not once per occurrence.
func TestResolver_AddNewDomains_DedupesDuplicateDomains(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("dup.example.com.", dns.TypeA, "10.0.0.9")
r.SetChainResolver(chain, 50)
sd := dnsconfig.ServerDomains{
Stuns: []domain.Domain{"dup.example.com"},
Turns: []domain.Domain{"dup.example.com"},
}
_, err := r.UpdateFromServerDomains(context.Background(), sd)
require.NoError(t, err)
assert.Equal(t, 1, chain.callCount("dup.example.com.", dns.TypeA),
"a domain appearing under multiple server-domain types must resolve once")
}
// A failure marker must be dropped once its domain leaves the server-domains set
// so the map stays bounded to the current set.
func TestResolver_UpdateFromServerDomains_PrunesFailedResolves(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.err = errors.New("resolve boom")
r.SetChainResolver(chain, 50)
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Signal: domain.Domain("gone.example.com")})
require.NoError(t, err)
r.mutex.RLock()
_, marked := r.failedResolves[domain.Domain("gone.example.com")]
r.mutex.RUnlock()
require.True(t, marked, "failed resolve must be recorded")
_, err = r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Signal: domain.Domain("other.example.com")})
require.NoError(t, err)
r.mutex.RLock()
_, stillMarked := r.failedResolves[domain.Domain("gone.example.com")]
r.mutex.RUnlock()
assert.False(t, stillMarked, "failure marker for a domain no longer in the set must be pruned")
}
// When one family hard-errors while the other resolves, the domain is cached
// for the working family but recorded as incomplete so the failed family is
// retried under backoff instead of being treated as fully resolved forever.
func TestResolver_AddNewDomains_RetriesPartialFamilyFailure(t *testing.T) {
d := domain.Domain("relay.example.com")
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("relay.example.com.", dns.TypeA, "10.0.0.2")
chain.setErr("relay.example.com.", dns.TypeAAAA, errors.New("servfail"))
r.SetChainResolver(chain, 50)
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: []domain.Domain{d}})
require.NoError(t, err)
r.mutex.RLock()
_, aCached := r.records[dns.Question{Name: "relay.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}]
_, marked := r.failedResolves[d]
r.mutex.RUnlock()
require.True(t, aCached, "the working family must still be cached")
require.True(t, marked, "a partial failure must be recorded so the failed family is retried")
assert.False(t, r.needsResolve(d), "within the backoff window the domain is not retried")
r.mutex.Lock()
r.failedResolves[d] = time.Now().Add(-2 * refreshBackoff)
r.mutex.Unlock()
assert.True(t, r.needsResolve(d), "after the backoff elapses the domain is retried to pick up the missing family")
}
// A family that returns NODATA (legitimately absent, e.g. an IPv4-only host) is
// not a failure: the domain must not be marked for retry, otherwise it would be
// re-resolved on every sync.
func TestResolver_AddNewDomains_NodataIsNotFailure(t *testing.T) {
d := domain.Domain("v4only.example.com")
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("v4only.example.com.", dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: []domain.Domain{d}})
require.NoError(t, err)
r.mutex.RLock()
_, marked := r.failedResolves[d]
r.mutex.RUnlock()
assert.False(t, marked, "a NODATA family must not be recorded as a failure")
assert.False(t, r.needsResolve(d), "an IPv4-only host must not be re-resolved on later syncs")
}

View File

@@ -82,12 +82,6 @@ const (
PeerConnectionTimeoutMax = 45000 // ms
PeerConnectionTimeoutMin = 30000 // ms
disableAutoUpdate = "disabled"
// systemInfoTimeout bounds how long the sync loop waits for system info / posture
// check gathering. The gathering runs uncancellable system calls (process scan,
// exec, os.Stat); without this bound a single stuck call freezes handleSync, and
// thus syncMsgMux, for as long as the call hangs (observed multi-minute freezes).
systemInfoTimeout = 15 * time.Second
)
var ErrResetConnection = fmt.Errorf("reset connection")
@@ -1072,22 +1066,11 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
}
e.checks = checks
info, ok := system.GetInfoWithChecksTimeout(e.ctx, systemInfoTimeout, checks)
if !ok {
// Gathering timed out; skip the meta sync this cycle rather than blocking the
// sync loop (and syncMsgMux) on a stuck system call. A later sync will retry.
return nil
info, err := system.GetInfoWithChecks(e.ctx, checks)
if err != nil {
log.Warnf("failed to get system info with checks: %v", err)
info = system.GetInfo(e.ctx)
}
e.applyInfoFlags(info)
if err := e.mgmClient.SyncMeta(info); err != nil {
return fmt.Errorf("could not sync meta: error %s", err)
}
return nil
}
// applyInfoFlags sets the engine's config-derived feature flags on the gathered system info.
func (e *Engine) applyInfoFlags(info *system.Info) {
info.SetFlags(
e.config.RosenpassEnabled,
e.config.RosenpassPermissive,
@@ -1106,6 +1089,12 @@ func (e *Engine) applyInfoFlags(info *system.Info) {
e.config.EnableSSHRemotePortForwarding,
e.config.DisableSSHAuth,
)
if err := e.mgmClient.SyncMeta(info); err != nil {
log.Errorf("could not sync meta: error %s", err)
return err
}
return nil
}
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
@@ -1251,15 +1240,31 @@ func (e *Engine) receiveManagementEvents() {
e.shutdownWg.Add(1)
go func() {
defer e.shutdownWg.Done()
info, ok := system.GetInfoWithChecksTimeout(e.ctx, systemInfoTimeout, e.checks)
if !ok {
// Gathering timed out; connect the stream with base info so management
// connectivity still comes up rather than blocking here.
info, err := system.GetInfoWithChecks(e.ctx, e.checks)
if err != nil {
log.Warnf("failed to get system info with checks: %v", err)
info = system.GetInfo(e.ctx)
}
e.applyInfoFlags(info)
info.SetFlags(
e.config.RosenpassEnabled,
e.config.RosenpassPermissive,
&e.config.ServerSSHAllowed,
e.config.DisableClientRoutes,
e.config.DisableServerRoutes,
e.config.DisableDNS,
e.config.DisableFirewall,
e.config.BlockLANAccess,
e.config.BlockInbound,
e.config.DisableIPv6,
e.config.LazyConnectionEnabled,
e.config.EnableSSHRoot,
e.config.EnableSSHSFTP,
e.config.EnableSSHLocalPortForwarding,
e.config.EnableSSHRemotePortForwarding,
e.config.DisableSSHAuth,
)
err := e.mgmClient.Sync(e.ctx, info, e.handleSync)
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
if err != nil {
// happens if management is unavailable for a long time.
// We want to cancel the operation of the whole client

View File

@@ -2,10 +2,8 @@ package system
import (
"context"
"errors"
"net/netip"
"strings"
"time"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/metadata"
@@ -156,7 +154,7 @@ func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, erro
processCheckPaths = append(processCheckPaths, check.GetFiles()...)
}
files, err := checkFileAndProcess(ctx, processCheckPaths)
files, err := checkFileAndProcess(processCheckPaths)
if err != nil {
return nil, err
}
@@ -168,39 +166,3 @@ func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, erro
log.Debugf("all system information gathered successfully")
return info, nil
}
// GetInfoWithChecksTimeout is GetInfoWithChecks bounded by timeout. Posture-check gathering
// runs uncancellable system calls (process enumeration, os.Stat), so calling it inline can
// block the caller for as long as such a call hangs. It runs in a goroutine instead: if it
// does not return within timeout the caller gets (nil, false) and should proceed with
// degraded behavior rather than block. On a gathering error it falls back to base GetInfo.
//
// The buffered channel lets the abandoned goroutine finish and exit once its blocking call
// returns, so it does not leak beyond the duration of that call.
func GetInfoWithChecksTimeout(ctx context.Context, timeout time.Duration, checks []*proto.Checks) (*Info, bool) {
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
infoCh := make(chan *Info, 1)
go func() {
info, err := GetInfoWithChecks(ctx, checks)
if err != nil {
log.Warnf("failed to get system info with checks: %v", err)
info = GetInfo(ctx)
}
infoCh <- info
}()
select {
case info := <-infoCh:
return info, true
case <-ctx.Done():
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
log.Warnf("gathering system info with checks timed out after %s", timeout)
} else {
// Parent context canceled (e.g. shutdown), not a timeout.
log.Warnf("gathering system info with checks canceled: %v", ctx.Err())
}
return nil, false
}
}

View File

@@ -50,7 +50,7 @@ func GetInfo(ctx context.Context) *Info {
}
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
func checkFileAndProcess(_ context.Context, _ []string) ([]File, error) {
func checkFileAndProcess(paths []string) ([]File, error) {
return []File{}, nil
}

View File

@@ -32,7 +32,7 @@ func GetInfo(ctx context.Context) *Info {
sysName := string(bytes.Split(utsname.Sysname[:], []byte{0})[0])
machine := string(bytes.Split(utsname.Machine[:], []byte{0})[0])
release := string(bytes.Split(utsname.Release[:], []byte{0})[0])
swVersion, err := exec.CommandContext(ctx, "sw_vers", "-productVersion").Output()
swVersion, err := exec.Command("sw_vers", "-productVersion").Output()
if err != nil {
log.Warnf("got an error while retrieving macOS version with sw_vers, error: %s. Using darwin version instead.\n", err)
swVersion = []byte(release)

View File

@@ -105,7 +105,7 @@ func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
}
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
func checkFileAndProcess(_ context.Context, _ []string) ([]File, error) {
func checkFileAndProcess(paths []string) ([]File, error) {
return []File{}, nil
}

View File

@@ -103,7 +103,7 @@ func collectLocationInfo(info *Info) {
}
}
func checkFileAndProcess(_ context.Context, _ []string) ([]File, error) {
func checkFileAndProcess(_ []string) ([]File, error) {
return []File{}, nil
}

View File

@@ -3,7 +3,6 @@ package system
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc/metadata"
@@ -35,20 +34,6 @@ func Test_CustomHostname(t *testing.T) {
assert.Equal(t, want, got.Hostname)
}
func TestGetInfoWithChecksTimeout_Success(t *testing.T) {
info, ok := GetInfoWithChecksTimeout(context.Background(), 30*time.Second, nil)
assert.True(t, ok, "expected gathering to complete within the timeout")
assert.NotNil(t, info)
}
func TestGetInfoWithChecksTimeout_Timeout(t *testing.T) {
// A 1ns budget expires before the (real) system-info gathering can finish, so the
// caller must get (nil, false) instead of blocking on the in-flight goroutine.
info, ok := GetInfoWithChecksTimeout(context.Background(), time.Nanosecond, nil)
assert.False(t, ok, "expected timeout to be reported")
assert.Nil(t, info)
}
func Test_NetAddresses(t *testing.T) {
addr, err := networkAddresses()
if err != nil {

View File

@@ -3,30 +3,24 @@
package system
import (
"context"
"os"
"slices"
"github.com/shirou/gopsutil/v3/process"
)
// getRunningProcesses returns a list of running process paths. The context bounds the work:
// the per-PID loop bails as soon as ctx is done, and the gopsutil calls honor it where they
// can, so a stuck enumeration cannot run unbounded.
func getRunningProcesses(ctx context.Context) ([]string, error) {
processIDs, err := process.PidsWithContext(ctx)
// getRunningProcesses returns a list of running process paths.
func getRunningProcesses() ([]string, error) {
processIDs, err := process.Pids()
if err != nil {
return nil, err
}
processMap := make(map[string]bool)
for _, pID := range processIDs {
if err := ctx.Err(); err != nil {
return nil, err
}
p := &process.Process{Pid: pID}
path, _ := p.ExeWithContext(ctx)
path, _ := p.Exe()
if path != "" {
processMap[path] = false
}
@@ -41,21 +35,18 @@ func getRunningProcesses(ctx context.Context) ([]string, error) {
}
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
func checkFileAndProcess(ctx context.Context, paths []string) ([]File, error) {
func checkFileAndProcess(paths []string) ([]File, error) {
files := make([]File, len(paths))
if len(paths) == 0 {
return files, nil
}
runningProcesses, err := getRunningProcesses(ctx)
runningProcesses, err := getRunningProcesses()
if err != nil {
return nil, err
}
for i, path := range paths {
if err := ctx.Err(); err != nil {
return nil, err
}
file := File{Path: path}
_, err := os.Stat(path)

View File

@@ -1,7 +1,6 @@
package system
import (
"context"
"testing"
"github.com/shirou/gopsutil/v3/process"
@@ -10,7 +9,7 @@ import (
func Benchmark_getRunningProcesses(b *testing.B) {
b.Run("getRunningProcesses new", func(b *testing.B) {
for i := 0; i < b.N; i++ {
ps, err := getRunningProcesses(context.Background())
ps, err := getRunningProcesses()
if err != nil {
b.Fatalf("unexpected error: %v", err)
}
@@ -30,38 +29,12 @@ func Benchmark_getRunningProcesses(b *testing.B) {
}
}
})
s, _ := getRunningProcesses(context.Background())
s, _ := getRunningProcesses()
b.Logf("getRunningProcesses returned %d processes", len(s))
s, _ = getRunningProcessesOld()
b.Logf("getRunningProcessesOld returned %d processes", len(s))
}
func TestCheckFileAndProcess_ContextCanceled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
// With a canceled context and non-empty paths the gathering must bail with an error
// instead of running the (potentially blocking) process scan / stat loop.
if _, err := checkFileAndProcess(ctx, []string{"/does/not/exist"}); err == nil {
t.Fatal("expected error on canceled context, got nil")
}
}
func TestCheckFileAndProcess_EmptyPaths(t *testing.T) {
// No check paths means no work to do: it must return immediately with no error,
// even on a canceled context (nothing to scan or stat).
ctx, cancel := context.WithCancel(context.Background())
cancel()
files, err := checkFileAndProcess(ctx, nil)
if err != nil {
t.Fatalf("unexpected error for empty paths: %v", err)
}
if len(files) != 0 {
t.Fatalf("expected no files, got %d", len(files))
}
}
func getRunningProcessesOld() ([]string, error) {
processes, err := process.Processes()
if err != nil {

View File

@@ -610,12 +610,10 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
return nil, nil, 0, err
}
startPosture := time.Now()
postureChecks, err := c.getPeerPostureChecks(account, peerID)
if err != nil {
return nil, nil, 0, err
}
log.WithContext(ctx).Debugf("getPeerPostureChecks took %s", time.Since(startPosture))
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
if err != nil {

View File

@@ -11,7 +11,7 @@ import (
const (
reconnThreshold = 5 * time.Minute
baseBlockDuration = 10 * time.Minute // Duration for which a peer is banned after exceeding the reconnection limit
baseBlockDuration = 30 * time.Minute // Duration for which a peer is banned after exceeding the reconnection limit
reconnLimitForBan = 30 // Number of reconnections within the reconnTreshold that triggers a ban
metaChangeLimit = 3 // Number of reconnections with different metadata that triggers a ban of one peer
)
@@ -139,22 +139,13 @@ func (l *loginFilter) addLogin(wgPubKey string, metaHash uint64) {
state.lastSeen = now
}
func metaHash(meta nbpeer.PeerSystemMeta, pubip string) uint64 {
func metaHash(meta nbpeer.PeerSystemMeta) uint64 {
h := fnv.New64a()
h.Write([]byte(meta.WtVersion))
h.Write([]byte(meta.OSVersion))
h.Write([]byte(meta.KernelVersion))
h.Write([]byte(meta.Hostname))
h.Write([]byte(meta.SystemSerialNumber))
h.Write([]byte(pubip))
macs := uint64(0)
for _, na := range meta.NetworkAddresses {
for _, r := range na.Mac {
macs += uint64(r)
}
}
return h.Sum64() + macs
return h.Sum64()
}

View File

@@ -164,9 +164,7 @@ func BenchmarkHashingMethods(b *testing.B) {
KernelVersion: "5.15.0-76-generic",
Hostname: "prod-server-database-01",
SystemSerialNumber: "PC-1234567890",
NetworkAddresses: []nbpeer.NetworkAddress{{Mac: "00:1B:44:11:3A:B7"}, {Mac: "00:1B:44:11:3A:B8"}},
}
pubip := "8.8.8.8"
var resultString string
var resultUint uint64
@@ -175,7 +173,7 @@ func BenchmarkHashingMethods(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
resultString = builderString(meta, pubip)
resultString = builderString(meta)
}
})
@@ -183,7 +181,7 @@ func BenchmarkHashingMethods(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
resultString = fnvHashToString(meta, pubip)
resultString = fnvHashToString(meta)
}
})
@@ -191,7 +189,7 @@ func BenchmarkHashingMethods(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
resultUint = metaHash(meta, pubip)
resultUint = metaHash(meta)
}
})
@@ -199,29 +197,20 @@ func BenchmarkHashingMethods(b *testing.B) {
_ = resultUint
}
func fnvHashToString(meta nbpeer.PeerSystemMeta, pubip string) string {
func fnvHashToString(meta nbpeer.PeerSystemMeta) string {
h := fnv.New64a()
if len(meta.NetworkAddresses) != 0 {
for _, na := range meta.NetworkAddresses {
h.Write([]byte(na.Mac))
}
}
h.Write([]byte(meta.WtVersion))
h.Write([]byte(meta.OSVersion))
h.Write([]byte(meta.KernelVersion))
h.Write([]byte(meta.Hostname))
h.Write([]byte(meta.SystemSerialNumber))
h.Write([]byte(pubip))
return strconv.FormatUint(h.Sum64(), 16)
}
func builderString(meta nbpeer.PeerSystemMeta, pubip string) string {
mac := getMacAddress(meta.NetworkAddresses)
estimatedSize := len(meta.WtVersion) + len(meta.OSVersion) + len(meta.KernelVersion) + len(meta.Hostname) + len(meta.SystemSerialNumber) +
len(pubip) + len(mac) + 6
func builderString(meta nbpeer.PeerSystemMeta) string {
estimatedSize := len(meta.WtVersion) + len(meta.OSVersion) + len(meta.KernelVersion) + len(meta.Hostname) + len(meta.SystemSerialNumber) + 4
var b strings.Builder
b.Grow(estimatedSize)
@@ -235,23 +224,10 @@ func builderString(meta nbpeer.PeerSystemMeta, pubip string) string {
b.WriteString(meta.Hostname)
b.WriteByte('|')
b.WriteString(meta.SystemSerialNumber)
b.WriteByte('|')
b.WriteString(pubip)
return b.String()
}
func getMacAddress(nas []nbpeer.NetworkAddress) string {
if len(nas) == 0 {
return ""
}
macs := make([]string, 0, len(nas))
for _, na := range nas {
macs = append(macs, na.Mac)
}
return strings.Join(macs, "/")
}
func BenchmarkLoginFilter_ParallelLoad(b *testing.B) {
filter := newLoginFilterWithCfg(testAdvancedCfg())
numKeys := 100000

View File

@@ -254,7 +254,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
return mapError(ctx, err)
}
metahashed := metaHash(peerMeta, sRealIP)
metahashed := metaHash(peerMeta)
if userID == "" && !s.loginFilter.allowLogin(peerKey.String(), metahashed) {
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountSyncRequestBlocked()
@@ -306,7 +306,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
log.WithContext(ctx).Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)
}
metahash := metaHash(peerMeta, realIP.String())
metahash := metaHash(peerMeta)
s.loginFilter.addLogin(peerKey.String(), metahash)
peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP, syncStart)
@@ -732,7 +732,7 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
}
peerMeta := extractPeerMeta(ctx, loginReq.GetMeta())
metahashed := metaHash(peerMeta, sRealIP)
metahashed := metaHash(peerMeta)
if !s.loginFilter.allowLogin(peerKey.String(), metahashed) {
if s.logBlockedPeers {
log.WithContext(ctx).Tracef("peer %s with meta hash %d is blocked from login", peerKey.String(), metahashed)
@@ -788,7 +788,11 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
ExtraDNSLabels: loginReq.GetDnsLabels(),
})
if err != nil {
log.WithContext(ctx).Warnf("failed logging in peer %s: %s", peerKey, err)
if errors.Is(err, internalStatus.ErrNoAuthMethodProvided) {
log.WithContext(ctx).Tracef("failed logging in peer %s: %s", peerKey, err)
} else {
log.WithContext(ctx).Warnf("failed logging in peer %s: %s", peerKey, err)
}
return nil, mapError(ctx, err)
}
@@ -1205,7 +1209,7 @@ func (s *Server) SyncMeta(ctx context.Context, req *proto.EncryptedMessage) (*pr
return nil, msg
}
err = s.accountManager.SyncPeerMeta(ctx, peerKey.String(), extractPeerMeta(ctx, syncMetaReq.GetMeta()))
err = s.accountManager.SyncPeerMeta(ctx, peerKey.String(), extractPeerMeta(ctx, syncMetaReq.GetMeta()), realIP)
if err != nil {
return nil, mapError(ctx, err)
}
@@ -1254,7 +1258,10 @@ func (s *Server) Logout(ctx context.Context, req *proto.EncryptedMessage) (*prot
func toProtocolChecks(ctx context.Context, postureChecks []*posture.Checks) []*proto.Checks {
protoChecks := make([]*proto.Checks, 0, len(postureChecks))
for _, postureCheck := range postureChecks {
protoChecks = append(protoChecks, toProtocolCheck(postureCheck))
check := toProtocolCheck(postureCheck)
if check != nil {
protoChecks = append(protoChecks, check)
}
}
return protoChecks
@@ -1278,5 +1285,9 @@ func toProtocolCheck(postureCheck *posture.Checks) *proto.Checks {
}
}
if len(protoCheck.Files) == 0 {
return nil
}
return protoCheck
}

View File

@@ -1889,12 +1889,12 @@ func domainIsUpToDate(domain string, domainCategory string, userAuth auth.UserAu
// concurrent stream that started earlier loses the optimistic-lock race
// in MarkPeerConnected and bails without writing.
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID)
peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, RealIP: realIP}, accountID)
if err != nil {
return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err)
}
if err := am.MarkPeerConnected(ctx, peerPubKey, realIP, accountID, syncTime.UnixNano(), netMap); err != nil {
if err := am.MarkPeerConnected(ctx, peerPubKey, accountID, syncTime.UnixNano(), netMap); err != nil {
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
}
@@ -1914,13 +1914,13 @@ func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, account
return nil
}
func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error {
func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) error {
accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, peerPubKey)
if err != nil {
return err
}
_, _, _, _, err = am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, accountID)
_, _, _, _, err = am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, RealIP: realIP, UpdateAccountPeers: true}, accountID)
if err != nil {
return err
}

View File

@@ -62,7 +62,7 @@ type Manager interface {
GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
ListUsers(ctx context.Context, accountID string) ([]*types.User, error)
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error
MarkPeerConnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error
MarkPeerDisconnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
UpdatePeer(ctx context.Context, accountID, userID string, p *nbpeer.Peer) (*nbpeer.Peer, error)
@@ -123,7 +123,7 @@ type Manager interface {
GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error)
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) error
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error)
GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error)

View File

@@ -1323,17 +1323,17 @@ func (mr *MockManagerMockRecorder) ExtendPeerSession(ctx, peerPubKey, userID int
}
// MarkPeerConnected mocks base method.
func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "MarkPeerConnected", ctx, peerKey, realIP, accountID, sessionStartedAt, nmap)
ret := m.ctrl.Call(m, "MarkPeerConnected", ctx, peerKey, accountID, sessionStartedAt, nmap)
ret0, _ := ret[0].(error)
return ret0
}
// MarkPeerConnected indicates an expected call of MarkPeerConnected.
func (mr *MockManagerMockRecorder) MarkPeerConnected(ctx, peerKey, realIP, accountID, sessionStartedAt, nmap interface{}) *gomock.Call {
func (mr *MockManagerMockRecorder) MarkPeerConnected(ctx, peerKey, accountID, sessionStartedAt, nmap interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnected", reflect.TypeOf((*MockManager)(nil).MarkPeerConnected), ctx, peerKey, realIP, accountID, sessionStartedAt, nmap)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnected", reflect.TypeOf((*MockManager)(nil).MarkPeerConnected), ctx, peerKey, accountID, sessionStartedAt, nmap)
}
// MarkPeerDisconnected mocks base method.
@@ -1586,17 +1586,17 @@ func (mr *MockManagerMockRecorder) SyncPeer(ctx, sync, accountID interface{}) *g
}
// SyncPeerMeta mocks base method.
func (m *MockManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta peer.PeerSystemMeta) error {
func (m *MockManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta peer.PeerSystemMeta, realIP net.IP) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SyncPeerMeta", ctx, peerPubKey, meta)
ret := m.ctrl.Call(m, "SyncPeerMeta", ctx, peerPubKey, meta, realIP)
ret0, _ := ret[0].(error)
return ret0
}
// SyncPeerMeta indicates an expected call of SyncPeerMeta.
func (mr *MockManagerMockRecorder) SyncPeerMeta(ctx, peerPubKey, meta interface{}) *gomock.Call {
func (mr *MockManagerMockRecorder) SyncPeerMeta(ctx, peerPubKey, meta, realIP interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SyncPeerMeta", reflect.TypeOf((*MockManager)(nil).SyncPeerMeta), ctx, peerPubKey, meta)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SyncPeerMeta", reflect.TypeOf((*MockManager)(nil).SyncPeerMeta), ctx, peerPubKey, meta, realIP)
}
// SyncUserJWTGroups mocks base method.

View File

@@ -1836,7 +1836,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano(), nil)
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), accountID, time.Now().UTC().UnixNano(), nil)
require.NoError(t, err, "unable to mark peer connected")
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
@@ -1907,7 +1907,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
require.NoError(t, err, "unable to get the account")
// when we mark peer as connected, the peer login expiration routine should trigger
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano(), nil)
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), accountID, time.Now().UTC().UnixNano(), nil)
require.NoError(t, err, "unable to mark peer connected")
failed := waitTimeout(wg, time.Second)
@@ -1916,6 +1916,117 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
}
}
func TestDefaultAccountManager_MarkPeerDisconnected_SchedulesInactivityExpiration(t *testing.T) {
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
peerPubKey := key.PublicKey().String()
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: peerPubKey,
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
InactivityExpirationEnabled: true,
}, false)
require.NoError(t, err, "unable to add peer")
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true,
PeerInactivityExpiration: time.Hour,
PeerInactivityExpirationEnabled: true,
Extra: &types.ExtraSettings{},
})
require.NoError(t, err, "expecting to update account settings successfully but got error")
// Establish a session so the matching-token disconnect is actually applied.
streamStartTime := time.Now().UTC()
err = manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano(), nil)
require.NoError(t, err, "unable to mark peer connected")
// Install the mock only now, so the assertion observes the disconnect, not
// the earlier connect.
scheduled := make(chan struct{}, 1)
manager.peerInactivityExpiry = &MockScheduler{
CancelFunc: func(ctx context.Context, IDs []string) {},
ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
select {
case scheduled <- struct{}{}:
default:
}
},
}
err = manager.MarkPeerDisconnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano())
require.NoError(t, err, "unable to mark peer disconnected")
select {
case <-scheduled:
// expected: disconnect re-armed the inactivity expiry timer
case <-time.After(time.Second):
t.Fatal("expected inactivity expiration to be rescheduled when an eligible peer disconnects")
}
}
func TestDefaultAccountManager_MarkPeerDisconnected_SkipsInactivityExpirationWhenDisabled(t *testing.T) {
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
peerPubKey := key.PublicKey().String()
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: peerPubKey,
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
InactivityExpirationEnabled: true,
}, false)
require.NoError(t, err, "unable to add peer")
// Peer is eligible (SSO + inactivity enabled) but the account-level setting
// stays disabled, so disconnect must not schedule anything.
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true,
PeerInactivityExpiration: time.Hour,
PeerInactivityExpirationEnabled: false,
Extra: &types.ExtraSettings{},
})
require.NoError(t, err, "expecting to update account settings successfully but got error")
streamStartTime := time.Now().UTC()
err = manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano(), nil)
require.NoError(t, err, "unable to mark peer connected")
scheduled := make(chan struct{}, 1)
manager.peerInactivityExpiry = &MockScheduler{
CancelFunc: func(ctx context.Context, IDs []string) {},
ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
select {
case scheduled <- struct{}{}:
default:
}
},
}
err = manager.MarkPeerDisconnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano())
require.NoError(t, err, "unable to mark peer disconnected")
select {
case <-scheduled:
t.Fatal("inactivity expiration must not be scheduled while the account-level setting is disabled")
case <-time.After(200 * time.Millisecond):
// expected: nothing scheduled
}
}
func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
@@ -1935,7 +2046,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
t.Run("disconnect peer when session token matches", func(t *testing.T) {
streamStartTime := time.Now().UTC()
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano(), nil)
err = manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano(), nil)
require.NoError(t, err, "unable to mark peer connected")
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
@@ -1956,7 +2067,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
t.Run("skip disconnect when stored session is newer (zombie stream protection)", func(t *testing.T) {
// Newer stream wins on connect (sets SessionStartedAt = now ns).
streamStartTime := time.Now().UTC()
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano(), nil)
err = manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano(), nil)
require.NoError(t, err, "unable to mark peer connected")
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
@@ -1980,7 +2091,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
t.Run("skip stale connect when stored session is newer (blocked goroutine protection)", func(t *testing.T) {
node2SyncTime := time.Now().UTC()
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node2SyncTime.UnixNano(), nil)
err = manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, node2SyncTime.UnixNano(), nil)
require.NoError(t, err, "node 2 should connect peer")
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
@@ -1990,7 +2101,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
"SessionStartedAt should equal node2SyncTime token")
node1StaleSyncTime := node2SyncTime.Add(-1 * time.Minute)
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node1StaleSyncTime.UnixNano(), nil)
err = manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, node1StaleSyncTime.UnixNano(), nil)
require.NoError(t, err, "stale connect should not return error")
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
@@ -2052,7 +2163,7 @@ func TestDefaultAccountManager_MarkPeerConnected_ConcurrentRace(t *testing.T) {
defer done.Done()
ready.Done()
start.Wait()
errs <- manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, token, nil)
errs <- manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, token, nil)
}()
}
@@ -2093,7 +2204,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano(), nil)
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), accountID, time.Now().UTC().UnixNano(), nil)
require.NoError(t, err, "unable to mark peer connected")
wg := &sync.WaitGroup{}

View File

@@ -41,7 +41,7 @@ func TestAffectedPeers_DependencyCoverageMatrix(t *testing.T) {
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
return affectedpeers.Change{ChangedPeerIDs: []string{s.routerPeerID}},
[]string{s.sourcePeerID}, []string{s.unrelatedPeerID}
[]string{s.sourcePeerID, s.routerPeerID}, []string{s.unrelatedPeerID}
},
},
{
@@ -106,12 +106,8 @@ func TestAffectedPeers_DependencyCoverageMatrix(t *testing.T) {
change, mustContain, mustExclude := r.build(t, s, ctx)
affected := resolveAffected(t, s.manager.Store, s.accountID, change)
for _, id := range mustContain {
assert.Contains(t, affected, id, "expected peer to be affected")
}
for _, id := range mustExclude {
assert.NotContains(t, affected, id, "peer must not be affected")
}
assert.ElementsMatch(t, affected, mustContain, "expected peer to be affected")
assert.NotContains(t, affected, mustExclude, "peer must not be affected")
})
}
}

View File

@@ -251,7 +251,9 @@ func TestAffectedPeers_E2E_UpdateResource_DestinationResourcePolicy_RefreshesSou
}
}
func TestAffectedPeers_E2E_UpdateResource_DisabledSiblingRouter_StillBridged(t *testing.T) {
// A disabled sibling router routes to nobody, so updating a resource on its network
// must NOT refresh its peer (the enabled router carries the bridge instead).
func TestAffectedPeers_E2E_UpdateResource_DisabledSiblingRouterNotBridged(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
@@ -274,13 +276,18 @@ func TestAffectedPeers_E2E_UpdateResource_DisabledSiblingRouter_StillBridged(t *
require.NoError(t, err)
disabledCh := s.updateManager.CreateChannel(ctx, disabledRouterPeer.ID)
t.Cleanup(func() { s.updateManager.CloseChannel(ctx, disabledRouterPeer.ID) })
enabledCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
t.Cleanup(func() {
s.updateManager.CloseChannel(ctx, disabledRouterPeer.ID)
s.updateManager.CloseChannel(ctx, s.routerPeerID)
})
settleAffectedUpdates(disabledCh)
settleAffectedUpdates(disabledCh, enabledCh)
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, disabledCh)
peerShouldReceiveUpdate(t, enabledCh)
peerShouldNotReceiveUpdate(t, disabledCh)
close(done)
}()
@@ -298,7 +305,7 @@ func TestAffectedPeers_E2E_UpdateResource_DisabledSiblingRouter_StillBridged(t *
select {
case <-done:
case <-time.After(peerUpdateTimeout):
t.Error("timeout: resource update did not refresh the disabled sibling router's peer")
t.Error("timeout")
}
}

View File

@@ -682,6 +682,9 @@ func TestAffectedPeers_AllRoutingPeers_Network(t *testing.T) {
assert.Contains(t, affected, secondRouterPeer.ID, "second routing peer on the same network must also be affected")
}
// A disabled router in the snapshot routes to nobody, so it is skipped when the
// walk scans existing account data: a policy edit still folds the literal source
// group, but not the disabled router's peer.
func TestAffectedPeers_DisabledRouter(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
@@ -694,11 +697,13 @@ func TestAffectedPeers_DisabledRouter(t *testing.T) {
affected := s.resolvePolicyAffected(ctx, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID))
assert.Contains(t, affected, s.sourcePeerID, "source peer must be affected")
assert.Contains(t, affected, s.routerPeerID,
"disabled router's peer must still be affected: Enabled must not gate affected-peers")
assert.Contains(t, affected, s.sourcePeerID, "source peer (literal policy source group) must be affected")
assert.NotContains(t, affected, s.routerPeerID,
"a disabled router routes to nobody, so its peer must not be folded from snapshot data")
}
// A disabled resource in the snapshot is skipped: the policy edit still folds the
// literal source group, but the resource no longer bridges to its network's router.
func TestAffectedPeers_DisabledResource(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
@@ -710,9 +715,9 @@ func TestAffectedPeers_DisabledResource(t *testing.T) {
affected := s.resolvePolicyAffected(ctx, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID))
assert.Contains(t, affected, s.sourcePeerID, "source peer must be affected")
assert.Contains(t, affected, s.routerPeerID,
"disabled resource must still resolve the routing peer: Enabled must not gate affected-peers")
assert.Contains(t, affected, s.sourcePeerID, "source peer (literal policy source group) must be affected")
assert.NotContains(t, affected, s.routerPeerID,
"a disabled resource routes to nobody, so its network's router must not be folded from snapshot data")
}
func TestAffectedPeers_DisabledRule(t *testing.T) {

View File

@@ -96,33 +96,54 @@ func affectedGroupID(i int) string { return fmt.Sprintf("affected-grp-%d", i)
func affectedGroupName(i int) string { return fmt.Sprintf("AffectedGroup%d", i) }
func TestCollectGroupChange_PolicyLinked(t *testing.T) {
manager, s, accountID, _, groupIDs := setupAffectedPeersTest(t)
manager, s, accountID, peerIDs, groupIDs := setupAffectedPeersTest(t)
ctx := context.Background()
_, err := manager.SavePolicy(ctx, accountID, userID, &types.Policy{
Enabled: true,
Rules: []*types.PolicyRule{
{
Enabled: true,
Sources: []string{groupIDs[0]},
Destinations: []string{groupIDs[1]},
Bidirectional: true,
Action: types.PolicyTrafficActionAccept,
Enabled: true,
Sources: []string{groupIDs[0]},
Destinations: []string{groupIDs[1]},
SourceResource: types.Resource{ID: peerIDs[0], Type: types.ResourceTypePeer},
DestinationResource: types.Resource{ID: peerIDs[1], Type: types.ResourceTypePeer},
Bidirectional: true,
Action: types.PolicyTrafficActionAccept,
},
{
Enabled: true,
Sources: []string{groupIDs[0]},
Destinations: []string{groupIDs[1]},
SourceResource: types.Resource{ID: peerIDs[2], Type: types.ResourceTypeHost},
DestinationResource: types.Resource{ID: peerIDs[3], Type: types.ResourceTypeHost},
Bidirectional: true,
Action: types.PolicyTrafficActionAccept,
},
{
Enabled: true,
Sources: []string{groupIDs[0]},
Destinations: []string{groupIDs[1]},
SourceResource: types.Resource{ID: "", Type: types.ResourceTypePeer},
DestinationResource: types.Resource{ID: "", Type: types.ResourceTypePeer},
Bidirectional: true,
Action: types.PolicyTrafficActionAccept,
},
},
}, true)
require.NoError(t, err)
groups, _ := collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[0]})
assert.Contains(t, groups, groupIDs[0])
assert.Contains(t, groups, groupIDs[1])
groups, directPeers := collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[0]})
assert.ElementsMatch(t, groups, []string{groupIDs[0], groupIDs[1]})
assert.ElementsMatch(t, directPeers, []string{peerIDs[1]})
groups, _ = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[1]})
assert.Contains(t, groups, groupIDs[0])
assert.Contains(t, groups, groupIDs[1])
groups, directPeers = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[1]})
assert.ElementsMatch(t, groups, []string{groupIDs[0], groupIDs[1]})
assert.ElementsMatch(t, directPeers, []string{peerIDs[0]})
groups, _ = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[2]})
groups, directPeers = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[2]})
assert.Empty(t, groups)
assert.Empty(t, directPeers)
}
func TestCollectGroupChange_PolicyWithDirectPeerResource(t *testing.T) {
@@ -133,20 +154,44 @@ func TestCollectGroupChange_PolicyWithDirectPeerResource(t *testing.T) {
Enabled: true,
Rules: []*types.PolicyRule{
{
Enabled: true,
Sources: []string{groupIDs[0]},
SourceResource: types.Resource{ID: peerIDs[3], Type: types.ResourceTypePeer},
Destinations: []string{groupIDs[1]},
Action: types.PolicyTrafficActionAccept,
Enabled: true,
Sources: []string{groupIDs[0]},
SourceResource: types.Resource{ID: peerIDs[3], Type: types.ResourceTypePeer},
DestinationResource: types.Resource{ID: peerIDs[4], Type: types.ResourceTypePeer},
Destinations: []string{groupIDs[1]},
Action: types.PolicyTrafficActionAccept,
},
{
Enabled: true,
Sources: []string{groupIDs[0]},
SourceResource: types.Resource{ID: peerIDs[1], Type: types.ResourceTypeHost},
DestinationResource: types.Resource{ID: peerIDs[2], Type: types.ResourceTypeHost},
Destinations: []string{groupIDs[1]},
Action: types.PolicyTrafficActionAccept,
},
{
Enabled: true,
Sources: []string{groupIDs[0]},
SourceResource: types.Resource{ID: "", Type: types.ResourceTypePeer},
DestinationResource: types.Resource{ID: "", Type: types.ResourceTypePeer},
Destinations: []string{groupIDs[1]},
Action: types.PolicyTrafficActionAccept,
},
},
}, true)
require.NoError(t, err)
groups, directPeers := collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[0]})
assert.Contains(t, groups, groupIDs[0])
assert.Contains(t, groups, groupIDs[1])
assert.Contains(t, directPeers, peerIDs[3])
assert.ElementsMatch(t, groups, []string{groupIDs[0], groupIDs[1]})
assert.ElementsMatch(t, directPeers, []string{peerIDs[4]})
groups, directPeers = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[1]})
assert.ElementsMatch(t, groups, []string{groupIDs[0], groupIDs[1]})
assert.ElementsMatch(t, directPeers, []string{peerIDs[3]})
groups, directPeers = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[2]})
assert.Empty(t, groups)
assert.Empty(t, directPeers)
}
func TestCollectGroupChange_PolicyWithNonPeerResource_NoDirectPeers(t *testing.T) {
@@ -168,8 +213,7 @@ func TestCollectGroupChange_PolicyWithNonPeerResource_NoDirectPeers(t *testing.T
require.NoError(t, err)
groups, directPeers := collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[0]})
assert.Contains(t, groups, groupIDs[0])
assert.Contains(t, groups, groupIDs[1])
assert.ElementsMatch(t, groups, []string{groupIDs[0], groupIDs[1]})
assert.Empty(t, directPeers, "non-peer resources should not produce direct peer IDs")
}
@@ -294,6 +338,7 @@ func TestCollectGroupChange_NetworkRouterLinked(t *testing.T) {
AccountID: accountID,
PeerGroups: []string{groupIDs[0]},
Peer: peerIDs[3],
Enabled: true,
})
require.NoError(t, err)
@@ -324,6 +369,7 @@ func TestCollectGroupChange_NetworkRouterPeerOnlyNoGroups(t *testing.T) {
NetworkID: net1.ID,
AccountID: accountID,
Peer: peerIDs[4],
Enabled: true,
})
require.NoError(t, err)
@@ -373,17 +419,11 @@ func TestCollectGroupChange_MultipleEntities(t *testing.T) {
require.NoError(t, err)
groups, directPeers := collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[0]})
assert.Contains(t, groups, groupIDs[0])
assert.Contains(t, groups, groupIDs[1])
assert.NotContains(t, groups, groupIDs[2])
assert.NotContains(t, groups, groupIDs[3])
assert.ElementsMatch(t, groups, []string{groupIDs[0], groupIDs[1]})
assert.Empty(t, directPeers)
groups, directPeers = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[3]})
assert.Contains(t, groups, groupIDs[2])
assert.Contains(t, groups, groupIDs[3])
assert.NotContains(t, groups, groupIDs[0])
assert.NotContains(t, groups, groupIDs[1])
assert.ElementsMatch(t, groups, []string{groupIDs[2], groupIDs[3]})
assert.Empty(t, directPeers)
}
@@ -452,8 +492,9 @@ func TestResolveAffectedPeers_PolicyBetweenTwoGroups(t *testing.T) {
result = manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[1]})
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[1]}, result)
// peerIDs[2] is unrelated to the route; only its own map can change.
result = manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[2]})
assert.Empty(t, result)
assert.ElementsMatch(t, []string{peerIDs[2]}, result)
}
func TestResolveAffectedPeers_PolicyThreeGroups(t *testing.T) {
@@ -474,7 +515,7 @@ func TestResolveAffectedPeers_PolicyThreeGroups(t *testing.T) {
require.NoError(t, err)
result := manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[0]})
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[1], peerIDs[2]}, result)
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[2]}, result)
}
func TestResolveAffectedPeers_RoutePeerGroups(t *testing.T) {
@@ -506,8 +547,9 @@ func TestResolveAffectedPeers_RoutePeerGroups(t *testing.T) {
result = manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[1]})
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[1]}, result)
// peerIDs[2] is in no policy; only its own map can change, so it refreshes itself.
result = manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[2]})
assert.Empty(t, result)
assert.ElementsMatch(t, []string{peerIDs[2]}, result)
}
func TestResolveAffectedPeers_RouteWithDirectPeer(t *testing.T) {
@@ -564,9 +606,9 @@ func TestResolveAffectedPeers_RouteWithAccessControlGroups(t *testing.T) {
result := manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[2]})
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[1], peerIDs[2]}, result)
// peer3 is unrelated
// peer3 is unrelated to the route; only its own map can change.
result = manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[3]})
assert.Empty(t, result)
assert.ElementsMatch(t, []string{peerIDs[3]}, result)
}
func TestResolveAffectedPeers_NetworkRouter(t *testing.T) {
@@ -587,6 +629,7 @@ func TestResolveAffectedPeers_NetworkRouter(t *testing.T) {
AccountID: accountID,
PeerGroups: []string{groupIDs[0]},
Peer: peerIDs[3],
Enabled: true,
})
require.NoError(t, err)
@@ -659,9 +702,13 @@ func TestResolveAffectedPeers_PeerInMultipleGroups(t *testing.T) {
}, true)
require.NoError(t, err)
// peer0 is in group0 AND group1, so both policies apply
// peer0 is in group0 AND group1, so both policies apply. A peer change folds
// only the changed peer plus the opposite side of each rule: group2 (peer2) via
// the group0 policy and group3 (peer3) via the group1 policy. peer1, a co-member
// of group1, is a sibling of the changed peer and must NOT refresh.
result := manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[0]})
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[1], peerIDs[2], peerIDs[3]}, result)
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[2], peerIDs[3]}, result)
assert.NotContains(t, result, peerIDs[1], "co-member of the changed peer's group must not refresh")
}
func TestResolveAffectedPeers_MultipleChangedPeers(t *testing.T) {
@@ -697,7 +744,7 @@ func TestResolveAffectedPeers_MultipleChangedPeers(t *testing.T) {
require.NoError(t, err)
result := manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[0], peerIDs[2]})
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[1], peerIDs[2], peerIDs[3]}, result)
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[2], peerIDs[1], peerIDs[3]}, result)
}
func TestResolveAffectedPeers_SharedGroupAcrossPolicyAndRoute(t *testing.T) {
@@ -854,8 +901,9 @@ func TestAffectedPeers_IsolatedPolicies(t *testing.T) {
assert.NotContains(t, result, peerIDs[0])
assert.NotContains(t, result, peerIDs[1])
// peerIDs[4] is in neither isolated policy; only its own map can change.
result = manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[4]})
assert.Empty(t, result)
assert.ElementsMatch(t, []string{peerIDs[4]}, result)
}
func TestAffectedPeers_IsolatedRouteAndPolicy(t *testing.T) {
@@ -977,12 +1025,13 @@ func TestAffectedPeers_GroupUpdateOnlyAffectsLinkedPeers(t *testing.T) {
})
}
func TestAffectedPeers_UnlinkedGroupChange_NoUpdates(t *testing.T) {
// A peer in no policy/route refreshes only itself — no other peer is affected.
func TestAffectedPeers_UnlinkedPeerChange_RefreshesSelfOnly(t *testing.T) {
manager, s, accountID, peerIDs, _ := setupAffectedPeersTest(t)
ctx := context.Background()
result := manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[0]})
assert.Empty(t, result)
assert.ElementsMatch(t, []string{peerIDs[0]}, result)
}
// TestAffectedPeers_PolicyChange_UnrelatedPeerNoUpdate verifies that creating/deleting a
@@ -1332,6 +1381,7 @@ func TestAffectedPeers_NetworkRouterUnlinkedPeerNoUpdate(t *testing.T) {
NetworkID: net1.ID,
AccountID: accountID,
PeerGroups: []string{"nr-grpA"},
Enabled: true,
})
require.NoError(t, err)
@@ -1755,7 +1805,9 @@ func TestCollectAffectedFromProxyServices_GroupContainingTargetPeerChanged(t *te
assert.Contains(t, directPeers, peerIDs[1], "target peer must be refreshed")
}
func TestCollectAffectedFromProxyServices_DisabledServiceStillMatches(t *testing.T) {
// A disabled service in the snapshot proxies nothing, so it is skipped: a changed
// target peer does not pull in the service's proxy peer.
func TestCollectAffectedFromProxyServices_DisabledServiceSkipped(t *testing.T) {
manager, s, accountID, peerIDs, _ := setupAffectedPeersTest(t)
ctx := context.Background()
@@ -1781,8 +1833,7 @@ func TestCollectAffectedFromProxyServices_DisabledServiceStillMatches(t *testing
require.NoError(t, s.CreateService(ctx, svc))
_, directPeers := collectPeerChangeAffectedGroups(ctx, manager.Store, accountID, nil, []string{peerIDs[1]})
assert.Contains(t, directPeers, peerIDs[0], "disabled service should still trigger a refresh so peers are ready when re-enabled")
assert.Contains(t, directPeers, peerIDs[1], "disabled target should still trigger a refresh")
assert.NotContains(t, directPeers, peerIDs[0], "a disabled service proxies nothing, so its proxy peer must not be folded")
}
func TestCollectAffectedFromProxyServices_NonPeerTargetType(t *testing.T) {

View File

@@ -6,7 +6,12 @@
// and before a delete/removal severs the old state).
// - Snapshot.Expand: in-memory walk, no store access. Run AFTER the tx commits.
//
// Enabled is never consulted: toggling it is itself an observable change.
// Enabled handling differs by source. Disabled objects in the SNAPSHOT (existing
// account policies/resources/routers/routes/proxy services and their rules/targets)
// route to nobody and are skipped — they cannot affect any peer's map. Objects in
// the CHANGE itself are processed regardless of Enabled, so disabling one still
// refreshes the peers that lose access (the toggle is the observable change, and the
// update carries the oldnew state).
package affectedpeers
import (
@@ -61,7 +66,8 @@ func Load(ctx context.Context, s store.Store, accountID string, c Change) (*Snap
// loadCollections reads the policy/route/nameserver/dns/router/resource/proxy
// collections a Change can touch, gated to what the walk needs.
func (snap *Snapshot) loadCollections(ctx context.Context, s store.Store, accountID string, c Change) error {
hasGroupOrPeerChange := len(c.ChangedGroupIDs) > 0 || len(c.ChangedPeerIDs) > 0 || len(c.Resources) > 0
// LinkGroups drive the same policy/route/dns walk as a changed group or peer.
hasGroupOrPeerChange := len(c.ChangedGroupIDs) > 0 || len(c.ChangedPeerIDs) > 0 || len(c.LinkGroups) > 0 || len(c.Resources) > 0
hasNetworkObject := len(c.Routers) > 0 || len(c.Resources) > 0 || len(c.Networks) > 0
// the resource<->router bridge can fire for any of these
needsRoutersResources := hasGroupOrPeerChange || len(c.PostureCheckIDs) > 0 || len(c.Policies) > 0 || hasNetworkObject
@@ -76,7 +82,7 @@ func (snap *Snapshot) loadCollections(ctx context.Context, s store.Store, accoun
return err
}
}
if len(c.ChangedGroupIDs) > 0 || len(c.ChangedPeerIDs) > 0 {
if len(c.ChangedGroupIDs) > 0 || len(c.ChangedPeerIDs) > 0 || len(c.LinkGroups) > 0 {
if err := snap.loadDNS(ctx, s, accountID); err != nil {
return err
}
@@ -174,6 +180,24 @@ type Change struct {
// folded in — but only when the group is linked (an unlinked group has no map
// impact), matching how current members are handled.
RemovedPeersByGroup map[string][]string
// OutputPeerIDs are peers folded straight into the result without seeding their
// group memberships into the walk. Use for the peer whose group membership changed:
// the peer itself must refresh, but its OTHER groups did not change, so they must
// not be walked. Contrast ChangedPeerIDs, which seeds ALL of the peer's groups
// (correct when the peer's own attributes changed, e.g. IP/status).
OutputPeerIDs []string
// LinkGroups are groups used ONLY to match policies/routes/routers and walk to the
// OPPOSITE side — they are never expanded to their own members. Use this when a
// peer's group membership changed: pass the peer in ChangedPeerIDs and its
// group(s) here. The opposite side of the policies the group participates in
// refreshes, but the group's other members (siblings) do not — nothing changed for
// them. For an intra-group policy (A→A) the opposite side IS the group, so its
// members still refresh via the opposite-side fold, exactly when they genuinely
// gain/lose the changed peer. Unlike ChangedGroupIDs, a LinkGroup is not added to
// the output, so a one-sided membership change never wakes the whole group.
LinkGroups []string
}
func (c Change) isEmpty() bool {
@@ -186,7 +210,9 @@ func (c Change) isEmpty() bool {
len(c.Networks) == 0 &&
len(c.PostureCheckIDs) == 0 &&
len(c.DistributionGroupIDs) == 0 &&
len(c.RemovedPeersByGroup) == 0
len(c.RemovedPeersByGroup) == 0 &&
len(c.LinkGroups) == 0 &&
len(c.OutputPeerIDs) == 0
}
// Expand returns the deduplicated affected peer IDs from the preloaded Snapshot,
@@ -197,8 +223,8 @@ func (snap *Snapshot) Expand(ctx context.Context, accountID string, c Change) []
return nil
}
r := newResolver(ctx, snap, accountID, c)
log.WithContext(ctx).Tracef("affectedpeers expand start: account=%s changedGroups=%v changedPeers=%v policies=%d routes=%d routers=%d resources=%d networks=%d postureChecks=%v distributionGroups=%v",
accountID, c.ChangedGroupIDs, c.ChangedPeerIDs, len(c.Policies), len(c.Routes), len(c.Routers), len(c.Resources), len(c.Networks), c.PostureCheckIDs, c.DistributionGroupIDs)
log.WithContext(ctx).Tracef("affectedpeers expand start: account=%s changedGroups=%v changedPeers=%v linkGroups=%v policies=%d routes=%d routers=%d resources=%d networks=%d postureChecks=%v distributionGroups=%v",
accountID, c.ChangedGroupIDs, c.ChangedPeerIDs, c.LinkGroups, len(c.Policies), len(c.Routes), len(c.Routers), len(c.Resources), len(c.Networks), c.PostureCheckIDs, c.DistributionGroupIDs)
r.walk()
return r.expand()
}
@@ -216,57 +242,84 @@ func Collect(ctx context.Context, s store.Store, accountID string, c Change) (gr
}
r := newResolver(ctx, snap, accountID, c)
r.walk()
return setToSlice(r.groupSet), setToSlice(r.peerSet)
return setToSlice(r.affectedGroups), setToSlice(r.affectedPeers)
}
func newResolver(ctx context.Context, snap *Snapshot, accountID string, c Change) *resolver {
r := &resolver{
ctx: ctx,
snap: snap,
accountID: accountID,
change: c,
changedGroupSet: toSet(c.ChangedGroupIDs),
changedPeerSet: toSet(c.ChangedPeerIDs),
groupSet: make(map[string]struct{}),
peerSet: make(map[string]struct{}),
networkIDs: make(map[string]struct{}),
ctx: ctx,
snap: snap,
accountID: accountID,
change: c,
linkGroups: toSet(c.ChangedGroupIDs),
outputGroups: toSet(c.ChangedGroupIDs),
changedPeers: toSet(c.ChangedPeerIDs),
affectedGroups: make(map[string]struct{}),
affectedPeers: make(map[string]struct{}),
}
// LinkGroups match policies/routes to find the opposite side but are NOT output:
// they go into linkGroups only, never outputGroups, so their members never fold in.
addAll(r.linkGroups, c.LinkGroups)
// Resolve each changed peer to its groups here so callers pass only ChangedPeerIDs.
r.seedChangedGroupsFromPeers()
r.matchedPolicies = append(r.matchedPolicies, c.Policies...)
return r
}
// seedChangedGroupsFromPeers adds each changed peer's groups to changedGroupSet so
// seedChangedGroupsFromPeers adds each changed peer's groups to linkGroups so
// the group-driven walkers fire for memberships, not just direct peer references.
// These seeded groups are for MATCHING only — folding the changed entity's own
// side is gated on outputGroups (the caller-reported groups), so a seeded group
// never folds its whole membership; only the changed peer itself folds in.
func (r *resolver) seedChangedGroupsFromPeers() {
if len(r.changedPeerSet) == 0 {
if len(r.changedPeers) == 0 {
return
}
for groupID, members := range r.snap.groupPeers {
for pID := range r.changedPeerSet {
for pID := range r.changedPeers {
if _, ok := members[pID]; ok {
r.changedGroupSet[groupID] = struct{}{}
r.linkGroups[groupID] = struct{}{}
break
}
}
}
}
// policySide selects which side of a policy rule to walk.
type policySide int
const (
sideSource policySide = iota
sideDestination
)
func (s policySide) opposite() policySide {
if s == sideSource {
return sideDestination
}
return sideSource
}
// walk resolves affected peers in two buckets, by how far each change propagates.
//
// BOTH-SIDES — the rule itself changed (an explicit policy edit, or a policy whose
// posture check changed). Source AND destination refresh, so each such policy is
// walked on both sides.
//
// OPPOSITE-SIDE — an endpoint moved but no rule changed. For each policy the change
// touches we fold only the side AWAY from the change:
// - a changed peer/group sits ON a policy side -> fold the opposite side;
// - a changed router/resource/network sits on a NETWORK -> fold the SOURCE side of
// the policies whose destination reaches it (and the routers it implies).
//
// Routes, nameserver groups, DNS and embedded-proxy services distribute to their own
// member peers, outside the policy graph, and are folded here too.
func (r *resolver) walk() {
r.collectFromExplicitPolicies()
r.collectFromExplicitRoutes(r.change.Routes)
r.collectFromExplicitRouters(r.change.Routers)
r.collectFromExplicitResources(r.change.Resources)
r.collectFromExplicitNetworks(r.change.Networks)
r.collectFromPostureChecks(r.change.PostureCheckIDs)
for _, policy := range r.bothSidesPolicies() {
r.foldPolicySide(policy, sideSource)
r.foldPolicySide(policy, sideDestination)
}
// Distribution groups (nameserver/DNS) affect only their member peers: fold them
// straight into groupSet so expand() maps them to members, without the policy/
// route walk that changedGroupSet would trigger.
addAll(r.groupSet, r.change.DistributionGroupIDs)
if len(r.changedGroupSet) > 0 || len(r.changedPeerSet) > 0 {
if len(r.linkGroups) > 0 || len(r.changedPeers) > 0 {
r.collectFromPolicies()
r.collectFromRoutes()
r.collectFromNameServers()
@@ -275,7 +328,31 @@ func (r *resolver) walk() {
r.collectFromProxyServices()
}
r.collectResourceRouterBridge()
r.collectFromChangedRoutes(r.change.Routes)
r.collectFromChangedRouters(r.change.Routers)
r.collectFromChangedResources(r.change.Resources)
r.collectFromChangedNetworks(r.change.Networks)
// The explicitly changed peers always refresh their own maps. OnPeersUpdated only
// refreshes the resolver's output (it ignores the separately-passed changed peers),
// so the changed peer reaches its own new map only via here. An offline/deleted
// peer in the set is filtered downstream (filterConnectedAffectedPeers).
addAll(r.affectedPeers, setToSlice(r.changedPeers))
// OutputPeerIDs refresh themselves too, but unlike changedPeers their group
// memberships were not seeded into the walk (only the changed group was).
addAll(r.affectedPeers, r.change.OutputPeerIDs)
// Distribution groups (nameserver/DNS) affect only their member peers: fold them
// straight into affectedGroups so expand() maps them to members, without the
// policy/route walk that linkGroups would trigger.
addAll(r.affectedGroups, r.change.DistributionGroupIDs)
}
// bothSidesPolicies are the policies whose rule changed: the explicitly edited ones
// plus those gated by a changed posture check. walk folds both their sides.
func (r *resolver) bothSidesPolicies() []*types.Policy {
policies := append([]*types.Policy(nil), r.change.Policies...)
return r.appendPoliciesForPostureChecks(policies, r.change.PostureCheckIDs)
}
type resolver struct {
@@ -284,27 +361,71 @@ type resolver struct {
accountID string
change Change
changedGroupSet map[string]struct{}
changedPeerSet map[string]struct{}
// Inputs — what changed. Set once at construction, read-only during the walk
// (except linkGroups, which collectFromExplicitResources also seeds).
//
// linkGroups is the MATCH set: caller-changed groups the groups of changed
// peers changed-resource groups. A rule/route/router matches the change when
// one of its groups is here — used only to find the opposite side to fold.
//
// outputGroups is the FOLD-WHOLE-GROUP set: ONLY Change.ChangedGroupIDs. When a
// matched group is here, its whole membership is affected. A peer-seeded group
// is in linkGroups but NOT outputGroups, so it folds only the changed peer
// (changedPeers), never its siblings.
linkGroups map[string]struct{}
outputGroups map[string]struct{}
changedPeers map[string]struct{}
groupSet map[string]struct{}
peerSet map[string]struct{}
matchedPolicies []*types.Policy
networkIDs map[string]struct{}
// Outputs — the answer. The only sets the walk accumulates into. affectedGroups
// is expanded to its member peers in expand().
affectedGroups map[string]struct{}
affectedPeers map[string]struct{}
}
func (r *resolver) policies() []*types.Policy { return r.snap.policies }
// policies returns the account's ENABLED policies from the snapshot. Disabled
// policies grant no access, so the walk skips them when scanning existing account
// data. Explicitly changed policies (Change.Policies, via bothSidesPolicies) are
// processed regardless of Enabled, so disabling one still refreshes its peers.
func (r *resolver) policies() []*types.Policy {
enabled := make([]*types.Policy, 0, len(r.snap.policies))
for _, policy := range r.snap.policies {
if policy != nil && policy.Enabled {
enabled = append(enabled, policy)
}
}
return enabled
}
func (r *resolver) networkResources() []*resourceTypes.NetworkResource { return r.snap.resources }
// networkResources / networkRouters return the account's ENABLED resources/routers
// from the snapshot. Disabled objects route to nobody, so the walk skips them when
// it scans existing account data. The explicitly changed objects in the Change are
// processed regardless of Enabled (collectFromChanged*), so disabling one still
// refreshes the peers that lose access.
func (r *resolver) networkResources() []*resourceTypes.NetworkResource {
enabled := make([]*resourceTypes.NetworkResource, 0, len(r.snap.resources))
for _, resource := range r.snap.resources {
if resource.Enabled {
enabled = append(enabled, resource)
}
}
return enabled
}
func (r *resolver) networkRouters() []*routerTypes.NetworkRouter { return r.snap.routers }
func (r *resolver) networkRouters() []*routerTypes.NetworkRouter {
enabled := make([]*routerTypes.NetworkRouter, 0, len(r.snap.routers))
for _, router := range r.snap.routers {
if router.Enabled {
enabled = append(enabled, router)
}
}
return enabled
}
// peerIDsForGroups maps a group set to its member peer IDs via the preloaded index.
func (r *resolver) peerIDsForGroups(groupSet map[string]struct{}) []string {
func (r *resolver) peerIDsForGroups(groups map[string]struct{}) []string {
seen := make(map[string]struct{})
var ids []string
for gID := range groupSet {
for gID := range groups {
for pID := range r.snap.groupPeers[gID] {
if _, ok := seen[pID]; ok {
continue
@@ -317,25 +438,25 @@ func (r *resolver) peerIDsForGroups(groupSet map[string]struct{}) []string {
}
func (r *resolver) expand() []string {
peerIDs := r.peerIDsForGroups(r.groupSet)
peerIDs := r.peerIDsForGroups(r.affectedGroups)
log.WithContext(r.ctx).Tracef("affectedpeers expand: account=%s affectedGroups=%v -> %d group-member peers; direct peers=%v",
r.accountID, setToSlice(r.groupSet), len(peerIDs), setToSlice(r.peerSet))
r.accountID, setToSlice(r.affectedGroups), len(peerIDs), setToSlice(r.affectedPeers))
seen := make(map[string]struct{}, len(peerIDs))
for _, id := range peerIDs {
seen[id] = struct{}{}
}
for id := range r.peerSet {
for id := range r.affectedPeers {
if _, ok := seen[id]; !ok {
peerIDs = append(peerIDs, id)
seen[id] = struct{}{}
}
}
// Fold in removed peers only when their group is linked (in groupSet).
// Fold in removed peers only when their group is linked (in affectedGroups).
for groupID, removed := range r.change.RemovedPeersByGroup {
if _, linked := r.groupSet[groupID]; !linked {
if _, linked := r.affectedGroups[groupID]; !linked {
continue
}
for _, id := range removed {
@@ -351,169 +472,309 @@ func (r *resolver) expand() []string {
return peerIDs
}
func (r *resolver) collectFromExplicitPolicies() {
for _, policy := range r.matchedPolicies {
if policy == nil {
continue
}
log.WithContext(r.ctx).Tracef("collectFromExplicitPolicies: changed policy %s (%s) -> folding rule groups %v + direct peers",
policy.ID, policy.Name, policy.RuleGroups())
addAll(r.groupSet, policy.RuleGroups())
collectPolicyDirectPeers(policy, r.peerSet)
// ruleSideGroups / ruleSideResource return the groups and the resource on the given
// side of a rule.
func ruleSideGroups(rule *types.PolicyRule, side policySide) []string {
if side == sideDestination {
return rule.Destinations
}
return rule.Sources
}
func (r *resolver) collectFromExplicitRoutes(routes []*route.Route) {
for _, rt := range routes {
if rt == nil {
continue
}
log.WithContext(r.ctx).Tracef("collectFromExplicitRoutes: changed route %s -> folding groups=%v peerGroups=%v accessControlGroups=%v peer=%q",
rt.ID, rt.Groups, rt.PeerGroups, rt.AccessControlGroups, rt.Peer)
addAll(r.groupSet, rt.Groups, rt.PeerGroups, rt.AccessControlGroups)
if rt.Peer != "" {
r.peerSet[rt.Peer] = struct{}{}
}
func ruleSideResource(rule *types.PolicyRule, side policySide) types.Resource {
if side == sideDestination {
return rule.DestinationResource
}
return rule.SourceResource
}
// collectFromExplicitRouters folds changed routers' peers and marks their networks
// for the bridge. Passing the old router keeps a repointed router's previous peers
// affected without a post-commit read.
func (r *resolver) collectFromExplicitRouters(routers []*routerTypes.NetworkRouter) {
for _, router := range routers {
if router == nil {
continue
}
log.WithContext(r.ctx).Tracef("collectFromExplicitRouters: changed router %s on network %s -> folding peerGroups=%v peer=%q and marking network for source bridge",
router.ID, router.NetworkID, router.PeerGroups, router.Peer)
addAll(r.groupSet, router.PeerGroups)
if router.Peer != "" {
r.peerSet[router.Peer] = struct{}{}
}
if router.NetworkID != "" {
r.networkIDs[router.NetworkID] = struct{}{}
}
}
}
// collectFromExplicitResources marks changed resources' networks for the bridge and
// treats their group IDs as changed, so policies targeting the resource via a
// now-detached (old) group still refresh.
func (r *resolver) collectFromExplicitResources(resources []*resourceTypes.NetworkResource) {
for _, resource := range resources {
if resource == nil {
continue
}
log.WithContext(r.ctx).Tracef("collectFromExplicitResources: changed resource %s on network %s -> marking network for bridge and treating groups %v as changed",
resource.ID, resource.NetworkID, resource.GroupIDs)
addAll(r.changedGroupSet, resource.GroupIDs)
if resource.NetworkID != "" {
r.networkIDs[resource.NetworkID] = struct{}{}
}
}
}
// collectFromExplicitNetworks marks changed networks for the bridge. A network has
// no groups/peers of its own.
func (r *resolver) collectFromExplicitNetworks(networks []*networkTypes.Network) {
for _, network := range networks {
if network == nil {
continue
}
log.WithContext(r.ctx).Tracef("collectFromExplicitNetworks: changed network %s -> marking for bridge", network.ID)
if network.ID != "" {
r.networkIDs[network.ID] = struct{}{}
}
}
}
func (r *resolver) collectFromPostureChecks(postureCheckIDs []string) {
if len(postureCheckIDs) == 0 {
// foldPolicySide folds one side of a policy down to affected peers: its groups
// (resolved to members in expand) and its direct peer. When the side is the
// DESTINATION and references a network resource (directly or via a destination
// group's resources), it also folds the routers that serve that resource's network
// — a destination resource is reached through its routers. A resource on the SOURCE
// side routes to nobody (GetPoliciesForNetworkResource matches destinations only),
// so the router hop is destination-only.
func (r *resolver) foldPolicySide(policy *types.Policy, side policySide) {
if policy == nil {
return
}
for _, rule := range policy.Rules {
addAll(r.affectedGroups, ruleSideGroups(rule, side))
res := ruleSideResource(rule, side)
if res.Type == types.ResourceTypePeer && res.ID != "" {
r.affectedPeers[res.ID] = struct{}{}
}
}
if side == sideDestination {
r.foldRoutersForResources(r.policyDestinationResourceIDs(policy))
}
}
// appendPoliciesForPostureChecks appends every policy that references a changed
// posture check (a rule change, so walk both sides).
func (r *resolver) appendPoliciesForPostureChecks(policies []*types.Policy, postureCheckIDs []string) []*types.Policy {
if len(postureCheckIDs) == 0 {
return policies
}
ids := toSet(postureCheckIDs)
for _, policy := range r.policies() {
if !policyReferencesPostureChecks(policy, ids) {
continue
}
log.WithContext(r.ctx).Tracef("collectFromPostureChecks: policy %s (%s) references changed posture checks %v -> folding rule groups %v + direct peers",
policy.ID, policy.Name, postureCheckIDs, policy.RuleGroups())
addAll(r.groupSet, policy.RuleGroups())
collectPolicyDirectPeers(policy, r.peerSet)
r.matchedPolicies = append(r.matchedPolicies, policy)
log.WithContext(r.ctx).Tracef("appendPoliciesForPostureChecks: policy %s (%s) references changed posture checks %v -> both-sides policy",
policy.ID, policy.Name, postureCheckIDs)
policies = append(policies, policy)
}
return policies
}
// collectFromPolicies folds, for every policy whose rule a changed group or peer
// touches, only the OPPOSITE side (down to peers, incl. destination routers), plus
// the changed entity's own side: the changed group's whole membership when the
// group itself changed (outputGroups), or the changed peer alone when matched via a
// peer-seeded group (never its co-members).
func (r *resolver) collectFromPolicies() {
for _, policy := range r.policies() {
for _, rule := range policy.Rules {
if !rule.Enabled {
continue // a disabled rule grants no access
}
r.foldRuleSideIfChanged(policy, rule, sideSource)
r.foldRuleSideIfChanged(policy, rule, sideDestination)
}
}
}
func (r *resolver) collectFromPolicies() {
for _, policy := range r.policies() {
matchedByGroup := policyReferencesGroups(policy, r.changedGroupSet)
matchedByPeer := len(r.changedPeerSet) > 0 && policyReferencesDirectPeers(policy, r.changedPeerSet)
if !matchedByGroup && !matchedByPeer {
// foldRuleSideIfChanged: when a changed group or direct peer sits on `side` of the
// rule, fold the opposite side fully (groups/peers + destination routers) and fold
// the changed entity's own side (the whole changed group, or the changed peer alone).
func (r *resolver) foldRuleSideIfChanged(policy *types.Policy, rule *types.PolicyRule, side policySide) {
nearGroups := ruleSideGroups(rule, side)
nearResource := ruleSideResource(rule, side)
matchedByGroup := anyInSet(nearGroups, r.linkGroups)
matchedByPeer := isDirectPeerInSet(nearResource, r.changedPeers)
if !matchedByGroup && !matchedByPeer {
return
}
// Opposite side, fully down to peers (a destination opposite also folds routers).
r.foldPolicySideForRule(policy, rule, side.opposite())
// Own side: fold the whole changed group's members only when the group itself
// changed (outputGroups). A peer-seeded or link-only group is not folded here —
// its siblings never refresh. The changed peers themselves are folded once, after
// the walk (see walk()).
for _, gID := range nearGroups {
if _, ok := r.outputGroups[gID]; ok {
r.affectedGroups[gID] = struct{}{}
}
}
// When the changed side IS a destination, the resources it targets are reached
// through their network's routers, so those routers refresh too (e.g. attaching a
// resource to a destination group, or a changed destination group/resource).
if side == sideDestination {
r.foldRoutersForResources(r.ruleDestinationResourceIDs(rule))
}
}
// foldPolicySideForRule folds one side of a single rule (groups + direct peer), and
// for a destination side the routers of that rule's destination resources.
func (r *resolver) foldPolicySideForRule(policy *types.Policy, rule *types.PolicyRule, side policySide) {
addAll(r.affectedGroups, ruleSideGroups(rule, side))
res := ruleSideResource(rule, side)
if res.Type == types.ResourceTypePeer && res.ID != "" {
r.affectedPeers[res.ID] = struct{}{}
}
if side == sideDestination {
r.foldRoutersForResources(r.ruleDestinationResourceIDs(rule))
}
}
// collectFromChangedRoutes folds an explicitly changed route's own groups and peer.
func (r *resolver) collectFromChangedRoutes(routes []*route.Route) {
for _, rt := range routes {
if rt == nil {
continue
}
log.WithContext(r.ctx).Tracef("collectFromPolicies: policy %s (%s) matched (byGroup=%t byPeer=%t) -> folding rule groups %v + direct peers",
policy.ID, policy.Name, matchedByGroup, matchedByPeer, policy.RuleGroups())
addAll(r.groupSet, policy.RuleGroups())
collectPolicyDirectPeers(policy, r.peerSet)
r.matchedPolicies = append(r.matchedPolicies, policy)
log.WithContext(r.ctx).Tracef("collectFromChangedRoutes: changed route %s -> folding groups=%v peerGroups=%v accessControlGroups=%v peer=%q",
rt.ID, rt.Groups, rt.PeerGroups, rt.AccessControlGroups, rt.Peer)
addAll(r.affectedGroups, rt.Groups, rt.PeerGroups, rt.AccessControlGroups)
if rt.Peer != "" {
r.affectedPeers[rt.Peer] = struct{}{}
}
}
}
// collectFromChangedRouters: a changed router refreshes its OWN backing peer/groups
// (the changed entity) and the SOURCE side of every policy reaching a resource on
// its network (the router serves the whole network). Sibling routers on the network
// are independent and are NOT folded. Passing the old router state keeps a repointed
// router's previous backing affected without a post-commit read.
func (r *resolver) collectFromChangedRouters(routers []*routerTypes.NetworkRouter) {
for _, router := range routers {
if router == nil {
continue
}
log.WithContext(r.ctx).Tracef("collectFromChangedRouters: changed router %s on network %s -> folding its own peerGroups=%v peer=%q + sources reaching network resources",
router.ID, router.NetworkID, router.PeerGroups, router.Peer)
addAll(r.affectedGroups, router.PeerGroups)
if router.Peer != "" {
r.affectedPeers[router.Peer] = struct{}{}
}
if router.NetworkID != "" {
r.foldPolicySourcesForResources(r.networkResourceIDs(router.NetworkID))
}
}
}
// collectFromChangedResources: a changed resource refreshes the SOURCE side of the
// policies targeting EXACTLY that resource — directly, or via one of the resource's
// own groups (oldnew across the change, so a now-detached group's sources still
// refresh) — plus the routers serving its network (the resource is reached through
// them). It does not touch sibling resources on the same network.
func (r *resolver) collectFromChangedResources(resources []*resourceTypes.NetworkResource) {
for _, resource := range resources {
if resource == nil {
continue
}
log.WithContext(r.ctx).Tracef("collectFromChangedResources: changed resource %s on network %s (groups %v) -> folding sources of policies targeting it + its network's routers",
resource.ID, resource.NetworkID, resource.GroupIDs)
r.foldPolicySourcesForResource(resource.ID, resource.GroupIDs)
if resource.NetworkID != "" {
r.foldRoutersOnNetworks(map[string]struct{}{resource.NetworkID: {}})
}
}
}
// foldPolicySourcesForResource folds the source side of every policy whose
// destination is the given resource — referenced directly, or via any of the given
// groups (the resource's own oldnew groups, which captures a detached group).
func (r *resolver) foldPolicySourcesForResource(resourceID string, groupIDs []string) {
groups := toSet(groupIDs)
for _, policy := range r.policies() {
if !policyTargetsResourceOrGroups(policy, resourceID, groups) {
continue
}
log.WithContext(r.ctx).Tracef("foldPolicySourcesForResource: policy %s (%s) targets changed resource %s -> folding its source groups/peers", policy.ID, policy.Name, resourceID)
collectPolicySources(policy, r.affectedGroups, r.affectedPeers)
}
}
// policyTargetsResourceOrGroups reports whether a policy's destination is the given
// resource directly, or one of the given destination groups.
func policyTargetsResourceOrGroups(policy *types.Policy, resourceID string, groups map[string]struct{}) bool {
if policy == nil {
return false
}
for _, rule := range policy.Rules {
if !rule.Enabled {
continue
}
if rule.DestinationResource.Type != types.ResourceTypePeer && rule.DestinationResource.ID == resourceID && resourceID != "" {
return true
}
if anyInSet(rule.Destinations, groups) {
return true
}
}
return false
}
// collectFromChangedNetworks: a changed network refreshes the SOURCE side of the
// policies reaching any of its resources, plus its routers. A network has no
// groups/peers of its own.
func (r *resolver) collectFromChangedNetworks(networks []*networkTypes.Network) {
for _, network := range networks {
if network == nil || network.ID == "" {
continue
}
log.WithContext(r.ctx).Tracef("collectFromChangedNetworks: changed network %s -> folding sources reaching its resources + its routers", network.ID)
resourceIDs := r.networkResourceIDs(network.ID)
r.foldPolicySourcesForResources(resourceIDs)
r.foldRoutersOnNetworks(map[string]struct{}{network.ID: {}})
}
}
// foldPolicySourcesForResources folds the source groups/peers of every policy whose
// destination targets one of resourceIDs (directly or via a destination group).
func (r *resolver) foldPolicySourcesForResources(resourceIDs map[string]struct{}) {
if len(resourceIDs) == 0 {
return
}
for _, policy := range r.policies() {
if r.policyTargetsResources(policy, resourceIDs) {
log.WithContext(r.ctx).Tracef("foldPolicySourcesForResources: policy %s (%s) targets a changed resource -> folding its source groups/peers", policy.ID, policy.Name)
collectPolicySources(policy, r.affectedGroups, r.affectedPeers)
}
}
}
func (r *resolver) collectFromRoutes() {
for _, rt := range r.snap.routes {
matchedByGroup := anyInSet(rt.Groups, r.changedGroupSet) || anyInSet(rt.PeerGroups, r.changedGroupSet) || anyInSet(rt.AccessControlGroups, r.changedGroupSet)
matchedByPeer := rt.Peer != "" && len(r.changedPeerSet) > 0 && isInSet(rt.Peer, r.changedPeerSet)
if !rt.Enabled {
continue // disabled routes route to nobody; skip existing account data
}
matchedByGroup := anyInSet(rt.Groups, r.linkGroups) || anyInSet(rt.PeerGroups, r.linkGroups) || anyInSet(rt.AccessControlGroups, r.linkGroups)
matchedByPeer := rt.Peer != "" && len(r.changedPeers) > 0 && isInSet(rt.Peer, r.changedPeers)
if !matchedByGroup && !matchedByPeer {
continue
}
log.WithContext(r.ctx).Tracef("collectFromRoutes: route %s matched (byGroup=%t byPeer=%t) -> folding groups=%v peerGroups=%v accessControlGroups=%v peer=%q",
rt.ID, matchedByGroup, matchedByPeer, rt.Groups, rt.PeerGroups, rt.AccessControlGroups, rt.Peer)
addAll(r.groupSet, rt.Groups, rt.PeerGroups, rt.AccessControlGroups)
addAll(r.affectedGroups, rt.Groups, rt.PeerGroups, rt.AccessControlGroups)
if rt.Peer != "" {
r.peerSet[rt.Peer] = struct{}{}
r.affectedPeers[rt.Peer] = struct{}{}
}
}
}
func (r *resolver) collectFromNameServers() {
if len(r.changedGroupSet) == 0 {
if len(r.linkGroups) == 0 {
return
}
for _, ns := range r.snap.nsGroups {
if anyInSet(ns.Groups, r.changedGroupSet) {
if anyInSet(ns.Groups, r.linkGroups) {
log.WithContext(r.ctx).Tracef("collectFromNameServers: nameserver group %s references a changed group -> folding its groups %v", ns.ID, ns.Groups)
addAll(r.groupSet, ns.Groups)
addAll(r.affectedGroups, ns.Groups)
}
}
}
func (r *resolver) collectFromDNSSettings() {
if len(r.changedGroupSet) == 0 || r.snap.dnsSettings == nil {
if len(r.linkGroups) == 0 || r.snap.dnsSettings == nil {
return
}
for _, gID := range r.snap.dnsSettings.DisabledManagementGroups {
if _, ok := r.changedGroupSet[gID]; ok {
if _, ok := r.linkGroups[gID]; ok {
log.WithContext(r.ctx).Tracef("collectFromDNSSettings: changed group %s is in DisabledManagementGroups -> folding it", gID)
r.groupSet[gID] = struct{}{}
r.affectedGroups[gID] = struct{}{}
}
}
}
// collectFromNetworkRouters handles a changed group/peer that BACKS a router (the
// routing peer set moved): the router's own peers refresh and so do the sources of
// the policies reaching its network's resources. Sibling routers on the network are
// independent and are not folded.
func (r *resolver) collectFromNetworkRouters() {
for _, router := range r.networkRouters() {
matchedByGroup := anyInSet(router.PeerGroups, r.changedGroupSet)
matchedByPeer := router.Peer != "" && len(r.changedPeerSet) > 0 && isInSet(router.Peer, r.changedPeerSet)
matchedByGroup := anyInSet(router.PeerGroups, r.linkGroups)
matchedByPeer := router.Peer != "" && len(r.changedPeers) > 0 && isInSet(router.Peer, r.changedPeers)
if !matchedByGroup && !matchedByPeer {
continue
}
log.WithContext(r.ctx).Tracef("collectFromNetworkRouters: router %s on network %s matched (byGroup=%t byPeer=%t) -> folding peerGroups=%v peer=%q and marking network for source bridge",
log.WithContext(r.ctx).Tracef("collectFromNetworkRouters: router %s on network %s matched (byGroup=%t byPeer=%t) -> folding its peerGroups=%v peer=%q + sources reaching network resources",
router.ID, router.NetworkID, matchedByGroup, matchedByPeer, router.PeerGroups, router.Peer)
addAll(r.groupSet, router.PeerGroups)
addAll(r.affectedGroups, router.PeerGroups)
if router.Peer != "" {
r.peerSet[router.Peer] = struct{}{}
r.affectedPeers[router.Peer] = struct{}{}
}
if router.NetworkID != "" {
r.foldPolicySourcesForResources(r.networkResourceIDs(router.NetworkID))
}
r.networkIDs[router.NetworkID] = struct{}{}
}
}
@@ -526,42 +787,45 @@ func (r *resolver) collectFromProxyServices() {
expanded := r.expandChangedPeersWithGroups()
for _, svc := range services {
if svc == nil {
continue
if svc == nil || !svc.Enabled {
continue // a disabled service proxies nothing; skip existing account data
}
proxyPeers := proxyByCluster[svc.ProxyCluster]
if len(proxyPeers) == 0 {
continue
}
matchedByPeer := serviceMatchesChangedPeers(svc, proxyPeers, expanded)
matchedByAccessGroup := anyInSet(svc.AccessGroups, r.changedGroupSet)
matchedByAccessGroup := anyInSet(svc.AccessGroups, r.linkGroups)
if !matchedByPeer && !matchedByAccessGroup {
continue
}
log.WithContext(r.ctx).Tracef("collectFromProxyServices: service %s (cluster=%s) matched (byProxyOrTargetPeer=%t byAccessGroup=%t) -> folding %d proxy peers, peer targets and access groups %v",
svc.ID, svc.ProxyCluster, matchedByPeer, matchedByAccessGroup, len(proxyPeers), svc.AccessGroups)
for _, pid := range proxyPeers {
r.peerSet[pid] = struct{}{}
r.affectedPeers[pid] = struct{}{}
}
for _, target := range svc.Targets {
if !target.Enabled {
continue // a disabled target forwards nothing
}
if target.TargetType == rpservice.TargetTypePeer && target.TargetId != "" {
r.peerSet[target.TargetId] = struct{}{}
r.affectedPeers[target.TargetId] = struct{}{}
}
}
addAll(r.groupSet, svc.AccessGroups)
addAll(r.affectedGroups, svc.AccessGroups)
}
}
func (r *resolver) expandChangedPeersWithGroups() map[string]struct{} {
if len(r.changedGroupSet) == 0 {
return r.changedPeerSet
if len(r.linkGroups) == 0 {
return r.changedPeers
}
ids := r.peerIDsForGroups(r.changedGroupSet)
ids := r.peerIDsForGroups(r.linkGroups)
if len(ids) == 0 {
return r.changedPeerSet
return r.changedPeers
}
merged := make(map[string]struct{}, len(r.changedPeerSet)+len(ids))
for id := range r.changedPeerSet {
merged := make(map[string]struct{}, len(r.changedPeers)+len(ids))
for id := range r.changedPeers {
merged[id] = struct{}{}
}
for _, id := range ids {
@@ -570,54 +834,36 @@ func (r *resolver) expandChangedPeersWithGroups() map[string]struct{} {
return merged
}
// collectResourceRouterBridge crosses between source peers and routing peers, which
// are reachable only via resource -> network -> router, not through the policy's own
// groups: source -> router (targeted resources' networks), then router -> source.
func (r *resolver) collectResourceRouterBridge() {
r.bridgeSourceToRouters()
r.bridgeRoutersToSources()
}
func (r *resolver) bridgeSourceToRouters() {
resourceIDs := r.policyDestinationResourceIDs(r.matchedPolicies...)
// foldRoutersForResources folds the routers serving the networks of the given
// resources (a destination resource is reached through its network's routers). It is
// the resource -> network -> router hop used by foldPolicySide for a destination.
func (r *resolver) foldRoutersForResources(resourceIDs map[string]struct{}) {
if len(resourceIDs) == 0 {
return
}
networkIDs := r.resourceNetworkIDs(resourceIDs)
log.WithContext(r.ctx).Tracef("bridgeSourceToRouters: targeted resources %v -> networks %v (their routers become affected via the router->source pass)",
setToSlice(resourceIDs), setToSlice(networkIDs))
for id := range networkIDs {
r.networkIDs[id] = struct{}{}
}
r.foldRoutersOnNetworks(r.resourceNetworkIDs(resourceIDs))
}
func (r *resolver) bridgeRoutersToSources() {
if len(r.networkIDs) == 0 {
return
// ruleDestinationResourceIDs returns the destination resource IDs of a single rule:
// the direct DestinationResource plus the resources of its destination groups.
func (r *resolver) ruleDestinationResourceIDs(rule *types.PolicyRule) map[string]struct{} {
resourceIDs := make(map[string]struct{})
if rule.DestinationResource.Type != types.ResourceTypePeer && rule.DestinationResource.ID != "" {
resourceIDs[rule.DestinationResource.ID] = struct{}{}
}
r.addGroupResourceIDs(toSet(rule.Destinations), resourceIDs)
return resourceIDs
}
log.WithContext(r.ctx).Tracef("bridgeRoutersToSources: affected networks %v -> folding their routing peers and the source peers of policies targeting their resources",
setToSlice(r.networkIDs))
r.foldRoutersOnNetworks(r.networkIDs)
// networkResourceIDs returns the IDs of all resources on the given network.
func (r *resolver) networkResourceIDs(networkID string) map[string]struct{} {
resourceIDs := make(map[string]struct{})
for _, resource := range r.networkResources() {
if _, ok := r.networkIDs[resource.NetworkID]; ok {
if resource.NetworkID == networkID {
resourceIDs[resource.ID] = struct{}{}
}
}
if len(resourceIDs) == 0 {
return
}
for _, policy := range r.policies() {
if r.policyTargetsResources(policy, resourceIDs) {
log.WithContext(r.ctx).Tracef("bridgeRoutersToSources: policy %s (%s) targets an affected-network resource -> folding its source groups/peers", policy.ID, policy.Name)
collectPolicySources(policy, r.groupSet, r.peerSet)
}
}
return resourceIDs
}
func (r *resolver) foldRoutersOnNetworks(networkIDs map[string]struct{}) {
@@ -627,9 +873,9 @@ func (r *resolver) foldRoutersOnNetworks(networkIDs map[string]struct{}) {
}
log.WithContext(r.ctx).Tracef("bridgeRoutersToSources: router %s serves affected network %s -> folding peerGroups=%v peer=%q",
router.ID, router.NetworkID, router.PeerGroups, router.Peer)
addAll(r.groupSet, router.PeerGroups)
addAll(r.affectedGroups, router.PeerGroups)
if router.Peer != "" {
r.peerSet[router.Peer] = struct{}{}
r.affectedPeers[router.Peer] = struct{}{}
}
}
}
@@ -650,6 +896,9 @@ func (r *resolver) policyTargetsResources(policy *types.Policy, resourceIDs map[
}
destGroupSet := make(map[string]struct{})
for _, rule := range policy.Rules {
if !rule.Enabled {
continue
}
if rule.DestinationResource.Type != types.ResourceTypePeer && isInSet(rule.DestinationResource.ID, resourceIDs) {
return true
}
@@ -714,44 +963,20 @@ func (r *resolver) addGroupResourceIDs(groupIDs map[string]struct{}, resourceIDs
}
}
func collectPolicyDirectPeers(policy *types.Policy, peerSet map[string]struct{}) {
// collectPolicySources folds the source groups/peers of a snapshot policy's enabled
// rules (a disabled rule grants no access).
func collectPolicySources(policy *types.Policy, groups, peers map[string]struct{}) {
for _, rule := range policy.Rules {
if !rule.Enabled {
continue
}
addAll(groups, rule.Sources)
if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" {
peerSet[rule.SourceResource.ID] = struct{}{}
}
if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" {
peerSet[rule.DestinationResource.ID] = struct{}{}
peers[rule.SourceResource.ID] = struct{}{}
}
}
}
func collectPolicySources(policy *types.Policy, groupSet, peerSet map[string]struct{}) {
for _, rule := range policy.Rules {
addAll(groupSet, rule.Sources)
if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" {
peerSet[rule.SourceResource.ID] = struct{}{}
}
}
}
func policyReferencesGroups(policy *types.Policy, groupSet map[string]struct{}) bool {
for _, rule := range policy.Rules {
if anyInSet(rule.Sources, groupSet) || anyInSet(rule.Destinations, groupSet) {
return true
}
}
return false
}
func policyReferencesDirectPeers(policy *types.Policy, changedSet map[string]struct{}) bool {
for _, rule := range policy.Rules {
if isDirectPeerInSet(rule.SourceResource, changedSet) || isDirectPeerInSet(rule.DestinationResource, changedSet) {
return true
}
}
return false
}
func policyReferencesPostureChecks(policy *types.Policy, ids map[string]struct{}) bool {
for _, id := range policy.SourcePostureChecks {
if _, ok := ids[id]; ok {
@@ -776,7 +1001,7 @@ func serviceMatchesChangedPeers(svc *rpservice.Service, proxyPeers []string, cha
}
}
for _, target := range svc.Targets {
if target.TargetType != rpservice.TargetTypePeer || target.TargetId == "" {
if !target.Enabled || target.TargetType != rpservice.TargetTypePeer || target.TargetId == "" {
continue
}
if _, ok := changedPeers[target.TargetId]; ok {

View File

@@ -10,8 +10,8 @@ import (
"github.com/netbirdio/netbird/management/server/types"
)
// policyGroupsAndPeers mirrors the explicit-policy extraction (RuleGroups +
// direct peers) the resolver folds in, for asserting the pure logic.
// policyGroupsAndPeers mirrors the both-sides extraction (RuleGroups + direct peers)
// the resolver folds in for a changed policy, for asserting the pure logic.
func policyGroupsAndPeers(policies ...*types.Policy) (groups []string, peers []string) {
peerSet := map[string]struct{}{}
for _, p := range policies {
@@ -19,7 +19,14 @@ func policyGroupsAndPeers(policies ...*types.Policy) (groups []string, peers []s
continue
}
groups = append(groups, p.RuleGroups()...)
collectPolicyDirectPeers(p, peerSet)
for _, rule := range p.Rules {
if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" {
peerSet[rule.SourceResource.ID] = struct{}{}
}
if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" {
peerSet[rule.DestinationResource.ID] = struct{}{}
}
}
}
for id := range peerSet {
peers = append(peers, id)
@@ -80,26 +87,6 @@ func TestChangeIsEmpty(t *testing.T) {
assert.False(t, Change{PostureCheckIDs: []string{"pc"}}.isEmpty())
}
func TestPolicyReferencesGroups(t *testing.T) {
policy := &types.Policy{Rules: []*types.PolicyRule{{Sources: []string{"g1", "g2"}, Destinations: []string{"g3"}}}}
assert.True(t, policyReferencesGroups(policy, map[string]struct{}{"g1": {}}))
assert.True(t, policyReferencesGroups(policy, map[string]struct{}{"g3": {}}))
assert.False(t, policyReferencesGroups(policy, map[string]struct{}{"g4": {}}))
assert.False(t, policyReferencesGroups(policy, map[string]struct{}{}))
}
func TestPolicyReferencesDirectPeers(t *testing.T) {
policy := &types.Policy{Rules: []*types.PolicyRule{{
SourceResource: types.Resource{Type: types.ResourceTypePeer, ID: "p1"},
DestinationResource: types.Resource{Type: types.ResourceTypeHost, ID: "r1"},
}}}
assert.True(t, policyReferencesDirectPeers(policy, map[string]struct{}{"p1": {}}))
assert.False(t, policyReferencesDirectPeers(policy, map[string]struct{}{"r1": {}}))
assert.False(t, policyReferencesDirectPeers(policy, map[string]struct{}{"p2": {}}))
}
func TestPolicyReferencesPostureChecks(t *testing.T) {
policy := &types.Policy{SourcePostureChecks: []string{"pc1", "pc2"}}
@@ -107,24 +94,9 @@ func TestPolicyReferencesPostureChecks(t *testing.T) {
assert.False(t, policyReferencesPostureChecks(policy, map[string]struct{}{"pc3": {}}))
}
func TestCollectPolicyDirectPeers(t *testing.T) {
policy := &types.Policy{Rules: []*types.PolicyRule{{
SourceResource: types.Resource{Type: types.ResourceTypePeer, ID: "p1"},
DestinationResource: types.Resource{Type: types.ResourceTypePeer, ID: "p2"},
}, {
DestinationResource: types.Resource{Type: types.ResourceTypeHost, ID: "r1"},
}}}
peerSet := map[string]struct{}{}
collectPolicyDirectPeers(policy, peerSet)
assert.Contains(t, peerSet, "p1")
assert.Contains(t, peerSet, "p2")
assert.NotContains(t, peerSet, "r1")
}
func TestCollectPolicySources(t *testing.T) {
policy := &types.Policy{Rules: []*types.PolicyRule{{
Enabled: true,
Sources: []string{"g1"},
SourceResource: types.Resource{Type: types.ResourceTypePeer, ID: "p1"},
Destinations: []string{"g2"},

View File

@@ -520,7 +520,12 @@ func collectDeletableGroups(ctx context.Context, transaction store.Store, accoun
// GroupAddPeer appends peer to the group
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
var snap *affectedpeers.Snapshot
change := affectedpeers.Change{ChangedGroupIDs: []string{groupID}}
// A membership change affects only the peer itself and the opposite side of THIS
// group's policies — not the group's other members, and not the peer's other
// groups. LinkGroups walks only this group (matched, not expanded); OutputPeerIDs
// refreshes the peer without seeding its other group memberships. For an
// intra-group policy the opposite side is the group, so its members still refresh.
change := affectedpeers.Change{OutputPeerIDs: []string{peerID}, LinkGroups: []string{groupID}}
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err := transaction.AddPeerToGroup(ctx, accountID, peerID, groupID); err != nil {
@@ -586,10 +591,11 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID
// GroupDeletePeer removes peer from the group
func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error {
var snap *affectedpeers.Snapshot
change := affectedpeers.Change{
ChangedGroupIDs: []string{groupID},
RemovedPeersByGroup: map[string][]string{groupID: {peerID}},
}
// Same as GroupAddPeer: the removed peer and the opposite side of THIS group's
// policies refresh, not the group's other members or the peer's other groups. The
// peer is no longer in the group's index, but LinkGroups still drives the
// opposite-side walk, and OutputPeerIDs refreshes the removed peer itself.
change := affectedpeers.Change{OutputPeerIDs: []string{peerID}, LinkGroups: []string{groupID}}
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err := transaction.RemovePeerFromGroup(ctx, peerID, groupID); err != nil {
@@ -600,8 +606,6 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
return err
}
// The removed peer is carried in change.RemovedPeersByGroup and folded in
// only when the group is linked, so loading post-removal is correct.
var err error
if snap, err = affectedpeers.Load(ctx, transaction, accountID, change); err != nil {
return err

View File

@@ -220,7 +220,7 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) {
}
includeServiceUser, err := strconv.ParseBool(serviceUser)
log.WithContext(r.Context()).Debugf("Should include service user: %v", includeServiceUser)
log.WithContext(r.Context()).Tracef("Should include service user: %v", includeServiceUser)
if err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid service_user query parameter"), w)
return

View File

@@ -39,7 +39,7 @@ type MockAccountManager struct {
GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error)
GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error
MarkPeerDisconnectedFunc func(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
@@ -114,7 +114,7 @@ type MockAccountManager struct {
GetIdpManagerFunc func() idp.Manager
UpdateIntegratedValidatorFunc func(ctx context.Context, accountID, userID, validator string, groups []string) error
GroupValidationFunc func(ctx context.Context, accountId string, groups []string) (bool, error)
SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) error
FindExistingPostureCheckFunc func(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error)
GetAccountByIDFunc func(ctx context.Context, accountID string, userID string) (*types.Account, error)
@@ -345,9 +345,9 @@ func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userAuth
}
// MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
if am.MarkPeerConnectedFunc != nil {
return am.MarkPeerConnectedFunc(ctx, peerKey, realIP, accountID, sessionStartedAt, nmap)
return am.MarkPeerConnectedFunc(ctx, peerKey, accountID, sessionStartedAt, nmap)
}
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
}
@@ -975,9 +975,9 @@ func (am *MockAccountManager) GroupValidation(ctx context.Context, accountId str
}
// SyncPeerMeta mocks SyncPeerMeta of the AccountManager interface
func (am *MockAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error {
func (am *MockAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) error {
if am.SyncPeerMetaFunc != nil {
return am.SyncPeerMetaFunc(ctx, peerPubKey, meta)
return am.SyncPeerMetaFunc(ctx, peerPubKey, meta, realIP)
}
return status.Errorf(codes.Unimplemented, "method SyncPeerMeta is not implemented")
}

View File

@@ -74,7 +74,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
//
// Disconnects use MarkPeerDisconnected and require the session to match
// exactly; see PeerStatus.SessionStartedAt for the protocol.
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
start := time.Now()
defer func() {
am.metrics.AccountManagerMetrics().RecordPeerStatusUpdateDuration(telemetry.PeerStatusConnect, time.Since(start))
@@ -102,10 +102,6 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
}
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusConnect, telemetry.PeerStatusApplied)
if am.geo != nil && realIP != nil {
am.updatePeerLocationIfChanged(ctx, accountID, peer, realIP)
}
if err = am.schedulePeerExpirations(ctx, accountID, peer); err != nil {
return err
}
@@ -192,27 +188,40 @@ func (am *DefaultAccountManager) MarkPeerDisconnected(ctx context.Context, peerP
}
}
if peer.AddedWithSSOLogin() && peer.InactivityExpirationEnabled {
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Warnf("failed getting account settings to schedule inactivity expiration for peer %s: %v", peer.ID, err)
} else if settings.PeerInactivityExpirationEnabled {
am.checkAndSchedulePeerInactivityExpiration(ctx, accountID)
}
}
return nil
}
// updatePeerLocationIfChanged refreshes the geolocation on a separate
// row update, only when the connection IP actually changed. Geo lookups
// are expensive so we skip same-IP reconnects.
func (am *DefaultAccountManager) updatePeerLocationIfChanged(ctx context.Context, accountID string, peer *nbpeer.Peer, realIP net.IP) {
// resolvePeerLocation looks up the geo location for realIP, returning nil when
// there is nothing to apply: geo disabled, no real IP, the IP is unchanged from
// what the peer already has, or the lookup failed. Geo lookups are skipped on
// same-IP reconnects since they are comparatively expensive. The returned value
// is applied by Peer.UpdateMetaIfNew so the change is persisted by its peer save.
func (am *DefaultAccountManager) resolvePeerLocation(ctx context.Context, peer *nbpeer.Peer, realIP net.IP) *nbpeer.Location {
if am.geo == nil || realIP == nil {
return nil
}
if peer.Location.ConnectionIP != nil && peer.Location.ConnectionIP.Equal(realIP) {
return
return nil
}
location, err := am.geo.Lookup(realIP)
if err != nil {
log.WithContext(ctx).Warnf("failed to get location for peer %s realip: [%s]: %v", peer.ID, realIP.String(), err)
return
return nil
}
peer.Location.ConnectionIP = realIP
peer.Location.CountryCode = location.Country.ISOCode
peer.Location.CityName = location.City.Names.En
peer.Location.GeoNameID = location.City.GeonameID
if err := am.Store.SavePeerLocation(ctx, accountID, peer); err != nil {
log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err)
return &nbpeer.Location{
ConnectionIP: realIP,
CountryCode: location.Country.ISOCode,
CityName: location.City.Names.En,
GeoNameID: location.City.GeonameID,
}
}
@@ -721,7 +730,7 @@ func (am *DefaultAccountManager) handleSetupKeyAddedPeer(ctx context.Context, en
func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKey, userID string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.Network, []*posture.Checks, bool, error) {
if setupKey == "" && userID == "" && !peer.ProxyMeta.Embedded {
// no auth method provided => reject access
return nil, nil, nil, false, status.Errorf(status.Unauthenticated, "no peer auth method provided, please use a setup key or interactive SSO login")
return nil, nil, nil, false, status.ErrNoAuthMethodProvided
}
upperKey := strings.ToUpper(setupKey)
@@ -980,7 +989,8 @@ func getPeerIPDNSLabel(ip netip.Addr, peerHostName string) (string, error) {
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
var peer *nbpeer.Peer
var updated, versionChanged, ipv6CapabilityChanged bool
var ipv6CapabilityChanged bool
var metaDiff nbpeer.MetaDiff
var err error
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
@@ -1010,9 +1020,10 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
}
oldHasIPv6Cap := peer.HasCapability(nbpeer.PeerCapabilityIPv6Overlay)
updated, versionChanged = peer.UpdateMetaIfNew(ctx, sync.Meta)
newLocation := am.resolvePeerLocation(ctx, peer, sync.RealIP)
metaDiff = peer.UpdateMetaIfNew(ctx, sync.Meta, newLocation)
ipv6CapabilityChanged = oldHasIPv6Cap != peer.HasCapability(nbpeer.PeerCapabilityIPv6Overlay)
if updated {
if metaDiff.Updated() {
am.metrics.AccountManagerMetrics().CountPeerMetUpdate()
log.WithContext(ctx).Tracef("peer %s metadata updated", peer.ID)
if err = transaction.SavePeer(ctx, accountID, peer); err != nil {
@@ -1040,9 +1051,10 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
return nil, nil, nil, 0, err
}
if isStatusChanged || sync.UpdateAccountPeers || ipv6CapabilityChanged || (updated && (len(resPostureChecks) > 0 || versionChanged)) {
metaDiffAffectsPosture := posture.AffectsPosture(ctx, &metaDiff, resPostureChecks)
if isStatusChanged || sync.UpdateAccountPeers || ipv6CapabilityChanged || metaDiffAffectsPosture || metaDiff.VersionChanged() || metaDiff.HostnameChanged() {
changedPeerIDs := []string{peer.ID}
affectedPeerIDs := am.syncPeerAffectedPeers(ctx, accountID, peer.ID, nmap, peerNotValid, updated, len(resPostureChecks) > 0)
affectedPeerIDs := am.syncPeerAffectedPeers(ctx, accountID, peer.ID, nmap, peerNotValid, metaDiffAffectsPosture)
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
return nil, nil, nil, 0, fmt.Errorf("notify network map controller of peer update: %w", err)
}
@@ -1059,8 +1071,8 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
// metadata change that flips a posture result removes this peer from others'
// maps asymmetrically; that case (and an invalid peer, whose map is empty) falls
// back to the resolver.
func (am *DefaultAccountManager) syncPeerAffectedPeers(ctx context.Context, accountID, peerID string, nmap *types.NetworkMap, peerNotValid, metaUpdated, hasPostureChecks bool) []string {
if peerNotValid || (metaUpdated && hasPostureChecks) {
func (am *DefaultAccountManager) syncPeerAffectedPeers(ctx context.Context, accountID, peerID string, nmap *types.NetworkMap, peerNotValid, metaChangeAffectedPosture bool) []string {
if peerNotValid || metaChangeAffectedPosture {
return am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, []string{peerID})
}
return affectedPeerIDsFromNetworkMap(nmap, peerID)

View File

@@ -107,6 +107,15 @@ type Location struct {
GeoNameID uint // city level geoname id
}
// equal reports whether two locations match. ConnectionIP is a net.IP slice, so it uses
// IP.Equal, not ==.
func (l Location) equal(other Location) bool {
return l.CountryCode == other.CountryCode &&
l.CityName == other.CityName &&
l.GeoNameID == other.GeoNameID &&
l.ConnectionIP.Equal(other.ConnectionIP)
}
// NetworkAddress is the IP address with network and MAC address of a network interface
type NetworkAddress struct {
NetIP netip.Prefix `gorm:"serializer:json"`
@@ -256,50 +265,88 @@ func (p *Peer) Copy() *Peer {
}
}
// UpdateMetaIfNew updates peer's system metadata if new information is provided
// returns true if meta was updated, false otherwise
func (p *Peer) UpdateMetaIfNew(ctx context.Context, meta PeerSystemMeta) (updated, versionChanged bool) {
// UpdateMetaIfNew updates peer's system metadata and connection geo location if
// new information is provided. newLocation is the geo location resolved from the
// peer's current connection IP, or nil when there is nothing to apply (geo
// disabled, no real IP, or the IP is unchanged); the caller owns the expensive
// lookup and the same-IP guard. It returns a MetaDiff describing what changed;
// diff.Updated() reports whether the peer needs to be persisted.
func (p *Peer) UpdateMetaIfNew(ctx context.Context, meta PeerSystemMeta, newLocation *Location) MetaDiff {
if meta.isEmpty() {
return updated, versionChanged
return MetaDiff{}
}
versionChanged = p.Meta.WtVersion != meta.WtVersion
// Avoid overwriting UIVersion if the update was triggered sole by the CLI client
if meta.UIVersion == "" {
meta.UIVersion = p.Meta.UIVersion
}
oldVersion := p.Meta.WtVersion
effectiveLocation := p.Location
if newLocation != nil {
effectiveLocation = *newLocation
}
diff := metaDiff(p.Meta, meta)
if len(diff) != 0 {
diff := diffMeta(p.Meta, meta, p.Location, effectiveLocation)
if diff.Updated() {
p.Meta = meta
updated = true
}
p.Location = effectiveLocation
if diff.Updated() {
log.WithContext(ctx).Debug(diff.LogSummary())
}
versionInfo := ""
if versionChanged {
versionInfo = fmt.Sprintf("version changed: %s -> %s, ", oldVersion, meta.WtVersion)
}
if len(diff) > 0 || versionChanged {
log.WithContext(ctx).
Debugf("peer meta updated, %s%d field(s) changed: %s", versionInfo, len(diff), strings.Join(diff, ", "))
}
return updated, versionChanged
return diff
}
// MetaDiff holds a peer's full before/after state across a sync: both metas and both
// connection locations (the location lives on Peer, not PeerSystemMeta, but posture
// checks read it). Changed lists what moved, for logging and the persistence decision;
// the snapshots let a posture check be replayed against old and new. Everything is derived
// from these fields, so there are no parallel per-field flags to keep in sync.
type MetaDiff struct {
OldMeta PeerSystemMeta
NewMeta PeerSystemMeta
OldLocation Location
NewLocation Location
Changed []string
}
// Updated reports whether anything changed and the peer must be persisted. diffMeta fills
// Changed in the pass that builds the diff, so this is a length check, not a re-comparison.
// Pointer receiver: MetaDiff embeds two metas, so copying it per call is wasteful.
func (d *MetaDiff) Updated() bool {
return len(d.Changed) != 0
}
// VersionChanged reports whether the WireGuard client version changed (a client upgrade).
func (d *MetaDiff) VersionChanged() bool {
return d.OldMeta.WtVersion != d.NewMeta.WtVersion
}
// HostnameChanged reports whether the peer's hostname changed.
func (d *MetaDiff) HostnameChanged() bool {
return d.OldMeta.Hostname != d.NewMeta.Hostname
}
// LogSummary renders the changed fields as a single human-readable line.
func (d *MetaDiff) LogSummary() string {
return fmt.Sprintf("peer meta updated, %d field(s) changed: %s",
len(d.Changed), strings.Join(d.Changed, ", "))
}
// metaDiff returns a human-readable list of the fields that differ between the
// old and new meta, each formatted as `field: <old> -> <new>`. It is the single
// source of truth for meta comparison: isEqual reports equality as an empty
// diff, so the log line can never disagree with the change decision. Slices are
// cloned before sorting, so callers' meta is not mutated.
func metaDiff(oldMeta, newMeta PeerSystemMeta) []string {
var diff []string
return diffMeta(oldMeta, newMeta, Location{}, Location{}).Changed
}
// diffMeta snapshots a peer's old and new state and records a Changed entry per field that
// moved. It is the single source of truth for the comparison: isEqual is an empty Changed
// list, so the log line and the persistence decision can never disagree.
func diffMeta(oldMeta, newMeta PeerSystemMeta, oldLocation, newLocation Location) MetaDiff {
d := MetaDiff{OldMeta: oldMeta, NewMeta: newMeta, OldLocation: oldLocation, NewLocation: newLocation}
add := func(field string, oldVal, newVal any) {
diff = append(diff, fmt.Sprintf("%s: %v -> %v", field, oldVal, newVal))
d.Changed = append(d.Changed, fmt.Sprintf("%s: %v -> %v", field, oldVal, newVal))
}
if oldMeta.Hostname != newMeta.Hostname {
@@ -353,16 +400,18 @@ func metaDiff(oldMeta, newMeta PeerSystemMeta) []string {
if !capabilitiesEqual(oldMeta.Capabilities, newMeta.Capabilities) {
add("capabilities", oldMeta.Capabilities, newMeta.Capabilities)
}
if !sameMultiset(oldMeta.NetworkAddresses, newMeta.NetworkAddresses) {
add("network_addresses", fmt.Sprintf("%v", oldMeta.NetworkAddresses), fmt.Sprintf("%v", newMeta.NetworkAddresses))
}
if !sameMultiset(oldMeta.Files, newMeta.Files) {
add("files", fmt.Sprintf("%v", oldMeta.Files), fmt.Sprintf("%v", newMeta.Files))
}
return diff
if !oldLocation.equal(newLocation) {
add("connection_ip", oldLocation.ConnectionIP, newLocation.ConnectionIP)
}
return d
}
// sameMultiset reports whether two slices contain the same elements with the

View File

@@ -0,0 +1,202 @@
package posture
import (
"context"
"net"
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
)
// diffFrom builds a MetaDiff from the old/new snapshots AffectsPosture replays against.
func diffFrom(oldMeta, newMeta nbpeer.PeerSystemMeta, oldLoc, newLoc nbpeer.Location) *nbpeer.MetaDiff {
return &nbpeer.MetaDiff{
OldMeta: oldMeta,
NewMeta: newMeta,
OldLocation: oldLoc,
NewLocation: newLoc,
}
}
func checks(def ChecksDefinition) []*Checks {
return []*Checks{{Checks: def}}
}
func TestAffectsPosture_NilDiff(t *testing.T) {
assert.False(t, AffectsPosture(context.Background(), nil, checks(ChecksDefinition{
NBVersionCheck: &NBVersionCheck{MinVersion: "1.0.0"},
})))
}
func TestAffectsPosture_NBVersion(t *testing.T) {
c := checks(ChecksDefinition{NBVersionCheck: &NBVersionCheck{MinVersion: "1.2.0"}})
tests := []struct {
name string
oldVer, newVer string
want bool
}{
{"both above min, no flip", "1.3.0", "1.4.0", false},
{"both below min, no flip", "1.0.0", "1.1.0", false},
{"crosses up below->above", "1.1.0", "1.3.0", true},
{"crosses down above->below", "1.3.0", "1.1.0", true},
{"unparsable old only -> flip", "garbage", "1.3.0", true},
{"unparsable both -> no flip", "garbage", "junk", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
diff := diffFrom(
nbpeer.PeerSystemMeta{WtVersion: tt.oldVer},
nbpeer.PeerSystemMeta{WtVersion: tt.newVer},
nbpeer.Location{}, nbpeer.Location{},
)
assert.Equal(t, tt.want, AffectsPosture(context.Background(), diff, c))
})
}
}
func TestAffectsPosture_OSVersion_KernelBumpWithinMin(t *testing.T) {
c := checks(ChecksDefinition{OSVersionCheck: &OSVersionCheck{
Linux: &MinKernelVersionCheck{MinKernelVersion: "5.0.0"},
}})
// Kernel moves but stays above the minimum: verdict stays pass -> not affected.
withinMin := diffFrom(
nbpeer.PeerSystemMeta{GoOS: "linux", KernelVersion: "5.10.0-arch1"},
nbpeer.PeerSystemMeta{GoOS: "linux", KernelVersion: "5.15.0-arch2"},
nbpeer.Location{}, nbpeer.Location{},
)
assert.False(t, AffectsPosture(context.Background(), withinMin, c))
// Kernel drops below the minimum: verdict flips pass -> fail -> affected.
crossesDown := diffFrom(
nbpeer.PeerSystemMeta{GoOS: "linux", KernelVersion: "5.10.0-arch1"},
nbpeer.PeerSystemMeta{GoOS: "linux", KernelVersion: "4.19.0-arch1"},
nbpeer.Location{}, nbpeer.Location{},
)
assert.True(t, AffectsPosture(context.Background(), crossesDown, c))
}
func TestAffectsPosture_OSVersion_GoOSSwitchFlipsVerdict(t *testing.T) {
// Only Linux is constrained. An OS outside the switch (freebsd) passes; switching to a
// failing linux kernel flips the verdict pass -> fail.
c := checks(ChecksDefinition{OSVersionCheck: &OSVersionCheck{
Linux: &MinKernelVersionCheck{MinKernelVersion: "6.0.0"},
}})
diff := diffFrom(
nbpeer.PeerSystemMeta{GoOS: "freebsd"},
nbpeer.PeerSystemMeta{GoOS: "linux", KernelVersion: "4.19.0"},
nbpeer.Location{}, nbpeer.Location{},
)
assert.True(t, AffectsPosture(context.Background(), diff, c))
}
func TestAffectsPosture_Process_GoOSSwitchFlipsVerdict(t *testing.T) {
// Process runs at a linux path. Switching GoOS to windows (no WindowsPath configured)
// flips the verdict.
c := checks(ChecksDefinition{ProcessCheck: &ProcessCheck{
Processes: []Process{{LinuxPath: "/usr/bin/foo"}},
}})
files := []nbpeer.File{{Path: "/usr/bin/foo", ProcessIsRunning: true}}
diff := diffFrom(
nbpeer.PeerSystemMeta{GoOS: "linux", Files: files},
nbpeer.PeerSystemMeta{GoOS: "windows", Files: files},
nbpeer.Location{}, nbpeer.Location{},
)
assert.True(t, AffectsPosture(context.Background(), diff, c))
}
func TestAffectsPosture_Process_UnrelatedFileChange(t *testing.T) {
// A tracked process stays running while an unrelated file is added: the verdict does
// not move, so posture is not affected.
c := checks(ChecksDefinition{ProcessCheck: &ProcessCheck{
Processes: []Process{{LinuxPath: "/usr/bin/foo"}},
}})
diff := diffFrom(
nbpeer.PeerSystemMeta{GoOS: "linux", Files: []nbpeer.File{
{Path: "/usr/bin/foo", ProcessIsRunning: true},
}},
nbpeer.PeerSystemMeta{GoOS: "linux", Files: []nbpeer.File{
{Path: "/usr/bin/foo", ProcessIsRunning: true},
{Path: "/usr/bin/bar", ProcessIsRunning: true},
}},
nbpeer.Location{}, nbpeer.Location{},
)
assert.False(t, AffectsPosture(context.Background(), diff, c))
}
func TestAffectsPosture_GeoLocation(t *testing.T) {
c := checks(ChecksDefinition{GeoLocationCheck: &GeoLocationCheck{
Action: CheckActionAllow,
Locations: []Location{{CountryCode: "DE"}},
}})
// Moving within allowed countries keeps the verdict; moving out flips it.
stayAllowed := diffFrom(
nbpeer.PeerSystemMeta{}, nbpeer.PeerSystemMeta{},
nbpeer.Location{CountryCode: "DE", CityName: "Berlin"},
nbpeer.Location{CountryCode: "DE", CityName: "Munich"},
)
assert.False(t, AffectsPosture(context.Background(), stayAllowed, c))
moveOut := diffFrom(
nbpeer.PeerSystemMeta{}, nbpeer.PeerSystemMeta{},
nbpeer.Location{CountryCode: "DE"},
nbpeer.Location{CountryCode: "FR"},
)
assert.True(t, AffectsPosture(context.Background(), moveOut, c))
}
func TestAffectsPosture_PeerNetworkRange_ConnectionIP(t *testing.T) {
// The check reads the connection IP. Moving out of the allowed range flips the verdict;
// moving within it does not.
_, allowed, _ := net.ParseCIDR("10.0.0.0/8")
c := checks(ChecksDefinition{PeerNetworkRangeCheck: &PeerNetworkRangeCheck{
Action: CheckActionAllow,
Ranges: []netip.Prefix{netip.MustParsePrefix(allowed.String())},
}})
movesOutOfRange := diffFrom(
nbpeer.PeerSystemMeta{}, nbpeer.PeerSystemMeta{},
nbpeer.Location{ConnectionIP: net.ParseIP("10.1.2.3")},
nbpeer.Location{ConnectionIP: net.ParseIP("8.8.8.8")},
)
assert.True(t, AffectsPosture(context.Background(), movesOutOfRange, c))
staysInRange := diffFrom(
nbpeer.PeerSystemMeta{}, nbpeer.PeerSystemMeta{},
nbpeer.Location{ConnectionIP: net.ParseIP("10.1.2.3")},
nbpeer.Location{ConnectionIP: net.ParseIP("10.9.9.9")},
)
assert.False(t, AffectsPosture(context.Background(), staysInRange, c))
}
func TestAffectsPosture_IrrelevantFieldChange(t *testing.T) {
// Hostname changes but no check reads it: not affected even with checks present.
c := checks(ChecksDefinition{
NBVersionCheck: &NBVersionCheck{MinVersion: "1.0.0"},
GeoLocationCheck: &GeoLocationCheck{Action: CheckActionAllow, Locations: []Location{{CountryCode: "DE"}}},
})
diff := diffFrom(
nbpeer.PeerSystemMeta{Hostname: "old", WtVersion: "1.5.0"},
nbpeer.PeerSystemMeta{Hostname: "new", WtVersion: "1.5.0"},
nbpeer.Location{CountryCode: "DE"}, nbpeer.Location{CountryCode: "DE"},
)
assert.False(t, AffectsPosture(context.Background(), diff, c))
}
func TestAffectsPosture_NoChecks(t *testing.T) {
diff := diffFrom(
nbpeer.PeerSystemMeta{WtVersion: "1.0.0"},
nbpeer.PeerSystemMeta{WtVersion: "2.0.0"},
nbpeer.Location{}, nbpeer.Location{},
)
assert.False(t, AffectsPosture(context.Background(), diff, nil))
}

View File

@@ -7,6 +7,8 @@ import (
"regexp"
"github.com/hashicorp/go-version"
log "github.com/sirupsen/logrus"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/status"
@@ -51,6 +53,46 @@ type Checks struct {
Checks ChecksDefinition `gorm:"serializer:json"`
}
// AffectsPosture reports whether the change in diff flips the verdict of any check. It
// replays each check against the peer's old and new state and compares verdicts, so a
// change that moves a field but stays the right side of a threshold (e.g. a kernel bump
// still above the minimum) does not force a re-evaluation. See verdictChanged for how an
// evaluation error counts.
func AffectsPosture(ctx context.Context, diff *nbpeer.MetaDiff, checks []*Checks) bool {
if diff == nil {
return false
}
oldPeer := nbpeer.Peer{Meta: diff.OldMeta, Location: diff.OldLocation}
newPeer := nbpeer.Peer{Meta: diff.NewMeta, Location: diff.NewLocation}
for _, c := range checks {
for _, check := range c.GetChecks() {
if verdictChanged(ctx, check, oldPeer, newPeer) {
return true
}
}
}
return false
}
// verdictChanged replays check against old and new state and reports whether the verdict
// differs. Like callers, it treats an evaluation error as deny: two errors are the same
// verdict (no change), an error on one side only is a flip.
func verdictChanged(ctx context.Context, check Check, oldPeer, newPeer nbpeer.Peer) bool {
oldPass, oldErr := check.Check(ctx, oldPeer)
newPass, newErr := check.Check(ctx, newPeer)
oldVerdict := oldPass && (oldErr == nil)
newVerdict := newPass && (newErr == nil)
changed := oldVerdict != newVerdict
log.WithContext(ctx).Tracef("posture check %s replay: verdict %t -> %t (changed=%t), errs: %v -> %v",
check.Name(), oldVerdict, newVerdict, changed, oldErr, newErr)
return changed
}
// ChecksDefinition contains definition of actual check
type ChecksDefinition struct {
NBVersionCheck *NBVersionCheck `json:",omitempty"`

View File

@@ -489,6 +489,7 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
policy := &types.Policy{
AccountID: account.Id,
Enabled: true,
Rules: []*types.PolicyRule{
{
Enabled: true,

View File

@@ -581,28 +581,6 @@ func (s *SqlStore) MarkPeerDisconnectedIfSameSession(ctx context.Context, accoun
return result.RowsAffected > 0, nil
}
func (s *SqlStore) SavePeerLocation(ctx context.Context, accountID string, peerWithLocation *nbpeer.Peer) error {
// To maintain data integrity, we create a copy of the peer's location to prevent unintended updates to other fields.
var peerCopy nbpeer.Peer
// Since the location field has been migrated to JSON serialization,
// updating the struct ensures the correct data format is inserted into the database.
peerCopy.Location = peerWithLocation.Location
result := s.db.Model(&nbpeer.Peer{}).
Where(accountAndIDQueryCondition, accountID, peerWithLocation.ID).
Updates(peerCopy)
if result.Error != nil {
return status.Errorf(status.Internal, "failed to save peer locations to store: %v", result.Error)
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, peerNotFoundFMT, peerWithLocation.ID)
}
return nil
}
// ApproveAccountPeers marks all peers that currently require approval in the given account as approved.
func (s *SqlStore) ApproveAccountPeers(ctx context.Context, accountID string) (int, error) {
result := s.db.Model(&nbpeer.Peer{}).

View File

@@ -618,56 +618,6 @@ func TestSqlStore_SavePeerStatus(t *testing.T) {
assert.WithinDurationf(t, newStatus.LastSeen, actual.LastSeen.UTC(), time.Millisecond, "LastSeen should be equal")
}
func TestSqlStore_SavePeerLocation(t *testing.T) {
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b")
require.NoError(t, err)
peer := &nbpeer.Peer{
AccountID: account.Id,
ID: "testpeer",
Location: nbpeer.Location{
ConnectionIP: net.ParseIP("0.0.0.0"),
CountryCode: "YY",
CityName: "City",
GeoNameID: 1,
},
CreatedAt: time.Now().UTC(),
Meta: nbpeer.PeerSystemMeta{},
}
// error is expected as peer is not in store yet
err = store.SavePeerLocation(context.Background(), account.Id, peer)
assert.Error(t, err)
account.Peers[peer.ID] = peer
err = store.SaveAccount(context.Background(), account)
require.NoError(t, err)
peer.Location.ConnectionIP = net.ParseIP("35.1.1.1")
peer.Location.CountryCode = "DE"
peer.Location.CityName = "Berlin"
peer.Location.GeoNameID = 2950159
err = store.SavePeerLocation(context.Background(), account.Id, account.Peers[peer.ID])
assert.NoError(t, err)
account, err = store.GetAccount(context.Background(), account.Id)
require.NoError(t, err)
actual := account.Peers[peer.ID].Location
assert.Equal(t, peer.Location, actual)
peer.ID = "non-existing-peer"
err = store.SavePeerLocation(context.Background(), account.Id, peer)
assert.Error(t, err)
parsedErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")
}
func Test_TestGetAccountByPrivateDomain(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")

View File

@@ -185,7 +185,6 @@ type Store interface {
// recorded by the database. Returns true when the update happened,
// false when a newer session has taken over.
MarkPeerDisconnectedIfSameSession(ctx context.Context, accountID, peerID string, sessionStartedAt int64) (bool, error)
SavePeerLocation(ctx context.Context, accountID string, peer *nbpeer.Peer) error
ApproveAccountPeers(ctx context.Context, accountID string) (int, error)
DeletePeer(ctx context.Context, accountID string, peerID string) error

View File

@@ -2968,20 +2968,6 @@ func (mr *MockStoreMockRecorder) SavePeer(ctx, accountID, peer interface{}) *gom
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SavePeer", reflect.TypeOf((*MockStore)(nil).SavePeer), ctx, accountID, peer)
}
// SavePeerLocation mocks base method.
func (m *MockStore) SavePeerLocation(ctx context.Context, accountID string, peer *peer.Peer) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SavePeerLocation", ctx, accountID, peer)
ret0, _ := ret[0].(error)
return ret0
}
// SavePeerLocation indicates an expected call of SavePeerLocation.
func (mr *MockStoreMockRecorder) SavePeerLocation(ctx, accountID, peer interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SavePeerLocation", reflect.TypeOf((*MockStore)(nil).SavePeerLocation), ctx, accountID, peer)
}
// SavePeerStatus mocks base method.
func (m *MockStore) SavePeerStatus(ctx context.Context, accountID, peerID string, status peer.PeerStatus) error {
m.ctrl.T.Helper()

View File

@@ -12,6 +12,9 @@ type PeerSync struct {
WireGuardPubKey string
// Meta is the system information passed by peer, must be always present
Meta nbpeer.PeerSystemMeta
// RealIP is the peer's connection IP, used to refresh its geo location.
// May be nil when the request has no associated connection IP.
RealIP net.IP
// UpdateAccountPeers indicate updating account peers,
// which occurs when the peer's metadata is updated
UpdateAccountPeers bool

View File

@@ -1059,8 +1059,8 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a
if err != nil {
return nil, err
}
log.WithContext(ctx).Debugf("Got %d users from ExternalCache for account %s", len(usersFromIntegration), accountID)
log.WithContext(ctx).Debugf("Got %d users from InternalCache for account %s", len(queriedUsers), accountID)
log.WithContext(ctx).Tracef("Got %d users from ExternalCache for account %s", len(usersFromIntegration), accountID)
log.WithContext(ctx).Tracef("Got %d users from InternalCache for account %s", len(queriedUsers), accountID)
queriedUsers = append(queriedUsers, usersFromIntegration...)
}

View File

@@ -48,6 +48,10 @@ type Type int32
var (
ErrExtraSettingsNotFound = errors.New("extra settings not found")
ErrPeerAlreadyLoggedIn = errors.New("peer with the same public key is already logged in")
// ErrNoAuthMethodProvided is returned when a peer login attempt carries neither a
// setup key nor an SSO token. Match it with errors.Is.
ErrNoAuthMethodProvided = Errorf(Unauthenticated, "no peer auth method provided, please use a setup key or interactive SSO login")
)
// Error is an internal error
@@ -66,6 +70,16 @@ func (e *Error) Error() string {
return e.Message
}
// Is reports whether target is an *Error with the same type and message,
// enabling matching with errors.Is against sentinel errors.
func (e *Error) Is(target error) bool {
var t *Error
if !errors.As(target, &t) {
return false
}
return e.ErrorType == t.ErrorType && e.Message == t.Message
}
// Errorf returns Error(ErrorType, fmt.Sprintf(format, a...)).
func Errorf(errorType Type, format string, a ...interface{}) error {
return &Error{