Compare commits

..

10 Commits

Author SHA1 Message Date
pascal
42867c7a59 ignore siblings when add/remove peer from group 2026-06-25 01:48:39 +02:00
pascal
c2db940a8c simplify affected peers walk 2026-06-25 00:01:17 +02:00
pascal
62ffa08744 split networkIDs to check 2026-06-24 22:39:35 +02:00
Dmitri Dolguikh
d8e7f2e9e6 a couple of fixes
Signed-off-by: Dmitri Dolguikh <dmitri.external@netbird.io>
2026-06-24 15:40:44 +02:00
Dmitri Dolguikh
1205641b44 fixed test
Signed-off-by: Dmitri Dolguikh <dmitri.external@netbird.io>
2026-06-24 15:17:41 +02:00
Dmitri Dolguikh
56e8215ebe updated 'resource-routing-bridge/router-peer-change refreshes policy sources' test to expect router peer among changed peer ids
Signed-off-by: Dmitri Dolguikh <dmitri.external@netbird.io>
2026-06-24 14:29:23 +02:00
Dmitri Dolguikh
9b768d1773 fixed a bug in collectFromPolicies
Signed-off-by: Dmitri Dolguikh <dmitri.external@netbird.io>
2026-06-24 14:10:37 +02:00
Dmitri Dolguikh
33954ea15e fixing tests + adding tests
Signed-off-by: Dmitri Dolguikh <dmitri.external@netbird.io>
2026-06-24 13:07:53 +02:00
Dmitri Dolguikh
4c4434a871 fixed a few tests
Signed-off-by: Dmitri Dolguikh <dmitri.external@netbird.io>
2026-06-24 09:48:45 +02:00
Dmitri Dolguikh
7873f337df when collecting group and peer IDs from policies, do so directionally
Signed-off-by: Dmitri Dolguikh <dmitri.external@netbird.io>
2026-06-23 18:39:53 +02:00
13 changed files with 627 additions and 883 deletions

View File

@@ -51,20 +51,13 @@ type cachedRecord struct {
}
// Resolver caches critical NetBird infrastructure domains.
// records, refreshing, failedResolves, mgmtDomain and serverDomains are all
// guarded by mutex.
// records, refreshing, mgmtDomain and serverDomains are all guarded by mutex.
type Resolver struct {
records map[dns.Question]*cachedRecord
mgmtDomain *domain.Domain
serverDomains *dnsconfig.ServerDomains
mutex sync.RWMutex
// failedResolves records the last failed initial resolve per domain so a
// domain that never resolves isn't retried on every server-domains update
// until refreshBackoff elapses. Entries are cleared on success and pruned
// to the current server-domains set.
failedResolves map[domain.Domain]time.Time
chain ChainResolver
chainMaxPriority int
refreshGroup singleflight.Group
@@ -83,10 +76,9 @@ type Resolver struct {
// NewResolver creates a new management domains cache resolver.
func NewResolver() *Resolver {
return &Resolver{
records: make(map[dns.Question]*cachedRecord),
refreshing: make(map[dns.Question]*atomic.Bool),
failedResolves: make(map[domain.Domain]time.Time),
cacheTTL: resolveCacheTTL(),
records: make(map[dns.Question]*cachedRecord),
refreshing: make(map[dns.Question]*atomic.Bool),
cacheTTL: resolveCacheTTL(),
}
}
@@ -181,9 +173,7 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
// AddDomain resolves a domain and stores its A/AAAA records in the cache.
// A family that resolves NODATA (nil err, zero records) evicts any stale
// entry for that qtype. When one family hard-errors while the other succeeds,
// the resolved family is still cached but AddDomain returns an error so the
// caller retries the incomplete resolve rather than treating it as complete.
// entry for that qtype.
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
@@ -213,10 +203,6 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
log.Debugf("added/updated domain=%s with %d A records and %d AAAA records",
d.SafeString(), len(aRecords), len(aaaaRecords))
if errA != nil || errAAAA != nil {
return fmt.Errorf("resolve %s: incomplete, a family failed: %w", d.SafeString(), errors.Join(errA, errAAAA))
}
return nil
}
@@ -476,7 +462,6 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error {
delete(m.records, qAAAA)
delete(m.refreshing, qA)
delete(m.refreshing, qAAAA)
delete(m.failedResolves, d)
log.Debugf("removed domain=%s from cache", d.SafeString())
return nil
@@ -520,7 +505,6 @@ func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dn
allDomains := m.extractDomainsFromServerDomains(updatedServerDomains)
currentDomains := m.GetCachedDomains()
removedDomains = m.removeStaleDomains(currentDomains, allDomains)
m.pruneFailedResolves(allDomains)
}
m.addNewDomains(ctx, newDomains)
@@ -593,85 +577,13 @@ func (m *Resolver) isManagementDomain(domain domain.Domain) bool {
return m.mgmtDomain != nil && domain == *m.mgmtDomain
}
// addNewDomains resolves and caches domains that are not yet in the cache,
// running the lookups concurrently. Domains already cached are skipped and left
// to the stale-while-revalidate refresh path, so a sync never re-resolves them
// synchronously: once NetBird owns the OS resolver the resolve runs through the
// handler chain and would otherwise dial the managed upstreams under the engine
// sync lock on every update.
// addNewDomains resolves and caches all domains from the update
func (m *Resolver) addNewDomains(ctx context.Context, newDomains domain.List) {
var wg sync.WaitGroup
seen := make(map[domain.Domain]struct{}, len(newDomains))
for _, newDomain := range newDomains {
if _, dup := seen[newDomain]; dup {
continue
}
seen[newDomain] = struct{}{}
if !m.needsResolve(newDomain) {
continue
}
wg.Add(1)
go func(d domain.Domain) {
defer wg.Done()
if err := m.AddDomain(ctx, d); err != nil {
m.markResolveFailed(d)
log.Warnf("failed to add/update domain=%s: %v", d.SafeString(), err)
return
}
m.clearResolveFailed(d)
log.Debugf("added/updated management cache domain=%s", d.SafeString())
}(newDomain)
}
wg.Wait()
}
// needsResolve reports whether d should be resolved now. A recent failed or
// incomplete resolve gates retries on the backoff even when one family is
// already cached, so a transiently-failed family is retried instead of being
// treated as fully resolved. Otherwise a domain with any cached record is left
// to the stale-while-revalidate refresh path.
func (m *Resolver) needsResolve(d domain.Domain) bool {
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
m.mutex.RLock()
defer m.mutex.RUnlock()
if failedAt, ok := m.failedResolves[d]; ok {
return time.Since(failedAt) >= refreshBackoff
}
for _, qtype := range []uint16{dns.TypeA, dns.TypeAAAA} {
q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET}
if _, ok := m.records[q]; ok {
return false
}
}
return true
}
func (m *Resolver) markResolveFailed(d domain.Domain) {
m.mutex.Lock()
m.failedResolves[d] = time.Now()
m.mutex.Unlock()
}
func (m *Resolver) clearResolveFailed(d domain.Domain) {
m.mutex.Lock()
delete(m.failedResolves, d)
m.mutex.Unlock()
}
// pruneFailedResolves drops failure markers for domains no longer present in
// the server-domains set, keeping the map bounded to the current set (a
// failed-only domain has no cached record, so RemoveDomain never sees it).
func (m *Resolver) pruneFailedResolves(domains domain.List) {
m.mutex.Lock()
defer m.mutex.Unlock()
for d := range m.failedResolves {
if !slices.Contains(domains, d) {
delete(m.failedResolves, d)
if err := m.AddDomain(ctx, newDomain); err != nil {
log.Warnf("failed to add/update domain=%s: %v", newDomain.SafeString(), err)
} else {
log.Debugf("added/updated management cache domain=%s", newDomain.SafeString())
}
}
}

View File

@@ -21,7 +21,6 @@ type fakeChain struct {
mu sync.Mutex
calls map[string]int
answers map[string][]dns.RR
qErr map[string]error
err error
hasRoot bool
onLookup func()
@@ -31,7 +30,6 @@ func newFakeChain() *fakeChain {
return &fakeChain{
calls: map[string]int{},
answers: map[string][]dns.RR{},
qErr: map[string]error{},
hasRoot: true,
}
}
@@ -49,9 +47,6 @@ func (f *fakeChain) ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriori
f.calls[key]++
answers := f.answers[key]
err := f.err
if err == nil {
err = f.qErr[key]
}
onLookup := f.onLookup
f.mu.Unlock()
@@ -80,12 +75,6 @@ func (f *fakeChain) setAnswer(name string, qtype uint16, ip string) {
}
}
func (f *fakeChain) setErr(name string, qtype uint16, err error) {
f.mu.Lock()
defer f.mu.Unlock()
f.qErr[name+"|"+dns.TypeToString[qtype]] = err
}
func (f *fakeChain) callCount(name string, qtype uint16) int {
f.mu.Lock()
defer f.mu.Unlock()

View File

@@ -1,183 +0,0 @@
package mgmt
import (
"context"
"errors"
"sync/atomic"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/shared/management/domain"
)
// A domain already in the cache must not be re-resolved on a subsequent server
// domains update; it is left to the stale-while-revalidate refresh path.
func TestResolver_UpdateFromServerDomains_SkipsCached(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("signal.example.com.", dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
sd := dnsconfig.ServerDomains{Signal: domain.Domain("signal.example.com")}
_, err := r.UpdateFromServerDomains(context.Background(), sd)
require.NoError(t, err)
require.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
"first update must resolve the domain")
_, err = r.UpdateFromServerDomains(context.Background(), sd)
require.NoError(t, err)
assert.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
"cached domain must not be re-resolved on a subsequent update")
}
// New domains in a single update must resolve concurrently rather than serially.
func TestResolver_AddNewDomains_ResolvesConcurrently(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
var inflight, maxInflight atomic.Int32
chain.onLookup = func() {
n := inflight.Add(1)
for {
old := maxInflight.Load()
if n <= old || maxInflight.CompareAndSwap(old, n) {
break
}
}
time.Sleep(50 * time.Millisecond)
inflight.Add(-1)
}
relays := []domain.Domain{"a.example.com", "b.example.com", "c.example.com", "d.example.com"}
for _, d := range relays {
chain.setAnswer(dns.Fqdn(string(d)), dns.TypeA, "10.0.0.2")
}
r.SetChainResolver(chain, 50)
start := time.Now()
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: relays})
require.NoError(t, err)
elapsed := time.Since(start)
assert.GreaterOrEqual(t, int(maxInflight.Load()), 2, "domains must resolve concurrently")
// Serial resolution of 4 domains would take at least 4*50ms; concurrent is far less.
assert.Less(t, elapsed, 300*time.Millisecond, "resolution should not be serial")
}
// A domain that fails to resolve must not be retried on every update; the
// failure backoff suppresses re-resolution until it expires.
func TestResolver_UpdateFromServerDomains_BacksOffFailures(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.err = errors.New("resolve boom")
r.SetChainResolver(chain, 50)
sd := dnsconfig.ServerDomains{Signal: domain.Domain("signal.example.com")}
_, err := r.UpdateFromServerDomains(context.Background(), sd)
require.NoError(t, err)
require.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
"first update must attempt the resolve")
_, err = r.UpdateFromServerDomains(context.Background(), sd)
require.NoError(t, err)
assert.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
"failed resolve must back off and not retry on the next update")
}
// A domain listed under more than one server-domain type (e.g. STUN and TURN on
// the same host) must be resolved once per update, not once per occurrence.
func TestResolver_AddNewDomains_DedupesDuplicateDomains(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("dup.example.com.", dns.TypeA, "10.0.0.9")
r.SetChainResolver(chain, 50)
sd := dnsconfig.ServerDomains{
Stuns: []domain.Domain{"dup.example.com"},
Turns: []domain.Domain{"dup.example.com"},
}
_, err := r.UpdateFromServerDomains(context.Background(), sd)
require.NoError(t, err)
assert.Equal(t, 1, chain.callCount("dup.example.com.", dns.TypeA),
"a domain appearing under multiple server-domain types must resolve once")
}
// A failure marker must be dropped once its domain leaves the server-domains set
// so the map stays bounded to the current set.
func TestResolver_UpdateFromServerDomains_PrunesFailedResolves(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.err = errors.New("resolve boom")
r.SetChainResolver(chain, 50)
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Signal: domain.Domain("gone.example.com")})
require.NoError(t, err)
r.mutex.RLock()
_, marked := r.failedResolves[domain.Domain("gone.example.com")]
r.mutex.RUnlock()
require.True(t, marked, "failed resolve must be recorded")
_, err = r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Signal: domain.Domain("other.example.com")})
require.NoError(t, err)
r.mutex.RLock()
_, stillMarked := r.failedResolves[domain.Domain("gone.example.com")]
r.mutex.RUnlock()
assert.False(t, stillMarked, "failure marker for a domain no longer in the set must be pruned")
}
// When one family hard-errors while the other resolves, the domain is cached
// for the working family but recorded as incomplete so the failed family is
// retried under backoff instead of being treated as fully resolved forever.
func TestResolver_AddNewDomains_RetriesPartialFamilyFailure(t *testing.T) {
d := domain.Domain("relay.example.com")
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("relay.example.com.", dns.TypeA, "10.0.0.2")
chain.setErr("relay.example.com.", dns.TypeAAAA, errors.New("servfail"))
r.SetChainResolver(chain, 50)
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: []domain.Domain{d}})
require.NoError(t, err)
r.mutex.RLock()
_, aCached := r.records[dns.Question{Name: "relay.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}]
_, marked := r.failedResolves[d]
r.mutex.RUnlock()
require.True(t, aCached, "the working family must still be cached")
require.True(t, marked, "a partial failure must be recorded so the failed family is retried")
assert.False(t, r.needsResolve(d), "within the backoff window the domain is not retried")
r.mutex.Lock()
r.failedResolves[d] = time.Now().Add(-2 * refreshBackoff)
r.mutex.Unlock()
assert.True(t, r.needsResolve(d), "after the backoff elapses the domain is retried to pick up the missing family")
}
// A family that returns NODATA (legitimately absent, e.g. an IPv4-only host) is
// not a failure: the domain must not be marked for retry, otherwise it would be
// re-resolved on every sync.
func TestResolver_AddNewDomains_NodataIsNotFailure(t *testing.T) {
d := domain.Domain("v4only.example.com")
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("v4only.example.com.", dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: []domain.Domain{d}})
require.NoError(t, err)
r.mutex.RLock()
_, marked := r.failedResolves[d]
r.mutex.RUnlock()
assert.False(t, marked, "a NODATA family must not be recorded as a failure")
assert.False(t, r.needsResolve(d), "an IPv4-only host must not be re-resolved on later syncs")
}

