Compare commits

..

8 Commits

Author SHA1 Message Date
pascal
62ffa08744 split networkIDs to check 2026-06-24 22:39:35 +02:00
Dmitri Dolguikh
d8e7f2e9e6 a couple of fixes
Signed-off-by: Dmitri Dolguikh <dmitri.external@netbird.io>
2026-06-24 15:40:44 +02:00
Dmitri Dolguikh
1205641b44 fixed test
Signed-off-by: Dmitri Dolguikh <dmitri.external@netbird.io>
2026-06-24 15:17:41 +02:00
Dmitri Dolguikh
56e8215ebe updated 'resource-routing-bridge/router-peer-change refreshes policy sources' test to expect router peer among changed peer ids
Signed-off-by: Dmitri Dolguikh <dmitri.external@netbird.io>
2026-06-24 14:29:23 +02:00
Dmitri Dolguikh
9b768d1773 fixed a bug in collectFromPolicies
Signed-off-by: Dmitri Dolguikh <dmitri.external@netbird.io>
2026-06-24 14:10:37 +02:00
Dmitri Dolguikh
33954ea15e fixing tests + adding tests
Signed-off-by: Dmitri Dolguikh <dmitri.external@netbird.io>
2026-06-24 13:07:53 +02:00
Dmitri Dolguikh
4c4434a871 fixed a few tests
Signed-off-by: Dmitri Dolguikh <dmitri.external@netbird.io>
2026-06-24 09:48:45 +02:00
Dmitri Dolguikh
7873f337df when collecting group and peer IDs from policies, do so directionally
Signed-off-by: Dmitri Dolguikh <dmitri.external@netbird.io>
2026-06-23 18:39:53 +02:00
7 changed files with 205 additions and 390 deletions

View File

@@ -51,20 +51,13 @@ type cachedRecord struct {
}
// Resolver caches critical NetBird infrastructure domains.
// records, refreshing, failedResolves, mgmtDomain and serverDomains are all
// guarded by mutex.
// records, refreshing, mgmtDomain and serverDomains are all guarded by mutex.
type Resolver struct {
records map[dns.Question]*cachedRecord
mgmtDomain *domain.Domain
serverDomains *dnsconfig.ServerDomains
mutex sync.RWMutex
// failedResolves records the last failed initial resolve per domain so a
// domain that never resolves isn't retried on every server-domains update
// until refreshBackoff elapses. Entries are cleared on success and pruned
// to the current server-domains set.
failedResolves map[domain.Domain]time.Time
chain ChainResolver
chainMaxPriority int
refreshGroup singleflight.Group
@@ -83,10 +76,9 @@ type Resolver struct {
// NewResolver creates a new management domains cache resolver.
func NewResolver() *Resolver {
return &Resolver{
records: make(map[dns.Question]*cachedRecord),
refreshing: make(map[dns.Question]*atomic.Bool),
failedResolves: make(map[domain.Domain]time.Time),
cacheTTL: resolveCacheTTL(),
records: make(map[dns.Question]*cachedRecord),
refreshing: make(map[dns.Question]*atomic.Bool),
cacheTTL: resolveCacheTTL(),
}
}
@@ -181,9 +173,7 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
// AddDomain resolves a domain and stores its A/AAAA records in the cache.
// A family that resolves NODATA (nil err, zero records) evicts any stale
// entry for that qtype. When one family hard-errors while the other succeeds,
// the resolved family is still cached but AddDomain returns an error so the
// caller retries the incomplete resolve rather than treating it as complete.
// entry for that qtype.
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
@@ -213,10 +203,6 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
log.Debugf("added/updated domain=%s with %d A records and %d AAAA records",
d.SafeString(), len(aRecords), len(aaaaRecords))
if errA != nil || errAAAA != nil {
return fmt.Errorf("resolve %s: incomplete, a family failed: %w", d.SafeString(), errors.Join(errA, errAAAA))
}
return nil
}
@@ -476,7 +462,6 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error {
delete(m.records, qAAAA)
delete(m.refreshing, qA)
delete(m.refreshing, qAAAA)
delete(m.failedResolves, d)
log.Debugf("removed domain=%s from cache", d.SafeString())
return nil
@@ -520,7 +505,6 @@ func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dn
allDomains := m.extractDomainsFromServerDomains(updatedServerDomains)
currentDomains := m.GetCachedDomains()
removedDomains = m.removeStaleDomains(currentDomains, allDomains)
m.pruneFailedResolves(allDomains)
}
m.addNewDomains(ctx, newDomains)
@@ -593,85 +577,13 @@ func (m *Resolver) isManagementDomain(domain domain.Domain) bool {
return m.mgmtDomain != nil && domain == *m.mgmtDomain
}
// addNewDomains resolves and caches domains that are not yet in the cache,
// running the lookups concurrently. Domains already cached are skipped and left
// to the stale-while-revalidate refresh path, so a sync never re-resolves them
// synchronously: once NetBird owns the OS resolver the resolve runs through the
// handler chain and would otherwise dial the managed upstreams under the engine
// sync lock on every update.
// addNewDomains resolves and caches all domains from the update
func (m *Resolver) addNewDomains(ctx context.Context, newDomains domain.List) {
var wg sync.WaitGroup
seen := make(map[domain.Domain]struct{}, len(newDomains))
for _, newDomain := range newDomains {
if _, dup := seen[newDomain]; dup {
continue
}
seen[newDomain] = struct{}{}
if !m.needsResolve(newDomain) {
continue
}
wg.Add(1)
go func(d domain.Domain) {
defer wg.Done()
if err := m.AddDomain(ctx, d); err != nil {
m.markResolveFailed(d)
log.Warnf("failed to add/update domain=%s: %v", d.SafeString(), err)
return
}
m.clearResolveFailed(d)
log.Debugf("added/updated management cache domain=%s", d.SafeString())
}(newDomain)
}
wg.Wait()
}
// needsResolve reports whether d should be resolved now. A recent failed or
// incomplete resolve gates retries on the backoff even when one family is
// already cached, so a transiently-failed family is retried instead of being
// treated as fully resolved. Otherwise a domain with any cached record is left
// to the stale-while-revalidate refresh path.
func (m *Resolver) needsResolve(d domain.Domain) bool {
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
m.mutex.RLock()
defer m.mutex.RUnlock()
if failedAt, ok := m.failedResolves[d]; ok {
return time.Since(failedAt) >= refreshBackoff
}
for _, qtype := range []uint16{dns.TypeA, dns.TypeAAAA} {
q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET}
if _, ok := m.records[q]; ok {
return false
}
}
return true
}
func (m *Resolver) markResolveFailed(d domain.Domain) {
m.mutex.Lock()
m.failedResolves[d] = time.Now()
m.mutex.Unlock()
}
func (m *Resolver) clearResolveFailed(d domain.Domain) {
m.mutex.Lock()
delete(m.failedResolves, d)
m.mutex.Unlock()
}
// pruneFailedResolves drops failure markers for domains no longer present in
// the server-domains set, keeping the map bounded to the current set (a
// failed-only domain has no cached record, so RemoveDomain never sees it).
func (m *Resolver) pruneFailedResolves(domains domain.List) {
m.mutex.Lock()
defer m.mutex.Unlock()
for d := range m.failedResolves {
if !slices.Contains(domains, d) {
delete(m.failedResolves, d)
if err := m.AddDomain(ctx, newDomain); err != nil {
log.Warnf("failed to add/update domain=%s: %v", newDomain.SafeString(), err)
} else {
log.Debugf("added/updated management cache domain=%s", newDomain.SafeString())
}
}
}

View File

@@ -21,7 +21,6 @@ type fakeChain struct {
mu sync.Mutex
calls map[string]int
answers map[string][]dns.RR
qErr map[string]error
err error
hasRoot bool
onLookup func()
@@ -31,7 +30,6 @@ func newFakeChain() *fakeChain {
return &fakeChain{
calls: map[string]int{},
answers: map[string][]dns.RR{},
qErr: map[string]error{},
hasRoot: true,
}
}
@@ -49,9 +47,6 @@ func (f *fakeChain) ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriori
f.calls[key]++
answers := f.answers[key]
err := f.err
if err == nil {
err = f.qErr[key]
}
onLookup := f.onLookup
f.mu.Unlock()
@@ -80,12 +75,6 @@ func (f *fakeChain) setAnswer(name string, qtype uint16, ip string) {
}
}
func (f *fakeChain) setErr(name string, qtype uint16, err error) {
f.mu.Lock()
defer f.mu.Unlock()
f.qErr[name+"|"+dns.TypeToString[qtype]] = err
}
func (f *fakeChain) callCount(name string, qtype uint16) int {
f.mu.Lock()
defer f.mu.Unlock()

View File

@@ -1,183 +0,0 @@
package mgmt
import (
"context"
"errors"
"sync/atomic"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/shared/management/domain"
)
// A domain already in the cache must not be re-resolved on a subsequent server
// domains update; it is left to the stale-while-revalidate refresh path.
func TestResolver_UpdateFromServerDomains_SkipsCached(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("signal.example.com.", dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
sd := dnsconfig.ServerDomains{Signal: domain.Domain("signal.example.com")}
_, err := r.UpdateFromServerDomains(context.Background(), sd)
require.NoError(t, err)
require.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
"first update must resolve the domain")
_, err = r.UpdateFromServerDomains(context.Background(), sd)
require.NoError(t, err)
assert.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
"cached domain must not be re-resolved on a subsequent update")
}
// New domains in a single update must resolve concurrently rather than serially.
func TestResolver_AddNewDomains_ResolvesConcurrently(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
var inflight, maxInflight atomic.Int32
chain.onLookup = func() {
n := inflight.Add(1)
for {
old := maxInflight.Load()
if n <= old || maxInflight.CompareAndSwap(old, n) {
break
}
}
time.Sleep(50 * time.Millisecond)
inflight.Add(-1)
}
relays := []domain.Domain{"a.example.com", "b.example.com", "c.example.com", "d.example.com"}
for _, d := range relays {
chain.setAnswer(dns.Fqdn(string(d)), dns.TypeA, "10.0.0.2")
}
r.SetChainResolver(chain, 50)
start := time.Now()
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: relays})
require.NoError(t, err)
elapsed := time.Since(start)
assert.GreaterOrEqual(t, int(maxInflight.Load()), 2, "domains must resolve concurrently")
// Serial resolution of 4 domains would take at least 4*50ms; concurrent is far less.
assert.Less(t, elapsed, 300*time.Millisecond, "resolution should not be serial")
}
// A domain that fails to resolve must not be retried on every update; the
// failure backoff suppresses re-resolution until it expires.
func TestResolver_UpdateFromServerDomains_BacksOffFailures(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.err = errors.New("resolve boom")
r.SetChainResolver(chain, 50)
sd := dnsconfig.ServerDomains{Signal: domain.Domain("signal.example.com")}
_, err := r.UpdateFromServerDomains(context.Background(), sd)
require.NoError(t, err)
require.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
"first update must attempt the resolve")
_, err = r.UpdateFromServerDomains(context.Background(), sd)
require.NoError(t, err)
assert.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
"failed resolve must back off and not retry on the next update")
}
// A domain listed under more than one server-domain type (e.g. STUN and TURN on
// the same host) must be resolved once per update, not once per occurrence.
func TestResolver_AddNewDomains_DedupesDuplicateDomains(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("dup.example.com.", dns.TypeA, "10.0.0.9")
r.SetChainResolver(chain, 50)
sd := dnsconfig.ServerDomains{
Stuns: []domain.Domain{"dup.example.com"},
Turns: []domain.Domain{"dup.example.com"},
}
_, err := r.UpdateFromServerDomains(context.Background(), sd)
require.NoError(t, err)
assert.Equal(t, 1, chain.callCount("dup.example.com.", dns.TypeA),
"a domain appearing under multiple server-domain types must resolve once")
}
// A failure marker must be dropped once its domain leaves the server-domains set
// so the map stays bounded to the current set.
func TestResolver_UpdateFromServerDomains_PrunesFailedResolves(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.err = errors.New("resolve boom")
r.SetChainResolver(chain, 50)
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Signal: domain.Domain("gone.example.com")})
require.NoError(t, err)
r.mutex.RLock()
_, marked := r.failedResolves[domain.Domain("gone.example.com")]
r.mutex.RUnlock()
require.True(t, marked, "failed resolve must be recorded")
_, err = r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Signal: domain.Domain("other.example.com")})
require.NoError(t, err)
r.mutex.RLock()
_, stillMarked := r.failedResolves[domain.Domain("gone.example.com")]
r.mutex.RUnlock()
assert.False(t, stillMarked, "failure marker for a domain no longer in the set must be pruned")
}
// When one family hard-errors while the other resolves, the domain is cached
// for the working family but recorded as incomplete so the failed family is
// retried under backoff instead of being treated as fully resolved forever.
func TestResolver_AddNewDomains_RetriesPartialFamilyFailure(t *testing.T) {
d := domain.Domain("relay.example.com")
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("relay.example.com.", dns.TypeA, "10.0.0.2")
chain.setErr("relay.example.com.", dns.TypeAAAA, errors.New("servfail"))
r.SetChainResolver(chain, 50)
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: []domain.Domain{d}})
require.NoError(t, err)
r.mutex.RLock()
_, aCached := r.records[dns.Question{Name: "relay.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}]
_, marked := r.failedResolves[d]
r.mutex.RUnlock()
require.True(t, aCached, "the working family must still be cached")
require.True(t, marked, "a partial failure must be recorded so the failed family is retried")
assert.False(t, r.needsResolve(d), "within the backoff window the domain is not retried")
r.mutex.Lock()
r.failedResolves[d] = time.Now().Add(-2 * refreshBackoff)
r.mutex.Unlock()
assert.True(t, r.needsResolve(d), "after the backoff elapses the domain is retried to pick up the missing family")
}
// A family that returns NODATA (legitimately absent, e.g. an IPv4-only host) is
// not a failure: the domain must not be marked for retry, otherwise it would be
// re-resolved on every sync.
func TestResolver_AddNewDomains_NodataIsNotFailure(t *testing.T) {
d := domain.Domain("v4only.example.com")
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("v4only.example.com.", dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: []domain.Domain{d}})
require.NoError(t, err)
r.mutex.RLock()
_, marked := r.failedResolves[d]
r.mutex.RUnlock()
assert.False(t, marked, "a NODATA family must not be recorded as a failure")
assert.False(t, r.needsResolve(d), "an IPv4-only host must not be re-resolved on later syncs")
}