View File

@@ -610,10 +610,12 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
return nil, nil, 0, err
}
startPosture := time.Now()
postureChecks, err := c.getPeerPostureChecks(account, peerID)
if err != nil {
return nil, nil, 0, err
}
log.WithContext(ctx).Debugf("getPeerPostureChecks took %s", time.Since(startPosture))
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
if err != nil {

View File

@@ -41,7 +41,7 @@ func TestAffectedPeers_DependencyCoverageMatrix(t *testing.T) {
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
return affectedpeers.Change{ChangedPeerIDs: []string{s.routerPeerID}},
[]string{s.sourcePeerID}, []string{s.unrelatedPeerID}
[]string{s.sourcePeerID, s.routerPeerID}, []string{s.unrelatedPeerID}
},
},
{
@@ -106,12 +106,8 @@ func TestAffectedPeers_DependencyCoverageMatrix(t *testing.T) {
change, mustContain, mustExclude := r.build(t, s, ctx)
affected := resolveAffected(t, s.manager.Store, s.accountID, change)
for _, id := range mustContain {
assert.Contains(t, affected, id, "expected peer to be affected")
}
for _, id := range mustExclude {
assert.NotContains(t, affected, id, "peer must not be affected")
}
assert.ElementsMatch(t, affected, mustContain, "expected peer to be affected")
assert.NotContains(t, affected, mustExclude, "peer must not be affected")
})
}
}

View File

@@ -96,33 +96,54 @@ func affectedGroupID(i int) string { return fmt.Sprintf("affected-grp-%d", i)
func affectedGroupName(i int) string { return fmt.Sprintf("AffectedGroup%d", i) }
func TestCollectGroupChange_PolicyLinked(t *testing.T) {
manager, s, accountID, _, groupIDs := setupAffectedPeersTest(t)
manager, s, accountID, peerIDs, groupIDs := setupAffectedPeersTest(t)
ctx := context.Background()
_, err := manager.SavePolicy(ctx, accountID, userID, &types.Policy{
Enabled: true,
Rules: []*types.PolicyRule{
{
Enabled: true,
Sources: []string{groupIDs[0]},
Destinations: []string{groupIDs[1]},
Bidirectional: true,
Action: types.PolicyTrafficActionAccept,
Enabled: true,
Sources: []string{groupIDs[0]},
Destinations: []string{groupIDs[1]},
SourceResource: types.Resource{ID: peerIDs[0], Type: types.ResourceTypePeer},
DestinationResource: types.Resource{ID: peerIDs[1], Type: types.ResourceTypePeer},
Bidirectional: true,
Action: types.PolicyTrafficActionAccept,
},
{
Enabled: true,
Sources: []string{groupIDs[0]},
Destinations: []string{groupIDs[1]},
SourceResource: types.Resource{ID: peerIDs[2], Type: types.ResourceTypeHost},
DestinationResource: types.Resource{ID: peerIDs[3], Type: types.ResourceTypeHost},
Bidirectional: true,
Action: types.PolicyTrafficActionAccept,
},
{
Enabled: true,
Sources: []string{groupIDs[0]},
Destinations: []string{groupIDs[1]},
SourceResource: types.Resource{ID: "", Type: types.ResourceTypePeer},
DestinationResource: types.Resource{ID: "", Type: types.ResourceTypePeer},
Bidirectional: true,
Action: types.PolicyTrafficActionAccept,
},
},
}, true)
require.NoError(t, err)
groups, _ := collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[0]})
assert.Contains(t, groups, groupIDs[0])
assert.Contains(t, groups, groupIDs[1])
groups, directPeers := collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[0]})
assert.ElementsMatch(t, groups, []string{groupIDs[0], groupIDs[1]})
assert.ElementsMatch(t, directPeers, []string{peerIDs[1]})
groups, _ = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[1]})
assert.Contains(t, groups, groupIDs[0])
assert.Contains(t, groups, groupIDs[1])
groups, directPeers = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[1]})
assert.ElementsMatch(t, groups, []string{groupIDs[0], groupIDs[1]})
assert.ElementsMatch(t, directPeers, []string{peerIDs[0]})
groups, _ = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[2]})
groups, directPeers = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[2]})
assert.Empty(t, groups)
assert.Empty(t, directPeers)
}
func TestCollectGroupChange_PolicyWithDirectPeerResource(t *testing.T) {
@@ -133,20 +154,44 @@ func TestCollectGroupChange_PolicyWithDirectPeerResource(t *testing.T) {
Enabled: true,
Rules: []*types.PolicyRule{
{
Enabled: true,
Sources: []string{groupIDs[0]},
SourceResource: types.Resource{ID: peerIDs[3], Type: types.ResourceTypePeer},
Destinations: []string{groupIDs[1]},
Action: types.PolicyTrafficActionAccept,
Enabled: true,
Sources: []string{groupIDs[0]},
SourceResource: types.Resource{ID: peerIDs[3], Type: types.ResourceTypePeer},
DestinationResource: types.Resource{ID: peerIDs[4], Type: types.ResourceTypePeer},
Destinations: []string{groupIDs[1]},
Action: types.PolicyTrafficActionAccept,
},
{
Enabled: true,
Sources: []string{groupIDs[0]},
SourceResource: types.Resource{ID: peerIDs[1], Type: types.ResourceTypeHost},
DestinationResource: types.Resource{ID: peerIDs[2], Type: types.ResourceTypeHost},
Destinations: []string{groupIDs[1]},
Action: types.PolicyTrafficActionAccept,
},
{
Enabled: true,
Sources: []string{groupIDs[0]},
SourceResource: types.Resource{ID: "", Type: types.ResourceTypePeer},
DestinationResource: types.Resource{ID: "", Type: types.ResourceTypePeer},
Destinations: []string{groupIDs[1]},
Action: types.PolicyTrafficActionAccept,
},
},
}, true)
require.NoError(t, err)
groups, directPeers := collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[0]})
assert.Contains(t, groups, groupIDs[0])
assert.Contains(t, groups, groupIDs[1])
assert.Contains(t, directPeers, peerIDs[3])
assert.ElementsMatch(t, groups, []string{groupIDs[0], groupIDs[1]})
assert.ElementsMatch(t, directPeers, []string{peerIDs[4]})
groups, directPeers = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[1]})
assert.ElementsMatch(t, groups, []string{groupIDs[0], groupIDs[1]})
assert.ElementsMatch(t, directPeers, []string{peerIDs[3]})
groups, directPeers = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[2]})
assert.Empty(t, groups)
assert.Empty(t, directPeers)
}
func TestCollectGroupChange_PolicyWithNonPeerResource_NoDirectPeers(t *testing.T) {
@@ -168,8 +213,7 @@ func TestCollectGroupChange_PolicyWithNonPeerResource_NoDirectPeers(t *testing.T
require.NoError(t, err)
groups, directPeers := collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[0]})
assert.Contains(t, groups, groupIDs[0])
assert.Contains(t, groups, groupIDs[1])
assert.ElementsMatch(t, groups, []string{groupIDs[0], groupIDs[1]})
assert.Empty(t, directPeers, "non-peer resources should not produce direct peer IDs")
}
@@ -373,17 +417,11 @@ func TestCollectGroupChange_MultipleEntities(t *testing.T) {
require.NoError(t, err)
groups, directPeers := collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[0]})
assert.Contains(t, groups, groupIDs[0])
assert.Contains(t, groups, groupIDs[1])
assert.NotContains(t, groups, groupIDs[2])
assert.NotContains(t, groups, groupIDs[3])
assert.ElementsMatch(t, groups, []string{groupIDs[0], groupIDs[1]})
assert.Empty(t, directPeers)
groups, directPeers = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[3]})
assert.Contains(t, groups, groupIDs[2])
assert.Contains(t, groups, groupIDs[3])
assert.NotContains(t, groups, groupIDs[0])
assert.NotContains(t, groups, groupIDs[1])
assert.ElementsMatch(t, groups, []string{groupIDs[2], groupIDs[3]})
assert.Empty(t, directPeers)
}
@@ -452,8 +490,9 @@ func TestResolveAffectedPeers_PolicyBetweenTwoGroups(t *testing.T) {
result = manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[1]})
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[1]}, result)
// peerIDs[2] is unrelated to the route; only its own map can change.
result = manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[2]})
assert.Empty(t, result)
assert.ElementsMatch(t, []string{peerIDs[2]}, result)
}
func TestResolveAffectedPeers_PolicyThreeGroups(t *testing.T) {
@@ -474,7 +513,7 @@ func TestResolveAffectedPeers_PolicyThreeGroups(t *testing.T) {
require.NoError(t, err)
result := manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[0]})
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[1], peerIDs[2]}, result)
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[2]}, result)
}
func TestResolveAffectedPeers_RoutePeerGroups(t *testing.T) {
@@ -506,8 +545,9 @@ func TestResolveAffectedPeers_RoutePeerGroups(t *testing.T) {
result = manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[1]})
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[1]}, result)
// peerIDs[2] is in no policy; only its own map can change, so it refreshes itself.
result = manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[2]})
assert.Empty(t, result)
assert.ElementsMatch(t, []string{peerIDs[2]}, result)
}
func TestResolveAffectedPeers_RouteWithDirectPeer(t *testing.T) {
@@ -564,9 +604,9 @@ func TestResolveAffectedPeers_RouteWithAccessControlGroups(t *testing.T) {
result := manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[2]})
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[1], peerIDs[2]}, result)
// peer3 is unrelated
// peer3 is unrelated to the route; only its own map can change.
result = manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[3]})
assert.Empty(t, result)
assert.ElementsMatch(t, []string{peerIDs[3]}, result)
}
func TestResolveAffectedPeers_NetworkRouter(t *testing.T) {
@@ -659,9 +699,13 @@ func TestResolveAffectedPeers_PeerInMultipleGroups(t *testing.T) {
}, true)
require.NoError(t, err)
// peer0 is in group0 AND group1, so both policies apply
// peer0 is in group0 AND group1, so both policies apply. A peer change folds
// only the changed peer plus the opposite side of each rule: group2 (peer2) via
// the group0 policy and group3 (peer3) via the group1 policy. peer1, a co-member
// of group1, is a sibling of the changed peer and must NOT refresh.
result := manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[0]})
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[1], peerIDs[2], peerIDs[3]}, result)
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[2], peerIDs[3]}, result)
assert.NotContains(t, result, peerIDs[1], "co-member of the changed peer's group must not refresh")
}
func TestResolveAffectedPeers_MultipleChangedPeers(t *testing.T) {
@@ -697,7 +741,7 @@ func TestResolveAffectedPeers_MultipleChangedPeers(t *testing.T) {
require.NoError(t, err)
result := manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[0], peerIDs[2]})
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[1], peerIDs[2], peerIDs[3]}, result)
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[2], peerIDs[1], peerIDs[3]}, result)
}
func TestResolveAffectedPeers_SharedGroupAcrossPolicyAndRoute(t *testing.T) {
@@ -854,8 +898,9 @@ func TestAffectedPeers_IsolatedPolicies(t *testing.T) {
assert.NotContains(t, result, peerIDs[0])
assert.NotContains(t, result, peerIDs[1])
// peerIDs[4] is in neither isolated policy; only its own map can change.
result = manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[4]})
assert.Empty(t, result)
assert.ElementsMatch(t, []string{peerIDs[4]}, result)
}
func TestAffectedPeers_IsolatedRouteAndPolicy(t *testing.T) {
@@ -977,12 +1022,13 @@ func TestAffectedPeers_GroupUpdateOnlyAffectsLinkedPeers(t *testing.T) {
})
}
func TestAffectedPeers_UnlinkedGroupChange_NoUpdates(t *testing.T) {
// A peer in no policy/route refreshes only itself — no other peer is affected.
func TestAffectedPeers_UnlinkedPeerChange_RefreshesSelfOnly(t *testing.T) {
manager, s, accountID, peerIDs, _ := setupAffectedPeersTest(t)
ctx := context.Background()
result := manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[0]})
assert.Empty(t, result)
assert.ElementsMatch(t, []string{peerIDs[0]}, result)
}
// TestAffectedPeers_PolicyChange_UnrelatedPeerNoUpdate verifies that creating/deleting a

View File

@@ -61,7 +61,8 @@ func Load(ctx context.Context, s store.Store, accountID string, c Change) (*Snap
// loadCollections reads the policy/route/nameserver/dns/router/resource/proxy
// collections a Change can touch, gated to what the walk needs.
func (snap *Snapshot) loadCollections(ctx context.Context, s store.Store, accountID string, c Change) error {
hasGroupOrPeerChange := len(c.ChangedGroupIDs) > 0 || len(c.ChangedPeerIDs) > 0 || len(c.Resources) > 0
// LinkGroups drive the same policy/route/dns walk as a changed group or peer.
hasGroupOrPeerChange := len(c.ChangedGroupIDs) > 0 || len(c.ChangedPeerIDs) > 0 || len(c.LinkGroups) > 0 || len(c.Resources) > 0
hasNetworkObject := len(c.Routers) > 0 || len(c.Resources) > 0 || len(c.Networks) > 0
// the resource<->router bridge can fire for any of these
needsRoutersResources := hasGroupOrPeerChange || len(c.PostureCheckIDs) > 0 || len(c.Policies) > 0 || hasNetworkObject
@@ -76,7 +77,7 @@ func (snap *Snapshot) loadCollections(ctx context.Context, s store.Store, accoun
return err
}
}
if len(c.ChangedGroupIDs) > 0 || len(c.ChangedPeerIDs) > 0 {
if len(c.ChangedGroupIDs) > 0 || len(c.ChangedPeerIDs) > 0 || len(c.LinkGroups) > 0 {
if err := snap.loadDNS(ctx, s, accountID); err != nil {
return err
}
@@ -174,6 +175,24 @@ type Change struct {
// folded in — but only when the group is linked (an unlinked group has no map
// impact), matching how current members are handled.
RemovedPeersByGroup map[string][]string
// OutputPeerIDs are peers folded straight into the result without seeding their
// group memberships into the walk. Use for the peer whose group membership changed:
// the peer itself must refresh, but its OTHER groups did not change, so they must
// not be walked. Contrast ChangedPeerIDs, which seeds ALL of the peer's groups
// (correct when the peer's own attributes changed, e.g. IP/status).
OutputPeerIDs []string
// LinkGroups are groups used ONLY to match policies/routes/routers and walk to the
// OPPOSITE side — they are never expanded to their own members. Use this when a
// peer's group membership changed: pass the peer in ChangedPeerIDs and its
// group(s) here. The opposite side of the policies the group participates in
// refreshes, but the group's other members (siblings) do not — nothing changed for
// them. For an intra-group policy (A→A) the opposite side IS the group, so its
// members still refresh via the opposite-side fold, exactly when they genuinely
// gain/lose the changed peer. Unlike ChangedGroupIDs, a LinkGroup is not added to
// the output, so a one-sided membership change never wakes the whole group.
LinkGroups []string
}
func (c Change) isEmpty() bool {
@@ -186,7 +205,9 @@ func (c Change) isEmpty() bool {
len(c.Networks) == 0 &&
len(c.PostureCheckIDs) == 0 &&
len(c.DistributionGroupIDs) == 0 &&
len(c.RemovedPeersByGroup) == 0
len(c.RemovedPeersByGroup) == 0 &&
len(c.LinkGroups) == 0 &&
len(c.OutputPeerIDs) == 0
}
// Expand returns the deduplicated affected peer IDs from the preloaded Snapshot,
@@ -197,8 +218,8 @@ func (snap *Snapshot) Expand(ctx context.Context, accountID string, c Change) []
return nil
}
r := newResolver(ctx, snap, accountID, c)
log.WithContext(ctx).Tracef("affectedpeers expand start: account=%s changedGroups=%v changedPeers=%v policies=%d routes=%d routers=%d resources=%d networks=%d postureChecks=%v distributionGroups=%v",
accountID, c.ChangedGroupIDs, c.ChangedPeerIDs, len(c.Policies), len(c.Routes), len(c.Routers), len(c.Resources), len(c.Networks), c.PostureCheckIDs, c.DistributionGroupIDs)
log.WithContext(ctx).Tracef("affectedpeers expand start: account=%s changedGroups=%v changedPeers=%v linkGroups=%v policies=%d routes=%d routers=%d resources=%d networks=%d postureChecks=%v distributionGroups=%v",
accountID, c.ChangedGroupIDs, c.ChangedPeerIDs, c.LinkGroups, len(c.Policies), len(c.Routes), len(c.Routers), len(c.Resources), len(c.Networks), c.PostureCheckIDs, c.DistributionGroupIDs)
r.walk()
return r.expand()
}
@@ -216,57 +237,84 @@ func Collect(ctx context.Context, s store.Store, accountID string, c Change) (gr
}
r := newResolver(ctx, snap, accountID, c)
r.walk()
return setToSlice(r.groupSet), setToSlice(r.peerSet)
return setToSlice(r.affectedGroups), setToSlice(r.affectedPeers)
}
func newResolver(ctx context.Context, snap *Snapshot, accountID string, c Change) *resolver {
r := &resolver{
ctx: ctx,
snap: snap,
accountID: accountID,
change: c,
changedGroupSet: toSet(c.ChangedGroupIDs),
changedPeerSet: toSet(c.ChangedPeerIDs),
groupSet: make(map[string]struct{}),
peerSet: make(map[string]struct{}),
networkIDs: make(map[string]struct{}),
ctx: ctx,
snap: snap,
accountID: accountID,
change: c,
linkGroups: toSet(c.ChangedGroupIDs),
outputGroups: toSet(c.ChangedGroupIDs),
changedPeers: toSet(c.ChangedPeerIDs),
affectedGroups: make(map[string]struct{}),
affectedPeers: make(map[string]struct{}),
}
// LinkGroups match policies/routes to find the opposite side but are NOT output:
// they go into linkGroups only, never outputGroups, so their members never fold in.
addAll(r.linkGroups, c.LinkGroups)
// Resolve each changed peer to its groups here so callers pass only ChangedPeerIDs.
r.seedChangedGroupsFromPeers()
r.matchedPolicies = append(r.matchedPolicies, c.Policies...)
return r
}
// seedChangedGroupsFromPeers adds each changed peer's groups to changedGroupSet so
// seedChangedGroupsFromPeers adds each changed peer's groups to linkGroups so
// the group-driven walkers fire for memberships, not just direct peer references.
// These seeded groups are for MATCHING only — folding the changed entity's own
// side is gated on outputGroups (the caller-reported groups), so a seeded group
// never folds its whole membership; only the changed peer itself folds in.
func (r *resolver) seedChangedGroupsFromPeers() {
if len(r.changedPeerSet) == 0 {
if len(r.changedPeers) == 0 {
return
}
for groupID, members := range r.snap.groupPeers {
for pID := range r.changedPeerSet {
for pID := range r.changedPeers {
if _, ok := members[pID]; ok {
r.changedGroupSet[groupID] = struct{}{}
r.linkGroups[groupID] = struct{}{}
break
}
}
}
}
// policySide selects which side of a policy rule to walk.
type policySide int
const (
sideSource policySide = iota
sideDestination
)
func (s policySide) opposite() policySide {
if s == sideSource {
return sideDestination
}
return sideSource
}
// walk resolves affected peers in two buckets, by how far each change propagates.
//
// BOTH-SIDES — the rule itself changed (an explicit policy edit, or a policy whose
// posture check changed). Source AND destination refresh, so each such policy is
// walked on both sides.
//
// OPPOSITE-SIDE — an endpoint moved but no rule changed. For each policy the change
// touches we fold only the side AWAY from the change:
// - a changed peer/group sits ON a policy side -> fold the opposite side;
// - a changed router/resource/network sits on a NETWORK -> fold the SOURCE side of
// the policies whose destination reaches it (and the routers it implies).
//
// Routes, nameserver groups, DNS and embedded-proxy services distribute to their own
// member peers, outside the policy graph, and are folded here too.
func (r *resolver) walk() {
r.collectFromExplicitPolicies()
r.collectFromExplicitRoutes(r.change.Routes)
r.collectFromExplicitRouters(r.change.Routers)
r.collectFromExplicitResources(r.change.Resources)
r.collectFromExplicitNetworks(r.change.Networks)
r.collectFromPostureChecks(r.change.PostureCheckIDs)
for _, policy := range r.bothSidesPolicies() {
r.foldPolicySide(policy, sideSource)
r.foldPolicySide(policy, sideDestination)
}
// Distribution groups (nameserver/DNS) affect only their member peers: fold them
// straight into groupSet so expand() maps them to members, without the policy/
// route walk that changedGroupSet would trigger.
addAll(r.groupSet, r.change.DistributionGroupIDs)
if len(r.changedGroupSet) > 0 || len(r.changedPeerSet) > 0 {
if len(r.linkGroups) > 0 || len(r.changedPeers) > 0 {
r.collectFromPolicies()
r.collectFromRoutes()
r.collectFromNameServers()
@@ -275,7 +323,31 @@ func (r *resolver) walk() {
r.collectFromProxyServices()
}
r.collectResourceRouterBridge()
r.collectFromChangedRoutes(r.change.Routes)
r.collectFromChangedRouters(r.change.Routers)
r.collectFromChangedResources(r.change.Resources)
r.collectFromChangedNetworks(r.change.Networks)
// The explicitly changed peers always refresh their own maps. OnPeersUpdated only
// refreshes the resolver's output (it ignores the separately-passed changed peers),
// so the changed peer reaches its own new map only via here. An offline/deleted
// peer in the set is filtered downstream (filterConnectedAffectedPeers).
addAll(r.affectedPeers, setToSlice(r.changedPeers))
// OutputPeerIDs refresh themselves too, but unlike changedPeers their group
// memberships were not seeded into the walk (only the changed group was).
addAll(r.affectedPeers, r.change.OutputPeerIDs)
// Distribution groups (nameserver/DNS) affect only their member peers: fold them
// straight into affectedGroups so expand() maps them to members, without the
// policy/route walk that linkGroups would trigger.
addAll(r.affectedGroups, r.change.DistributionGroupIDs)
}
// bothSidesPolicies are the policies whose rule changed: the explicitly edited ones
// plus those gated by a changed posture check. walk folds both their sides.
func (r *resolver) bothSidesPolicies() []*types.Policy {
policies := append([]*types.Policy(nil), r.change.Policies...)
return r.appendPoliciesForPostureChecks(policies, r.change.PostureCheckIDs)
}
type resolver struct {
@@ -284,14 +356,25 @@ type resolver struct {
accountID string
change Change
changedGroupSet map[string]struct{}
changedPeerSet map[string]struct{}
// Inputs — what changed. Set once at construction, read-only during the walk
// (except linkGroups, which collectFromExplicitResources also seeds).
//
// linkGroups is the MATCH set: caller-changed groups the groups of changed
// peers changed-resource groups. A rule/route/router matches the change when
// one of its groups is here — used only to find the opposite side to fold.
//
// outputGroups is the FOLD-WHOLE-GROUP set: ONLY Change.ChangedGroupIDs. When a
// matched group is here, its whole membership is affected. A peer-seeded group
// is in linkGroups but NOT outputGroups, so it folds only the changed peer
// (changedPeers), never its siblings.
linkGroups map[string]struct{}
outputGroups map[string]struct{}
changedPeers map[string]struct{}
groupSet map[string]struct{}
peerSet map[string]struct{}
matchedPolicies []*types.Policy
networkIDs map[string]struct{}
// Outputs — the answer. The only sets the walk accumulates into. affectedGroups
// is expanded to its member peers in expand().
affectedGroups map[string]struct{}
affectedPeers map[string]struct{}
}
func (r *resolver) policies() []*types.Policy { return r.snap.policies }
@@ -301,10 +384,10 @@ func (r *resolver) networkResources() []*resourceTypes.NetworkResource { return
func (r *resolver) networkRouters() []*routerTypes.NetworkRouter { return r.snap.routers }
// peerIDsForGroups maps a group set to its member peer IDs via the preloaded index.
func (r *resolver) peerIDsForGroups(groupSet map[string]struct{}) []string {
func (r *resolver) peerIDsForGroups(groups map[string]struct{}) []string {
seen := make(map[string]struct{})
var ids []string
for gID := range groupSet {
for gID := range groups {
for pID := range r.snap.groupPeers[gID] {
if _, ok := seen[pID]; ok {
continue
@@ -317,25 +400,25 @@ func (r *resolver) peerIDsForGroups(groupSet map[string]struct{}) []string {
}
func (r *resolver) expand() []string {
peerIDs := r.peerIDsForGroups(r.groupSet)
peerIDs := r.peerIDsForGroups(r.affectedGroups)
log.WithContext(r.ctx).Tracef("affectedpeers expand: account=%s affectedGroups=%v -> %d group-member peers; direct peers=%v",
r.accountID, setToSlice(r.groupSet), len(peerIDs), setToSlice(r.peerSet))
r.accountID, setToSlice(r.affectedGroups), len(peerIDs), setToSlice(r.affectedPeers))
seen := make(map[string]struct{}, len(peerIDs))
for _, id := range peerIDs {
seen[id] = struct{}{}
}
for id := range r.peerSet {
for id := range r.affectedPeers {
if _, ok := seen[id]; !ok {
peerIDs = append(peerIDs, id)
seen[id] = struct{}{}
}
}
// Fold in removed peers only when their group is linked (in groupSet).
// Fold in removed peers only when their group is linked (in affectedGroups).
for groupID, removed := range r.change.RemovedPeersByGroup {
if _, linked := r.groupSet[groupID]; !linked {
if _, linked := r.affectedGroups[groupID]; !linked {
continue
}
for _, id := range removed {
@@ -351,169 +434,300 @@ func (r *resolver) expand() []string {
return peerIDs
}
func (r *resolver) collectFromExplicitPolicies() {
for _, policy := range r.matchedPolicies {
if policy == nil {
continue
}
log.WithContext(r.ctx).Tracef("collectFromExplicitPolicies: changed policy %s (%s) -> folding rule groups %v + direct peers",
policy.ID, policy.Name, policy.RuleGroups())
addAll(r.groupSet, policy.RuleGroups())
collectPolicyDirectPeers(policy, r.peerSet)
// ruleSideGroups / ruleSideResource return the groups and the resource on the given
// side of a rule.
func ruleSideGroups(rule *types.PolicyRule, side policySide) []string {
if side == sideDestination {
return rule.Destinations
}
return rule.Sources
}
func (r *resolver) collectFromExplicitRoutes(routes []*route.Route) {
for _, rt := range routes {
if rt == nil {
continue
}
log.WithContext(r.ctx).Tracef("collectFromExplicitRoutes: changed route %s -> folding groups=%v peerGroups=%v accessControlGroups=%v peer=%q",
rt.ID, rt.Groups, rt.PeerGroups, rt.AccessControlGroups, rt.Peer)
addAll(r.groupSet, rt.Groups, rt.PeerGroups, rt.AccessControlGroups)
if rt.Peer != "" {
r.peerSet[rt.Peer] = struct{}{}
}
func ruleSideResource(rule *types.PolicyRule, side policySide) types.Resource {
if side == sideDestination {
return rule.DestinationResource
}
return rule.SourceResource
}
// collectFromExplicitRouters folds changed routers' peers and marks their networks
// for the bridge. Passing the old router keeps a repointed router's previous peers
// affected without a post-commit read.
func (r *resolver) collectFromExplicitRouters(routers []*routerTypes.NetworkRouter) {
for _, router := range routers {
if router == nil {
continue
}
log.WithContext(r.ctx).Tracef("collectFromExplicitRouters: changed router %s on network %s -> folding peerGroups=%v peer=%q and marking network for source bridge",
router.ID, router.NetworkID, router.PeerGroups, router.Peer)
addAll(r.groupSet, router.PeerGroups)
if router.Peer != "" {
r.peerSet[router.Peer] = struct{}{}
}
if router.NetworkID != "" {
r.networkIDs[router.NetworkID] = struct{}{}
}
}
}
// collectFromExplicitResources marks changed resources' networks for the bridge and
// treats their group IDs as changed, so policies targeting the resource via a
// now-detached (old) group still refresh.
func (r *resolver) collectFromExplicitResources(resources []*resourceTypes.NetworkResource) {
for _, resource := range resources {
if resource == nil {
continue
}
log.WithContext(r.ctx).Tracef("collectFromExplicitResources: changed resource %s on network %s -> marking network for bridge and treating groups %v as changed",
resource.ID, resource.NetworkID, resource.GroupIDs)
addAll(r.changedGroupSet, resource.GroupIDs)
if resource.NetworkID != "" {
r.networkIDs[resource.NetworkID] = struct{}{}
}
}
}
// collectFromExplicitNetworks marks changed networks for the bridge. A network has
// no groups/peers of its own.
func (r *resolver) collectFromExplicitNetworks(networks []*networkTypes.Network) {
for _, network := range networks {
if network == nil {
continue
}
log.WithContext(r.ctx).Tracef("collectFromExplicitNetworks: changed network %s -> marking for bridge", network.ID)
if network.ID != "" {
r.networkIDs[network.ID] = struct{}{}
}
}
}
func (r *resolver) collectFromPostureChecks(postureCheckIDs []string) {
if len(postureCheckIDs) == 0 {
// foldPolicySide folds one side of a policy down to affected peers: its groups
// (resolved to members in expand) and its direct peer. When the side is the
// DESTINATION and references a network resource (directly or via a destination
// group's resources), it also folds the routers that serve that resource's network
// — a destination resource is reached through its routers. A resource on the SOURCE
// side routes to nobody (GetPoliciesForNetworkResource matches destinations only),
// so the router hop is destination-only.
func (r *resolver) foldPolicySide(policy *types.Policy, side policySide) {
if policy == nil {
return
}
for _, rule := range policy.Rules {
addAll(r.affectedGroups, ruleSideGroups(rule, side))
res := ruleSideResource(rule, side)
if res.Type == types.ResourceTypePeer && res.ID != "" {
r.affectedPeers[res.ID] = struct{}{}
}
}
if side == sideDestination {
r.foldRoutersForResources(r.policyDestinationResourceIDs(policy))
}
}
// appendPoliciesForPostureChecks appends every policy that references a changed
// posture check (a rule change, so walk both sides).
func (r *resolver) appendPoliciesForPostureChecks(policies []*types.Policy, postureCheckIDs []string) []*types.Policy {
if len(postureCheckIDs) == 0 {
return policies
}
ids := toSet(postureCheckIDs)
for _, policy := range r.policies() {
if !policyReferencesPostureChecks(policy, ids) {
continue
}
log.WithContext(r.ctx).Tracef("collectFromPostureChecks: policy %s (%s) references changed posture checks %v -> folding rule groups %v + direct peers",
policy.ID, policy.Name, postureCheckIDs, policy.RuleGroups())
addAll(r.groupSet, policy.RuleGroups())
collectPolicyDirectPeers(policy, r.peerSet)
r.matchedPolicies = append(r.matchedPolicies, policy)
log.WithContext(r.ctx).Tracef("appendPoliciesForPostureChecks: policy %s (%s) references changed posture checks %v -> both-sides policy",
policy.ID, policy.Name, postureCheckIDs)
policies = append(policies, policy)
}
return policies
}
// collectFromPolicies folds, for every policy whose rule a changed group or peer
// touches, only the OPPOSITE side (down to peers, incl. destination routers), plus
// the changed entity's own side: the changed group's whole membership when the
// group itself changed (outputGroups), or the changed peer alone when matched via a
// peer-seeded group (never its co-members).
func (r *resolver) collectFromPolicies() {
for _, policy := range r.policies() {
for _, rule := range policy.Rules {
r.foldRuleSideIfChanged(policy, rule, sideSource)
r.foldRuleSideIfChanged(policy, rule, sideDestination)
}
}
}
func (r *resolver) collectFromPolicies() {
for _, policy := range r.policies() {
matchedByGroup := policyReferencesGroups(policy, r.changedGroupSet)
matchedByPeer := len(r.changedPeerSet) > 0 && policyReferencesDirectPeers(policy, r.changedPeerSet)
if !matchedByGroup && !matchedByPeer {
// foldRuleSideIfChanged: when a changed group or direct peer sits on `side` of the
// rule, fold the opposite side fully (groups/peers + destination routers) and fold
// the changed entity's own side (the whole changed group, or the changed peer alone).
func (r *resolver) foldRuleSideIfChanged(policy *types.Policy, rule *types.PolicyRule, side policySide) {
nearGroups := ruleSideGroups(rule, side)
nearResource := ruleSideResource(rule, side)
matchedByGroup := anyInSet(nearGroups, r.linkGroups)
matchedByPeer := isDirectPeerInSet(nearResource, r.changedPeers)
if !matchedByGroup && !matchedByPeer {
return
}
// Opposite side, fully down to peers (a destination opposite also folds routers).
r.foldPolicySideForRule(policy, rule, side.opposite())
// Own side: fold the whole changed group's members only when the group itself
// changed (outputGroups). A peer-seeded or link-only group is not folded here —
// its siblings never refresh. The changed peers themselves are folded once, after
// the walk (see walk()).
for _, gID := range nearGroups {
if _, ok := r.outputGroups[gID]; ok {
r.affectedGroups[gID] = struct{}{}
}
}
// When the changed side IS a destination, the resources it targets are reached
// through their network's routers, so those routers refresh too (e.g. attaching a
// resource to a destination group, or a changed destination group/resource).
if side == sideDestination {
r.foldRoutersForResources(r.ruleDestinationResourceIDs(rule))
}
}
// foldPolicySideForRule folds one side of a single rule (groups + direct peer), and
// for a destination side the routers of that rule's destination resources.
func (r *resolver) foldPolicySideForRule(policy *types.Policy, rule *types.PolicyRule, side policySide) {
addAll(r.affectedGroups, ruleSideGroups(rule, side))
res := ruleSideResource(rule, side)
if res.Type == types.ResourceTypePeer && res.ID != "" {
r.affectedPeers[res.ID] = struct{}{}
}
if side == sideDestination {
r.foldRoutersForResources(r.ruleDestinationResourceIDs(rule))
}
}
// collectFromChangedRoutes folds an explicitly changed route's own groups and peer.
func (r *resolver) collectFromChangedRoutes(routes []*route.Route) {
for _, rt := range routes {
if rt == nil {
continue
}
log.WithContext(r.ctx).Tracef("collectFromPolicies: policy %s (%s) matched (byGroup=%t byPeer=%t) -> folding rule groups %v + direct peers",
policy.ID, policy.Name, matchedByGroup, matchedByPeer, policy.RuleGroups())
addAll(r.groupSet, policy.RuleGroups())
collectPolicyDirectPeers(policy, r.peerSet)
r.matchedPolicies = append(r.matchedPolicies, policy)
log.WithContext(r.ctx).Tracef("collectFromChangedRoutes: changed route %s -> folding groups=%v peerGroups=%v accessControlGroups=%v peer=%q",
rt.ID, rt.Groups, rt.PeerGroups, rt.AccessControlGroups, rt.Peer)
addAll(r.affectedGroups, rt.Groups, rt.PeerGroups, rt.AccessControlGroups)
if rt.Peer != "" {
r.affectedPeers[rt.Peer] = struct{}{}
}
}
}
// collectFromChangedRouters: a changed router refreshes its OWN backing peer/groups
// (the changed entity) and the SOURCE side of every policy reaching a resource on
// its network (the router serves the whole network). Sibling routers on the network
// are independent and are NOT folded. Passing the old router state keeps a repointed
// router's previous backing affected without a post-commit read.
func (r *resolver) collectFromChangedRouters(routers []*routerTypes.NetworkRouter) {
for _, router := range routers {
if router == nil {
continue
}
log.WithContext(r.ctx).Tracef("collectFromChangedRouters: changed router %s on network %s -> folding its own peerGroups=%v peer=%q + sources reaching network resources",
router.ID, router.NetworkID, router.PeerGroups, router.Peer)
addAll(r.affectedGroups, router.PeerGroups)
if router.Peer != "" {
r.affectedPeers[router.Peer] = struct{}{}
}
if router.NetworkID != "" {
r.foldPolicySourcesForResources(r.networkResourceIDs(router.NetworkID))
}
}
}
// collectFromChangedResources: a changed resource refreshes the SOURCE side of the
// policies targeting EXACTLY that resource — directly, or via one of the resource's
// own groups (oldnew across the change, so a now-detached group's sources still
// refresh) — plus the routers serving its network (the resource is reached through
// them). It does not touch sibling resources on the same network.
func (r *resolver) collectFromChangedResources(resources []*resourceTypes.NetworkResource) {
for _, resource := range resources {
if resource == nil {
continue
}
log.WithContext(r.ctx).Tracef("collectFromChangedResources: changed resource %s on network %s (groups %v) -> folding sources of policies targeting it + its network's routers",
resource.ID, resource.NetworkID, resource.GroupIDs)
r.foldPolicySourcesForResource(resource.ID, resource.GroupIDs)
if resource.NetworkID != "" {
r.foldRoutersOnNetworks(map[string]struct{}{resource.NetworkID: {}})
}
}
}
// foldPolicySourcesForResource folds the source side of every policy whose
// destination is the given resource — referenced directly, or via any of the given
// groups (the resource's own oldnew groups, which captures a detached group).
func (r *resolver) foldPolicySourcesForResource(resourceID string, groupIDs []string) {
groups := toSet(groupIDs)
for _, policy := range r.policies() {
if !policyTargetsResourceOrGroups(policy, resourceID, groups) {
continue
}
log.WithContext(r.ctx).Tracef("foldPolicySourcesForResource: policy %s (%s) targets changed resource %s -> folding its source groups/peers", policy.ID, policy.Name, resourceID)
collectPolicySources(policy, r.affectedGroups, r.affectedPeers)
}
}
// policyTargetsResourceOrGroups reports whether a policy's destination is the given
// resource directly, or one of the given destination groups.
func policyTargetsResourceOrGroups(policy *types.Policy, resourceID string, groups map[string]struct{}) bool {
if policy == nil {
return false
}
for _, rule := range policy.Rules {
if rule.DestinationResource.Type != types.ResourceTypePeer && rule.DestinationResource.ID == resourceID && resourceID != "" {
return true
}
if anyInSet(rule.Destinations, groups) {
return true
}
}
return false
}
// collectFromChangedNetworks: a changed network refreshes the SOURCE side of the
// policies reaching any of its resources, plus its routers. A network has no
// groups/peers of its own.
func (r *resolver) collectFromChangedNetworks(networks []*networkTypes.Network) {
for _, network := range networks {
if network == nil || network.ID == "" {
continue
}
log.WithContext(r.ctx).Tracef("collectFromChangedNetworks: changed network %s -> folding sources reaching its resources + its routers", network.ID)
resourceIDs := r.networkResourceIDs(network.ID)
r.foldPolicySourcesForResources(resourceIDs)
r.foldRoutersOnNetworks(map[string]struct{}{network.ID: {}})
}
}
// foldPolicySourcesForResources folds the source groups/peers of every policy whose
// destination targets one of resourceIDs (directly or via a destination group).
func (r *resolver) foldPolicySourcesForResources(resourceIDs map[string]struct{}) {
if len(resourceIDs) == 0 {
return
}
for _, policy := range r.policies() {
if r.policyTargetsResources(policy, resourceIDs) {
log.WithContext(r.ctx).Tracef("foldPolicySourcesForResources: policy %s (%s) targets a changed resource -> folding its source groups/peers", policy.ID, policy.Name)
collectPolicySources(policy, r.affectedGroups, r.affectedPeers)
}
}
}
func (r *resolver) collectFromRoutes() {
for _, rt := range r.snap.routes {
matchedByGroup := anyInSet(rt.Groups, r.changedGroupSet) || anyInSet(rt.PeerGroups, r.changedGroupSet) || anyInSet(rt.AccessControlGroups, r.changedGroupSet)
matchedByPeer := rt.Peer != "" && len(r.changedPeerSet) > 0 && isInSet(rt.Peer, r.changedPeerSet)
matchedByGroup := anyInSet(rt.Groups, r.linkGroups) || anyInSet(rt.PeerGroups, r.linkGroups) || anyInSet(rt.AccessControlGroups, r.linkGroups)
matchedByPeer := rt.Peer != "" && len(r.changedPeers) > 0 && isInSet(rt.Peer, r.changedPeers)
if !matchedByGroup && !matchedByPeer {
continue
}
log.WithContext(r.ctx).Tracef("collectFromRoutes: route %s matched (byGroup=%t byPeer=%t) -> folding groups=%v peerGroups=%v accessControlGroups=%v peer=%q",
rt.ID, matchedByGroup, matchedByPeer, rt.Groups, rt.PeerGroups, rt.AccessControlGroups, rt.Peer)
addAll(r.groupSet, rt.Groups, rt.PeerGroups, rt.AccessControlGroups)
addAll(r.affectedGroups, rt.Groups, rt.PeerGroups, rt.AccessControlGroups)
if rt.Peer != "" {
r.peerSet[rt.Peer] = struct{}{}
r.affectedPeers[rt.Peer] = struct{}{}
}
}
}
func (r *resolver) collectFromNameServers() {
if len(r.changedGroupSet) == 0 {
if len(r.linkGroups) == 0 {
return
}
for _, ns := range r.snap.nsGroups {
if anyInSet(ns.Groups, r.changedGroupSet) {
if anyInSet(ns.Groups, r.linkGroups) {
log.WithContext(r.ctx).Tracef("collectFromNameServers: nameserver group %s references a changed group -> folding its groups %v", ns.ID, ns.Groups)
addAll(r.groupSet, ns.Groups)
addAll(r.affectedGroups, ns.Groups)
}
}
}
func (r *resolver) collectFromDNSSettings() {
if len(r.changedGroupSet) == 0 || r.snap.dnsSettings == nil {
if len(r.linkGroups) == 0 || r.snap.dnsSettings == nil {
return
}
for _, gID := range r.snap.dnsSettings.DisabledManagementGroups {
if _, ok := r.changedGroupSet[gID]; ok {
if _, ok := r.linkGroups[gID]; ok {
log.WithContext(r.ctx).Tracef("collectFromDNSSettings: changed group %s is in DisabledManagementGroups -> folding it", gID)
r.groupSet[gID] = struct{}{}
r.affectedGroups[gID] = struct{}{}
}
}
}
// collectFromNetworkRouters handles a changed group/peer that BACKS a router (the
// routing peer set moved): the router's own peers refresh and so do the sources of
// the policies reaching its network's resources. Sibling routers on the network are
// independent and are not folded.
func (r *resolver) collectFromNetworkRouters() {
for _, router := range r.networkRouters() {
matchedByGroup := anyInSet(router.PeerGroups, r.changedGroupSet)
matchedByPeer := router.Peer != "" && len(r.changedPeerSet) > 0 && isInSet(router.Peer, r.changedPeerSet)
matchedByGroup := anyInSet(router.PeerGroups, r.linkGroups)
matchedByPeer := router.Peer != "" && len(r.changedPeers) > 0 && isInSet(router.Peer, r.changedPeers)
if !matchedByGroup && !matchedByPeer {
continue
}
log.WithContext(r.ctx).Tracef("collectFromNetworkRouters: router %s on network %s matched (byGroup=%t byPeer=%t) -> folding peerGroups=%v peer=%q and marking network for source bridge",
log.WithContext(r.ctx).Tracef("collectFromNetworkRouters: router %s on network %s matched (byGroup=%t byPeer=%t) -> folding its peerGroups=%v peer=%q + sources reaching network resources",
router.ID, router.NetworkID, matchedByGroup, matchedByPeer, router.PeerGroups, router.Peer)
addAll(r.groupSet, router.PeerGroups)
addAll(r.affectedGroups, router.PeerGroups)
if router.Peer != "" {
r.peerSet[router.Peer] = struct{}{}
r.affectedPeers[router.Peer] = struct{}{}
}
if router.NetworkID != "" {
r.foldPolicySourcesForResources(r.networkResourceIDs(router.NetworkID))
}
r.networkIDs[router.NetworkID] = struct{}{}
}
}
@@ -534,34 +748,34 @@ func (r *resolver) collectFromProxyServices() {
continue
}
matchedByPeer := serviceMatchesChangedPeers(svc, proxyPeers, expanded)
matchedByAccessGroup := anyInSet(svc.AccessGroups, r.changedGroupSet)
matchedByAccessGroup := anyInSet(svc.AccessGroups, r.linkGroups)
if !matchedByPeer && !matchedByAccessGroup {
continue
}
log.WithContext(r.ctx).Tracef("collectFromProxyServices: service %s (cluster=%s) matched (byProxyOrTargetPeer=%t byAccessGroup=%t) -> folding %d proxy peers, peer targets and access groups %v",
svc.ID, svc.ProxyCluster, matchedByPeer, matchedByAccessGroup, len(proxyPeers), svc.AccessGroups)
for _, pid := range proxyPeers {
r.peerSet[pid] = struct{}{}
r.affectedPeers[pid] = struct{}{}
}
for _, target := range svc.Targets {
if target.TargetType == rpservice.TargetTypePeer && target.TargetId != "" {
r.peerSet[target.TargetId] = struct{}{}
r.affectedPeers[target.TargetId] = struct{}{}
}
}
addAll(r.groupSet, svc.AccessGroups)
addAll(r.affectedGroups, svc.AccessGroups)
}
}
func (r *resolver) expandChangedPeersWithGroups() map[string]struct{} {
if len(r.changedGroupSet) == 0 {
return r.changedPeerSet
if len(r.linkGroups) == 0 {
return r.changedPeers
}
ids := r.peerIDsForGroups(r.changedGroupSet)
ids := r.peerIDsForGroups(r.linkGroups)
if len(ids) == 0 {
return r.changedPeerSet
return r.changedPeers
}
merged := make(map[string]struct{}, len(r.changedPeerSet)+len(ids))
for id := range r.changedPeerSet {
merged := make(map[string]struct{}, len(r.changedPeers)+len(ids))
for id := range r.changedPeers {
merged[id] = struct{}{}
}
for _, id := range ids {
@@ -570,54 +784,36 @@ func (r *resolver) expandChangedPeersWithGroups() map[string]struct{} {
return merged
}
// collectResourceRouterBridge crosses between source peers and routing peers, which
// are reachable only via resource -> network -> router, not through the policy's own
// groups: source -> router (targeted resources' networks), then router -> source.
func (r *resolver) collectResourceRouterBridge() {
r.bridgeSourceToRouters()
r.bridgeRoutersToSources()
}
func (r *resolver) bridgeSourceToRouters() {
resourceIDs := r.policyDestinationResourceIDs(r.matchedPolicies...)
// foldRoutersForResources folds the routers serving the networks of the given
// resources (a destination resource is reached through its network's routers). It is
// the resource -> network -> router hop used by foldPolicySide for a destination.
func (r *resolver) foldRoutersForResources(resourceIDs map[string]struct{}) {
if len(resourceIDs) == 0 {
return
}
networkIDs := r.resourceNetworkIDs(resourceIDs)
log.WithContext(r.ctx).Tracef("bridgeSourceToRouters: targeted resources %v -> networks %v (their routers become affected via the router->source pass)",
setToSlice(resourceIDs), setToSlice(networkIDs))
for id := range networkIDs {
r.networkIDs[id] = struct{}{}
}
r.foldRoutersOnNetworks(r.resourceNetworkIDs(resourceIDs))
}
func (r *resolver) bridgeRoutersToSources() {
if len(r.networkIDs) == 0 {
return
// ruleDestinationResourceIDs returns the destination resource IDs of a single rule:
// the direct DestinationResource plus the resources of its destination groups.
func (r *resolver) ruleDestinationResourceIDs(rule *types.PolicyRule) map[string]struct{} {
resourceIDs := make(map[string]struct{})
if rule.DestinationResource.Type != types.ResourceTypePeer && rule.DestinationResource.ID != "" {
resourceIDs[rule.DestinationResource.ID] = struct{}{}
}
r.addGroupResourceIDs(toSet(rule.Destinations), resourceIDs)
return resourceIDs
}
log.WithContext(r.ctx).Tracef("bridgeRoutersToSources: affected networks %v -> folding their routing peers and the source peers of policies targeting their resources",
setToSlice(r.networkIDs))
r.foldRoutersOnNetworks(r.networkIDs)
// networkResourceIDs returns the IDs of all resources on the given network.
func (r *resolver) networkResourceIDs(networkID string) map[string]struct{} {
resourceIDs := make(map[string]struct{})
for _, resource := range r.networkResources() {
if _, ok := r.networkIDs[resource.NetworkID]; ok {
if resource.NetworkID == networkID {
resourceIDs[resource.ID] = struct{}{}
}
}
if len(resourceIDs) == 0 {
return
}
for _, policy := range r.policies() {
if r.policyTargetsResources(policy, resourceIDs) {
log.WithContext(r.ctx).Tracef("bridgeRoutersToSources: policy %s (%s) targets an affected-network resource -> folding its source groups/peers", policy.ID, policy.Name)
collectPolicySources(policy, r.groupSet, r.peerSet)
}
}
return resourceIDs
}
func (r *resolver) foldRoutersOnNetworks(networkIDs map[string]struct{}) {
@@ -627,9 +823,9 @@ func (r *resolver) foldRoutersOnNetworks(networkIDs map[string]struct{}) {
}
log.WithContext(r.ctx).Tracef("bridgeRoutersToSources: router %s serves affected network %s -> folding peerGroups=%v peer=%q",
router.ID, router.NetworkID, router.PeerGroups, router.Peer)
addAll(r.groupSet, router.PeerGroups)
addAll(r.affectedGroups, router.PeerGroups)
if router.Peer != "" {
r.peerSet[router.Peer] = struct{}{}
r.affectedPeers[router.Peer] = struct{}{}
}
}
}
@@ -714,44 +910,26 @@ func (r *resolver) addGroupResourceIDs(groupIDs map[string]struct{}, resourceIDs
}
}
func collectPolicyDirectPeers(policy *types.Policy, peerSet map[string]struct{}) {
func collectPolicyDirectPeers(policy *types.Policy, peers map[string]struct{}) {
for _, rule := range policy.Rules {
if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" {
peerSet[rule.SourceResource.ID] = struct{}{}
peers[rule.SourceResource.ID] = struct{}{}
}
if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" {
peerSet[rule.DestinationResource.ID] = struct{}{}
peers[rule.DestinationResource.ID] = struct{}{}
}
}
}
func collectPolicySources(policy *types.Policy, groupSet, peerSet map[string]struct{}) {
func collectPolicySources(policy *types.Policy, groups, peers map[string]struct{}) {
for _, rule := range policy.Rules {
addAll(groupSet, rule.Sources)
addAll(groups, rule.Sources)
if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" {
peerSet[rule.SourceResource.ID] = struct{}{}
peers[rule.SourceResource.ID] = struct{}{}
}
}
}
func policyReferencesGroups(policy *types.Policy, groupSet map[string]struct{}) bool {
for _, rule := range policy.Rules {
if anyInSet(rule.Sources, groupSet) || anyInSet(rule.Destinations, groupSet) {
return true
}
}
return false
}
func policyReferencesDirectPeers(policy *types.Policy, changedSet map[string]struct{}) bool {
for _, rule := range policy.Rules {
if isDirectPeerInSet(rule.SourceResource, changedSet) || isDirectPeerInSet(rule.DestinationResource, changedSet) {
return true
}
}
return false
}
func policyReferencesPostureChecks(policy *types.Policy, ids map[string]struct{}) bool {
for _, id := range policy.SourcePostureChecks {
if _, ok := ids[id]; ok {

View File

@@ -80,26 +80,6 @@ func TestChangeIsEmpty(t *testing.T) {
assert.False(t, Change{PostureCheckIDs: []string{"pc"}}.isEmpty())
}
func TestPolicyReferencesGroups(t *testing.T) {
policy := &types.Policy{Rules: []*types.PolicyRule{{Sources: []string{"g1", "g2"}, Destinations: []string{"g3"}}}}
assert.True(t, policyReferencesGroups(policy, map[string]struct{}{"g1": {}}))
assert.True(t, policyReferencesGroups(policy, map[string]struct{}{"g3": {}}))
assert.False(t, policyReferencesGroups(policy, map[string]struct{}{"g4": {}}))
assert.False(t, policyReferencesGroups(policy, map[string]struct{}{}))
}
func TestPolicyReferencesDirectPeers(t *testing.T) {
policy := &types.Policy{Rules: []*types.PolicyRule{{
SourceResource: types.Resource{Type: types.ResourceTypePeer, ID: "p1"},
DestinationResource: types.Resource{Type: types.ResourceTypeHost, ID: "r1"},
}}}
assert.True(t, policyReferencesDirectPeers(policy, map[string]struct{}{"p1": {}}))
assert.False(t, policyReferencesDirectPeers(policy, map[string]struct{}{"r1": {}}))
assert.False(t, policyReferencesDirectPeers(policy, map[string]struct{}{"p2": {}}))
}
func TestPolicyReferencesPostureChecks(t *testing.T) {
policy := &types.Policy{SourcePostureChecks: []string{"pc1", "pc2"}}

View File

@@ -520,7 +520,12 @@ func collectDeletableGroups(ctx context.Context, transaction store.Store, accoun
// GroupAddPeer appends peer to the group
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
var snap *affectedpeers.Snapshot
change := affectedpeers.Change{ChangedGroupIDs: []string{groupID}}
// A membership change affects only the peer itself and the opposite side of THIS
// group's policies — not the group's other members, and not the peer's other
// groups. LinkGroups walks only this group (matched, not expanded); OutputPeerIDs
// refreshes the peer without seeding its other group memberships. For an
// intra-group policy the opposite side is the group, so its members still refresh.
change := affectedpeers.Change{OutputPeerIDs: []string{peerID}, LinkGroups: []string{groupID}}
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err := transaction.AddPeerToGroup(ctx, accountID, peerID, groupID); err != nil {
@@ -586,10 +591,11 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID
// GroupDeletePeer removes peer from the group
func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error {
var snap *affectedpeers.Snapshot
change := affectedpeers.Change{
ChangedGroupIDs: []string{groupID},
RemovedPeersByGroup: map[string][]string{groupID: {peerID}},
}
// Same as GroupAddPeer: the removed peer and the opposite side of THIS group's
// policies refresh, not the group's other members or the peer's other groups. The
// peer is no longer in the group's index, but LinkGroups still drives the
// opposite-side walk, and OutputPeerIDs refreshes the removed peer itself.
change := affectedpeers.Change{OutputPeerIDs: []string{peerID}, LinkGroups: []string{groupID}}
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err := transaction.RemovePeerFromGroup(ctx, peerID, groupID); err != nil {
@@ -600,8 +606,6 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
return err
}
// The removed peer is carried in change.RemovedPeersByGroup and folded in
// only when the group is linked, so loading post-removal is correct.
var err error
if snap, err = affectedpeers.Load(ctx, transaction, accountID, change); err != nil {
return err

View File

@@ -1051,8 +1051,8 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
return nil, nil, nil, 0, err
}
metaDiffAffectsPosture := posture.AffectsPosture(ctx, &metaDiff, resPostureChecks)
if isStatusChanged || sync.UpdateAccountPeers || ipv6CapabilityChanged || metaDiffAffectsPosture || metaDiff.VersionChanged() || metaDiff.HostnameChanged() {
metaDiffAffectsPosture := posture.AffectsPosture(&metaDiff, resPostureChecks)
if isStatusChanged || sync.UpdateAccountPeers || ipv6CapabilityChanged || metaDiffAffectsPosture || metaDiff.VersionChanged || metaDiff.Hostname {
changedPeerIDs := []string{peer.ID}
affectedPeerIDs := am.syncPeerAffectedPeers(ctx, accountID, peer.ID, nmap, peerNotValid, metaDiffAffectsPosture)
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {

View File

@@ -107,15 +107,6 @@ type Location struct {
GeoNameID uint // city level geoname id
}
// equal reports whether two locations match. ConnectionIP is a net.IP slice, so it uses
// IP.Equal, not ==.
func (l Location) equal(other Location) bool {
return l.CountryCode == other.CountryCode &&
l.CityName == other.CityName &&
l.GeoNameID == other.GeoNameID &&
l.ConnectionIP.Equal(other.ConnectionIP)
}
// NetworkAddress is the IP address with network and MAC address of a network interface
type NetworkAddress struct {
NetIP netip.Prefix `gorm:"serializer:json"`
@@ -276,139 +267,183 @@ func (p *Peer) UpdateMetaIfNew(ctx context.Context, meta PeerSystemMeta, newLoca
return MetaDiff{}
}
versionChanged := p.Meta.WtVersion != meta.WtVersion
// Avoid overwriting UIVersion if the update was triggered sole by the CLI client
if meta.UIVersion == "" {
meta.UIVersion = p.Meta.UIVersion
}
effectiveLocation := p.Location
if newLocation != nil {
effectiveLocation = *newLocation
}
oldVersion := p.Meta.WtVersion
diff := diffMeta(p.Meta, meta, p.Location, effectiveLocation)
if diff.Updated() {
diff := diffMeta(p.Meta, meta)
if diff.Any() {
p.Meta = meta
}
p.Location = effectiveLocation
diff.VersionChanged = versionChanged
if diff.Updated() {
log.WithContext(ctx).Debug(diff.LogSummary())
locationInfo := ""
if newLocation != nil {
p.Location = *newLocation
diff.LocationChanged = true
locationInfo = fmt.Sprintf("location changed to %s, ", newLocation.ConnectionIP)
}
versionInfo := ""
if diff.VersionChanged {
versionInfo = fmt.Sprintf("version changed: %s -> %s, ", oldVersion, meta.WtVersion)
}
if diff.Any() || diff.VersionChanged || diff.LocationChanged {
log.WithContext(ctx).
Debugf("peer meta updated, %s%s%d field(s) changed: %s", versionInfo, locationInfo, len(diff.Changed), strings.Join(diff.Changed, ", "))
}
return diff
}
// MetaDiff holds a peer's full before/after state across a sync: both metas and both
// connection locations (the location lives on Peer, not PeerSystemMeta, but posture
// checks read it). Changed lists what moved, for logging and the persistence decision;
// the snapshots let a posture check be replayed against old and new. Everything is derived
// from these fields, so there are no parallel per-field flags to keep in sync.
// MetaDiff records which PeerSystemMeta fields differ between two metas. Each bool
// maps to a single struct field, except Environment, which is split into Cloud and
// Platform. Changed holds the human-readable `field: <old> -> <new>` entries so the
// existing log line and isEqual can be derived from the same comparison.
//
// VersionChanged and LocationChanged sit outside the per-meta-field set:
// VersionChanged tracks the WireGuard client version specifically (compared before
// the UIVersion fixup, to signal client upgrades) and LocationChanged tracks the
// peer's connection geo location, which lives on Peer rather than PeerSystemMeta.
// Neither contributes an entry to Changed, so the field-coverage accounting stays
// driven purely by the PeerSystemMeta comparison.
type MetaDiff struct {
OldMeta PeerSystemMeta
NewMeta PeerSystemMeta
OldLocation Location
NewLocation Location
Hostname bool
GoOS bool
Kernel bool
KernelVersion bool
Core bool
Platform bool
OS bool
OSVersion bool
WtVersion bool
UIVersion bool
SystemSerialNumber bool
SystemProductName bool
SystemManufacturer bool
EnvironmentCloud bool
EnvironmentPlatform bool
Flags bool
Capabilities bool
NetworkAddresses bool
Files bool
VersionChanged bool
LocationChanged bool
Changed []string
}
// Updated reports whether anything changed and the peer must be persisted. diffMeta fills
// Changed in the pass that builds the diff, so this is a length check, not a re-comparison.
// Pointer receiver: MetaDiff embeds two metas, so copying it per call is wasteful.
func (d *MetaDiff) Updated() bool {
// Any reports whether any PeerSystemMeta field changed.
func (d MetaDiff) Any() bool {
return len(d.Changed) != 0
}
// VersionChanged reports whether the WireGuard client version changed (a client upgrade).
func (d *MetaDiff) VersionChanged() bool {
return d.OldMeta.WtVersion != d.NewMeta.WtVersion
}
// HostnameChanged reports whether the peer's hostname changed.
func (d *MetaDiff) HostnameChanged() bool {
return d.OldMeta.Hostname != d.NewMeta.Hostname
}
// LogSummary renders the changed fields as a single human-readable line.
func (d *MetaDiff) LogSummary() string {
return fmt.Sprintf("peer meta updated, %d field(s) changed: %s",
len(d.Changed), strings.Join(d.Changed, ", "))
// Updated reports whether the peer needs to be persisted: any meta field changed
// or the geo location changed. The version flag alone does not imply a write,
// since a version change is also reflected in the WtVersion meta field.
func (d MetaDiff) Updated() bool {
return d.Any() || d.LocationChanged || d.VersionChanged
}
func metaDiff(oldMeta, newMeta PeerSystemMeta) []string {
return diffMeta(oldMeta, newMeta, Location{}, Location{}).Changed
return diffMeta(oldMeta, newMeta).Changed
}
// diffMeta snapshots a peer's old and new state and records a Changed entry per field that
// moved. It is the single source of truth for the comparison: isEqual is an empty Changed
// list, so the log line and the persistence decision can never disagree.
func diffMeta(oldMeta, newMeta PeerSystemMeta, oldLocation, newLocation Location) MetaDiff {
d := MetaDiff{OldMeta: oldMeta, NewMeta: newMeta, OldLocation: oldLocation, NewLocation: newLocation}
// diffMeta compares two metas field by field, returning both a per-field flag set
// (for callers that need to know exactly what changed, e.g. matching against
// posture checks) and the human-readable Changed list. It is the single source of
// truth for meta comparison: isEqual reports equality as an empty diff, so the log
// line, the change decision, and the flags can never disagree.
func diffMeta(oldMeta, newMeta PeerSystemMeta) MetaDiff {
var d MetaDiff
add := func(field string, oldVal, newVal any) {
d.Changed = append(d.Changed, fmt.Sprintf("%s: %v -> %v", field, oldVal, newVal))
}
if oldMeta.Hostname != newMeta.Hostname {
d.Hostname = true
add("hostname", oldMeta.Hostname, newMeta.Hostname)
}
if oldMeta.GoOS != newMeta.GoOS {
d.GoOS = true
add("goos", oldMeta.GoOS, newMeta.GoOS)
}
if oldMeta.Kernel != newMeta.Kernel {
d.Kernel = true
add("kernel", oldMeta.Kernel, newMeta.Kernel)
}
if oldMeta.KernelVersion != newMeta.KernelVersion {
d.KernelVersion = true
add("kernel_version", oldMeta.KernelVersion, newMeta.KernelVersion)
}
if oldMeta.Core != newMeta.Core {
d.Core = true
add("core", oldMeta.Core, newMeta.Core)
}
if oldMeta.Platform != newMeta.Platform {
d.Platform = true
add("platform", oldMeta.Platform, newMeta.Platform)
}
if oldMeta.OS != newMeta.OS {
d.OS = true
add("os", oldMeta.OS, newMeta.OS)
}
if oldMeta.OSVersion != newMeta.OSVersion {
d.OSVersion = true
add("os_version", oldMeta.OSVersion, newMeta.OSVersion)
}
if oldMeta.WtVersion != newMeta.WtVersion {
d.WtVersion = true
add("wt_version", oldMeta.WtVersion, newMeta.WtVersion)
}
if oldMeta.UIVersion != newMeta.UIVersion {
d.UIVersion = true
add("ui_version", oldMeta.UIVersion, newMeta.UIVersion)
}
if oldMeta.SystemSerialNumber != newMeta.SystemSerialNumber {
d.SystemSerialNumber = true
add("system_serial_number", oldMeta.SystemSerialNumber, newMeta.SystemSerialNumber)
}
if oldMeta.SystemProductName != newMeta.SystemProductName {
d.SystemProductName = true
add("system_product_name", oldMeta.SystemProductName, newMeta.SystemProductName)
}
if oldMeta.SystemManufacturer != newMeta.SystemManufacturer {
d.SystemManufacturer = true
add("system_manufacturer", oldMeta.SystemManufacturer, newMeta.SystemManufacturer)
}
if oldMeta.Environment.Cloud != newMeta.Environment.Cloud {
d.EnvironmentCloud = true
add("environment_cloud", oldMeta.Environment.Cloud, newMeta.Environment.Cloud)
}
if oldMeta.Environment.Platform != newMeta.Environment.Platform {
d.EnvironmentPlatform = true
add("environment_platform", oldMeta.Environment.Platform, newMeta.Environment.Platform)
}
if !oldMeta.Flags.isEqual(newMeta.Flags) {
d.Flags = true
add("flags", fmt.Sprintf("%+v", oldMeta.Flags), fmt.Sprintf("%+v", newMeta.Flags))
}
if !capabilitiesEqual(oldMeta.Capabilities, newMeta.Capabilities) {
d.Capabilities = true
add("capabilities", oldMeta.Capabilities, newMeta.Capabilities)
}
if !sameMultiset(oldMeta.NetworkAddresses, newMeta.NetworkAddresses) {
d.NetworkAddresses = true
add("network_addresses", fmt.Sprintf("%v", oldMeta.NetworkAddresses), fmt.Sprintf("%v", newMeta.NetworkAddresses))
}
if !sameMultiset(oldMeta.Files, newMeta.Files) {
add("files", fmt.Sprintf("%v", oldMeta.Files), fmt.Sprintf("%v", newMeta.Files))
}
if !oldLocation.equal(newLocation) {
add("connection_ip", oldLocation.ConnectionIP, newLocation.ConnectionIP)
if !sameMultiset(oldMeta.Files, newMeta.Files) {
d.Files = true
add("files", fmt.Sprintf("%v", oldMeta.Files), fmt.Sprintf("%v", newMeta.Files))
}
return d

View File

@@ -1,202 +0,0 @@
package posture
import (
"context"
"net"
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
)
// diffFrom builds a MetaDiff from the old/new snapshots AffectsPosture replays against.
func diffFrom(oldMeta, newMeta nbpeer.PeerSystemMeta, oldLoc, newLoc nbpeer.Location) *nbpeer.MetaDiff {
return &nbpeer.MetaDiff{
OldMeta: oldMeta,
NewMeta: newMeta,
OldLocation: oldLoc,
NewLocation: newLoc,
}
}
func checks(def ChecksDefinition) []*Checks {
return []*Checks{{Checks: def}}
}
func TestAffectsPosture_NilDiff(t *testing.T) {
assert.False(t, AffectsPosture(context.Background(), nil, checks(ChecksDefinition{
NBVersionCheck: &NBVersionCheck{MinVersion: "1.0.0"},
})))
}
func TestAffectsPosture_NBVersion(t *testing.T) {
c := checks(ChecksDefinition{NBVersionCheck: &NBVersionCheck{MinVersion: "1.2.0"}})
tests := []struct {
name string
oldVer, newVer string
want bool
}{
{"both above min, no flip", "1.3.0", "1.4.0", false},
{"both below min, no flip", "1.0.0", "1.1.0", false},
{"crosses up below->above", "1.1.0", "1.3.0", true},
{"crosses down above->below", "1.3.0", "1.1.0", true},
{"unparsable old only -> flip", "garbage", "1.3.0", true},
{"unparsable both -> no flip", "garbage", "junk", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
diff := diffFrom(
nbpeer.PeerSystemMeta{WtVersion: tt.oldVer},
nbpeer.PeerSystemMeta{WtVersion: tt.newVer},
nbpeer.Location{}, nbpeer.Location{},
)
assert.Equal(t, tt.want, AffectsPosture(context.Background(), diff, c))
})
}
}
func TestAffectsPosture_OSVersion_KernelBumpWithinMin(t *testing.T) {
c := checks(ChecksDefinition{OSVersionCheck: &OSVersionCheck{
Linux: &MinKernelVersionCheck{MinKernelVersion: "5.0.0"},
}})
// Kernel moves but stays above the minimum: verdict stays pass -> not affected.
withinMin := diffFrom(
nbpeer.PeerSystemMeta{GoOS: "linux", KernelVersion: "5.10.0-arch1"},
nbpeer.PeerSystemMeta{GoOS: "linux", KernelVersion: "5.15.0-arch2"},
nbpeer.Location{}, nbpeer.Location{},
)
assert.False(t, AffectsPosture(context.Background(), withinMin, c))
// Kernel drops below the minimum: verdict flips pass -> fail -> affected.
crossesDown := diffFrom(
nbpeer.PeerSystemMeta{GoOS: "linux", KernelVersion: "5.10.0-arch1"},
nbpeer.PeerSystemMeta{GoOS: "linux", KernelVersion: "4.19.0-arch1"},
nbpeer.Location{}, nbpeer.Location{},
)
assert.True(t, AffectsPosture(context.Background(), crossesDown, c))
}
func TestAffectsPosture_OSVersion_GoOSSwitchFlipsVerdict(t *testing.T) {
// Only Linux is constrained. An OS outside the switch (freebsd) passes; switching to a
// failing linux kernel flips the verdict pass -> fail.
c := checks(ChecksDefinition{OSVersionCheck: &OSVersionCheck{
Linux: &MinKernelVersionCheck{MinKernelVersion: "6.0.0"},
}})
diff := diffFrom(
nbpeer.PeerSystemMeta{GoOS: "freebsd"},
nbpeer.PeerSystemMeta{GoOS: "linux", KernelVersion: "4.19.0"},
nbpeer.Location{}, nbpeer.Location{},
)
assert.True(t, AffectsPosture(context.Background(), diff, c))
}
func TestAffectsPosture_Process_GoOSSwitchFlipsVerdict(t *testing.T) {
// Process runs at a linux path. Switching GoOS to windows (no WindowsPath configured)
// flips the verdict.
c := checks(ChecksDefinition{ProcessCheck: &ProcessCheck{
Processes: []Process{{LinuxPath: "/usr/bin/foo"}},
}})
files := []nbpeer.File{{Path: "/usr/bin/foo", ProcessIsRunning: true}}
diff := diffFrom(
nbpeer.PeerSystemMeta{GoOS: "linux", Files: files},
nbpeer.PeerSystemMeta{GoOS: "windows", Files: files},
nbpeer.Location{}, nbpeer.Location{},
)
assert.True(t, AffectsPosture(context.Background(), diff, c))
}
func TestAffectsPosture_Process_UnrelatedFileChange(t *testing.T) {
// A tracked process stays running while an unrelated file is added: the verdict does
// not move, so posture is not affected.
c := checks(ChecksDefinition{ProcessCheck: &ProcessCheck{
Processes: []Process{{LinuxPath: "/usr/bin/foo"}},
}})
diff := diffFrom(
nbpeer.PeerSystemMeta{GoOS: "linux", Files: []nbpeer.File{
{Path: "/usr/bin/foo", ProcessIsRunning: true},
}},
nbpeer.PeerSystemMeta{GoOS: "linux", Files: []nbpeer.File{
{Path: "/usr/bin/foo", ProcessIsRunning: true},
{Path: "/usr/bin/bar", ProcessIsRunning: true},
}},
nbpeer.Location{}, nbpeer.Location{},
)
assert.False(t, AffectsPosture(context.Background(), diff, c))
}
func TestAffectsPosture_GeoLocation(t *testing.T) {
c := checks(ChecksDefinition{GeoLocationCheck: &GeoLocationCheck{
Action: CheckActionAllow,
Locations: []Location{{CountryCode: "DE"}},
}})
// Moving within allowed countries keeps the verdict; moving out flips it.
stayAllowed := diffFrom(
nbpeer.PeerSystemMeta{}, nbpeer.PeerSystemMeta{},
nbpeer.Location{CountryCode: "DE", CityName: "Berlin"},
nbpeer.Location{CountryCode: "DE", CityName: "Munich"},
)
assert.False(t, AffectsPosture(context.Background(), stayAllowed, c))
moveOut := diffFrom(
nbpeer.PeerSystemMeta{}, nbpeer.PeerSystemMeta{},
nbpeer.Location{CountryCode: "DE"},
nbpeer.Location{CountryCode: "FR"},
)
assert.True(t, AffectsPosture(context.Background(), moveOut, c))
}
func TestAffectsPosture_PeerNetworkRange_ConnectionIP(t *testing.T) {
// The check reads the connection IP. Moving out of the allowed range flips the verdict;
// moving within it does not.
_, allowed, _ := net.ParseCIDR("10.0.0.0/8")
c := checks(ChecksDefinition{PeerNetworkRangeCheck: &PeerNetworkRangeCheck{
Action: CheckActionAllow,
Ranges: []netip.Prefix{netip.MustParsePrefix(allowed.String())},
}})
movesOutOfRange := diffFrom(
nbpeer.PeerSystemMeta{}, nbpeer.PeerSystemMeta{},
nbpeer.Location{ConnectionIP: net.ParseIP("10.1.2.3")},
nbpeer.Location{ConnectionIP: net.ParseIP("8.8.8.8")},
)
assert.True(t, AffectsPosture(context.Background(), movesOutOfRange, c))
staysInRange := diffFrom(
nbpeer.PeerSystemMeta{}, nbpeer.PeerSystemMeta{},
nbpeer.Location{ConnectionIP: net.ParseIP("10.1.2.3")},
nbpeer.Location{ConnectionIP: net.ParseIP("10.9.9.9")},
)
assert.False(t, AffectsPosture(context.Background(), staysInRange, c))
}
func TestAffectsPosture_IrrelevantFieldChange(t *testing.T) {
// Hostname changes but no check reads it: not affected even with checks present.
c := checks(ChecksDefinition{
NBVersionCheck: &NBVersionCheck{MinVersion: "1.0.0"},
GeoLocationCheck: &GeoLocationCheck{Action: CheckActionAllow, Locations: []Location{{CountryCode: "DE"}}},
})
diff := diffFrom(
nbpeer.PeerSystemMeta{Hostname: "old", WtVersion: "1.5.0"},
nbpeer.PeerSystemMeta{Hostname: "new", WtVersion: "1.5.0"},
nbpeer.Location{CountryCode: "DE"}, nbpeer.Location{CountryCode: "DE"},
)
assert.False(t, AffectsPosture(context.Background(), diff, c))
}
func TestAffectsPosture_NoChecks(t *testing.T) {
diff := diffFrom(
nbpeer.PeerSystemMeta{WtVersion: "1.0.0"},
nbpeer.PeerSystemMeta{WtVersion: "2.0.0"},
nbpeer.Location{}, nbpeer.Location{},
)
assert.False(t, AffectsPosture(context.Background(), diff, nil))
}

View File

@@ -7,7 +7,6 @@ import (
"regexp"
"github.com/hashicorp/go-version"
log "github.com/sirupsen/logrus"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/shared/management/http/api"
@@ -53,46 +52,34 @@ type Checks struct {
Checks ChecksDefinition `gorm:"serializer:json"`
}
// AffectsPosture reports whether the change in diff flips the verdict of any check. It
// replays each check against the peer's old and new state and compares verdicts, so a
// change that moves a field but stays the right side of a threshold (e.g. a kernel bump
// still above the minimum) does not force a re-evaluation. See verdictChanged for how an
// evaluation error counts.
func AffectsPosture(ctx context.Context, diff *nbpeer.MetaDiff, checks []*Checks) bool {
// AffectsPosture reports whether the peer metadata changes described by diff can
// alter the outcome of any of the given posture checks. It maps each check kind to
// the metadata fields it inspects, so an unrelated change (e.g. a hostname update)
// does not force a posture re-evaluation.
func AffectsPosture(diff *nbpeer.MetaDiff, checks []*Checks) bool {
if diff == nil {
return false
}
oldPeer := nbpeer.Peer{Meta: diff.OldMeta, Location: diff.OldLocation}
newPeer := nbpeer.Peer{Meta: diff.NewMeta, Location: diff.NewLocation}
for _, c := range checks {
for _, check := range c.GetChecks() {
if verdictChanged(ctx, check, oldPeer, newPeer) {
return true
}
if c.Checks.ProcessCheck != nil && diff.Files {
return true
}
if c.Checks.OSVersionCheck != nil && (diff.OSVersion || diff.OS || diff.KernelVersion) {
return true
}
if c.Checks.NBVersionCheck != nil && diff.WtVersion {
return true
}
if c.Checks.GeoLocationCheck != nil && diff.LocationChanged {
return true
}
if c.Checks.PeerNetworkRangeCheck != nil && diff.NetworkAddresses {
return true
}
}
return false
}
// verdictChanged replays check against old and new state and reports whether the verdict
// differs. Like callers, it treats an evaluation error as deny: two errors are the same
// verdict (no change), an error on one side only is a flip.
func verdictChanged(ctx context.Context, check Check, oldPeer, newPeer nbpeer.Peer) bool {
oldPass, oldErr := check.Check(ctx, oldPeer)
newPass, newErr := check.Check(ctx, newPeer)
oldVerdict := oldPass && oldErr == nil
newVerdict := newPass && newErr == nil
changed := oldVerdict != newVerdict
log.WithContext(ctx).Tracef("posture check %s replay: verdict %t -> %t (changed=%t), errs: %v -> %v",
check.Name(), oldVerdict, newVerdict, changed, oldErr, newErr)
return changed
}
// ChecksDefinition contains definition of actual check
type ChecksDefinition struct {
NBVersionCheck *NBVersionCheck `json:",omitempty"`