View File

@@ -41,7 +41,7 @@ func TestAffectedPeers_DependencyCoverageMatrix(t *testing.T) {
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
return affectedpeers.Change{ChangedPeerIDs: []string{s.routerPeerID}},
[]string{s.sourcePeerID}, []string{s.unrelatedPeerID}
[]string{s.sourcePeerID, s.routerPeerID}, []string{s.unrelatedPeerID}
},
},
{
@@ -106,12 +106,8 @@ func TestAffectedPeers_DependencyCoverageMatrix(t *testing.T) {
change, mustContain, mustExclude := r.build(t, s, ctx)
affected := resolveAffected(t, s.manager.Store, s.accountID, change)
for _, id := range mustContain {
assert.Contains(t, affected, id, "expected peer to be affected")
}
for _, id := range mustExclude {
assert.NotContains(t, affected, id, "peer must not be affected")
}
assert.ElementsMatch(t, affected, mustContain, "expected peer to be affected")
assert.NotContains(t, affected, mustExclude, "peer must not be affected")
})
}
}

View File

@@ -96,33 +96,54 @@ func affectedGroupID(i int) string { return fmt.Sprintf("affected-grp-%d", i)
func affectedGroupName(i int) string { return fmt.Sprintf("AffectedGroup%d", i) }
func TestCollectGroupChange_PolicyLinked(t *testing.T) {
manager, s, accountID, _, groupIDs := setupAffectedPeersTest(t)
manager, s, accountID, peerIDs, groupIDs := setupAffectedPeersTest(t)
ctx := context.Background()
_, err := manager.SavePolicy(ctx, accountID, userID, &types.Policy{
Enabled: true,
Rules: []*types.PolicyRule{
{
Enabled: true,
Sources: []string{groupIDs[0]},
Destinations: []string{groupIDs[1]},
Bidirectional: true,
Action: types.PolicyTrafficActionAccept,
Enabled: true,
Sources: []string{groupIDs[0]},
Destinations: []string{groupIDs[1]},
SourceResource: types.Resource{ID: peerIDs[0], Type: types.ResourceTypePeer},
DestinationResource: types.Resource{ID: peerIDs[1], Type: types.ResourceTypePeer},
Bidirectional: true,
Action: types.PolicyTrafficActionAccept,
},
{
Enabled: true,
Sources: []string{groupIDs[0]},
Destinations: []string{groupIDs[1]},
SourceResource: types.Resource{ID: peerIDs[2], Type: types.ResourceTypeHost},
DestinationResource: types.Resource{ID: peerIDs[3], Type: types.ResourceTypeHost},
Bidirectional: true,
Action: types.PolicyTrafficActionAccept,
},
{
Enabled: true,
Sources: []string{groupIDs[0]},
Destinations: []string{groupIDs[1]},
SourceResource: types.Resource{ID: "", Type: types.ResourceTypePeer},
DestinationResource: types.Resource{ID: "", Type: types.ResourceTypePeer},
Bidirectional: true,
Action: types.PolicyTrafficActionAccept,
},
},
}, true)
require.NoError(t, err)
groups, _ := collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[0]})
assert.Contains(t, groups, groupIDs[0])
assert.Contains(t, groups, groupIDs[1])
groups, directPeers := collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[0]})
assert.ElementsMatch(t, groups, []string{groupIDs[0], groupIDs[1]})
assert.ElementsMatch(t, directPeers, []string{peerIDs[1]})
groups, _ = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[1]})
assert.Contains(t, groups, groupIDs[0])
assert.Contains(t, groups, groupIDs[1])
groups, directPeers = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[1]})
assert.ElementsMatch(t, groups, []string{groupIDs[0], groupIDs[1]})
assert.ElementsMatch(t, directPeers, []string{peerIDs[0]})
groups, _ = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[2]})
groups, directPeers = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[2]})
assert.Empty(t, groups)
assert.Empty(t, directPeers)
}
func TestCollectGroupChange_PolicyWithDirectPeerResource(t *testing.T) {
@@ -133,20 +154,44 @@ func TestCollectGroupChange_PolicyWithDirectPeerResource(t *testing.T) {
Enabled: true,
Rules: []*types.PolicyRule{
{
Enabled: true,
Sources: []string{groupIDs[0]},
SourceResource: types.Resource{ID: peerIDs[3], Type: types.ResourceTypePeer},
Destinations: []string{groupIDs[1]},
Action: types.PolicyTrafficActionAccept,
Enabled: true,
Sources: []string{groupIDs[0]},
SourceResource: types.Resource{ID: peerIDs[3], Type: types.ResourceTypePeer},
DestinationResource: types.Resource{ID: peerIDs[4], Type: types.ResourceTypePeer},
Destinations: []string{groupIDs[1]},
Action: types.PolicyTrafficActionAccept,
},
{
Enabled: true,
Sources: []string{groupIDs[0]},
SourceResource: types.Resource{ID: peerIDs[1], Type: types.ResourceTypeHost},
DestinationResource: types.Resource{ID: peerIDs[2], Type: types.ResourceTypeHost},
Destinations: []string{groupIDs[1]},
Action: types.PolicyTrafficActionAccept,
},
{
Enabled: true,
Sources: []string{groupIDs[0]},
SourceResource: types.Resource{ID: "", Type: types.ResourceTypePeer},
DestinationResource: types.Resource{ID: "", Type: types.ResourceTypePeer},
Destinations: []string{groupIDs[1]},
Action: types.PolicyTrafficActionAccept,
},
},
}, true)
require.NoError(t, err)
groups, directPeers := collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[0]})
assert.Contains(t, groups, groupIDs[0])
assert.Contains(t, groups, groupIDs[1])
assert.Contains(t, directPeers, peerIDs[3])
assert.ElementsMatch(t, groups, []string{groupIDs[0], groupIDs[1]})
assert.ElementsMatch(t, directPeers, []string{peerIDs[4]})
groups, directPeers = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[1]})
assert.ElementsMatch(t, groups, []string{groupIDs[0], groupIDs[1]})
assert.ElementsMatch(t, directPeers, []string{peerIDs[3]})
groups, directPeers = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[2]})
assert.Empty(t, groups)
assert.Empty(t, directPeers)
}
func TestCollectGroupChange_PolicyWithNonPeerResource_NoDirectPeers(t *testing.T) {
@@ -168,8 +213,7 @@ func TestCollectGroupChange_PolicyWithNonPeerResource_NoDirectPeers(t *testing.T
require.NoError(t, err)
groups, directPeers := collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[0]})
assert.Contains(t, groups, groupIDs[0])
assert.Contains(t, groups, groupIDs[1])
assert.ElementsMatch(t, groups, []string{groupIDs[0], groupIDs[1]})
assert.Empty(t, directPeers, "non-peer resources should not produce direct peer IDs")
}
@@ -373,17 +417,11 @@ func TestCollectGroupChange_MultipleEntities(t *testing.T) {
require.NoError(t, err)
groups, directPeers := collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[0]})
assert.Contains(t, groups, groupIDs[0])
assert.Contains(t, groups, groupIDs[1])
assert.NotContains(t, groups, groupIDs[2])
assert.NotContains(t, groups, groupIDs[3])
assert.ElementsMatch(t, groups, []string{groupIDs[0], groupIDs[1]})
assert.Empty(t, directPeers)
groups, directPeers = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[3]})
assert.Contains(t, groups, groupIDs[2])
assert.Contains(t, groups, groupIDs[3])
assert.NotContains(t, groups, groupIDs[0])
assert.NotContains(t, groups, groupIDs[1])
assert.ElementsMatch(t, groups, []string{groupIDs[2], groupIDs[3]})
assert.Empty(t, directPeers)
}
@@ -474,7 +512,7 @@ func TestResolveAffectedPeers_PolicyThreeGroups(t *testing.T) {
require.NoError(t, err)
result := manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[0]})
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[1], peerIDs[2]}, result)
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[2]}, result)
}
func TestResolveAffectedPeers_RoutePeerGroups(t *testing.T) {
@@ -659,9 +697,13 @@ func TestResolveAffectedPeers_PeerInMultipleGroups(t *testing.T) {
}, true)
require.NoError(t, err)
// peer0 is in group0 AND group1, so both policies apply
// peer0 is in group0 AND group1, so both policies apply. A peer change folds
// only the changed peer plus the opposite side of each rule: group2 (peer2) via
// the group0 policy and group3 (peer3) via the group1 policy. peer1, a co-member
// of group1, is a sibling of the changed peer and must NOT refresh.
result := manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[0]})
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[1], peerIDs[2], peerIDs[3]}, result)
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[2], peerIDs[3]}, result)
assert.NotContains(t, result, peerIDs[1], "co-member of the changed peer's group must not refresh")
}
func TestResolveAffectedPeers_MultipleChangedPeers(t *testing.T) {
@@ -697,7 +739,7 @@ func TestResolveAffectedPeers_MultipleChangedPeers(t *testing.T) {
require.NoError(t, err)
result := manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[0], peerIDs[2]})
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[1], peerIDs[2], peerIDs[3]}, result)
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[2], peerIDs[1], peerIDs[3]}, result)
}
func TestResolveAffectedPeers_SharedGroupAcrossPolicyAndRoute(t *testing.T) {

View File

@@ -221,15 +221,17 @@ func Collect(ctx context.Context, s store.Store, accountID string, c Change) (gr
func newResolver(ctx context.Context, snap *Snapshot, accountID string, c Change) *resolver {
r := &resolver{
ctx: ctx,
snap: snap,
accountID: accountID,
change: c,
changedGroupSet: toSet(c.ChangedGroupIDs),
changedPeerSet: toSet(c.ChangedPeerIDs),
groupSet: make(map[string]struct{}),
peerSet: make(map[string]struct{}),
networkIDs: make(map[string]struct{}),
ctx: ctx,
snap: snap,
accountID: accountID,
change: c,
changedGroupSet: toSet(c.ChangedGroupIDs),
changedPeerSet: toSet(c.ChangedPeerIDs),
groupSet: make(map[string]struct{}),
peerSet: make(map[string]struct{}),
networkIDs: make(map[string]struct{}),
sourceOriginatedNetworkIDs: make(map[string]struct{}),
changedGroupIDs: toSet(c.ChangedGroupIDs),
}
// Resolve each changed peer to its groups here so callers pass only ChangedPeerIDs.
r.seedChangedGroupsFromPeers()
@@ -239,6 +241,9 @@ func newResolver(ctx context.Context, snap *Snapshot, accountID string, c Change
// seedChangedGroupsFromPeers adds each changed peer's groups to changedGroupSet so
// the group-driven walkers fire for memberships, not just direct peer references.
// These seeded groups are for MATCHING only — folding the changed entity's own
// side is gated on changedGroupIDs (the caller-reported groups), so a seeded group
// never folds its whole membership; only the changed peer itself folds in.
func (r *resolver) seedChangedGroupsFromPeers() {
if len(r.changedPeerSet) == 0 {
return
@@ -292,6 +297,18 @@ type resolver struct {
matchedPolicies []*types.Policy
networkIDs map[string]struct{}
// sourceOriginatedNetworkIDs are networks marked affected only because a
// source-side change targets a resource on them (bridgeSourceToRouters). Their
// routers must refresh, but the policy sources must not be folded back: a
// changed source propagates only to the opposite (router) side, never to its
// co-sources. Networks marked by a router/resource/network change are absent
// here and do fold sources, since the destination side itself changed.
sourceOriginatedNetworkIDs map[string]struct{}
// changedGroupIDs are the groups the caller reported as changed via
// Change.ChangedGroupIDs (NOT the peer-seeded ones in changedGroupSet). Only
// these fold their whole membership; a peer-seeded group folds the peer alone.
changedGroupIDs map[string]struct{}
}
func (r *resolver) policies() []*types.Policy { return r.snap.policies }
@@ -445,21 +462,88 @@ func (r *resolver) collectFromPostureChecks(postureCheckIDs []string) {
}
}
// collectFromPolicies folds, for every policy a changed group or peer touches:
// the opposite side of the matching rule, the changed entity's own side (the
// changed group itself, or the changed peer alone — never the changed side's
// sibling groups or co-members), and records the policy for the resource<->router
// bridge. A changed peer is mapped to its groups in changedGroupSet up front (see
// seedChangedGroupsFromPeers); changedGroupIDs holds only the caller-reported
// groups, so a peer-seeded group does not fold its whole membership.
func (r *resolver) collectFromPolicies() {
for _, policy := range r.policies() {
matchedByGroup := policyReferencesGroups(policy, r.changedGroupSet)
matchedByPeer := len(r.changedPeerSet) > 0 && policyReferencesDirectPeers(policy, r.changedPeerSet)
if !matchedByGroup && !matchedByPeer {
if !r.collectPolicyDirectional(policy) {
continue
}
log.WithContext(r.ctx).Tracef("collectFromPolicies: policy %s (%s) matched (byGroup=%t byPeer=%t) -> folding rule groups %v + direct peers",
policy.ID, policy.Name, matchedByGroup, matchedByPeer, policy.RuleGroups())
addAll(r.groupSet, policy.RuleGroups())
collectPolicyDirectPeers(policy, r.peerSet)
log.WithContext(r.ctx).Tracef("collectFromPolicies: policy %s (%s) matched directionally", policy.ID, policy.Name)
r.matchedPolicies = append(r.matchedPolicies, policy)
}
}
// collectPolicyDirectional folds one policy's affected groups/peers and reports
// whether it matched a changed group or peer at all (so the caller can record it
// for the bridge even when the opposite side is a resource, not a group).
func (r *resolver) collectPolicyDirectional(policy *types.Policy) bool {
matched := false
for _, rule := range policy.Rules {
matched = r.foldRuleSide(rule.Sources, rule.Destinations, rule.DestinationResource) || matched
matched = r.foldRuleSide(rule.Destinations, rule.Sources, rule.SourceResource) || matched
if isDirectPeerInSet(rule.SourceResource, r.changedPeerSet) {
r.peerSet[rule.SourceResource.ID] = struct{}{}
addAll(r.groupSet, rule.Destinations)
if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" {
r.peerSet[rule.DestinationResource.ID] = struct{}{}
}
matched = true
}
if isDirectPeerInSet(rule.DestinationResource, r.changedPeerSet) {
r.peerSet[rule.DestinationResource.ID] = struct{}{}
addAll(r.groupSet, rule.Sources)
if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" {
r.peerSet[rule.SourceResource.ID] = struct{}{}
}
matched = true
}
}
return matched
}
// foldRuleSide handles a changed group on `near` (Sources or Destinations): it
// folds the `far` (opposite) groups and far resource peer, the changed group(s)
// themselves (caller-reported groups only — not seeded ones, so a changed peer's
// group does not pull in its members), and the changed peers seeded from those
// groups (the peer alone). Returns whether the side matched.
func (r *resolver) foldRuleSide(near, far []string, farResource types.Resource) bool {
if !anyInSet(near, r.changedGroupSet) {
return false
}
addAll(r.groupSet, far)
if farResource.Type == types.ResourceTypePeer && farResource.ID != "" {
r.peerSet[farResource.ID] = struct{}{}
}
for _, gID := range near {
if _, ok := r.changedGroupIDs[gID]; ok {
r.groupSet[gID] = struct{}{} // changed group itself -> its members
}
r.foldChangedPeersInGroup(gID) // a changed peer in this group -> the peer alone
}
return true
}
// foldChangedPeersInGroup folds changed peers that belong to groupID directly into
// peerSet (the peer only, never its co-members).
func (r *resolver) foldChangedPeersInGroup(groupID string) {
if len(r.changedPeerSet) == 0 {
return
}
members := r.snap.groupPeers[groupID]
for pID := range r.changedPeerSet {
if _, ok := members[pID]; ok {
r.peerSet[pID] = struct{}{}
}
}
}
func (r *resolver) collectFromRoutes() {
for _, rt := range r.snap.routes {
matchedByGroup := anyInSet(rt.Groups, r.changedGroupSet) || anyInSet(rt.PeerGroups, r.changedGroupSet) || anyInSet(rt.AccessControlGroups, r.changedGroupSet)
@@ -588,6 +672,11 @@ func (r *resolver) bridgeSourceToRouters() {
log.WithContext(r.ctx).Tracef("bridgeSourceToRouters: targeted resources %v -> networks %v (their routers become affected via the router->source pass)",
setToSlice(resourceIDs), setToSlice(networkIDs))
for id := range networkIDs {
// Mark source-originated unless a router/resource/network change already
// marked this network directly (then it folds sources back).
if _, ok := r.networkIDs[id]; !ok {
r.sourceOriginatedNetworkIDs[id] = struct{}{}
}
r.networkIDs[id] = struct{}{}
}
}
@@ -602,11 +691,19 @@ func (r *resolver) bridgeRoutersToSources() {
r.foldRoutersOnNetworks(r.networkIDs)
// Sources are folded back only for networks the destination side itself changed
// (router/resource/network change). Networks reached only because a source-side
// change targets their resource must not refresh the policy's sources — the
// changed source propagates to the router side, not back to its co-sources.
resourceIDs := make(map[string]struct{})
for _, resource := range r.networkResources() {
if _, ok := r.networkIDs[resource.NetworkID]; ok {
resourceIDs[resource.ID] = struct{}{}
if _, ok := r.networkIDs[resource.NetworkID]; !ok {
continue
}
if _, sourceOriginated := r.sourceOriginatedNetworkIDs[resource.NetworkID]; sourceOriginated {
continue
}
resourceIDs[resource.ID] = struct{}{}
}
if len(resourceIDs) == 0 {
return
@@ -734,24 +831,6 @@ func collectPolicySources(policy *types.Policy, groupSet, peerSet map[string]str
}
}
func policyReferencesGroups(policy *types.Policy, groupSet map[string]struct{}) bool {
for _, rule := range policy.Rules {
if anyInSet(rule.Sources, groupSet) || anyInSet(rule.Destinations, groupSet) {
return true
}
}
return false
}
func policyReferencesDirectPeers(policy *types.Policy, changedSet map[string]struct{}) bool {
for _, rule := range policy.Rules {
if isDirectPeerInSet(rule.SourceResource, changedSet) || isDirectPeerInSet(rule.DestinationResource, changedSet) {
return true
}
}
return false
}
func policyReferencesPostureChecks(policy *types.Policy, ids map[string]struct{}) bool {
for _, id := range policy.SourcePostureChecks {
if _, ok := ids[id]; ok {

View File

@@ -80,26 +80,6 @@ func TestChangeIsEmpty(t *testing.T) {
assert.False(t, Change{PostureCheckIDs: []string{"pc"}}.isEmpty())
}
func TestPolicyReferencesGroups(t *testing.T) {
policy := &types.Policy{Rules: []*types.PolicyRule{{Sources: []string{"g1", "g2"}, Destinations: []string{"g3"}}}}
assert.True(t, policyReferencesGroups(policy, map[string]struct{}{"g1": {}}))
assert.True(t, policyReferencesGroups(policy, map[string]struct{}{"g3": {}}))
assert.False(t, policyReferencesGroups(policy, map[string]struct{}{"g4": {}}))
assert.False(t, policyReferencesGroups(policy, map[string]struct{}{}))
}
func TestPolicyReferencesDirectPeers(t *testing.T) {
policy := &types.Policy{Rules: []*types.PolicyRule{{
SourceResource: types.Resource{Type: types.ResourceTypePeer, ID: "p1"},
DestinationResource: types.Resource{Type: types.ResourceTypeHost, ID: "r1"},
}}}
assert.True(t, policyReferencesDirectPeers(policy, map[string]struct{}{"p1": {}}))
assert.False(t, policyReferencesDirectPeers(policy, map[string]struct{}{"r1": {}}))
assert.False(t, policyReferencesDirectPeers(policy, map[string]struct{}{"p2": {}}))
}
func TestPolicyReferencesPostureChecks(t *testing.T) {
policy := &types.Policy{SourcePostureChecks: []string{"pc1", "pc2"}}