mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-25 17:29:54 +00:00
Compare commits
30 Commits
t850
...
test/affec
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e048f221f0 | ||
|
|
4e8780e0c3 | ||
|
|
a1933f792a | ||
|
|
ccf1e965c3 | ||
|
|
38bd53eb71 | ||
|
|
845dd0c9bb | ||
|
|
2f0093a7e1 | ||
|
|
c5d26106f2 | ||
|
|
9746bc2a08 | ||
|
|
42867c7a59 | ||
|
|
d7740f9868 | ||
|
|
c2db940a8c | ||
|
|
62ffa08744 | ||
|
|
d8e7f2e9e6 | ||
|
|
1205641b44 | ||
|
|
de0cf0fc7a | ||
|
|
57f475c5a9 | ||
|
|
56e8215ebe | ||
|
|
5ae36cd260 | ||
|
|
9b768d1773 | ||
|
|
33954ea15e | ||
|
|
4c4434a871 | ||
|
|
9eadf50f4c | ||
|
|
be06016ad2 | ||
|
|
7873f337df | ||
|
|
17b2044596 | ||
|
|
07101c59ac | ||
|
|
51b6f6291b | ||
|
|
2ebf26006a | ||
|
|
211a26019a |
@@ -130,7 +130,7 @@ func debugConfigDump(cmd *cobra.Command, _ []string) error {
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
resp, err := client.GetConfig(cmd.Context(), &proto.GetConfigRequest{
|
||||
ProfileName: activeProf.Name,
|
||||
ProfileName: string(activeProf.ID),
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
@@ -51,13 +51,20 @@ type cachedRecord struct {
|
||||
}
|
||||
|
||||
// Resolver caches critical NetBird infrastructure domains.
|
||||
// records, refreshing, mgmtDomain and serverDomains are all guarded by mutex.
|
||||
// records, refreshing, failedResolves, mgmtDomain and serverDomains are all
|
||||
// guarded by mutex.
|
||||
type Resolver struct {
|
||||
records map[dns.Question]*cachedRecord
|
||||
mgmtDomain *domain.Domain
|
||||
serverDomains *dnsconfig.ServerDomains
|
||||
mutex sync.RWMutex
|
||||
|
||||
// failedResolves records the last failed initial resolve per domain so a
|
||||
// domain that never resolves isn't retried on every server-domains update
|
||||
// until refreshBackoff elapses. Entries are cleared on success and pruned
|
||||
// to the current server-domains set.
|
||||
failedResolves map[domain.Domain]time.Time
|
||||
|
||||
chain ChainResolver
|
||||
chainMaxPriority int
|
||||
refreshGroup singleflight.Group
|
||||
@@ -76,9 +83,10 @@ type Resolver struct {
|
||||
// NewResolver creates a new management domains cache resolver.
|
||||
func NewResolver() *Resolver {
|
||||
return &Resolver{
|
||||
records: make(map[dns.Question]*cachedRecord),
|
||||
refreshing: make(map[dns.Question]*atomic.Bool),
|
||||
cacheTTL: resolveCacheTTL(),
|
||||
records: make(map[dns.Question]*cachedRecord),
|
||||
refreshing: make(map[dns.Question]*atomic.Bool),
|
||||
failedResolves: make(map[domain.Domain]time.Time),
|
||||
cacheTTL: resolveCacheTTL(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -173,7 +181,9 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
|
||||
|
||||
// AddDomain resolves a domain and stores its A/AAAA records in the cache.
|
||||
// A family that resolves NODATA (nil err, zero records) evicts any stale
|
||||
// entry for that qtype.
|
||||
// entry for that qtype. When one family hard-errors while the other succeeds,
|
||||
// the resolved family is still cached but AddDomain returns an error so the
|
||||
// caller retries the incomplete resolve rather than treating it as complete.
|
||||
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
||||
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
|
||||
|
||||
@@ -203,6 +213,10 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
||||
log.Debugf("added/updated domain=%s with %d A records and %d AAAA records",
|
||||
d.SafeString(), len(aRecords), len(aaaaRecords))
|
||||
|
||||
if errA != nil || errAAAA != nil {
|
||||
return fmt.Errorf("resolve %s: incomplete, a family failed: %w", d.SafeString(), errors.Join(errA, errAAAA))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -462,6 +476,7 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error {
|
||||
delete(m.records, qAAAA)
|
||||
delete(m.refreshing, qA)
|
||||
delete(m.refreshing, qAAAA)
|
||||
delete(m.failedResolves, d)
|
||||
|
||||
log.Debugf("removed domain=%s from cache", d.SafeString())
|
||||
return nil
|
||||
@@ -505,6 +520,7 @@ func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dn
|
||||
allDomains := m.extractDomainsFromServerDomains(updatedServerDomains)
|
||||
currentDomains := m.GetCachedDomains()
|
||||
removedDomains = m.removeStaleDomains(currentDomains, allDomains)
|
||||
m.pruneFailedResolves(allDomains)
|
||||
}
|
||||
|
||||
m.addNewDomains(ctx, newDomains)
|
||||
@@ -577,13 +593,85 @@ func (m *Resolver) isManagementDomain(domain domain.Domain) bool {
|
||||
return m.mgmtDomain != nil && domain == *m.mgmtDomain
|
||||
}
|
||||
|
||||
// addNewDomains resolves and caches all domains from the update
|
||||
// addNewDomains resolves and caches domains that are not yet in the cache,
|
||||
// running the lookups concurrently. Domains already cached are skipped and left
|
||||
// to the stale-while-revalidate refresh path, so a sync never re-resolves them
|
||||
// synchronously: once NetBird owns the OS resolver the resolve runs through the
|
||||
// handler chain and would otherwise dial the managed upstreams under the engine
|
||||
// sync lock on every update.
|
||||
func (m *Resolver) addNewDomains(ctx context.Context, newDomains domain.List) {
|
||||
var wg sync.WaitGroup
|
||||
seen := make(map[domain.Domain]struct{}, len(newDomains))
|
||||
for _, newDomain := range newDomains {
|
||||
if err := m.AddDomain(ctx, newDomain); err != nil {
|
||||
log.Warnf("failed to add/update domain=%s: %v", newDomain.SafeString(), err)
|
||||
} else {
|
||||
log.Debugf("added/updated management cache domain=%s", newDomain.SafeString())
|
||||
if _, dup := seen[newDomain]; dup {
|
||||
continue
|
||||
}
|
||||
seen[newDomain] = struct{}{}
|
||||
|
||||
if !m.needsResolve(newDomain) {
|
||||
continue
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func(d domain.Domain) {
|
||||
defer wg.Done()
|
||||
if err := m.AddDomain(ctx, d); err != nil {
|
||||
m.markResolveFailed(d)
|
||||
log.Warnf("failed to add/update domain=%s: %v", d.SafeString(), err)
|
||||
return
|
||||
}
|
||||
m.clearResolveFailed(d)
|
||||
log.Debugf("added/updated management cache domain=%s", d.SafeString())
|
||||
}(newDomain)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// needsResolve reports whether d should be resolved now. A recent failed or
|
||||
// incomplete resolve gates retries on the backoff even when one family is
|
||||
// already cached, so a transiently-failed family is retried instead of being
|
||||
// treated as fully resolved. Otherwise a domain with any cached record is left
|
||||
// to the stale-while-revalidate refresh path.
|
||||
func (m *Resolver) needsResolve(d domain.Domain) bool {
|
||||
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
|
||||
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
if failedAt, ok := m.failedResolves[d]; ok {
|
||||
return time.Since(failedAt) >= refreshBackoff
|
||||
}
|
||||
|
||||
for _, qtype := range []uint16{dns.TypeA, dns.TypeAAAA} {
|
||||
q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET}
|
||||
if _, ok := m.records[q]; ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Resolver) markResolveFailed(d domain.Domain) {
|
||||
m.mutex.Lock()
|
||||
m.failedResolves[d] = time.Now()
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (m *Resolver) clearResolveFailed(d domain.Domain) {
|
||||
m.mutex.Lock()
|
||||
delete(m.failedResolves, d)
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
|
||||
// pruneFailedResolves drops failure markers for domains no longer present in
|
||||
// the server-domains set, keeping the map bounded to the current set (a
|
||||
// failed-only domain has no cached record, so RemoveDomain never sees it).
|
||||
func (m *Resolver) pruneFailedResolves(domains domain.List) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
for d := range m.failedResolves {
|
||||
if !slices.Contains(domains, d) {
|
||||
delete(m.failedResolves, d)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,6 +21,7 @@ type fakeChain struct {
|
||||
mu sync.Mutex
|
||||
calls map[string]int
|
||||
answers map[string][]dns.RR
|
||||
qErr map[string]error
|
||||
err error
|
||||
hasRoot bool
|
||||
onLookup func()
|
||||
@@ -30,6 +31,7 @@ func newFakeChain() *fakeChain {
|
||||
return &fakeChain{
|
||||
calls: map[string]int{},
|
||||
answers: map[string][]dns.RR{},
|
||||
qErr: map[string]error{},
|
||||
hasRoot: true,
|
||||
}
|
||||
}
|
||||
@@ -47,6 +49,9 @@ func (f *fakeChain) ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriori
|
||||
f.calls[key]++
|
||||
answers := f.answers[key]
|
||||
err := f.err
|
||||
if err == nil {
|
||||
err = f.qErr[key]
|
||||
}
|
||||
onLookup := f.onLookup
|
||||
f.mu.Unlock()
|
||||
|
||||
@@ -75,6 +80,12 @@ func (f *fakeChain) setAnswer(name string, qtype uint16, ip string) {
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeChain) setErr(name string, qtype uint16, err error) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.qErr[name+"|"+dns.TypeToString[qtype]] = err
|
||||
}
|
||||
|
||||
func (f *fakeChain) callCount(name string, qtype uint16) int {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
183
client/internal/dns/mgmt/mgmt_resolve_test.go
Normal file
183
client/internal/dns/mgmt/mgmt_resolve_test.go
Normal file
@@ -0,0 +1,183 @@
|
||||
package mgmt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
// A domain already in the cache must not be re-resolved on a subsequent server
|
||||
// domains update; it is left to the stale-while-revalidate refresh path.
|
||||
func TestResolver_UpdateFromServerDomains_SkipsCached(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("signal.example.com.", dns.TypeA, "10.0.0.2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
sd := dnsconfig.ServerDomains{Signal: domain.Domain("signal.example.com")}
|
||||
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), sd)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
|
||||
"first update must resolve the domain")
|
||||
|
||||
_, err = r.UpdateFromServerDomains(context.Background(), sd)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
|
||||
"cached domain must not be re-resolved on a subsequent update")
|
||||
}
|
||||
|
||||
// New domains in a single update must resolve concurrently rather than serially.
|
||||
func TestResolver_AddNewDomains_ResolvesConcurrently(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
|
||||
var inflight, maxInflight atomic.Int32
|
||||
chain.onLookup = func() {
|
||||
n := inflight.Add(1)
|
||||
for {
|
||||
old := maxInflight.Load()
|
||||
if n <= old || maxInflight.CompareAndSwap(old, n) {
|
||||
break
|
||||
}
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
inflight.Add(-1)
|
||||
}
|
||||
|
||||
relays := []domain.Domain{"a.example.com", "b.example.com", "c.example.com", "d.example.com"}
|
||||
for _, d := range relays {
|
||||
chain.setAnswer(dns.Fqdn(string(d)), dns.TypeA, "10.0.0.2")
|
||||
}
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
start := time.Now()
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: relays})
|
||||
require.NoError(t, err)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
assert.GreaterOrEqual(t, int(maxInflight.Load()), 2, "domains must resolve concurrently")
|
||||
// Serial resolution of 4 domains would take at least 4*50ms; concurrent is far less.
|
||||
assert.Less(t, elapsed, 300*time.Millisecond, "resolution should not be serial")
|
||||
}
|
||||
|
||||
// A domain that fails to resolve must not be retried on every update; the
|
||||
// failure backoff suppresses re-resolution until it expires.
|
||||
func TestResolver_UpdateFromServerDomains_BacksOffFailures(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.err = errors.New("resolve boom")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
sd := dnsconfig.ServerDomains{Signal: domain.Domain("signal.example.com")}
|
||||
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), sd)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
|
||||
"first update must attempt the resolve")
|
||||
|
||||
_, err = r.UpdateFromServerDomains(context.Background(), sd)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
|
||||
"failed resolve must back off and not retry on the next update")
|
||||
}
|
||||
|
||||
// A domain listed under more than one server-domain type (e.g. STUN and TURN on
|
||||
// the same host) must be resolved once per update, not once per occurrence.
|
||||
func TestResolver_AddNewDomains_DedupesDuplicateDomains(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("dup.example.com.", dns.TypeA, "10.0.0.9")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
sd := dnsconfig.ServerDomains{
|
||||
Stuns: []domain.Domain{"dup.example.com"},
|
||||
Turns: []domain.Domain{"dup.example.com"},
|
||||
}
|
||||
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), sd)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, chain.callCount("dup.example.com.", dns.TypeA),
|
||||
"a domain appearing under multiple server-domain types must resolve once")
|
||||
}
|
||||
|
||||
// A failure marker must be dropped once its domain leaves the server-domains set
|
||||
// so the map stays bounded to the current set.
|
||||
func TestResolver_UpdateFromServerDomains_PrunesFailedResolves(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.err = errors.New("resolve boom")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Signal: domain.Domain("gone.example.com")})
|
||||
require.NoError(t, err)
|
||||
r.mutex.RLock()
|
||||
_, marked := r.failedResolves[domain.Domain("gone.example.com")]
|
||||
r.mutex.RUnlock()
|
||||
require.True(t, marked, "failed resolve must be recorded")
|
||||
|
||||
_, err = r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Signal: domain.Domain("other.example.com")})
|
||||
require.NoError(t, err)
|
||||
r.mutex.RLock()
|
||||
_, stillMarked := r.failedResolves[domain.Domain("gone.example.com")]
|
||||
r.mutex.RUnlock()
|
||||
assert.False(t, stillMarked, "failure marker for a domain no longer in the set must be pruned")
|
||||
}
|
||||
|
||||
// When one family hard-errors while the other resolves, the domain is cached
|
||||
// for the working family but recorded as incomplete so the failed family is
|
||||
// retried under backoff instead of being treated as fully resolved forever.
|
||||
func TestResolver_AddNewDomains_RetriesPartialFamilyFailure(t *testing.T) {
|
||||
d := domain.Domain("relay.example.com")
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("relay.example.com.", dns.TypeA, "10.0.0.2")
|
||||
chain.setErr("relay.example.com.", dns.TypeAAAA, errors.New("servfail"))
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: []domain.Domain{d}})
|
||||
require.NoError(t, err)
|
||||
|
||||
r.mutex.RLock()
|
||||
_, aCached := r.records[dns.Question{Name: "relay.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}]
|
||||
_, marked := r.failedResolves[d]
|
||||
r.mutex.RUnlock()
|
||||
require.True(t, aCached, "the working family must still be cached")
|
||||
require.True(t, marked, "a partial failure must be recorded so the failed family is retried")
|
||||
|
||||
assert.False(t, r.needsResolve(d), "within the backoff window the domain is not retried")
|
||||
|
||||
r.mutex.Lock()
|
||||
r.failedResolves[d] = time.Now().Add(-2 * refreshBackoff)
|
||||
r.mutex.Unlock()
|
||||
assert.True(t, r.needsResolve(d), "after the backoff elapses the domain is retried to pick up the missing family")
|
||||
}
|
||||
|
||||
// A family that returns NODATA (legitimately absent, e.g. an IPv4-only host) is
|
||||
// not a failure: the domain must not be marked for retry, otherwise it would be
|
||||
// re-resolved on every sync.
|
||||
func TestResolver_AddNewDomains_NodataIsNotFailure(t *testing.T) {
|
||||
d := domain.Domain("v4only.example.com")
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("v4only.example.com.", dns.TypeA, "10.0.0.2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: []domain.Domain{d}})
|
||||
require.NoError(t, err)
|
||||
|
||||
r.mutex.RLock()
|
||||
_, marked := r.failedResolves[d]
|
||||
r.mutex.RUnlock()
|
||||
assert.False(t, marked, "a NODATA family must not be recorded as a failure")
|
||||
assert.False(t, r.needsResolve(d), "an IPv4-only host must not be re-resolved on later syncs")
|
||||
}
|
||||
@@ -82,12 +82,6 @@ const (
|
||||
PeerConnectionTimeoutMax = 45000 // ms
|
||||
PeerConnectionTimeoutMin = 30000 // ms
|
||||
disableAutoUpdate = "disabled"
|
||||
|
||||
// systemInfoTimeout bounds how long the sync loop waits for system info / posture
|
||||
// check gathering. The gathering runs uncancellable system calls (process scan,
|
||||
// exec, os.Stat); without this bound a single stuck call freezes handleSync, and
|
||||
// thus syncMsgMux, for as long as the call hangs (observed multi-minute freezes).
|
||||
systemInfoTimeout = 15 * time.Second
|
||||
)
|
||||
|
||||
var ErrResetConnection = fmt.Errorf("reset connection")
|
||||
@@ -1072,22 +1066,11 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
||||
}
|
||||
e.checks = checks
|
||||
|
||||
info, ok := system.GetInfoWithChecksTimeout(e.ctx, systemInfoTimeout, checks)
|
||||
if !ok {
|
||||
// Gathering timed out; skip the meta sync this cycle rather than blocking the
|
||||
// sync loop (and syncMsgMux) on a stuck system call. A later sync will retry.
|
||||
return nil
|
||||
info, err := system.GetInfoWithChecks(e.ctx, checks)
|
||||
if err != nil {
|
||||
log.Warnf("failed to get system info with checks: %v", err)
|
||||
info = system.GetInfo(e.ctx)
|
||||
}
|
||||
e.applyInfoFlags(info)
|
||||
|
||||
if err := e.mgmClient.SyncMeta(info); err != nil {
|
||||
return fmt.Errorf("could not sync meta: error %s", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyInfoFlags sets the engine's config-derived feature flags on the gathered system info.
|
||||
func (e *Engine) applyInfoFlags(info *system.Info) {
|
||||
info.SetFlags(
|
||||
e.config.RosenpassEnabled,
|
||||
e.config.RosenpassPermissive,
|
||||
@@ -1106,6 +1089,12 @@ func (e *Engine) applyInfoFlags(info *system.Info) {
|
||||
e.config.EnableSSHRemotePortForwarding,
|
||||
e.config.DisableSSHAuth,
|
||||
)
|
||||
|
||||
if err := e.mgmClient.SyncMeta(info); err != nil {
|
||||
log.Errorf("could not sync meta: error %s", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
@@ -1251,15 +1240,31 @@ func (e *Engine) receiveManagementEvents() {
|
||||
e.shutdownWg.Add(1)
|
||||
go func() {
|
||||
defer e.shutdownWg.Done()
|
||||
info, ok := system.GetInfoWithChecksTimeout(e.ctx, systemInfoTimeout, e.checks)
|
||||
if !ok {
|
||||
// Gathering timed out; connect the stream with base info so management
|
||||
// connectivity still comes up rather than blocking here.
|
||||
info, err := system.GetInfoWithChecks(e.ctx, e.checks)
|
||||
if err != nil {
|
||||
log.Warnf("failed to get system info with checks: %v", err)
|
||||
info = system.GetInfo(e.ctx)
|
||||
}
|
||||
e.applyInfoFlags(info)
|
||||
info.SetFlags(
|
||||
e.config.RosenpassEnabled,
|
||||
e.config.RosenpassPermissive,
|
||||
&e.config.ServerSSHAllowed,
|
||||
e.config.DisableClientRoutes,
|
||||
e.config.DisableServerRoutes,
|
||||
e.config.DisableDNS,
|
||||
e.config.DisableFirewall,
|
||||
e.config.BlockLANAccess,
|
||||
e.config.BlockInbound,
|
||||
e.config.DisableIPv6,
|
||||
e.config.LazyConnectionEnabled,
|
||||
e.config.EnableSSHRoot,
|
||||
e.config.EnableSSHSFTP,
|
||||
e.config.EnableSSHLocalPortForwarding,
|
||||
e.config.EnableSSHRemotePortForwarding,
|
||||
e.config.DisableSSHAuth,
|
||||
)
|
||||
|
||||
err := e.mgmClient.Sync(e.ctx, info, e.handleSync)
|
||||
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
|
||||
if err != nil {
|
||||
// happens if management is unavailable for a long time.
|
||||
// We want to cancel the operation of the whole client
|
||||
|
||||
@@ -2,10 +2,8 @@ package system
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc/metadata"
|
||||
@@ -156,7 +154,7 @@ func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, erro
|
||||
processCheckPaths = append(processCheckPaths, check.GetFiles()...)
|
||||
}
|
||||
|
||||
files, err := checkFileAndProcess(ctx, processCheckPaths)
|
||||
files, err := checkFileAndProcess(processCheckPaths)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -168,39 +166,3 @@ func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, erro
|
||||
log.Debugf("all system information gathered successfully")
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// GetInfoWithChecksTimeout is GetInfoWithChecks bounded by timeout. Posture-check gathering
|
||||
// runs uncancellable system calls (process enumeration, os.Stat), so calling it inline can
|
||||
// block the caller for as long as such a call hangs. It runs in a goroutine instead: if it
|
||||
// does not return within timeout the caller gets (nil, false) and should proceed with
|
||||
// degraded behavior rather than block. On a gathering error it falls back to base GetInfo.
|
||||
//
|
||||
// The buffered channel lets the abandoned goroutine finish and exit once its blocking call
|
||||
// returns, so it does not leak beyond the duration of that call.
|
||||
func GetInfoWithChecksTimeout(ctx context.Context, timeout time.Duration, checks []*proto.Checks) (*Info, bool) {
|
||||
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
infoCh := make(chan *Info, 1)
|
||||
go func() {
|
||||
info, err := GetInfoWithChecks(ctx, checks)
|
||||
if err != nil {
|
||||
log.Warnf("failed to get system info with checks: %v", err)
|
||||
info = GetInfo(ctx)
|
||||
}
|
||||
infoCh <- info
|
||||
}()
|
||||
|
||||
select {
|
||||
case info := <-infoCh:
|
||||
return info, true
|
||||
case <-ctx.Done():
|
||||
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
|
||||
log.Warnf("gathering system info with checks timed out after %s", timeout)
|
||||
} else {
|
||||
// Parent context canceled (e.g. shutdown), not a timeout.
|
||||
log.Warnf("gathering system info with checks canceled: %v", ctx.Err())
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -50,7 +50,7 @@ func GetInfo(ctx context.Context) *Info {
|
||||
}
|
||||
|
||||
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
|
||||
func checkFileAndProcess(_ context.Context, _ []string) ([]File, error) {
|
||||
func checkFileAndProcess(paths []string) ([]File, error) {
|
||||
return []File{}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ func GetInfo(ctx context.Context) *Info {
|
||||
sysName := string(bytes.Split(utsname.Sysname[:], []byte{0})[0])
|
||||
machine := string(bytes.Split(utsname.Machine[:], []byte{0})[0])
|
||||
release := string(bytes.Split(utsname.Release[:], []byte{0})[0])
|
||||
swVersion, err := exec.CommandContext(ctx, "sw_vers", "-productVersion").Output()
|
||||
swVersion, err := exec.Command("sw_vers", "-productVersion").Output()
|
||||
if err != nil {
|
||||
log.Warnf("got an error while retrieving macOS version with sw_vers, error: %s. Using darwin version instead.\n", err)
|
||||
swVersion = []byte(release)
|
||||
|
||||
@@ -105,7 +105,7 @@ func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
|
||||
}
|
||||
|
||||
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
|
||||
func checkFileAndProcess(_ context.Context, _ []string) ([]File, error) {
|
||||
func checkFileAndProcess(paths []string) ([]File, error) {
|
||||
return []File{}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -103,7 +103,7 @@ func collectLocationInfo(info *Info) {
|
||||
}
|
||||
}
|
||||
|
||||
func checkFileAndProcess(_ context.Context, _ []string) ([]File, error) {
|
||||
func checkFileAndProcess(_ []string) ([]File, error) {
|
||||
return []File{}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ package system
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc/metadata"
|
||||
@@ -35,20 +34,6 @@ func Test_CustomHostname(t *testing.T) {
|
||||
assert.Equal(t, want, got.Hostname)
|
||||
}
|
||||
|
||||
func TestGetInfoWithChecksTimeout_Success(t *testing.T) {
|
||||
info, ok := GetInfoWithChecksTimeout(context.Background(), 30*time.Second, nil)
|
||||
assert.True(t, ok, "expected gathering to complete within the timeout")
|
||||
assert.NotNil(t, info)
|
||||
}
|
||||
|
||||
func TestGetInfoWithChecksTimeout_Timeout(t *testing.T) {
|
||||
// A 1ns budget expires before the (real) system-info gathering can finish, so the
|
||||
// caller must get (nil, false) instead of blocking on the in-flight goroutine.
|
||||
info, ok := GetInfoWithChecksTimeout(context.Background(), time.Nanosecond, nil)
|
||||
assert.False(t, ok, "expected timeout to be reported")
|
||||
assert.Nil(t, info)
|
||||
}
|
||||
|
||||
func Test_NetAddresses(t *testing.T) {
|
||||
addr, err := networkAddresses()
|
||||
if err != nil {
|
||||
|
||||
@@ -3,30 +3,24 @@
|
||||
package system
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"slices"
|
||||
|
||||
"github.com/shirou/gopsutil/v3/process"
|
||||
)
|
||||
|
||||
// getRunningProcesses returns a list of running process paths. The context bounds the work:
|
||||
// the per-PID loop bails as soon as ctx is done, and the gopsutil calls honor it where they
|
||||
// can, so a stuck enumeration cannot run unbounded.
|
||||
func getRunningProcesses(ctx context.Context) ([]string, error) {
|
||||
processIDs, err := process.PidsWithContext(ctx)
|
||||
// getRunningProcesses returns a list of running process paths.
|
||||
func getRunningProcesses() ([]string, error) {
|
||||
processIDs, err := process.Pids()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
processMap := make(map[string]bool)
|
||||
for _, pID := range processIDs {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p := &process.Process{Pid: pID}
|
||||
|
||||
path, _ := p.ExeWithContext(ctx)
|
||||
path, _ := p.Exe()
|
||||
if path != "" {
|
||||
processMap[path] = false
|
||||
}
|
||||
@@ -41,21 +35,18 @@ func getRunningProcesses(ctx context.Context) ([]string, error) {
|
||||
}
|
||||
|
||||
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
|
||||
func checkFileAndProcess(ctx context.Context, paths []string) ([]File, error) {
|
||||
func checkFileAndProcess(paths []string) ([]File, error) {
|
||||
files := make([]File, len(paths))
|
||||
if len(paths) == 0 {
|
||||
return files, nil
|
||||
}
|
||||
|
||||
runningProcesses, err := getRunningProcesses(ctx)
|
||||
runningProcesses, err := getRunningProcesses()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i, path := range paths {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
file := File{Path: path}
|
||||
|
||||
_, err := os.Stat(path)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package system
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/shirou/gopsutil/v3/process"
|
||||
@@ -10,7 +9,7 @@ import (
|
||||
func Benchmark_getRunningProcesses(b *testing.B) {
|
||||
b.Run("getRunningProcesses new", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
ps, err := getRunningProcesses(context.Background())
|
||||
ps, err := getRunningProcesses()
|
||||
if err != nil {
|
||||
b.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -30,38 +29,12 @@ func Benchmark_getRunningProcesses(b *testing.B) {
|
||||
}
|
||||
}
|
||||
})
|
||||
s, _ := getRunningProcesses(context.Background())
|
||||
s, _ := getRunningProcesses()
|
||||
b.Logf("getRunningProcesses returned %d processes", len(s))
|
||||
s, _ = getRunningProcessesOld()
|
||||
b.Logf("getRunningProcessesOld returned %d processes", len(s))
|
||||
}
|
||||
|
||||
func TestCheckFileAndProcess_ContextCanceled(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
// With a canceled context and non-empty paths the gathering must bail with an error
|
||||
// instead of running the (potentially blocking) process scan / stat loop.
|
||||
if _, err := checkFileAndProcess(ctx, []string{"/does/not/exist"}); err == nil {
|
||||
t.Fatal("expected error on canceled context, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckFileAndProcess_EmptyPaths(t *testing.T) {
|
||||
// No check paths means no work to do: it must return immediately with no error,
|
||||
// even on a canceled context (nothing to scan or stat).
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
files, err := checkFileAndProcess(ctx, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error for empty paths: %v", err)
|
||||
}
|
||||
if len(files) != 0 {
|
||||
t.Fatalf("expected no files, got %d", len(files))
|
||||
}
|
||||
}
|
||||
|
||||
func getRunningProcessesOld() ([]string, error) {
|
||||
processes, err := process.Processes()
|
||||
if err != nil {
|
||||
|
||||
@@ -610,12 +610,10 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
|
||||
startPosture := time.Now()
|
||||
postureChecks, err := c.getPeerPostureChecks(account, peerID)
|
||||
if err != nil {
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
log.WithContext(ctx).Debugf("getPeerPostureChecks took %s", time.Since(startPosture))
|
||||
|
||||
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
|
||||
if err != nil {
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
|
||||
const (
|
||||
reconnThreshold = 5 * time.Minute
|
||||
baseBlockDuration = 10 * time.Minute // Duration for which a peer is banned after exceeding the reconnection limit
|
||||
baseBlockDuration = 30 * time.Minute // Duration for which a peer is banned after exceeding the reconnection limit
|
||||
reconnLimitForBan = 30 // Number of reconnections within the reconnTreshold that triggers a ban
|
||||
metaChangeLimit = 3 // Number of reconnections with different metadata that triggers a ban of one peer
|
||||
)
|
||||
@@ -139,22 +139,13 @@ func (l *loginFilter) addLogin(wgPubKey string, metaHash uint64) {
|
||||
state.lastSeen = now
|
||||
}
|
||||
|
||||
func metaHash(meta nbpeer.PeerSystemMeta, pubip string) uint64 {
|
||||
func metaHash(meta nbpeer.PeerSystemMeta) uint64 {
|
||||
h := fnv.New64a()
|
||||
|
||||
h.Write([]byte(meta.WtVersion))
|
||||
h.Write([]byte(meta.OSVersion))
|
||||
h.Write([]byte(meta.KernelVersion))
|
||||
h.Write([]byte(meta.Hostname))
|
||||
h.Write([]byte(meta.SystemSerialNumber))
|
||||
h.Write([]byte(pubip))
|
||||
|
||||
macs := uint64(0)
|
||||
for _, na := range meta.NetworkAddresses {
|
||||
for _, r := range na.Mac {
|
||||
macs += uint64(r)
|
||||
}
|
||||
}
|
||||
|
||||
return h.Sum64() + macs
|
||||
return h.Sum64()
|
||||
}
|
||||
|
||||
@@ -164,9 +164,7 @@ func BenchmarkHashingMethods(b *testing.B) {
|
||||
KernelVersion: "5.15.0-76-generic",
|
||||
Hostname: "prod-server-database-01",
|
||||
SystemSerialNumber: "PC-1234567890",
|
||||
NetworkAddresses: []nbpeer.NetworkAddress{{Mac: "00:1B:44:11:3A:B7"}, {Mac: "00:1B:44:11:3A:B8"}},
|
||||
}
|
||||
pubip := "8.8.8.8"
|
||||
|
||||
var resultString string
|
||||
var resultUint uint64
|
||||
@@ -175,7 +173,7 @@ func BenchmarkHashingMethods(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
resultString = builderString(meta, pubip)
|
||||
resultString = builderString(meta)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -183,7 +181,7 @@ func BenchmarkHashingMethods(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
resultString = fnvHashToString(meta, pubip)
|
||||
resultString = fnvHashToString(meta)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -191,7 +189,7 @@ func BenchmarkHashingMethods(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
resultUint = metaHash(meta, pubip)
|
||||
resultUint = metaHash(meta)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -199,29 +197,20 @@ func BenchmarkHashingMethods(b *testing.B) {
|
||||
_ = resultUint
|
||||
}
|
||||
|
||||
func fnvHashToString(meta nbpeer.PeerSystemMeta, pubip string) string {
|
||||
func fnvHashToString(meta nbpeer.PeerSystemMeta) string {
|
||||
h := fnv.New64a()
|
||||
|
||||
if len(meta.NetworkAddresses) != 0 {
|
||||
for _, na := range meta.NetworkAddresses {
|
||||
h.Write([]byte(na.Mac))
|
||||
}
|
||||
}
|
||||
|
||||
h.Write([]byte(meta.WtVersion))
|
||||
h.Write([]byte(meta.OSVersion))
|
||||
h.Write([]byte(meta.KernelVersion))
|
||||
h.Write([]byte(meta.Hostname))
|
||||
h.Write([]byte(meta.SystemSerialNumber))
|
||||
h.Write([]byte(pubip))
|
||||
|
||||
return strconv.FormatUint(h.Sum64(), 16)
|
||||
}
|
||||
|
||||
func builderString(meta nbpeer.PeerSystemMeta, pubip string) string {
|
||||
mac := getMacAddress(meta.NetworkAddresses)
|
||||
estimatedSize := len(meta.WtVersion) + len(meta.OSVersion) + len(meta.KernelVersion) + len(meta.Hostname) + len(meta.SystemSerialNumber) +
|
||||
len(pubip) + len(mac) + 6
|
||||
func builderString(meta nbpeer.PeerSystemMeta) string {
|
||||
estimatedSize := len(meta.WtVersion) + len(meta.OSVersion) + len(meta.KernelVersion) + len(meta.Hostname) + len(meta.SystemSerialNumber) + 4
|
||||
|
||||
var b strings.Builder
|
||||
b.Grow(estimatedSize)
|
||||
@@ -235,23 +224,10 @@ func builderString(meta nbpeer.PeerSystemMeta, pubip string) string {
|
||||
b.WriteString(meta.Hostname)
|
||||
b.WriteByte('|')
|
||||
b.WriteString(meta.SystemSerialNumber)
|
||||
b.WriteByte('|')
|
||||
b.WriteString(pubip)
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func getMacAddress(nas []nbpeer.NetworkAddress) string {
|
||||
if len(nas) == 0 {
|
||||
return ""
|
||||
}
|
||||
macs := make([]string, 0, len(nas))
|
||||
for _, na := range nas {
|
||||
macs = append(macs, na.Mac)
|
||||
}
|
||||
return strings.Join(macs, "/")
|
||||
}
|
||||
|
||||
func BenchmarkLoginFilter_ParallelLoad(b *testing.B) {
|
||||
filter := newLoginFilterWithCfg(testAdvancedCfg())
|
||||
numKeys := 100000
|
||||
|
||||
@@ -254,7 +254,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
||||
return mapError(ctx, err)
|
||||
}
|
||||
|
||||
metahashed := metaHash(peerMeta, sRealIP)
|
||||
metahashed := metaHash(peerMeta)
|
||||
if userID == "" && !s.loginFilter.allowLogin(peerKey.String(), metahashed) {
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().CountSyncRequestBlocked()
|
||||
@@ -306,7 +306,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
||||
log.WithContext(ctx).Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)
|
||||
}
|
||||
|
||||
metahash := metaHash(peerMeta, realIP.String())
|
||||
metahash := metaHash(peerMeta)
|
||||
s.loginFilter.addLogin(peerKey.String(), metahash)
|
||||
|
||||
peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP, syncStart)
|
||||
@@ -732,7 +732,7 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
|
||||
}
|
||||
|
||||
peerMeta := extractPeerMeta(ctx, loginReq.GetMeta())
|
||||
metahashed := metaHash(peerMeta, sRealIP)
|
||||
metahashed := metaHash(peerMeta)
|
||||
if !s.loginFilter.allowLogin(peerKey.String(), metahashed) {
|
||||
if s.logBlockedPeers {
|
||||
log.WithContext(ctx).Tracef("peer %s with meta hash %d is blocked from login", peerKey.String(), metahashed)
|
||||
@@ -788,7 +788,11 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
|
||||
ExtraDNSLabels: loginReq.GetDnsLabels(),
|
||||
})
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed logging in peer %s: %s", peerKey, err)
|
||||
if errors.Is(err, internalStatus.ErrNoAuthMethodProvided) {
|
||||
log.WithContext(ctx).Tracef("failed logging in peer %s: %s", peerKey, err)
|
||||
} else {
|
||||
log.WithContext(ctx).Warnf("failed logging in peer %s: %s", peerKey, err)
|
||||
}
|
||||
return nil, mapError(ctx, err)
|
||||
}
|
||||
|
||||
@@ -1205,7 +1209,7 @@ func (s *Server) SyncMeta(ctx context.Context, req *proto.EncryptedMessage) (*pr
|
||||
return nil, msg
|
||||
}
|
||||
|
||||
err = s.accountManager.SyncPeerMeta(ctx, peerKey.String(), extractPeerMeta(ctx, syncMetaReq.GetMeta()))
|
||||
err = s.accountManager.SyncPeerMeta(ctx, peerKey.String(), extractPeerMeta(ctx, syncMetaReq.GetMeta()), realIP)
|
||||
if err != nil {
|
||||
return nil, mapError(ctx, err)
|
||||
}
|
||||
@@ -1254,7 +1258,10 @@ func (s *Server) Logout(ctx context.Context, req *proto.EncryptedMessage) (*prot
|
||||
func toProtocolChecks(ctx context.Context, postureChecks []*posture.Checks) []*proto.Checks {
|
||||
protoChecks := make([]*proto.Checks, 0, len(postureChecks))
|
||||
for _, postureCheck := range postureChecks {
|
||||
protoChecks = append(protoChecks, toProtocolCheck(postureCheck))
|
||||
check := toProtocolCheck(postureCheck)
|
||||
if check != nil {
|
||||
protoChecks = append(protoChecks, check)
|
||||
}
|
||||
}
|
||||
|
||||
return protoChecks
|
||||
@@ -1278,5 +1285,9 @@ func toProtocolCheck(postureCheck *posture.Checks) *proto.Checks {
|
||||
}
|
||||
}
|
||||
|
||||
if len(protoCheck.Files) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return protoCheck
|
||||
}
|
||||
|
||||
@@ -1889,12 +1889,12 @@ func domainIsUpToDate(domain string, domainCategory string, userAuth auth.UserAu
|
||||
// concurrent stream that started earlier loses the optimistic-lock race
|
||||
// in MarkPeerConnected and bails without writing.
|
||||
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||
peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID)
|
||||
peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, RealIP: realIP}, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err)
|
||||
}
|
||||
|
||||
if err := am.MarkPeerConnected(ctx, peerPubKey, realIP, accountID, syncTime.UnixNano(), netMap); err != nil {
|
||||
if err := am.MarkPeerConnected(ctx, peerPubKey, accountID, syncTime.UnixNano(), netMap); err != nil {
|
||||
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
|
||||
}
|
||||
|
||||
@@ -1914,13 +1914,13 @@ func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, account
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error {
|
||||
func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) error {
|
||||
accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, peerPubKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, _, _, _, err = am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, accountID)
|
||||
_, _, _, _, err = am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, RealIP: realIP, UpdateAccountPeers: true}, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -62,7 +62,7 @@ type Manager interface {
|
||||
GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
|
||||
ListUsers(ctx context.Context, accountID string) ([]*types.User, error)
|
||||
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error
|
||||
MarkPeerConnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error
|
||||
MarkPeerDisconnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error
|
||||
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
|
||||
UpdatePeer(ctx context.Context, accountID, userID string, p *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||
@@ -123,7 +123,7 @@ type Manager interface {
|
||||
GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error)
|
||||
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error
|
||||
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
||||
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) error
|
||||
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||
GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error)
|
||||
GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error)
|
||||
|
||||
@@ -1323,17 +1323,17 @@ func (mr *MockManagerMockRecorder) ExtendPeerSession(ctx, peerPubKey, userID int
|
||||
}
|
||||
|
||||
// MarkPeerConnected mocks base method.
|
||||
func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
|
||||
func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "MarkPeerConnected", ctx, peerKey, realIP, accountID, sessionStartedAt, nmap)
|
||||
ret := m.ctrl.Call(m, "MarkPeerConnected", ctx, peerKey, accountID, sessionStartedAt, nmap)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// MarkPeerConnected indicates an expected call of MarkPeerConnected.
|
||||
func (mr *MockManagerMockRecorder) MarkPeerConnected(ctx, peerKey, realIP, accountID, sessionStartedAt, nmap interface{}) *gomock.Call {
|
||||
func (mr *MockManagerMockRecorder) MarkPeerConnected(ctx, peerKey, accountID, sessionStartedAt, nmap interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnected", reflect.TypeOf((*MockManager)(nil).MarkPeerConnected), ctx, peerKey, realIP, accountID, sessionStartedAt, nmap)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnected", reflect.TypeOf((*MockManager)(nil).MarkPeerConnected), ctx, peerKey, accountID, sessionStartedAt, nmap)
|
||||
}
|
||||
|
||||
// MarkPeerDisconnected mocks base method.
|
||||
@@ -1586,17 +1586,17 @@ func (mr *MockManagerMockRecorder) SyncPeer(ctx, sync, accountID interface{}) *g
|
||||
}
|
||||
|
||||
// SyncPeerMeta mocks base method.
|
||||
func (m *MockManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta peer.PeerSystemMeta) error {
|
||||
func (m *MockManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta peer.PeerSystemMeta, realIP net.IP) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SyncPeerMeta", ctx, peerPubKey, meta)
|
||||
ret := m.ctrl.Call(m, "SyncPeerMeta", ctx, peerPubKey, meta, realIP)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SyncPeerMeta indicates an expected call of SyncPeerMeta.
|
||||
func (mr *MockManagerMockRecorder) SyncPeerMeta(ctx, peerPubKey, meta interface{}) *gomock.Call {
|
||||
func (mr *MockManagerMockRecorder) SyncPeerMeta(ctx, peerPubKey, meta, realIP interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SyncPeerMeta", reflect.TypeOf((*MockManager)(nil).SyncPeerMeta), ctx, peerPubKey, meta)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SyncPeerMeta", reflect.TypeOf((*MockManager)(nil).SyncPeerMeta), ctx, peerPubKey, meta, realIP)
|
||||
}
|
||||
|
||||
// SyncUserJWTGroups mocks base method.
|
||||
|
||||
@@ -1836,7 +1836,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano(), nil)
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), accountID, time.Now().UTC().UnixNano(), nil)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
||||
@@ -1907,7 +1907,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
// when we mark peer as connected, the peer login expiration routine should trigger
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano(), nil)
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), accountID, time.Now().UTC().UnixNano(), nil)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
failed := waitTimeout(wg, time.Second)
|
||||
@@ -1916,6 +1916,117 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_MarkPeerDisconnected_SchedulesInactivityExpiration(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
key, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err, "unable to generate WireGuard key")
|
||||
peerPubKey := key.PublicKey().String()
|
||||
|
||||
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: peerPubKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
|
||||
InactivityExpirationEnabled: true,
|
||||
}, false)
|
||||
require.NoError(t, err, "unable to add peer")
|
||||
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: true,
|
||||
PeerInactivityExpiration: time.Hour,
|
||||
PeerInactivityExpirationEnabled: true,
|
||||
Extra: &types.ExtraSettings{},
|
||||
})
|
||||
require.NoError(t, err, "expecting to update account settings successfully but got error")
|
||||
|
||||
// Establish a session so the matching-token disconnect is actually applied.
|
||||
streamStartTime := time.Now().UTC()
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano(), nil)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
// Install the mock only now, so the assertion observes the disconnect, not
|
||||
// the earlier connect.
|
||||
scheduled := make(chan struct{}, 1)
|
||||
manager.peerInactivityExpiry = &MockScheduler{
|
||||
CancelFunc: func(ctx context.Context, IDs []string) {},
|
||||
ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
|
||||
select {
|
||||
case scheduled <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
err = manager.MarkPeerDisconnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano())
|
||||
require.NoError(t, err, "unable to mark peer disconnected")
|
||||
|
||||
select {
|
||||
case <-scheduled:
|
||||
// expected: disconnect re-armed the inactivity expiry timer
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected inactivity expiration to be rescheduled when an eligible peer disconnects")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_MarkPeerDisconnected_SkipsInactivityExpirationWhenDisabled(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
key, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err, "unable to generate WireGuard key")
|
||||
peerPubKey := key.PublicKey().String()
|
||||
|
||||
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: peerPubKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
|
||||
InactivityExpirationEnabled: true,
|
||||
}, false)
|
||||
require.NoError(t, err, "unable to add peer")
|
||||
|
||||
// Peer is eligible (SSO + inactivity enabled) but the account-level setting
|
||||
// stays disabled, so disconnect must not schedule anything.
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: true,
|
||||
PeerInactivityExpiration: time.Hour,
|
||||
PeerInactivityExpirationEnabled: false,
|
||||
Extra: &types.ExtraSettings{},
|
||||
})
|
||||
require.NoError(t, err, "expecting to update account settings successfully but got error")
|
||||
|
||||
streamStartTime := time.Now().UTC()
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano(), nil)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
scheduled := make(chan struct{}, 1)
|
||||
manager.peerInactivityExpiry = &MockScheduler{
|
||||
CancelFunc: func(ctx context.Context, IDs []string) {},
|
||||
ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
|
||||
select {
|
||||
case scheduled <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
err = manager.MarkPeerDisconnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano())
|
||||
require.NoError(t, err, "unable to mark peer disconnected")
|
||||
|
||||
select {
|
||||
case <-scheduled:
|
||||
t.Fatal("inactivity expiration must not be scheduled while the account-level setting is disabled")
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
// expected: nothing scheduled
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
@@ -1935,7 +2046,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
|
||||
|
||||
t.Run("disconnect peer when session token matches", func(t *testing.T) {
|
||||
streamStartTime := time.Now().UTC()
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano(), nil)
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano(), nil)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
@@ -1956,7 +2067,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
|
||||
t.Run("skip disconnect when stored session is newer (zombie stream protection)", func(t *testing.T) {
|
||||
// Newer stream wins on connect (sets SessionStartedAt = now ns).
|
||||
streamStartTime := time.Now().UTC()
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano(), nil)
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano(), nil)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
@@ -1980,7 +2091,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
|
||||
|
||||
t.Run("skip stale connect when stored session is newer (blocked goroutine protection)", func(t *testing.T) {
|
||||
node2SyncTime := time.Now().UTC()
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node2SyncTime.UnixNano(), nil)
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, node2SyncTime.UnixNano(), nil)
|
||||
require.NoError(t, err, "node 2 should connect peer")
|
||||
|
||||
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
@@ -1990,7 +2101,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
|
||||
"SessionStartedAt should equal node2SyncTime token")
|
||||
|
||||
node1StaleSyncTime := node2SyncTime.Add(-1 * time.Minute)
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node1StaleSyncTime.UnixNano(), nil)
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, node1StaleSyncTime.UnixNano(), nil)
|
||||
require.NoError(t, err, "stale connect should not return error")
|
||||
|
||||
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
@@ -2052,7 +2163,7 @@ func TestDefaultAccountManager_MarkPeerConnected_ConcurrentRace(t *testing.T) {
|
||||
defer done.Done()
|
||||
ready.Done()
|
||||
start.Wait()
|
||||
errs <- manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, token, nil)
|
||||
errs <- manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, token, nil)
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -2093,7 +2204,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano(), nil)
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), accountID, time.Now().UTC().UnixNano(), nil)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
|
||||
@@ -41,7 +41,7 @@ func TestAffectedPeers_DependencyCoverageMatrix(t *testing.T) {
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
return affectedpeers.Change{ChangedPeerIDs: []string{s.routerPeerID}},
|
||||
[]string{s.sourcePeerID}, []string{s.unrelatedPeerID}
|
||||
[]string{s.sourcePeerID, s.routerPeerID}, []string{s.unrelatedPeerID}
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -106,12 +106,8 @@ func TestAffectedPeers_DependencyCoverageMatrix(t *testing.T) {
|
||||
change, mustContain, mustExclude := r.build(t, s, ctx)
|
||||
affected := resolveAffected(t, s.manager.Store, s.accountID, change)
|
||||
|
||||
for _, id := range mustContain {
|
||||
assert.Contains(t, affected, id, "expected peer to be affected")
|
||||
}
|
||||
for _, id := range mustExclude {
|
||||
assert.NotContains(t, affected, id, "peer must not be affected")
|
||||
}
|
||||
assert.ElementsMatch(t, affected, mustContain, "expected peer to be affected")
|
||||
assert.NotContains(t, affected, mustExclude, "peer must not be affected")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -251,7 +251,9 @@ func TestAffectedPeers_E2E_UpdateResource_DestinationResourcePolicy_RefreshesSou
|
||||
}
|
||||
}
|
||||
|
||||
func TestAffectedPeers_E2E_UpdateResource_DisabledSiblingRouter_StillBridged(t *testing.T) {
|
||||
// A disabled sibling router routes to nobody, so updating a resource on its network
|
||||
// must NOT refresh its peer (the enabled router carries the bridge instead).
|
||||
func TestAffectedPeers_E2E_UpdateResource_DisabledSiblingRouterNotBridged(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
@@ -274,13 +276,18 @@ func TestAffectedPeers_E2E_UpdateResource_DisabledSiblingRouter_StillBridged(t *
|
||||
require.NoError(t, err)
|
||||
|
||||
disabledCh := s.updateManager.CreateChannel(ctx, disabledRouterPeer.ID)
|
||||
t.Cleanup(func() { s.updateManager.CloseChannel(ctx, disabledRouterPeer.ID) })
|
||||
enabledCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
|
||||
t.Cleanup(func() {
|
||||
s.updateManager.CloseChannel(ctx, disabledRouterPeer.ID)
|
||||
s.updateManager.CloseChannel(ctx, s.routerPeerID)
|
||||
})
|
||||
|
||||
settleAffectedUpdates(disabledCh)
|
||||
settleAffectedUpdates(disabledCh, enabledCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, disabledCh)
|
||||
peerShouldReceiveUpdate(t, enabledCh)
|
||||
peerShouldNotReceiveUpdate(t, disabledCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -298,7 +305,7 @@ func TestAffectedPeers_E2E_UpdateResource_DisabledSiblingRouter_StillBridged(t *
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout: resource update did not refresh the disabled sibling router's peer")
|
||||
t.Error("timeout")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -682,6 +682,9 @@ func TestAffectedPeers_AllRoutingPeers_Network(t *testing.T) {
|
||||
assert.Contains(t, affected, secondRouterPeer.ID, "second routing peer on the same network must also be affected")
|
||||
}
|
||||
|
||||
// A disabled router in the snapshot routes to nobody, so it is skipped when the
|
||||
// walk scans existing account data: a policy edit still folds the literal source
|
||||
// group, but not the disabled router's peer.
|
||||
func TestAffectedPeers_DisabledRouter(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
@@ -694,11 +697,13 @@ func TestAffectedPeers_DisabledRouter(t *testing.T) {
|
||||
|
||||
affected := s.resolvePolicyAffected(ctx, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID))
|
||||
|
||||
assert.Contains(t, affected, s.sourcePeerID, "source peer must be affected")
|
||||
assert.Contains(t, affected, s.routerPeerID,
|
||||
"disabled router's peer must still be affected: Enabled must not gate affected-peers")
|
||||
assert.Contains(t, affected, s.sourcePeerID, "source peer (literal policy source group) must be affected")
|
||||
assert.NotContains(t, affected, s.routerPeerID,
|
||||
"a disabled router routes to nobody, so its peer must not be folded from snapshot data")
|
||||
}
|
||||
|
||||
// A disabled resource in the snapshot is skipped: the policy edit still folds the
|
||||
// literal source group, but the resource no longer bridges to its network's router.
|
||||
func TestAffectedPeers_DisabledResource(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
@@ -710,9 +715,9 @@ func TestAffectedPeers_DisabledResource(t *testing.T) {
|
||||
|
||||
affected := s.resolvePolicyAffected(ctx, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID))
|
||||
|
||||
assert.Contains(t, affected, s.sourcePeerID, "source peer must be affected")
|
||||
assert.Contains(t, affected, s.routerPeerID,
|
||||
"disabled resource must still resolve the routing peer: Enabled must not gate affected-peers")
|
||||
assert.Contains(t, affected, s.sourcePeerID, "source peer (literal policy source group) must be affected")
|
||||
assert.NotContains(t, affected, s.routerPeerID,
|
||||
"a disabled resource routes to nobody, so its network's router must not be folded from snapshot data")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_DisabledRule(t *testing.T) {
|
||||
|
||||
@@ -96,33 +96,54 @@ func affectedGroupID(i int) string { return fmt.Sprintf("affected-grp-%d", i)
|
||||
func affectedGroupName(i int) string { return fmt.Sprintf("AffectedGroup%d", i) }
|
||||
|
||||
func TestCollectGroupChange_PolicyLinked(t *testing.T) {
|
||||
manager, s, accountID, _, groupIDs := setupAffectedPeersTest(t)
|
||||
manager, s, accountID, peerIDs, groupIDs := setupAffectedPeersTest(t)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := manager.SavePolicy(ctx, accountID, userID, &types.Policy{
|
||||
Enabled: true,
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{groupIDs[0]},
|
||||
Destinations: []string{groupIDs[1]},
|
||||
Bidirectional: true,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
Enabled: true,
|
||||
Sources: []string{groupIDs[0]},
|
||||
Destinations: []string{groupIDs[1]},
|
||||
SourceResource: types.Resource{ID: peerIDs[0], Type: types.ResourceTypePeer},
|
||||
DestinationResource: types.Resource{ID: peerIDs[1], Type: types.ResourceTypePeer},
|
||||
Bidirectional: true,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{groupIDs[0]},
|
||||
Destinations: []string{groupIDs[1]},
|
||||
SourceResource: types.Resource{ID: peerIDs[2], Type: types.ResourceTypeHost},
|
||||
DestinationResource: types.Resource{ID: peerIDs[3], Type: types.ResourceTypeHost},
|
||||
Bidirectional: true,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{groupIDs[0]},
|
||||
Destinations: []string{groupIDs[1]},
|
||||
SourceResource: types.Resource{ID: "", Type: types.ResourceTypePeer},
|
||||
DestinationResource: types.Resource{ID: "", Type: types.ResourceTypePeer},
|
||||
Bidirectional: true,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
groups, _ := collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[0]})
|
||||
assert.Contains(t, groups, groupIDs[0])
|
||||
assert.Contains(t, groups, groupIDs[1])
|
||||
groups, directPeers := collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[0]})
|
||||
assert.ElementsMatch(t, groups, []string{groupIDs[0], groupIDs[1]})
|
||||
assert.ElementsMatch(t, directPeers, []string{peerIDs[1]})
|
||||
|
||||
groups, _ = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[1]})
|
||||
assert.Contains(t, groups, groupIDs[0])
|
||||
assert.Contains(t, groups, groupIDs[1])
|
||||
groups, directPeers = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[1]})
|
||||
assert.ElementsMatch(t, groups, []string{groupIDs[0], groupIDs[1]})
|
||||
assert.ElementsMatch(t, directPeers, []string{peerIDs[0]})
|
||||
|
||||
groups, _ = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[2]})
|
||||
groups, directPeers = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[2]})
|
||||
assert.Empty(t, groups)
|
||||
assert.Empty(t, directPeers)
|
||||
}
|
||||
|
||||
func TestCollectGroupChange_PolicyWithDirectPeerResource(t *testing.T) {
|
||||
@@ -133,20 +154,44 @@ func TestCollectGroupChange_PolicyWithDirectPeerResource(t *testing.T) {
|
||||
Enabled: true,
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{groupIDs[0]},
|
||||
SourceResource: types.Resource{ID: peerIDs[3], Type: types.ResourceTypePeer},
|
||||
Destinations: []string{groupIDs[1]},
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
Enabled: true,
|
||||
Sources: []string{groupIDs[0]},
|
||||
SourceResource: types.Resource{ID: peerIDs[3], Type: types.ResourceTypePeer},
|
||||
DestinationResource: types.Resource{ID: peerIDs[4], Type: types.ResourceTypePeer},
|
||||
Destinations: []string{groupIDs[1]},
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{groupIDs[0]},
|
||||
SourceResource: types.Resource{ID: peerIDs[1], Type: types.ResourceTypeHost},
|
||||
DestinationResource: types.Resource{ID: peerIDs[2], Type: types.ResourceTypeHost},
|
||||
Destinations: []string{groupIDs[1]},
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{groupIDs[0]},
|
||||
SourceResource: types.Resource{ID: "", Type: types.ResourceTypePeer},
|
||||
DestinationResource: types.Resource{ID: "", Type: types.ResourceTypePeer},
|
||||
Destinations: []string{groupIDs[1]},
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
groups, directPeers := collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[0]})
|
||||
assert.Contains(t, groups, groupIDs[0])
|
||||
assert.Contains(t, groups, groupIDs[1])
|
||||
assert.Contains(t, directPeers, peerIDs[3])
|
||||
assert.ElementsMatch(t, groups, []string{groupIDs[0], groupIDs[1]})
|
||||
assert.ElementsMatch(t, directPeers, []string{peerIDs[4]})
|
||||
|
||||
groups, directPeers = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[1]})
|
||||
assert.ElementsMatch(t, groups, []string{groupIDs[0], groupIDs[1]})
|
||||
assert.ElementsMatch(t, directPeers, []string{peerIDs[3]})
|
||||
|
||||
groups, directPeers = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[2]})
|
||||
assert.Empty(t, groups)
|
||||
assert.Empty(t, directPeers)
|
||||
}
|
||||
|
||||
func TestCollectGroupChange_PolicyWithNonPeerResource_NoDirectPeers(t *testing.T) {
|
||||
@@ -168,8 +213,7 @@ func TestCollectGroupChange_PolicyWithNonPeerResource_NoDirectPeers(t *testing.T
|
||||
require.NoError(t, err)
|
||||
|
||||
groups, directPeers := collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[0]})
|
||||
assert.Contains(t, groups, groupIDs[0])
|
||||
assert.Contains(t, groups, groupIDs[1])
|
||||
assert.ElementsMatch(t, groups, []string{groupIDs[0], groupIDs[1]})
|
||||
assert.Empty(t, directPeers, "non-peer resources should not produce direct peer IDs")
|
||||
}
|
||||
|
||||
@@ -294,6 +338,7 @@ func TestCollectGroupChange_NetworkRouterLinked(t *testing.T) {
|
||||
AccountID: accountID,
|
||||
PeerGroups: []string{groupIDs[0]},
|
||||
Peer: peerIDs[3],
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -324,6 +369,7 @@ func TestCollectGroupChange_NetworkRouterPeerOnlyNoGroups(t *testing.T) {
|
||||
NetworkID: net1.ID,
|
||||
AccountID: accountID,
|
||||
Peer: peerIDs[4],
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -373,17 +419,11 @@ func TestCollectGroupChange_MultipleEntities(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
groups, directPeers := collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[0]})
|
||||
assert.Contains(t, groups, groupIDs[0])
|
||||
assert.Contains(t, groups, groupIDs[1])
|
||||
assert.NotContains(t, groups, groupIDs[2])
|
||||
assert.NotContains(t, groups, groupIDs[3])
|
||||
assert.ElementsMatch(t, groups, []string{groupIDs[0], groupIDs[1]})
|
||||
assert.Empty(t, directPeers)
|
||||
|
||||
groups, directPeers = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[3]})
|
||||
assert.Contains(t, groups, groupIDs[2])
|
||||
assert.Contains(t, groups, groupIDs[3])
|
||||
assert.NotContains(t, groups, groupIDs[0])
|
||||
assert.NotContains(t, groups, groupIDs[1])
|
||||
assert.ElementsMatch(t, groups, []string{groupIDs[2], groupIDs[3]})
|
||||
assert.Empty(t, directPeers)
|
||||
}
|
||||
|
||||
@@ -452,8 +492,9 @@ func TestResolveAffectedPeers_PolicyBetweenTwoGroups(t *testing.T) {
|
||||
result = manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[1]})
|
||||
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[1]}, result)
|
||||
|
||||
// peerIDs[2] is unrelated to the route; only its own map can change.
|
||||
result = manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[2]})
|
||||
assert.Empty(t, result)
|
||||
assert.ElementsMatch(t, []string{peerIDs[2]}, result)
|
||||
}
|
||||
|
||||
func TestResolveAffectedPeers_PolicyThreeGroups(t *testing.T) {
|
||||
@@ -474,7 +515,7 @@ func TestResolveAffectedPeers_PolicyThreeGroups(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
result := manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[0]})
|
||||
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[1], peerIDs[2]}, result)
|
||||
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[2]}, result)
|
||||
}
|
||||
|
||||
func TestResolveAffectedPeers_RoutePeerGroups(t *testing.T) {
|
||||
@@ -506,8 +547,9 @@ func TestResolveAffectedPeers_RoutePeerGroups(t *testing.T) {
|
||||
result = manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[1]})
|
||||
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[1]}, result)
|
||||
|
||||
// peerIDs[2] is in no policy; only its own map can change, so it refreshes itself.
|
||||
result = manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[2]})
|
||||
assert.Empty(t, result)
|
||||
assert.ElementsMatch(t, []string{peerIDs[2]}, result)
|
||||
}
|
||||
|
||||
func TestResolveAffectedPeers_RouteWithDirectPeer(t *testing.T) {
|
||||
@@ -564,9 +606,9 @@ func TestResolveAffectedPeers_RouteWithAccessControlGroups(t *testing.T) {
|
||||
result := manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[2]})
|
||||
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[1], peerIDs[2]}, result)
|
||||
|
||||
// peer3 is unrelated
|
||||
// peer3 is unrelated to the route; only its own map can change.
|
||||
result = manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[3]})
|
||||
assert.Empty(t, result)
|
||||
assert.ElementsMatch(t, []string{peerIDs[3]}, result)
|
||||
}
|
||||
|
||||
func TestResolveAffectedPeers_NetworkRouter(t *testing.T) {
|
||||
@@ -587,6 +629,7 @@ func TestResolveAffectedPeers_NetworkRouter(t *testing.T) {
|
||||
AccountID: accountID,
|
||||
PeerGroups: []string{groupIDs[0]},
|
||||
Peer: peerIDs[3],
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -659,9 +702,13 @@ func TestResolveAffectedPeers_PeerInMultipleGroups(t *testing.T) {
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
// peer0 is in group0 AND group1, so both policies apply
|
||||
// peer0 is in group0 AND group1, so both policies apply. A peer change folds
|
||||
// only the changed peer plus the opposite side of each rule: group2 (peer2) via
|
||||
// the group0 policy and group3 (peer3) via the group1 policy. peer1, a co-member
|
||||
// of group1, is a sibling of the changed peer and must NOT refresh.
|
||||
result := manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[0]})
|
||||
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[1], peerIDs[2], peerIDs[3]}, result)
|
||||
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[2], peerIDs[3]}, result)
|
||||
assert.NotContains(t, result, peerIDs[1], "co-member of the changed peer's group must not refresh")
|
||||
}
|
||||
|
||||
func TestResolveAffectedPeers_MultipleChangedPeers(t *testing.T) {
|
||||
@@ -697,7 +744,7 @@ func TestResolveAffectedPeers_MultipleChangedPeers(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
result := manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[0], peerIDs[2]})
|
||||
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[1], peerIDs[2], peerIDs[3]}, result)
|
||||
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[2], peerIDs[1], peerIDs[3]}, result)
|
||||
}
|
||||
|
||||
func TestResolveAffectedPeers_SharedGroupAcrossPolicyAndRoute(t *testing.T) {
|
||||
@@ -854,8 +901,9 @@ func TestAffectedPeers_IsolatedPolicies(t *testing.T) {
|
||||
assert.NotContains(t, result, peerIDs[0])
|
||||
assert.NotContains(t, result, peerIDs[1])
|
||||
|
||||
// peerIDs[4] is in neither isolated policy; only its own map can change.
|
||||
result = manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[4]})
|
||||
assert.Empty(t, result)
|
||||
assert.ElementsMatch(t, []string{peerIDs[4]}, result)
|
||||
}
|
||||
|
||||
func TestAffectedPeers_IsolatedRouteAndPolicy(t *testing.T) {
|
||||
@@ -977,12 +1025,13 @@ func TestAffectedPeers_GroupUpdateOnlyAffectsLinkedPeers(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestAffectedPeers_UnlinkedGroupChange_NoUpdates(t *testing.T) {
|
||||
// A peer in no policy/route refreshes only itself — no other peer is affected.
|
||||
func TestAffectedPeers_UnlinkedPeerChange_RefreshesSelfOnly(t *testing.T) {
|
||||
manager, s, accountID, peerIDs, _ := setupAffectedPeersTest(t)
|
||||
ctx := context.Background()
|
||||
|
||||
result := manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[0]})
|
||||
assert.Empty(t, result)
|
||||
assert.ElementsMatch(t, []string{peerIDs[0]}, result)
|
||||
}
|
||||
|
||||
// TestAffectedPeers_PolicyChange_UnrelatedPeerNoUpdate verifies that creating/deleting a
|
||||
@@ -1332,6 +1381,7 @@ func TestAffectedPeers_NetworkRouterUnlinkedPeerNoUpdate(t *testing.T) {
|
||||
NetworkID: net1.ID,
|
||||
AccountID: accountID,
|
||||
PeerGroups: []string{"nr-grpA"},
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1755,7 +1805,9 @@ func TestCollectAffectedFromProxyServices_GroupContainingTargetPeerChanged(t *te
|
||||
assert.Contains(t, directPeers, peerIDs[1], "target peer must be refreshed")
|
||||
}
|
||||
|
||||
func TestCollectAffectedFromProxyServices_DisabledServiceStillMatches(t *testing.T) {
|
||||
// A disabled service in the snapshot proxies nothing, so it is skipped: a changed
|
||||
// target peer does not pull in the service's proxy peer.
|
||||
func TestCollectAffectedFromProxyServices_DisabledServiceSkipped(t *testing.T) {
|
||||
manager, s, accountID, peerIDs, _ := setupAffectedPeersTest(t)
|
||||
ctx := context.Background()
|
||||
|
||||
@@ -1781,8 +1833,7 @@ func TestCollectAffectedFromProxyServices_DisabledServiceStillMatches(t *testing
|
||||
require.NoError(t, s.CreateService(ctx, svc))
|
||||
|
||||
_, directPeers := collectPeerChangeAffectedGroups(ctx, manager.Store, accountID, nil, []string{peerIDs[1]})
|
||||
assert.Contains(t, directPeers, peerIDs[0], "disabled service should still trigger a refresh so peers are ready when re-enabled")
|
||||
assert.Contains(t, directPeers, peerIDs[1], "disabled target should still trigger a refresh")
|
||||
assert.NotContains(t, directPeers, peerIDs[0], "a disabled service proxies nothing, so its proxy peer must not be folded")
|
||||
}
|
||||
|
||||
func TestCollectAffectedFromProxyServices_NonPeerTargetType(t *testing.T) {
|
||||
|
||||
@@ -6,7 +6,12 @@
|
||||
// and before a delete/removal severs the old state).
|
||||
// - Snapshot.Expand: in-memory walk, no store access. Run AFTER the tx commits.
|
||||
//
|
||||
// Enabled is never consulted: toggling it is itself an observable change.
|
||||
// Enabled handling differs by source. Disabled objects in the SNAPSHOT (existing
|
||||
// account policies/resources/routers/routes/proxy services and their rules/targets)
|
||||
// route to nobody and are skipped — they cannot affect any peer's map. Objects in
|
||||
// the CHANGE itself are processed regardless of Enabled, so disabling one still
|
||||
// refreshes the peers that lose access (the toggle is the observable change, and the
|
||||
// update carries the old∪new state).
|
||||
package affectedpeers
|
||||
|
||||
import (
|
||||
@@ -61,7 +66,8 @@ func Load(ctx context.Context, s store.Store, accountID string, c Change) (*Snap
|
||||
// loadCollections reads the policy/route/nameserver/dns/router/resource/proxy
|
||||
// collections a Change can touch, gated to what the walk needs.
|
||||
func (snap *Snapshot) loadCollections(ctx context.Context, s store.Store, accountID string, c Change) error {
|
||||
hasGroupOrPeerChange := len(c.ChangedGroupIDs) > 0 || len(c.ChangedPeerIDs) > 0 || len(c.Resources) > 0
|
||||
// LinkGroups drive the same policy/route/dns walk as a changed group or peer.
|
||||
hasGroupOrPeerChange := len(c.ChangedGroupIDs) > 0 || len(c.ChangedPeerIDs) > 0 || len(c.LinkGroups) > 0 || len(c.Resources) > 0
|
||||
hasNetworkObject := len(c.Routers) > 0 || len(c.Resources) > 0 || len(c.Networks) > 0
|
||||
// the resource<->router bridge can fire for any of these
|
||||
needsRoutersResources := hasGroupOrPeerChange || len(c.PostureCheckIDs) > 0 || len(c.Policies) > 0 || hasNetworkObject
|
||||
@@ -76,7 +82,7 @@ func (snap *Snapshot) loadCollections(ctx context.Context, s store.Store, accoun
|
||||
return err
|
||||
}
|
||||
}
|
||||
if len(c.ChangedGroupIDs) > 0 || len(c.ChangedPeerIDs) > 0 {
|
||||
if len(c.ChangedGroupIDs) > 0 || len(c.ChangedPeerIDs) > 0 || len(c.LinkGroups) > 0 {
|
||||
if err := snap.loadDNS(ctx, s, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -174,6 +180,24 @@ type Change struct {
|
||||
// folded in — but only when the group is linked (an unlinked group has no map
|
||||
// impact), matching how current members are handled.
|
||||
RemovedPeersByGroup map[string][]string
|
||||
|
||||
// OutputPeerIDs are peers folded straight into the result without seeding their
|
||||
// group memberships into the walk. Use for the peer whose group membership changed:
|
||||
// the peer itself must refresh, but its OTHER groups did not change, so they must
|
||||
// not be walked. Contrast ChangedPeerIDs, which seeds ALL of the peer's groups
|
||||
// (correct when the peer's own attributes changed, e.g. IP/status).
|
||||
OutputPeerIDs []string
|
||||
|
||||
// LinkGroups are groups used ONLY to match policies/routes/routers and walk to the
|
||||
// OPPOSITE side — they are never expanded to their own members. Use this when a
|
||||
// peer's group membership changed: pass the peer in ChangedPeerIDs and its
|
||||
// group(s) here. The opposite side of the policies the group participates in
|
||||
// refreshes, but the group's other members (siblings) do not — nothing changed for
|
||||
// them. For an intra-group policy (A→A) the opposite side IS the group, so its
|
||||
// members still refresh via the opposite-side fold, exactly when they genuinely
|
||||
// gain/lose the changed peer. Unlike ChangedGroupIDs, a LinkGroup is not added to
|
||||
// the output, so a one-sided membership change never wakes the whole group.
|
||||
LinkGroups []string
|
||||
}
|
||||
|
||||
func (c Change) isEmpty() bool {
|
||||
@@ -186,7 +210,9 @@ func (c Change) isEmpty() bool {
|
||||
len(c.Networks) == 0 &&
|
||||
len(c.PostureCheckIDs) == 0 &&
|
||||
len(c.DistributionGroupIDs) == 0 &&
|
||||
len(c.RemovedPeersByGroup) == 0
|
||||
len(c.RemovedPeersByGroup) == 0 &&
|
||||
len(c.LinkGroups) == 0 &&
|
||||
len(c.OutputPeerIDs) == 0
|
||||
}
|
||||
|
||||
// Expand returns the deduplicated affected peer IDs from the preloaded Snapshot,
|
||||
@@ -197,8 +223,8 @@ func (snap *Snapshot) Expand(ctx context.Context, accountID string, c Change) []
|
||||
return nil
|
||||
}
|
||||
r := newResolver(ctx, snap, accountID, c)
|
||||
log.WithContext(ctx).Tracef("affectedpeers expand start: account=%s changedGroups=%v changedPeers=%v policies=%d routes=%d routers=%d resources=%d networks=%d postureChecks=%v distributionGroups=%v",
|
||||
accountID, c.ChangedGroupIDs, c.ChangedPeerIDs, len(c.Policies), len(c.Routes), len(c.Routers), len(c.Resources), len(c.Networks), c.PostureCheckIDs, c.DistributionGroupIDs)
|
||||
log.WithContext(ctx).Tracef("affectedpeers expand start: account=%s changedGroups=%v changedPeers=%v linkGroups=%v policies=%d routes=%d routers=%d resources=%d networks=%d postureChecks=%v distributionGroups=%v",
|
||||
accountID, c.ChangedGroupIDs, c.ChangedPeerIDs, c.LinkGroups, len(c.Policies), len(c.Routes), len(c.Routers), len(c.Resources), len(c.Networks), c.PostureCheckIDs, c.DistributionGroupIDs)
|
||||
r.walk()
|
||||
return r.expand()
|
||||
}
|
||||
@@ -216,57 +242,84 @@ func Collect(ctx context.Context, s store.Store, accountID string, c Change) (gr
|
||||
}
|
||||
r := newResolver(ctx, snap, accountID, c)
|
||||
r.walk()
|
||||
return setToSlice(r.groupSet), setToSlice(r.peerSet)
|
||||
return setToSlice(r.affectedGroups), setToSlice(r.affectedPeers)
|
||||
}
|
||||
|
||||
func newResolver(ctx context.Context, snap *Snapshot, accountID string, c Change) *resolver {
|
||||
r := &resolver{
|
||||
ctx: ctx,
|
||||
snap: snap,
|
||||
accountID: accountID,
|
||||
change: c,
|
||||
changedGroupSet: toSet(c.ChangedGroupIDs),
|
||||
changedPeerSet: toSet(c.ChangedPeerIDs),
|
||||
groupSet: make(map[string]struct{}),
|
||||
peerSet: make(map[string]struct{}),
|
||||
networkIDs: make(map[string]struct{}),
|
||||
ctx: ctx,
|
||||
snap: snap,
|
||||
accountID: accountID,
|
||||
change: c,
|
||||
linkGroups: toSet(c.ChangedGroupIDs),
|
||||
outputGroups: toSet(c.ChangedGroupIDs),
|
||||
changedPeers: toSet(c.ChangedPeerIDs),
|
||||
affectedGroups: make(map[string]struct{}),
|
||||
affectedPeers: make(map[string]struct{}),
|
||||
}
|
||||
// LinkGroups match policies/routes to find the opposite side but are NOT output:
|
||||
// they go into linkGroups only, never outputGroups, so their members never fold in.
|
||||
addAll(r.linkGroups, c.LinkGroups)
|
||||
// Resolve each changed peer to its groups here so callers pass only ChangedPeerIDs.
|
||||
r.seedChangedGroupsFromPeers()
|
||||
r.matchedPolicies = append(r.matchedPolicies, c.Policies...)
|
||||
return r
|
||||
}
|
||||
|
||||
// seedChangedGroupsFromPeers adds each changed peer's groups to changedGroupSet so
|
||||
// seedChangedGroupsFromPeers adds each changed peer's groups to linkGroups so
|
||||
// the group-driven walkers fire for memberships, not just direct peer references.
|
||||
// These seeded groups are for MATCHING only — folding the changed entity's own
|
||||
// side is gated on outputGroups (the caller-reported groups), so a seeded group
|
||||
// never folds its whole membership; only the changed peer itself folds in.
|
||||
func (r *resolver) seedChangedGroupsFromPeers() {
|
||||
if len(r.changedPeerSet) == 0 {
|
||||
if len(r.changedPeers) == 0 {
|
||||
return
|
||||
}
|
||||
for groupID, members := range r.snap.groupPeers {
|
||||
for pID := range r.changedPeerSet {
|
||||
for pID := range r.changedPeers {
|
||||
if _, ok := members[pID]; ok {
|
||||
r.changedGroupSet[groupID] = struct{}{}
|
||||
r.linkGroups[groupID] = struct{}{}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// policySide selects which side of a policy rule to walk.
|
||||
type policySide int
|
||||
|
||||
const (
|
||||
sideSource policySide = iota
|
||||
sideDestination
|
||||
)
|
||||
|
||||
func (s policySide) opposite() policySide {
|
||||
if s == sideSource {
|
||||
return sideDestination
|
||||
}
|
||||
return sideSource
|
||||
}
|
||||
|
||||
// walk resolves affected peers in two buckets, by how far each change propagates.
|
||||
//
|
||||
// BOTH-SIDES — the rule itself changed (an explicit policy edit, or a policy whose
|
||||
// posture check changed). Source AND destination refresh, so each such policy is
|
||||
// walked on both sides.
|
||||
//
|
||||
// OPPOSITE-SIDE — an endpoint moved but no rule changed. For each policy the change
|
||||
// touches we fold only the side AWAY from the change:
|
||||
// - a changed peer/group sits ON a policy side -> fold the opposite side;
|
||||
// - a changed router/resource/network sits on a NETWORK -> fold the SOURCE side of
|
||||
// the policies whose destination reaches it (and the routers it implies).
|
||||
//
|
||||
// Routes, nameserver groups, DNS and embedded-proxy services distribute to their own
|
||||
// member peers, outside the policy graph, and are folded here too.
|
||||
func (r *resolver) walk() {
|
||||
r.collectFromExplicitPolicies()
|
||||
r.collectFromExplicitRoutes(r.change.Routes)
|
||||
r.collectFromExplicitRouters(r.change.Routers)
|
||||
r.collectFromExplicitResources(r.change.Resources)
|
||||
r.collectFromExplicitNetworks(r.change.Networks)
|
||||
r.collectFromPostureChecks(r.change.PostureCheckIDs)
|
||||
for _, policy := range r.bothSidesPolicies() {
|
||||
r.foldPolicySide(policy, sideSource)
|
||||
r.foldPolicySide(policy, sideDestination)
|
||||
}
|
||||
|
||||
// Distribution groups (nameserver/DNS) affect only their member peers: fold them
|
||||
// straight into groupSet so expand() maps them to members, without the policy/
|
||||
// route walk that changedGroupSet would trigger.
|
||||
addAll(r.groupSet, r.change.DistributionGroupIDs)
|
||||
|
||||
if len(r.changedGroupSet) > 0 || len(r.changedPeerSet) > 0 {
|
||||
if len(r.linkGroups) > 0 || len(r.changedPeers) > 0 {
|
||||
r.collectFromPolicies()
|
||||
r.collectFromRoutes()
|
||||
r.collectFromNameServers()
|
||||
@@ -275,7 +328,31 @@ func (r *resolver) walk() {
|
||||
r.collectFromProxyServices()
|
||||
}
|
||||
|
||||
r.collectResourceRouterBridge()
|
||||
r.collectFromChangedRoutes(r.change.Routes)
|
||||
r.collectFromChangedRouters(r.change.Routers)
|
||||
r.collectFromChangedResources(r.change.Resources)
|
||||
r.collectFromChangedNetworks(r.change.Networks)
|
||||
|
||||
// The explicitly changed peers always refresh their own maps. OnPeersUpdated only
|
||||
// refreshes the resolver's output (it ignores the separately-passed changed peers),
|
||||
// so the changed peer reaches its own new map only via here. An offline/deleted
|
||||
// peer in the set is filtered downstream (filterConnectedAffectedPeers).
|
||||
addAll(r.affectedPeers, setToSlice(r.changedPeers))
|
||||
// OutputPeerIDs refresh themselves too, but unlike changedPeers their group
|
||||
// memberships were not seeded into the walk (only the changed group was).
|
||||
addAll(r.affectedPeers, r.change.OutputPeerIDs)
|
||||
|
||||
// Distribution groups (nameserver/DNS) affect only their member peers: fold them
|
||||
// straight into affectedGroups so expand() maps them to members, without the
|
||||
// policy/route walk that linkGroups would trigger.
|
||||
addAll(r.affectedGroups, r.change.DistributionGroupIDs)
|
||||
}
|
||||
|
||||
// bothSidesPolicies are the policies whose rule changed: the explicitly edited ones
|
||||
// plus those gated by a changed posture check. walk folds both their sides.
|
||||
func (r *resolver) bothSidesPolicies() []*types.Policy {
|
||||
policies := append([]*types.Policy(nil), r.change.Policies...)
|
||||
return r.appendPoliciesForPostureChecks(policies, r.change.PostureCheckIDs)
|
||||
}
|
||||
|
||||
type resolver struct {
|
||||
@@ -284,27 +361,71 @@ type resolver struct {
|
||||
accountID string
|
||||
change Change
|
||||
|
||||
changedGroupSet map[string]struct{}
|
||||
changedPeerSet map[string]struct{}
|
||||
// Inputs — what changed. Set once at construction, read-only during the walk
|
||||
// (except linkGroups, which collectFromExplicitResources also seeds).
|
||||
//
|
||||
// linkGroups is the MATCH set: caller-changed groups ∪ the groups of changed
|
||||
// peers ∪ changed-resource groups. A rule/route/router matches the change when
|
||||
// one of its groups is here — used only to find the opposite side to fold.
|
||||
//
|
||||
// outputGroups is the FOLD-WHOLE-GROUP set: ONLY Change.ChangedGroupIDs. When a
|
||||
// matched group is here, its whole membership is affected. A peer-seeded group
|
||||
// is in linkGroups but NOT outputGroups, so it folds only the changed peer
|
||||
// (changedPeers), never its siblings.
|
||||
linkGroups map[string]struct{}
|
||||
outputGroups map[string]struct{}
|
||||
changedPeers map[string]struct{}
|
||||
|
||||
groupSet map[string]struct{}
|
||||
peerSet map[string]struct{}
|
||||
|
||||
matchedPolicies []*types.Policy
|
||||
networkIDs map[string]struct{}
|
||||
// Outputs — the answer. The only sets the walk accumulates into. affectedGroups
|
||||
// is expanded to its member peers in expand().
|
||||
affectedGroups map[string]struct{}
|
||||
affectedPeers map[string]struct{}
|
||||
}
|
||||
|
||||
func (r *resolver) policies() []*types.Policy { return r.snap.policies }
|
||||
// policies returns the account's ENABLED policies from the snapshot. Disabled
|
||||
// policies grant no access, so the walk skips them when scanning existing account
|
||||
// data. Explicitly changed policies (Change.Policies, via bothSidesPolicies) are
|
||||
// processed regardless of Enabled, so disabling one still refreshes its peers.
|
||||
func (r *resolver) policies() []*types.Policy {
|
||||
enabled := make([]*types.Policy, 0, len(r.snap.policies))
|
||||
for _, policy := range r.snap.policies {
|
||||
if policy != nil && policy.Enabled {
|
||||
enabled = append(enabled, policy)
|
||||
}
|
||||
}
|
||||
return enabled
|
||||
}
|
||||
|
||||
func (r *resolver) networkResources() []*resourceTypes.NetworkResource { return r.snap.resources }
|
||||
// networkResources / networkRouters return the account's ENABLED resources/routers
|
||||
// from the snapshot. Disabled objects route to nobody, so the walk skips them when
|
||||
// it scans existing account data. The explicitly changed objects in the Change are
|
||||
// processed regardless of Enabled (collectFromChanged*), so disabling one still
|
||||
// refreshes the peers that lose access.
|
||||
func (r *resolver) networkResources() []*resourceTypes.NetworkResource {
|
||||
enabled := make([]*resourceTypes.NetworkResource, 0, len(r.snap.resources))
|
||||
for _, resource := range r.snap.resources {
|
||||
if resource.Enabled {
|
||||
enabled = append(enabled, resource)
|
||||
}
|
||||
}
|
||||
return enabled
|
||||
}
|
||||
|
||||
func (r *resolver) networkRouters() []*routerTypes.NetworkRouter { return r.snap.routers }
|
||||
func (r *resolver) networkRouters() []*routerTypes.NetworkRouter {
|
||||
enabled := make([]*routerTypes.NetworkRouter, 0, len(r.snap.routers))
|
||||
for _, router := range r.snap.routers {
|
||||
if router.Enabled {
|
||||
enabled = append(enabled, router)
|
||||
}
|
||||
}
|
||||
return enabled
|
||||
}
|
||||
|
||||
// peerIDsForGroups maps a group set to its member peer IDs via the preloaded index.
|
||||
func (r *resolver) peerIDsForGroups(groupSet map[string]struct{}) []string {
|
||||
func (r *resolver) peerIDsForGroups(groups map[string]struct{}) []string {
|
||||
seen := make(map[string]struct{})
|
||||
var ids []string
|
||||
for gID := range groupSet {
|
||||
for gID := range groups {
|
||||
for pID := range r.snap.groupPeers[gID] {
|
||||
if _, ok := seen[pID]; ok {
|
||||
continue
|
||||
@@ -317,25 +438,25 @@ func (r *resolver) peerIDsForGroups(groupSet map[string]struct{}) []string {
|
||||
}
|
||||
|
||||
func (r *resolver) expand() []string {
|
||||
peerIDs := r.peerIDsForGroups(r.groupSet)
|
||||
peerIDs := r.peerIDsForGroups(r.affectedGroups)
|
||||
|
||||
log.WithContext(r.ctx).Tracef("affectedpeers expand: account=%s affectedGroups=%v -> %d group-member peers; direct peers=%v",
|
||||
r.accountID, setToSlice(r.groupSet), len(peerIDs), setToSlice(r.peerSet))
|
||||
r.accountID, setToSlice(r.affectedGroups), len(peerIDs), setToSlice(r.affectedPeers))
|
||||
|
||||
seen := make(map[string]struct{}, len(peerIDs))
|
||||
for _, id := range peerIDs {
|
||||
seen[id] = struct{}{}
|
||||
}
|
||||
for id := range r.peerSet {
|
||||
for id := range r.affectedPeers {
|
||||
if _, ok := seen[id]; !ok {
|
||||
peerIDs = append(peerIDs, id)
|
||||
seen[id] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// Fold in removed peers only when their group is linked (in groupSet).
|
||||
// Fold in removed peers only when their group is linked (in affectedGroups).
|
||||
for groupID, removed := range r.change.RemovedPeersByGroup {
|
||||
if _, linked := r.groupSet[groupID]; !linked {
|
||||
if _, linked := r.affectedGroups[groupID]; !linked {
|
||||
continue
|
||||
}
|
||||
for _, id := range removed {
|
||||
@@ -351,169 +472,309 @@ func (r *resolver) expand() []string {
|
||||
return peerIDs
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromExplicitPolicies() {
|
||||
for _, policy := range r.matchedPolicies {
|
||||
if policy == nil {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromExplicitPolicies: changed policy %s (%s) -> folding rule groups %v + direct peers",
|
||||
policy.ID, policy.Name, policy.RuleGroups())
|
||||
addAll(r.groupSet, policy.RuleGroups())
|
||||
collectPolicyDirectPeers(policy, r.peerSet)
|
||||
// ruleSideGroups / ruleSideResource return the groups and the resource on the given
|
||||
// side of a rule.
|
||||
func ruleSideGroups(rule *types.PolicyRule, side policySide) []string {
|
||||
if side == sideDestination {
|
||||
return rule.Destinations
|
||||
}
|
||||
return rule.Sources
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromExplicitRoutes(routes []*route.Route) {
|
||||
for _, rt := range routes {
|
||||
if rt == nil {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromExplicitRoutes: changed route %s -> folding groups=%v peerGroups=%v accessControlGroups=%v peer=%q",
|
||||
rt.ID, rt.Groups, rt.PeerGroups, rt.AccessControlGroups, rt.Peer)
|
||||
addAll(r.groupSet, rt.Groups, rt.PeerGroups, rt.AccessControlGroups)
|
||||
if rt.Peer != "" {
|
||||
r.peerSet[rt.Peer] = struct{}{}
|
||||
}
|
||||
func ruleSideResource(rule *types.PolicyRule, side policySide) types.Resource {
|
||||
if side == sideDestination {
|
||||
return rule.DestinationResource
|
||||
}
|
||||
return rule.SourceResource
|
||||
}
|
||||
|
||||
// collectFromExplicitRouters folds changed routers' peers and marks their networks
|
||||
// for the bridge. Passing the old router keeps a repointed router's previous peers
|
||||
// affected without a post-commit read.
|
||||
func (r *resolver) collectFromExplicitRouters(routers []*routerTypes.NetworkRouter) {
|
||||
for _, router := range routers {
|
||||
if router == nil {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromExplicitRouters: changed router %s on network %s -> folding peerGroups=%v peer=%q and marking network for source bridge",
|
||||
router.ID, router.NetworkID, router.PeerGroups, router.Peer)
|
||||
addAll(r.groupSet, router.PeerGroups)
|
||||
if router.Peer != "" {
|
||||
r.peerSet[router.Peer] = struct{}{}
|
||||
}
|
||||
if router.NetworkID != "" {
|
||||
r.networkIDs[router.NetworkID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// collectFromExplicitResources marks changed resources' networks for the bridge and
|
||||
// treats their group IDs as changed, so policies targeting the resource via a
|
||||
// now-detached (old) group still refresh.
|
||||
func (r *resolver) collectFromExplicitResources(resources []*resourceTypes.NetworkResource) {
|
||||
for _, resource := range resources {
|
||||
if resource == nil {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromExplicitResources: changed resource %s on network %s -> marking network for bridge and treating groups %v as changed",
|
||||
resource.ID, resource.NetworkID, resource.GroupIDs)
|
||||
addAll(r.changedGroupSet, resource.GroupIDs)
|
||||
if resource.NetworkID != "" {
|
||||
r.networkIDs[resource.NetworkID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// collectFromExplicitNetworks marks changed networks for the bridge. A network has
|
||||
// no groups/peers of its own.
|
||||
func (r *resolver) collectFromExplicitNetworks(networks []*networkTypes.Network) {
|
||||
for _, network := range networks {
|
||||
if network == nil {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromExplicitNetworks: changed network %s -> marking for bridge", network.ID)
|
||||
if network.ID != "" {
|
||||
r.networkIDs[network.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromPostureChecks(postureCheckIDs []string) {
|
||||
if len(postureCheckIDs) == 0 {
|
||||
// foldPolicySide folds one side of a policy down to affected peers: its groups
|
||||
// (resolved to members in expand) and its direct peer. When the side is the
|
||||
// DESTINATION and references a network resource (directly or via a destination
|
||||
// group's resources), it also folds the routers that serve that resource's network
|
||||
// — a destination resource is reached through its routers. A resource on the SOURCE
|
||||
// side routes to nobody (GetPoliciesForNetworkResource matches destinations only),
|
||||
// so the router hop is destination-only.
|
||||
func (r *resolver) foldPolicySide(policy *types.Policy, side policySide) {
|
||||
if policy == nil {
|
||||
return
|
||||
}
|
||||
for _, rule := range policy.Rules {
|
||||
addAll(r.affectedGroups, ruleSideGroups(rule, side))
|
||||
res := ruleSideResource(rule, side)
|
||||
if res.Type == types.ResourceTypePeer && res.ID != "" {
|
||||
r.affectedPeers[res.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
if side == sideDestination {
|
||||
r.foldRoutersForResources(r.policyDestinationResourceIDs(policy))
|
||||
}
|
||||
}
|
||||
|
||||
// appendPoliciesForPostureChecks appends every policy that references a changed
|
||||
// posture check (a rule change, so walk both sides).
|
||||
func (r *resolver) appendPoliciesForPostureChecks(policies []*types.Policy, postureCheckIDs []string) []*types.Policy {
|
||||
if len(postureCheckIDs) == 0 {
|
||||
return policies
|
||||
}
|
||||
ids := toSet(postureCheckIDs)
|
||||
for _, policy := range r.policies() {
|
||||
if !policyReferencesPostureChecks(policy, ids) {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromPostureChecks: policy %s (%s) references changed posture checks %v -> folding rule groups %v + direct peers",
|
||||
policy.ID, policy.Name, postureCheckIDs, policy.RuleGroups())
|
||||
addAll(r.groupSet, policy.RuleGroups())
|
||||
collectPolicyDirectPeers(policy, r.peerSet)
|
||||
r.matchedPolicies = append(r.matchedPolicies, policy)
|
||||
log.WithContext(r.ctx).Tracef("appendPoliciesForPostureChecks: policy %s (%s) references changed posture checks %v -> both-sides policy",
|
||||
policy.ID, policy.Name, postureCheckIDs)
|
||||
policies = append(policies, policy)
|
||||
}
|
||||
return policies
|
||||
}
|
||||
|
||||
// collectFromPolicies folds, for every policy whose rule a changed group or peer
|
||||
// touches, only the OPPOSITE side (down to peers, incl. destination routers), plus
|
||||
// the changed entity's own side: the changed group's whole membership when the
|
||||
// group itself changed (outputGroups), or the changed peer alone when matched via a
|
||||
// peer-seeded group (never its co-members).
|
||||
func (r *resolver) collectFromPolicies() {
|
||||
for _, policy := range r.policies() {
|
||||
for _, rule := range policy.Rules {
|
||||
if !rule.Enabled {
|
||||
continue // a disabled rule grants no access
|
||||
}
|
||||
r.foldRuleSideIfChanged(policy, rule, sideSource)
|
||||
r.foldRuleSideIfChanged(policy, rule, sideDestination)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromPolicies() {
|
||||
for _, policy := range r.policies() {
|
||||
matchedByGroup := policyReferencesGroups(policy, r.changedGroupSet)
|
||||
matchedByPeer := len(r.changedPeerSet) > 0 && policyReferencesDirectPeers(policy, r.changedPeerSet)
|
||||
if !matchedByGroup && !matchedByPeer {
|
||||
// foldRuleSideIfChanged: when a changed group or direct peer sits on `side` of the
|
||||
// rule, fold the opposite side fully (groups/peers + destination routers) and fold
|
||||
// the changed entity's own side (the whole changed group, or the changed peer alone).
|
||||
func (r *resolver) foldRuleSideIfChanged(policy *types.Policy, rule *types.PolicyRule, side policySide) {
|
||||
nearGroups := ruleSideGroups(rule, side)
|
||||
nearResource := ruleSideResource(rule, side)
|
||||
|
||||
matchedByGroup := anyInSet(nearGroups, r.linkGroups)
|
||||
matchedByPeer := isDirectPeerInSet(nearResource, r.changedPeers)
|
||||
if !matchedByGroup && !matchedByPeer {
|
||||
return
|
||||
}
|
||||
|
||||
// Opposite side, fully down to peers (a destination opposite also folds routers).
|
||||
r.foldPolicySideForRule(policy, rule, side.opposite())
|
||||
|
||||
// Own side: fold the whole changed group's members only when the group itself
|
||||
// changed (outputGroups). A peer-seeded or link-only group is not folded here —
|
||||
// its siblings never refresh. The changed peers themselves are folded once, after
|
||||
// the walk (see walk()).
|
||||
for _, gID := range nearGroups {
|
||||
if _, ok := r.outputGroups[gID]; ok {
|
||||
r.affectedGroups[gID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// When the changed side IS a destination, the resources it targets are reached
|
||||
// through their network's routers, so those routers refresh too (e.g. attaching a
|
||||
// resource to a destination group, or a changed destination group/resource).
|
||||
if side == sideDestination {
|
||||
r.foldRoutersForResources(r.ruleDestinationResourceIDs(rule))
|
||||
}
|
||||
}
|
||||
|
||||
// foldPolicySideForRule folds one side of a single rule (groups + direct peer), and
|
||||
// for a destination side the routers of that rule's destination resources.
|
||||
func (r *resolver) foldPolicySideForRule(policy *types.Policy, rule *types.PolicyRule, side policySide) {
|
||||
addAll(r.affectedGroups, ruleSideGroups(rule, side))
|
||||
res := ruleSideResource(rule, side)
|
||||
if res.Type == types.ResourceTypePeer && res.ID != "" {
|
||||
r.affectedPeers[res.ID] = struct{}{}
|
||||
}
|
||||
if side == sideDestination {
|
||||
r.foldRoutersForResources(r.ruleDestinationResourceIDs(rule))
|
||||
}
|
||||
}
|
||||
|
||||
// collectFromChangedRoutes folds an explicitly changed route's own groups and peer.
|
||||
func (r *resolver) collectFromChangedRoutes(routes []*route.Route) {
|
||||
for _, rt := range routes {
|
||||
if rt == nil {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromPolicies: policy %s (%s) matched (byGroup=%t byPeer=%t) -> folding rule groups %v + direct peers",
|
||||
policy.ID, policy.Name, matchedByGroup, matchedByPeer, policy.RuleGroups())
|
||||
addAll(r.groupSet, policy.RuleGroups())
|
||||
collectPolicyDirectPeers(policy, r.peerSet)
|
||||
r.matchedPolicies = append(r.matchedPolicies, policy)
|
||||
log.WithContext(r.ctx).Tracef("collectFromChangedRoutes: changed route %s -> folding groups=%v peerGroups=%v accessControlGroups=%v peer=%q",
|
||||
rt.ID, rt.Groups, rt.PeerGroups, rt.AccessControlGroups, rt.Peer)
|
||||
addAll(r.affectedGroups, rt.Groups, rt.PeerGroups, rt.AccessControlGroups)
|
||||
if rt.Peer != "" {
|
||||
r.affectedPeers[rt.Peer] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// collectFromChangedRouters: a changed router refreshes its OWN backing peer/groups
|
||||
// (the changed entity) and the SOURCE side of every policy reaching a resource on
|
||||
// its network (the router serves the whole network). Sibling routers on the network
|
||||
// are independent and are NOT folded. Passing the old router state keeps a repointed
|
||||
// router's previous backing affected without a post-commit read.
|
||||
func (r *resolver) collectFromChangedRouters(routers []*routerTypes.NetworkRouter) {
|
||||
for _, router := range routers {
|
||||
if router == nil {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromChangedRouters: changed router %s on network %s -> folding its own peerGroups=%v peer=%q + sources reaching network resources",
|
||||
router.ID, router.NetworkID, router.PeerGroups, router.Peer)
|
||||
addAll(r.affectedGroups, router.PeerGroups)
|
||||
if router.Peer != "" {
|
||||
r.affectedPeers[router.Peer] = struct{}{}
|
||||
}
|
||||
if router.NetworkID != "" {
|
||||
r.foldPolicySourcesForResources(r.networkResourceIDs(router.NetworkID))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// collectFromChangedResources: a changed resource refreshes the SOURCE side of the
|
||||
// policies targeting EXACTLY that resource — directly, or via one of the resource's
|
||||
// own groups (old∪new across the change, so a now-detached group's sources still
|
||||
// refresh) — plus the routers serving its network (the resource is reached through
|
||||
// them). It does not touch sibling resources on the same network.
|
||||
func (r *resolver) collectFromChangedResources(resources []*resourceTypes.NetworkResource) {
|
||||
for _, resource := range resources {
|
||||
if resource == nil {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromChangedResources: changed resource %s on network %s (groups %v) -> folding sources of policies targeting it + its network's routers",
|
||||
resource.ID, resource.NetworkID, resource.GroupIDs)
|
||||
r.foldPolicySourcesForResource(resource.ID, resource.GroupIDs)
|
||||
if resource.NetworkID != "" {
|
||||
r.foldRoutersOnNetworks(map[string]struct{}{resource.NetworkID: {}})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// foldPolicySourcesForResource folds the source side of every policy whose
|
||||
// destination is the given resource — referenced directly, or via any of the given
|
||||
// groups (the resource's own old∪new groups, which captures a detached group).
|
||||
func (r *resolver) foldPolicySourcesForResource(resourceID string, groupIDs []string) {
|
||||
groups := toSet(groupIDs)
|
||||
for _, policy := range r.policies() {
|
||||
if !policyTargetsResourceOrGroups(policy, resourceID, groups) {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("foldPolicySourcesForResource: policy %s (%s) targets changed resource %s -> folding its source groups/peers", policy.ID, policy.Name, resourceID)
|
||||
collectPolicySources(policy, r.affectedGroups, r.affectedPeers)
|
||||
}
|
||||
}
|
||||
|
||||
// policyTargetsResourceOrGroups reports whether a policy's destination is the given
|
||||
// resource directly, or one of the given destination groups.
|
||||
func policyTargetsResourceOrGroups(policy *types.Policy, resourceID string, groups map[string]struct{}) bool {
|
||||
if policy == nil {
|
||||
return false
|
||||
}
|
||||
for _, rule := range policy.Rules {
|
||||
if !rule.Enabled {
|
||||
continue
|
||||
}
|
||||
if rule.DestinationResource.Type != types.ResourceTypePeer && rule.DestinationResource.ID == resourceID && resourceID != "" {
|
||||
return true
|
||||
}
|
||||
if anyInSet(rule.Destinations, groups) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// collectFromChangedNetworks: a changed network refreshes the SOURCE side of the
|
||||
// policies reaching any of its resources, plus its routers. A network has no
|
||||
// groups/peers of its own.
|
||||
func (r *resolver) collectFromChangedNetworks(networks []*networkTypes.Network) {
|
||||
for _, network := range networks {
|
||||
if network == nil || network.ID == "" {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromChangedNetworks: changed network %s -> folding sources reaching its resources + its routers", network.ID)
|
||||
resourceIDs := r.networkResourceIDs(network.ID)
|
||||
r.foldPolicySourcesForResources(resourceIDs)
|
||||
r.foldRoutersOnNetworks(map[string]struct{}{network.ID: {}})
|
||||
}
|
||||
}
|
||||
|
||||
// foldPolicySourcesForResources folds the source groups/peers of every policy whose
|
||||
// destination targets one of resourceIDs (directly or via a destination group).
|
||||
func (r *resolver) foldPolicySourcesForResources(resourceIDs map[string]struct{}) {
|
||||
if len(resourceIDs) == 0 {
|
||||
return
|
||||
}
|
||||
for _, policy := range r.policies() {
|
||||
if r.policyTargetsResources(policy, resourceIDs) {
|
||||
log.WithContext(r.ctx).Tracef("foldPolicySourcesForResources: policy %s (%s) targets a changed resource -> folding its source groups/peers", policy.ID, policy.Name)
|
||||
collectPolicySources(policy, r.affectedGroups, r.affectedPeers)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromRoutes() {
|
||||
for _, rt := range r.snap.routes {
|
||||
matchedByGroup := anyInSet(rt.Groups, r.changedGroupSet) || anyInSet(rt.PeerGroups, r.changedGroupSet) || anyInSet(rt.AccessControlGroups, r.changedGroupSet)
|
||||
matchedByPeer := rt.Peer != "" && len(r.changedPeerSet) > 0 && isInSet(rt.Peer, r.changedPeerSet)
|
||||
if !rt.Enabled {
|
||||
continue // disabled routes route to nobody; skip existing account data
|
||||
}
|
||||
matchedByGroup := anyInSet(rt.Groups, r.linkGroups) || anyInSet(rt.PeerGroups, r.linkGroups) || anyInSet(rt.AccessControlGroups, r.linkGroups)
|
||||
matchedByPeer := rt.Peer != "" && len(r.changedPeers) > 0 && isInSet(rt.Peer, r.changedPeers)
|
||||
if !matchedByGroup && !matchedByPeer {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromRoutes: route %s matched (byGroup=%t byPeer=%t) -> folding groups=%v peerGroups=%v accessControlGroups=%v peer=%q",
|
||||
rt.ID, matchedByGroup, matchedByPeer, rt.Groups, rt.PeerGroups, rt.AccessControlGroups, rt.Peer)
|
||||
addAll(r.groupSet, rt.Groups, rt.PeerGroups, rt.AccessControlGroups)
|
||||
addAll(r.affectedGroups, rt.Groups, rt.PeerGroups, rt.AccessControlGroups)
|
||||
if rt.Peer != "" {
|
||||
r.peerSet[rt.Peer] = struct{}{}
|
||||
r.affectedPeers[rt.Peer] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromNameServers() {
|
||||
if len(r.changedGroupSet) == 0 {
|
||||
if len(r.linkGroups) == 0 {
|
||||
return
|
||||
}
|
||||
for _, ns := range r.snap.nsGroups {
|
||||
if anyInSet(ns.Groups, r.changedGroupSet) {
|
||||
if anyInSet(ns.Groups, r.linkGroups) {
|
||||
log.WithContext(r.ctx).Tracef("collectFromNameServers: nameserver group %s references a changed group -> folding its groups %v", ns.ID, ns.Groups)
|
||||
addAll(r.groupSet, ns.Groups)
|
||||
addAll(r.affectedGroups, ns.Groups)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromDNSSettings() {
|
||||
if len(r.changedGroupSet) == 0 || r.snap.dnsSettings == nil {
|
||||
if len(r.linkGroups) == 0 || r.snap.dnsSettings == nil {
|
||||
return
|
||||
}
|
||||
for _, gID := range r.snap.dnsSettings.DisabledManagementGroups {
|
||||
if _, ok := r.changedGroupSet[gID]; ok {
|
||||
if _, ok := r.linkGroups[gID]; ok {
|
||||
log.WithContext(r.ctx).Tracef("collectFromDNSSettings: changed group %s is in DisabledManagementGroups -> folding it", gID)
|
||||
r.groupSet[gID] = struct{}{}
|
||||
r.affectedGroups[gID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// collectFromNetworkRouters handles a changed group/peer that BACKS a router (the
|
||||
// routing peer set moved): the router's own peers refresh and so do the sources of
|
||||
// the policies reaching its network's resources. Sibling routers on the network are
|
||||
// independent and are not folded.
|
||||
func (r *resolver) collectFromNetworkRouters() {
|
||||
for _, router := range r.networkRouters() {
|
||||
matchedByGroup := anyInSet(router.PeerGroups, r.changedGroupSet)
|
||||
matchedByPeer := router.Peer != "" && len(r.changedPeerSet) > 0 && isInSet(router.Peer, r.changedPeerSet)
|
||||
matchedByGroup := anyInSet(router.PeerGroups, r.linkGroups)
|
||||
matchedByPeer := router.Peer != "" && len(r.changedPeers) > 0 && isInSet(router.Peer, r.changedPeers)
|
||||
if !matchedByGroup && !matchedByPeer {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromNetworkRouters: router %s on network %s matched (byGroup=%t byPeer=%t) -> folding peerGroups=%v peer=%q and marking network for source bridge",
|
||||
log.WithContext(r.ctx).Tracef("collectFromNetworkRouters: router %s on network %s matched (byGroup=%t byPeer=%t) -> folding its peerGroups=%v peer=%q + sources reaching network resources",
|
||||
router.ID, router.NetworkID, matchedByGroup, matchedByPeer, router.PeerGroups, router.Peer)
|
||||
addAll(r.groupSet, router.PeerGroups)
|
||||
addAll(r.affectedGroups, router.PeerGroups)
|
||||
if router.Peer != "" {
|
||||
r.peerSet[router.Peer] = struct{}{}
|
||||
r.affectedPeers[router.Peer] = struct{}{}
|
||||
}
|
||||
if router.NetworkID != "" {
|
||||
r.foldPolicySourcesForResources(r.networkResourceIDs(router.NetworkID))
|
||||
}
|
||||
r.networkIDs[router.NetworkID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -526,42 +787,45 @@ func (r *resolver) collectFromProxyServices() {
|
||||
expanded := r.expandChangedPeersWithGroups()
|
||||
|
||||
for _, svc := range services {
|
||||
if svc == nil {
|
||||
continue
|
||||
if svc == nil || !svc.Enabled {
|
||||
continue // a disabled service proxies nothing; skip existing account data
|
||||
}
|
||||
proxyPeers := proxyByCluster[svc.ProxyCluster]
|
||||
if len(proxyPeers) == 0 {
|
||||
continue
|
||||
}
|
||||
matchedByPeer := serviceMatchesChangedPeers(svc, proxyPeers, expanded)
|
||||
matchedByAccessGroup := anyInSet(svc.AccessGroups, r.changedGroupSet)
|
||||
matchedByAccessGroup := anyInSet(svc.AccessGroups, r.linkGroups)
|
||||
if !matchedByPeer && !matchedByAccessGroup {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromProxyServices: service %s (cluster=%s) matched (byProxyOrTargetPeer=%t byAccessGroup=%t) -> folding %d proxy peers, peer targets and access groups %v",
|
||||
svc.ID, svc.ProxyCluster, matchedByPeer, matchedByAccessGroup, len(proxyPeers), svc.AccessGroups)
|
||||
for _, pid := range proxyPeers {
|
||||
r.peerSet[pid] = struct{}{}
|
||||
r.affectedPeers[pid] = struct{}{}
|
||||
}
|
||||
for _, target := range svc.Targets {
|
||||
if !target.Enabled {
|
||||
continue // a disabled target forwards nothing
|
||||
}
|
||||
if target.TargetType == rpservice.TargetTypePeer && target.TargetId != "" {
|
||||
r.peerSet[target.TargetId] = struct{}{}
|
||||
r.affectedPeers[target.TargetId] = struct{}{}
|
||||
}
|
||||
}
|
||||
addAll(r.groupSet, svc.AccessGroups)
|
||||
addAll(r.affectedGroups, svc.AccessGroups)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) expandChangedPeersWithGroups() map[string]struct{} {
|
||||
if len(r.changedGroupSet) == 0 {
|
||||
return r.changedPeerSet
|
||||
if len(r.linkGroups) == 0 {
|
||||
return r.changedPeers
|
||||
}
|
||||
ids := r.peerIDsForGroups(r.changedGroupSet)
|
||||
ids := r.peerIDsForGroups(r.linkGroups)
|
||||
if len(ids) == 0 {
|
||||
return r.changedPeerSet
|
||||
return r.changedPeers
|
||||
}
|
||||
merged := make(map[string]struct{}, len(r.changedPeerSet)+len(ids))
|
||||
for id := range r.changedPeerSet {
|
||||
merged := make(map[string]struct{}, len(r.changedPeers)+len(ids))
|
||||
for id := range r.changedPeers {
|
||||
merged[id] = struct{}{}
|
||||
}
|
||||
for _, id := range ids {
|
||||
@@ -570,54 +834,36 @@ func (r *resolver) expandChangedPeersWithGroups() map[string]struct{} {
|
||||
return merged
|
||||
}
|
||||
|
||||
// collectResourceRouterBridge crosses between source peers and routing peers, which
|
||||
// are reachable only via resource -> network -> router, not through the policy's own
|
||||
// groups: source -> router (targeted resources' networks), then router -> source.
|
||||
func (r *resolver) collectResourceRouterBridge() {
|
||||
r.bridgeSourceToRouters()
|
||||
r.bridgeRoutersToSources()
|
||||
}
|
||||
|
||||
func (r *resolver) bridgeSourceToRouters() {
|
||||
resourceIDs := r.policyDestinationResourceIDs(r.matchedPolicies...)
|
||||
// foldRoutersForResources folds the routers serving the networks of the given
|
||||
// resources (a destination resource is reached through its network's routers). It is
|
||||
// the resource -> network -> router hop used by foldPolicySide for a destination.
|
||||
func (r *resolver) foldRoutersForResources(resourceIDs map[string]struct{}) {
|
||||
if len(resourceIDs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
networkIDs := r.resourceNetworkIDs(resourceIDs)
|
||||
log.WithContext(r.ctx).Tracef("bridgeSourceToRouters: targeted resources %v -> networks %v (their routers become affected via the router->source pass)",
|
||||
setToSlice(resourceIDs), setToSlice(networkIDs))
|
||||
for id := range networkIDs {
|
||||
r.networkIDs[id] = struct{}{}
|
||||
}
|
||||
r.foldRoutersOnNetworks(r.resourceNetworkIDs(resourceIDs))
|
||||
}
|
||||
|
||||
func (r *resolver) bridgeRoutersToSources() {
|
||||
if len(r.networkIDs) == 0 {
|
||||
return
|
||||
// ruleDestinationResourceIDs returns the destination resource IDs of a single rule:
|
||||
// the direct DestinationResource plus the resources of its destination groups.
|
||||
func (r *resolver) ruleDestinationResourceIDs(rule *types.PolicyRule) map[string]struct{} {
|
||||
resourceIDs := make(map[string]struct{})
|
||||
if rule.DestinationResource.Type != types.ResourceTypePeer && rule.DestinationResource.ID != "" {
|
||||
resourceIDs[rule.DestinationResource.ID] = struct{}{}
|
||||
}
|
||||
r.addGroupResourceIDs(toSet(rule.Destinations), resourceIDs)
|
||||
return resourceIDs
|
||||
}
|
||||
|
||||
log.WithContext(r.ctx).Tracef("bridgeRoutersToSources: affected networks %v -> folding their routing peers and the source peers of policies targeting their resources",
|
||||
setToSlice(r.networkIDs))
|
||||
|
||||
r.foldRoutersOnNetworks(r.networkIDs)
|
||||
|
||||
// networkResourceIDs returns the IDs of all resources on the given network.
|
||||
func (r *resolver) networkResourceIDs(networkID string) map[string]struct{} {
|
||||
resourceIDs := make(map[string]struct{})
|
||||
for _, resource := range r.networkResources() {
|
||||
if _, ok := r.networkIDs[resource.NetworkID]; ok {
|
||||
if resource.NetworkID == networkID {
|
||||
resourceIDs[resource.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
if len(resourceIDs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for _, policy := range r.policies() {
|
||||
if r.policyTargetsResources(policy, resourceIDs) {
|
||||
log.WithContext(r.ctx).Tracef("bridgeRoutersToSources: policy %s (%s) targets an affected-network resource -> folding its source groups/peers", policy.ID, policy.Name)
|
||||
collectPolicySources(policy, r.groupSet, r.peerSet)
|
||||
}
|
||||
}
|
||||
return resourceIDs
|
||||
}
|
||||
|
||||
func (r *resolver) foldRoutersOnNetworks(networkIDs map[string]struct{}) {
|
||||
@@ -627,9 +873,9 @@ func (r *resolver) foldRoutersOnNetworks(networkIDs map[string]struct{}) {
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("bridgeRoutersToSources: router %s serves affected network %s -> folding peerGroups=%v peer=%q",
|
||||
router.ID, router.NetworkID, router.PeerGroups, router.Peer)
|
||||
addAll(r.groupSet, router.PeerGroups)
|
||||
addAll(r.affectedGroups, router.PeerGroups)
|
||||
if router.Peer != "" {
|
||||
r.peerSet[router.Peer] = struct{}{}
|
||||
r.affectedPeers[router.Peer] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -650,6 +896,9 @@ func (r *resolver) policyTargetsResources(policy *types.Policy, resourceIDs map[
|
||||
}
|
||||
destGroupSet := make(map[string]struct{})
|
||||
for _, rule := range policy.Rules {
|
||||
if !rule.Enabled {
|
||||
continue
|
||||
}
|
||||
if rule.DestinationResource.Type != types.ResourceTypePeer && isInSet(rule.DestinationResource.ID, resourceIDs) {
|
||||
return true
|
||||
}
|
||||
@@ -714,44 +963,20 @@ func (r *resolver) addGroupResourceIDs(groupIDs map[string]struct{}, resourceIDs
|
||||
}
|
||||
}
|
||||
|
||||
func collectPolicyDirectPeers(policy *types.Policy, peerSet map[string]struct{}) {
|
||||
// collectPolicySources folds the source groups/peers of a snapshot policy's enabled
|
||||
// rules (a disabled rule grants no access).
|
||||
func collectPolicySources(policy *types.Policy, groups, peers map[string]struct{}) {
|
||||
for _, rule := range policy.Rules {
|
||||
if !rule.Enabled {
|
||||
continue
|
||||
}
|
||||
addAll(groups, rule.Sources)
|
||||
if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" {
|
||||
peerSet[rule.SourceResource.ID] = struct{}{}
|
||||
}
|
||||
if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" {
|
||||
peerSet[rule.DestinationResource.ID] = struct{}{}
|
||||
peers[rule.SourceResource.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func collectPolicySources(policy *types.Policy, groupSet, peerSet map[string]struct{}) {
|
||||
for _, rule := range policy.Rules {
|
||||
addAll(groupSet, rule.Sources)
|
||||
if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" {
|
||||
peerSet[rule.SourceResource.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func policyReferencesGroups(policy *types.Policy, groupSet map[string]struct{}) bool {
|
||||
for _, rule := range policy.Rules {
|
||||
if anyInSet(rule.Sources, groupSet) || anyInSet(rule.Destinations, groupSet) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func policyReferencesDirectPeers(policy *types.Policy, changedSet map[string]struct{}) bool {
|
||||
for _, rule := range policy.Rules {
|
||||
if isDirectPeerInSet(rule.SourceResource, changedSet) || isDirectPeerInSet(rule.DestinationResource, changedSet) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func policyReferencesPostureChecks(policy *types.Policy, ids map[string]struct{}) bool {
|
||||
for _, id := range policy.SourcePostureChecks {
|
||||
if _, ok := ids[id]; ok {
|
||||
@@ -776,7 +1001,7 @@ func serviceMatchesChangedPeers(svc *rpservice.Service, proxyPeers []string, cha
|
||||
}
|
||||
}
|
||||
for _, target := range svc.Targets {
|
||||
if target.TargetType != rpservice.TargetTypePeer || target.TargetId == "" {
|
||||
if !target.Enabled || target.TargetType != rpservice.TargetTypePeer || target.TargetId == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := changedPeers[target.TargetId]; ok {
|
||||
|
||||
@@ -10,8 +10,8 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// policyGroupsAndPeers mirrors the explicit-policy extraction (RuleGroups +
|
||||
// direct peers) the resolver folds in, for asserting the pure logic.
|
||||
// policyGroupsAndPeers mirrors the both-sides extraction (RuleGroups + direct peers)
|
||||
// the resolver folds in for a changed policy, for asserting the pure logic.
|
||||
func policyGroupsAndPeers(policies ...*types.Policy) (groups []string, peers []string) {
|
||||
peerSet := map[string]struct{}{}
|
||||
for _, p := range policies {
|
||||
@@ -19,7 +19,14 @@ func policyGroupsAndPeers(policies ...*types.Policy) (groups []string, peers []s
|
||||
continue
|
||||
}
|
||||
groups = append(groups, p.RuleGroups()...)
|
||||
collectPolicyDirectPeers(p, peerSet)
|
||||
for _, rule := range p.Rules {
|
||||
if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" {
|
||||
peerSet[rule.SourceResource.ID] = struct{}{}
|
||||
}
|
||||
if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" {
|
||||
peerSet[rule.DestinationResource.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
for id := range peerSet {
|
||||
peers = append(peers, id)
|
||||
@@ -80,26 +87,6 @@ func TestChangeIsEmpty(t *testing.T) {
|
||||
assert.False(t, Change{PostureCheckIDs: []string{"pc"}}.isEmpty())
|
||||
}
|
||||
|
||||
func TestPolicyReferencesGroups(t *testing.T) {
|
||||
policy := &types.Policy{Rules: []*types.PolicyRule{{Sources: []string{"g1", "g2"}, Destinations: []string{"g3"}}}}
|
||||
|
||||
assert.True(t, policyReferencesGroups(policy, map[string]struct{}{"g1": {}}))
|
||||
assert.True(t, policyReferencesGroups(policy, map[string]struct{}{"g3": {}}))
|
||||
assert.False(t, policyReferencesGroups(policy, map[string]struct{}{"g4": {}}))
|
||||
assert.False(t, policyReferencesGroups(policy, map[string]struct{}{}))
|
||||
}
|
||||
|
||||
func TestPolicyReferencesDirectPeers(t *testing.T) {
|
||||
policy := &types.Policy{Rules: []*types.PolicyRule{{
|
||||
SourceResource: types.Resource{Type: types.ResourceTypePeer, ID: "p1"},
|
||||
DestinationResource: types.Resource{Type: types.ResourceTypeHost, ID: "r1"},
|
||||
}}}
|
||||
|
||||
assert.True(t, policyReferencesDirectPeers(policy, map[string]struct{}{"p1": {}}))
|
||||
assert.False(t, policyReferencesDirectPeers(policy, map[string]struct{}{"r1": {}}))
|
||||
assert.False(t, policyReferencesDirectPeers(policy, map[string]struct{}{"p2": {}}))
|
||||
}
|
||||
|
||||
func TestPolicyReferencesPostureChecks(t *testing.T) {
|
||||
policy := &types.Policy{SourcePostureChecks: []string{"pc1", "pc2"}}
|
||||
|
||||
@@ -107,24 +94,9 @@ func TestPolicyReferencesPostureChecks(t *testing.T) {
|
||||
assert.False(t, policyReferencesPostureChecks(policy, map[string]struct{}{"pc3": {}}))
|
||||
}
|
||||
|
||||
func TestCollectPolicyDirectPeers(t *testing.T) {
|
||||
policy := &types.Policy{Rules: []*types.PolicyRule{{
|
||||
SourceResource: types.Resource{Type: types.ResourceTypePeer, ID: "p1"},
|
||||
DestinationResource: types.Resource{Type: types.ResourceTypePeer, ID: "p2"},
|
||||
}, {
|
||||
DestinationResource: types.Resource{Type: types.ResourceTypeHost, ID: "r1"},
|
||||
}}}
|
||||
|
||||
peerSet := map[string]struct{}{}
|
||||
collectPolicyDirectPeers(policy, peerSet)
|
||||
|
||||
assert.Contains(t, peerSet, "p1")
|
||||
assert.Contains(t, peerSet, "p2")
|
||||
assert.NotContains(t, peerSet, "r1")
|
||||
}
|
||||
|
||||
func TestCollectPolicySources(t *testing.T) {
|
||||
policy := &types.Policy{Rules: []*types.PolicyRule{{
|
||||
Enabled: true,
|
||||
Sources: []string{"g1"},
|
||||
SourceResource: types.Resource{Type: types.ResourceTypePeer, ID: "p1"},
|
||||
Destinations: []string{"g2"},
|
||||
|
||||
@@ -520,7 +520,12 @@ func collectDeletableGroups(ctx context.Context, transaction store.Store, accoun
|
||||
// GroupAddPeer appends peer to the group
|
||||
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
|
||||
var snap *affectedpeers.Snapshot
|
||||
change := affectedpeers.Change{ChangedGroupIDs: []string{groupID}}
|
||||
// A membership change affects only the peer itself and the opposite side of THIS
|
||||
// group's policies — not the group's other members, and not the peer's other
|
||||
// groups. LinkGroups walks only this group (matched, not expanded); OutputPeerIDs
|
||||
// refreshes the peer without seeding its other group memberships. For an
|
||||
// intra-group policy the opposite side is the group, so its members still refresh.
|
||||
change := affectedpeers.Change{OutputPeerIDs: []string{peerID}, LinkGroups: []string{groupID}}
|
||||
|
||||
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err := transaction.AddPeerToGroup(ctx, accountID, peerID, groupID); err != nil {
|
||||
@@ -586,10 +591,11 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID
|
||||
// GroupDeletePeer removes peer from the group
|
||||
func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error {
|
||||
var snap *affectedpeers.Snapshot
|
||||
change := affectedpeers.Change{
|
||||
ChangedGroupIDs: []string{groupID},
|
||||
RemovedPeersByGroup: map[string][]string{groupID: {peerID}},
|
||||
}
|
||||
// Same as GroupAddPeer: the removed peer and the opposite side of THIS group's
|
||||
// policies refresh, not the group's other members or the peer's other groups. The
|
||||
// peer is no longer in the group's index, but LinkGroups still drives the
|
||||
// opposite-side walk, and OutputPeerIDs refreshes the removed peer itself.
|
||||
change := affectedpeers.Change{OutputPeerIDs: []string{peerID}, LinkGroups: []string{groupID}}
|
||||
|
||||
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err := transaction.RemovePeerFromGroup(ctx, peerID, groupID); err != nil {
|
||||
@@ -600,8 +606,6 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
|
||||
return err
|
||||
}
|
||||
|
||||
// The removed peer is carried in change.RemovedPeersByGroup and folded in
|
||||
// only when the group is linked, so loading post-removal is correct.
|
||||
var err error
|
||||
if snap, err = affectedpeers.Load(ctx, transaction, accountID, change); err != nil {
|
||||
return err
|
||||
|
||||
@@ -220,7 +220,7 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
includeServiceUser, err := strconv.ParseBool(serviceUser)
|
||||
log.WithContext(r.Context()).Debugf("Should include service user: %v", includeServiceUser)
|
||||
log.WithContext(r.Context()).Tracef("Should include service user: %v", includeServiceUser)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid service_user query parameter"), w)
|
||||
return
|
||||
|
||||
@@ -39,7 +39,7 @@ type MockAccountManager struct {
|
||||
GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
|
||||
ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error)
|
||||
GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error
|
||||
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error
|
||||
MarkPeerDisconnectedFunc func(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error
|
||||
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
|
||||
@@ -114,7 +114,7 @@ type MockAccountManager struct {
|
||||
GetIdpManagerFunc func() idp.Manager
|
||||
UpdateIntegratedValidatorFunc func(ctx context.Context, accountID, userID, validator string, groups []string) error
|
||||
GroupValidationFunc func(ctx context.Context, accountId string, groups []string) (bool, error)
|
||||
SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
||||
SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) error
|
||||
FindExistingPostureCheckFunc func(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||
GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error)
|
||||
GetAccountByIDFunc func(ctx context.Context, accountID string, userID string) (*types.Account, error)
|
||||
@@ -345,9 +345,9 @@ func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userAuth
|
||||
}
|
||||
|
||||
// MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface
|
||||
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
|
||||
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
|
||||
if am.MarkPeerConnectedFunc != nil {
|
||||
return am.MarkPeerConnectedFunc(ctx, peerKey, realIP, accountID, sessionStartedAt, nmap)
|
||||
return am.MarkPeerConnectedFunc(ctx, peerKey, accountID, sessionStartedAt, nmap)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
|
||||
}
|
||||
@@ -975,9 +975,9 @@ func (am *MockAccountManager) GroupValidation(ctx context.Context, accountId str
|
||||
}
|
||||
|
||||
// SyncPeerMeta mocks SyncPeerMeta of the AccountManager interface
|
||||
func (am *MockAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error {
|
||||
func (am *MockAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) error {
|
||||
if am.SyncPeerMetaFunc != nil {
|
||||
return am.SyncPeerMetaFunc(ctx, peerPubKey, meta)
|
||||
return am.SyncPeerMetaFunc(ctx, peerPubKey, meta, realIP)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method SyncPeerMeta is not implemented")
|
||||
}
|
||||
|
||||
@@ -74,7 +74,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
|
||||
//
|
||||
// Disconnects use MarkPeerDisconnected and require the session to match
|
||||
// exactly; see PeerStatus.SessionStartedAt for the protocol.
|
||||
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
|
||||
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
am.metrics.AccountManagerMetrics().RecordPeerStatusUpdateDuration(telemetry.PeerStatusConnect, time.Since(start))
|
||||
@@ -102,10 +102,6 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
|
||||
}
|
||||
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusConnect, telemetry.PeerStatusApplied)
|
||||
|
||||
if am.geo != nil && realIP != nil {
|
||||
am.updatePeerLocationIfChanged(ctx, accountID, peer, realIP)
|
||||
}
|
||||
|
||||
if err = am.schedulePeerExpirations(ctx, accountID, peer); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -192,27 +188,40 @@ func (am *DefaultAccountManager) MarkPeerDisconnected(ctx context.Context, peerP
|
||||
}
|
||||
}
|
||||
|
||||
if peer.AddedWithSSOLogin() && peer.InactivityExpirationEnabled {
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed getting account settings to schedule inactivity expiration for peer %s: %v", peer.ID, err)
|
||||
} else if settings.PeerInactivityExpirationEnabled {
|
||||
am.checkAndSchedulePeerInactivityExpiration(ctx, accountID)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// updatePeerLocationIfChanged refreshes the geolocation on a separate
|
||||
// row update, only when the connection IP actually changed. Geo lookups
|
||||
// are expensive so we skip same-IP reconnects.
|
||||
func (am *DefaultAccountManager) updatePeerLocationIfChanged(ctx context.Context, accountID string, peer *nbpeer.Peer, realIP net.IP) {
|
||||
// resolvePeerLocation looks up the geo location for realIP, returning nil when
|
||||
// there is nothing to apply: geo disabled, no real IP, the IP is unchanged from
|
||||
// what the peer already has, or the lookup failed. Geo lookups are skipped on
|
||||
// same-IP reconnects since they are comparatively expensive. The returned value
|
||||
// is applied by Peer.UpdateMetaIfNew so the change is persisted by its peer save.
|
||||
func (am *DefaultAccountManager) resolvePeerLocation(ctx context.Context, peer *nbpeer.Peer, realIP net.IP) *nbpeer.Location {
|
||||
if am.geo == nil || realIP == nil {
|
||||
return nil
|
||||
}
|
||||
if peer.Location.ConnectionIP != nil && peer.Location.ConnectionIP.Equal(realIP) {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
location, err := am.geo.Lookup(realIP)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to get location for peer %s realip: [%s]: %v", peer.ID, realIP.String(), err)
|
||||
return
|
||||
return nil
|
||||
}
|
||||
peer.Location.ConnectionIP = realIP
|
||||
peer.Location.CountryCode = location.Country.ISOCode
|
||||
peer.Location.CityName = location.City.Names.En
|
||||
peer.Location.GeoNameID = location.City.GeonameID
|
||||
if err := am.Store.SavePeerLocation(ctx, accountID, peer); err != nil {
|
||||
log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err)
|
||||
return &nbpeer.Location{
|
||||
ConnectionIP: realIP,
|
||||
CountryCode: location.Country.ISOCode,
|
||||
CityName: location.City.Names.En,
|
||||
GeoNameID: location.City.GeonameID,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -721,7 +730,7 @@ func (am *DefaultAccountManager) handleSetupKeyAddedPeer(ctx context.Context, en
|
||||
func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKey, userID string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.Network, []*posture.Checks, bool, error) {
|
||||
if setupKey == "" && userID == "" && !peer.ProxyMeta.Embedded {
|
||||
// no auth method provided => reject access
|
||||
return nil, nil, nil, false, status.Errorf(status.Unauthenticated, "no peer auth method provided, please use a setup key or interactive SSO login")
|
||||
return nil, nil, nil, false, status.ErrNoAuthMethodProvided
|
||||
}
|
||||
|
||||
upperKey := strings.ToUpper(setupKey)
|
||||
@@ -980,7 +989,8 @@ func getPeerIPDNSLabel(ip netip.Addr, peerHostName string) (string, error) {
|
||||
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
|
||||
func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||
var peer *nbpeer.Peer
|
||||
var updated, versionChanged, ipv6CapabilityChanged bool
|
||||
var ipv6CapabilityChanged bool
|
||||
var metaDiff nbpeer.MetaDiff
|
||||
var err error
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
@@ -1010,9 +1020,10 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
|
||||
}
|
||||
|
||||
oldHasIPv6Cap := peer.HasCapability(nbpeer.PeerCapabilityIPv6Overlay)
|
||||
updated, versionChanged = peer.UpdateMetaIfNew(ctx, sync.Meta)
|
||||
newLocation := am.resolvePeerLocation(ctx, peer, sync.RealIP)
|
||||
metaDiff = peer.UpdateMetaIfNew(ctx, sync.Meta, newLocation)
|
||||
ipv6CapabilityChanged = oldHasIPv6Cap != peer.HasCapability(nbpeer.PeerCapabilityIPv6Overlay)
|
||||
if updated {
|
||||
if metaDiff.Updated() {
|
||||
am.metrics.AccountManagerMetrics().CountPeerMetUpdate()
|
||||
log.WithContext(ctx).Tracef("peer %s metadata updated", peer.ID)
|
||||
if err = transaction.SavePeer(ctx, accountID, peer); err != nil {
|
||||
@@ -1040,9 +1051,10 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
|
||||
return nil, nil, nil, 0, err
|
||||
}
|
||||
|
||||
if isStatusChanged || sync.UpdateAccountPeers || ipv6CapabilityChanged || (updated && (len(resPostureChecks) > 0 || versionChanged)) {
|
||||
metaDiffAffectsPosture := posture.AffectsPosture(ctx, &metaDiff, resPostureChecks)
|
||||
if isStatusChanged || sync.UpdateAccountPeers || ipv6CapabilityChanged || metaDiffAffectsPosture || metaDiff.VersionChanged() || metaDiff.HostnameChanged() {
|
||||
changedPeerIDs := []string{peer.ID}
|
||||
affectedPeerIDs := am.syncPeerAffectedPeers(ctx, accountID, peer.ID, nmap, peerNotValid, updated, len(resPostureChecks) > 0)
|
||||
affectedPeerIDs := am.syncPeerAffectedPeers(ctx, accountID, peer.ID, nmap, peerNotValid, metaDiffAffectsPosture)
|
||||
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
|
||||
return nil, nil, nil, 0, fmt.Errorf("notify network map controller of peer update: %w", err)
|
||||
}
|
||||
@@ -1059,8 +1071,8 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
|
||||
// metadata change that flips a posture result removes this peer from others'
|
||||
// maps asymmetrically; that case (and an invalid peer, whose map is empty) falls
|
||||
// back to the resolver.
|
||||
func (am *DefaultAccountManager) syncPeerAffectedPeers(ctx context.Context, accountID, peerID string, nmap *types.NetworkMap, peerNotValid, metaUpdated, hasPostureChecks bool) []string {
|
||||
if peerNotValid || (metaUpdated && hasPostureChecks) {
|
||||
func (am *DefaultAccountManager) syncPeerAffectedPeers(ctx context.Context, accountID, peerID string, nmap *types.NetworkMap, peerNotValid, metaChangeAffectedPosture bool) []string {
|
||||
if peerNotValid || metaChangeAffectedPosture {
|
||||
return am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, []string{peerID})
|
||||
}
|
||||
return affectedPeerIDsFromNetworkMap(nmap, peerID)
|
||||
|
||||
@@ -107,6 +107,15 @@ type Location struct {
|
||||
GeoNameID uint // city level geoname id
|
||||
}
|
||||
|
||||
// equal reports whether two locations match. ConnectionIP is a net.IP slice, so it uses
|
||||
// IP.Equal, not ==.
|
||||
func (l Location) equal(other Location) bool {
|
||||
return l.CountryCode == other.CountryCode &&
|
||||
l.CityName == other.CityName &&
|
||||
l.GeoNameID == other.GeoNameID &&
|
||||
l.ConnectionIP.Equal(other.ConnectionIP)
|
||||
}
|
||||
|
||||
// NetworkAddress is the IP address with network and MAC address of a network interface
|
||||
type NetworkAddress struct {
|
||||
NetIP netip.Prefix `gorm:"serializer:json"`
|
||||
@@ -256,50 +265,88 @@ func (p *Peer) Copy() *Peer {
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateMetaIfNew updates peer's system metadata if new information is provided
|
||||
// returns true if meta was updated, false otherwise
|
||||
func (p *Peer) UpdateMetaIfNew(ctx context.Context, meta PeerSystemMeta) (updated, versionChanged bool) {
|
||||
// UpdateMetaIfNew updates peer's system metadata and connection geo location if
|
||||
// new information is provided. newLocation is the geo location resolved from the
|
||||
// peer's current connection IP, or nil when there is nothing to apply (geo
|
||||
// disabled, no real IP, or the IP is unchanged); the caller owns the expensive
|
||||
// lookup and the same-IP guard. It returns a MetaDiff describing what changed;
|
||||
// diff.Updated() reports whether the peer needs to be persisted.
|
||||
func (p *Peer) UpdateMetaIfNew(ctx context.Context, meta PeerSystemMeta, newLocation *Location) MetaDiff {
|
||||
if meta.isEmpty() {
|
||||
return updated, versionChanged
|
||||
return MetaDiff{}
|
||||
}
|
||||
|
||||
versionChanged = p.Meta.WtVersion != meta.WtVersion
|
||||
|
||||
// Avoid overwriting UIVersion if the update was triggered sole by the CLI client
|
||||
if meta.UIVersion == "" {
|
||||
meta.UIVersion = p.Meta.UIVersion
|
||||
}
|
||||
|
||||
oldVersion := p.Meta.WtVersion
|
||||
effectiveLocation := p.Location
|
||||
if newLocation != nil {
|
||||
effectiveLocation = *newLocation
|
||||
}
|
||||
|
||||
diff := metaDiff(p.Meta, meta)
|
||||
if len(diff) != 0 {
|
||||
diff := diffMeta(p.Meta, meta, p.Location, effectiveLocation)
|
||||
if diff.Updated() {
|
||||
p.Meta = meta
|
||||
updated = true
|
||||
}
|
||||
p.Location = effectiveLocation
|
||||
|
||||
if diff.Updated() {
|
||||
log.WithContext(ctx).Debug(diff.LogSummary())
|
||||
}
|
||||
|
||||
versionInfo := ""
|
||||
if versionChanged {
|
||||
versionInfo = fmt.Sprintf("version changed: %s -> %s, ", oldVersion, meta.WtVersion)
|
||||
}
|
||||
|
||||
if len(diff) > 0 || versionChanged {
|
||||
log.WithContext(ctx).
|
||||
Debugf("peer meta updated, %s%d field(s) changed: %s", versionInfo, len(diff), strings.Join(diff, ", "))
|
||||
}
|
||||
|
||||
return updated, versionChanged
|
||||
return diff
|
||||
}
|
||||
|
||||
// MetaDiff holds a peer's full before/after state across a sync: both metas and both
|
||||
// connection locations (the location lives on Peer, not PeerSystemMeta, but posture
|
||||
// checks read it). Changed lists what moved, for logging and the persistence decision;
|
||||
// the snapshots let a posture check be replayed against old and new. Everything is derived
|
||||
// from these fields, so there are no parallel per-field flags to keep in sync.
|
||||
type MetaDiff struct {
|
||||
OldMeta PeerSystemMeta
|
||||
NewMeta PeerSystemMeta
|
||||
OldLocation Location
|
||||
NewLocation Location
|
||||
|
||||
Changed []string
|
||||
}
|
||||
|
||||
// Updated reports whether anything changed and the peer must be persisted. diffMeta fills
|
||||
// Changed in the pass that builds the diff, so this is a length check, not a re-comparison.
|
||||
// Pointer receiver: MetaDiff embeds two metas, so copying it per call is wasteful.
|
||||
func (d *MetaDiff) Updated() bool {
|
||||
return len(d.Changed) != 0
|
||||
}
|
||||
|
||||
// VersionChanged reports whether the WireGuard client version changed (a client upgrade).
|
||||
func (d *MetaDiff) VersionChanged() bool {
|
||||
return d.OldMeta.WtVersion != d.NewMeta.WtVersion
|
||||
}
|
||||
|
||||
// HostnameChanged reports whether the peer's hostname changed.
|
||||
func (d *MetaDiff) HostnameChanged() bool {
|
||||
return d.OldMeta.Hostname != d.NewMeta.Hostname
|
||||
}
|
||||
|
||||
// LogSummary renders the changed fields as a single human-readable line.
|
||||
func (d *MetaDiff) LogSummary() string {
|
||||
return fmt.Sprintf("peer meta updated, %d field(s) changed: %s",
|
||||
len(d.Changed), strings.Join(d.Changed, ", "))
|
||||
}
|
||||
|
||||
// metaDiff returns a human-readable list of the fields that differ between the
|
||||
// old and new meta, each formatted as `field: <old> -> <new>`. It is the single
|
||||
// source of truth for meta comparison: isEqual reports equality as an empty
|
||||
// diff, so the log line can never disagree with the change decision. Slices are
|
||||
// cloned before sorting, so callers' meta is not mutated.
|
||||
func metaDiff(oldMeta, newMeta PeerSystemMeta) []string {
|
||||
var diff []string
|
||||
return diffMeta(oldMeta, newMeta, Location{}, Location{}).Changed
|
||||
}
|
||||
|
||||
// diffMeta snapshots a peer's old and new state and records a Changed entry per field that
|
||||
// moved. It is the single source of truth for the comparison: isEqual is an empty Changed
|
||||
// list, so the log line and the persistence decision can never disagree.
|
||||
func diffMeta(oldMeta, newMeta PeerSystemMeta, oldLocation, newLocation Location) MetaDiff {
|
||||
d := MetaDiff{OldMeta: oldMeta, NewMeta: newMeta, OldLocation: oldLocation, NewLocation: newLocation}
|
||||
add := func(field string, oldVal, newVal any) {
|
||||
diff = append(diff, fmt.Sprintf("%s: %v -> %v", field, oldVal, newVal))
|
||||
d.Changed = append(d.Changed, fmt.Sprintf("%s: %v -> %v", field, oldVal, newVal))
|
||||
}
|
||||
|
||||
if oldMeta.Hostname != newMeta.Hostname {
|
||||
@@ -353,16 +400,18 @@ func metaDiff(oldMeta, newMeta PeerSystemMeta) []string {
|
||||
if !capabilitiesEqual(oldMeta.Capabilities, newMeta.Capabilities) {
|
||||
add("capabilities", oldMeta.Capabilities, newMeta.Capabilities)
|
||||
}
|
||||
|
||||
if !sameMultiset(oldMeta.NetworkAddresses, newMeta.NetworkAddresses) {
|
||||
add("network_addresses", fmt.Sprintf("%v", oldMeta.NetworkAddresses), fmt.Sprintf("%v", newMeta.NetworkAddresses))
|
||||
}
|
||||
|
||||
if !sameMultiset(oldMeta.Files, newMeta.Files) {
|
||||
add("files", fmt.Sprintf("%v", oldMeta.Files), fmt.Sprintf("%v", newMeta.Files))
|
||||
}
|
||||
|
||||
return diff
|
||||
if !oldLocation.equal(newLocation) {
|
||||
add("connection_ip", oldLocation.ConnectionIP, newLocation.ConnectionIP)
|
||||
}
|
||||
|
||||
return d
|
||||
}
|
||||
|
||||
// sameMultiset reports whether two slices contain the same elements with the
|
||||
|
||||
202
management/server/posture/affects_posture_test.go
Normal file
202
management/server/posture/affects_posture_test.go
Normal file
@@ -0,0 +1,202 @@
|
||||
package posture
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
)
|
||||
|
||||
// diffFrom builds a MetaDiff from the old/new snapshots AffectsPosture replays against.
|
||||
func diffFrom(oldMeta, newMeta nbpeer.PeerSystemMeta, oldLoc, newLoc nbpeer.Location) *nbpeer.MetaDiff {
|
||||
return &nbpeer.MetaDiff{
|
||||
OldMeta: oldMeta,
|
||||
NewMeta: newMeta,
|
||||
OldLocation: oldLoc,
|
||||
NewLocation: newLoc,
|
||||
}
|
||||
}
|
||||
|
||||
func checks(def ChecksDefinition) []*Checks {
|
||||
return []*Checks{{Checks: def}}
|
||||
}
|
||||
|
||||
func TestAffectsPosture_NilDiff(t *testing.T) {
|
||||
assert.False(t, AffectsPosture(context.Background(), nil, checks(ChecksDefinition{
|
||||
NBVersionCheck: &NBVersionCheck{MinVersion: "1.0.0"},
|
||||
})))
|
||||
}
|
||||
|
||||
func TestAffectsPosture_NBVersion(t *testing.T) {
|
||||
c := checks(ChecksDefinition{NBVersionCheck: &NBVersionCheck{MinVersion: "1.2.0"}})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
oldVer, newVer string
|
||||
want bool
|
||||
}{
|
||||
{"both above min, no flip", "1.3.0", "1.4.0", false},
|
||||
{"both below min, no flip", "1.0.0", "1.1.0", false},
|
||||
{"crosses up below->above", "1.1.0", "1.3.0", true},
|
||||
{"crosses down above->below", "1.3.0", "1.1.0", true},
|
||||
{"unparsable old only -> flip", "garbage", "1.3.0", true},
|
||||
{"unparsable both -> no flip", "garbage", "junk", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
diff := diffFrom(
|
||||
nbpeer.PeerSystemMeta{WtVersion: tt.oldVer},
|
||||
nbpeer.PeerSystemMeta{WtVersion: tt.newVer},
|
||||
nbpeer.Location{}, nbpeer.Location{},
|
||||
)
|
||||
assert.Equal(t, tt.want, AffectsPosture(context.Background(), diff, c))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAffectsPosture_OSVersion_KernelBumpWithinMin(t *testing.T) {
|
||||
c := checks(ChecksDefinition{OSVersionCheck: &OSVersionCheck{
|
||||
Linux: &MinKernelVersionCheck{MinKernelVersion: "5.0.0"},
|
||||
}})
|
||||
|
||||
// Kernel moves but stays above the minimum: verdict stays pass -> not affected.
|
||||
withinMin := diffFrom(
|
||||
nbpeer.PeerSystemMeta{GoOS: "linux", KernelVersion: "5.10.0-arch1"},
|
||||
nbpeer.PeerSystemMeta{GoOS: "linux", KernelVersion: "5.15.0-arch2"},
|
||||
nbpeer.Location{}, nbpeer.Location{},
|
||||
)
|
||||
assert.False(t, AffectsPosture(context.Background(), withinMin, c))
|
||||
|
||||
// Kernel drops below the minimum: verdict flips pass -> fail -> affected.
|
||||
crossesDown := diffFrom(
|
||||
nbpeer.PeerSystemMeta{GoOS: "linux", KernelVersion: "5.10.0-arch1"},
|
||||
nbpeer.PeerSystemMeta{GoOS: "linux", KernelVersion: "4.19.0-arch1"},
|
||||
nbpeer.Location{}, nbpeer.Location{},
|
||||
)
|
||||
assert.True(t, AffectsPosture(context.Background(), crossesDown, c))
|
||||
}
|
||||
|
||||
func TestAffectsPosture_OSVersion_GoOSSwitchFlipsVerdict(t *testing.T) {
|
||||
// Only Linux is constrained. An OS outside the switch (freebsd) passes; switching to a
|
||||
// failing linux kernel flips the verdict pass -> fail.
|
||||
c := checks(ChecksDefinition{OSVersionCheck: &OSVersionCheck{
|
||||
Linux: &MinKernelVersionCheck{MinKernelVersion: "6.0.0"},
|
||||
}})
|
||||
|
||||
diff := diffFrom(
|
||||
nbpeer.PeerSystemMeta{GoOS: "freebsd"},
|
||||
nbpeer.PeerSystemMeta{GoOS: "linux", KernelVersion: "4.19.0"},
|
||||
nbpeer.Location{}, nbpeer.Location{},
|
||||
)
|
||||
assert.True(t, AffectsPosture(context.Background(), diff, c))
|
||||
}
|
||||
|
||||
func TestAffectsPosture_Process_GoOSSwitchFlipsVerdict(t *testing.T) {
|
||||
// Process runs at a linux path. Switching GoOS to windows (no WindowsPath configured)
|
||||
// flips the verdict.
|
||||
c := checks(ChecksDefinition{ProcessCheck: &ProcessCheck{
|
||||
Processes: []Process{{LinuxPath: "/usr/bin/foo"}},
|
||||
}})
|
||||
|
||||
files := []nbpeer.File{{Path: "/usr/bin/foo", ProcessIsRunning: true}}
|
||||
diff := diffFrom(
|
||||
nbpeer.PeerSystemMeta{GoOS: "linux", Files: files},
|
||||
nbpeer.PeerSystemMeta{GoOS: "windows", Files: files},
|
||||
nbpeer.Location{}, nbpeer.Location{},
|
||||
)
|
||||
assert.True(t, AffectsPosture(context.Background(), diff, c))
|
||||
}
|
||||
|
||||
func TestAffectsPosture_Process_UnrelatedFileChange(t *testing.T) {
|
||||
// A tracked process stays running while an unrelated file is added: the verdict does
|
||||
// not move, so posture is not affected.
|
||||
c := checks(ChecksDefinition{ProcessCheck: &ProcessCheck{
|
||||
Processes: []Process{{LinuxPath: "/usr/bin/foo"}},
|
||||
}})
|
||||
|
||||
diff := diffFrom(
|
||||
nbpeer.PeerSystemMeta{GoOS: "linux", Files: []nbpeer.File{
|
||||
{Path: "/usr/bin/foo", ProcessIsRunning: true},
|
||||
}},
|
||||
nbpeer.PeerSystemMeta{GoOS: "linux", Files: []nbpeer.File{
|
||||
{Path: "/usr/bin/foo", ProcessIsRunning: true},
|
||||
{Path: "/usr/bin/bar", ProcessIsRunning: true},
|
||||
}},
|
||||
nbpeer.Location{}, nbpeer.Location{},
|
||||
)
|
||||
assert.False(t, AffectsPosture(context.Background(), diff, c))
|
||||
}
|
||||
|
||||
func TestAffectsPosture_GeoLocation(t *testing.T) {
|
||||
c := checks(ChecksDefinition{GeoLocationCheck: &GeoLocationCheck{
|
||||
Action: CheckActionAllow,
|
||||
Locations: []Location{{CountryCode: "DE"}},
|
||||
}})
|
||||
|
||||
// Moving within allowed countries keeps the verdict; moving out flips it.
|
||||
stayAllowed := diffFrom(
|
||||
nbpeer.PeerSystemMeta{}, nbpeer.PeerSystemMeta{},
|
||||
nbpeer.Location{CountryCode: "DE", CityName: "Berlin"},
|
||||
nbpeer.Location{CountryCode: "DE", CityName: "Munich"},
|
||||
)
|
||||
assert.False(t, AffectsPosture(context.Background(), stayAllowed, c))
|
||||
|
||||
moveOut := diffFrom(
|
||||
nbpeer.PeerSystemMeta{}, nbpeer.PeerSystemMeta{},
|
||||
nbpeer.Location{CountryCode: "DE"},
|
||||
nbpeer.Location{CountryCode: "FR"},
|
||||
)
|
||||
assert.True(t, AffectsPosture(context.Background(), moveOut, c))
|
||||
}
|
||||
|
||||
func TestAffectsPosture_PeerNetworkRange_ConnectionIP(t *testing.T) {
|
||||
// The check reads the connection IP. Moving out of the allowed range flips the verdict;
|
||||
// moving within it does not.
|
||||
_, allowed, _ := net.ParseCIDR("10.0.0.0/8")
|
||||
c := checks(ChecksDefinition{PeerNetworkRangeCheck: &PeerNetworkRangeCheck{
|
||||
Action: CheckActionAllow,
|
||||
Ranges: []netip.Prefix{netip.MustParsePrefix(allowed.String())},
|
||||
}})
|
||||
|
||||
movesOutOfRange := diffFrom(
|
||||
nbpeer.PeerSystemMeta{}, nbpeer.PeerSystemMeta{},
|
||||
nbpeer.Location{ConnectionIP: net.ParseIP("10.1.2.3")},
|
||||
nbpeer.Location{ConnectionIP: net.ParseIP("8.8.8.8")},
|
||||
)
|
||||
assert.True(t, AffectsPosture(context.Background(), movesOutOfRange, c))
|
||||
|
||||
staysInRange := diffFrom(
|
||||
nbpeer.PeerSystemMeta{}, nbpeer.PeerSystemMeta{},
|
||||
nbpeer.Location{ConnectionIP: net.ParseIP("10.1.2.3")},
|
||||
nbpeer.Location{ConnectionIP: net.ParseIP("10.9.9.9")},
|
||||
)
|
||||
assert.False(t, AffectsPosture(context.Background(), staysInRange, c))
|
||||
}
|
||||
|
||||
func TestAffectsPosture_IrrelevantFieldChange(t *testing.T) {
|
||||
// Hostname changes but no check reads it: not affected even with checks present.
|
||||
c := checks(ChecksDefinition{
|
||||
NBVersionCheck: &NBVersionCheck{MinVersion: "1.0.0"},
|
||||
GeoLocationCheck: &GeoLocationCheck{Action: CheckActionAllow, Locations: []Location{{CountryCode: "DE"}}},
|
||||
})
|
||||
|
||||
diff := diffFrom(
|
||||
nbpeer.PeerSystemMeta{Hostname: "old", WtVersion: "1.5.0"},
|
||||
nbpeer.PeerSystemMeta{Hostname: "new", WtVersion: "1.5.0"},
|
||||
nbpeer.Location{CountryCode: "DE"}, nbpeer.Location{CountryCode: "DE"},
|
||||
)
|
||||
assert.False(t, AffectsPosture(context.Background(), diff, c))
|
||||
}
|
||||
|
||||
func TestAffectsPosture_NoChecks(t *testing.T) {
|
||||
diff := diffFrom(
|
||||
nbpeer.PeerSystemMeta{WtVersion: "1.0.0"},
|
||||
nbpeer.PeerSystemMeta{WtVersion: "2.0.0"},
|
||||
nbpeer.Location{}, nbpeer.Location{},
|
||||
)
|
||||
assert.False(t, AffectsPosture(context.Background(), diff, nil))
|
||||
}
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"regexp"
|
||||
|
||||
"github.com/hashicorp/go-version"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -51,6 +53,46 @@ type Checks struct {
|
||||
Checks ChecksDefinition `gorm:"serializer:json"`
|
||||
}
|
||||
|
||||
// AffectsPosture reports whether the change in diff flips the verdict of any check. It
|
||||
// replays each check against the peer's old and new state and compares verdicts, so a
|
||||
// change that moves a field but stays the right side of a threshold (e.g. a kernel bump
|
||||
// still above the minimum) does not force a re-evaluation. See verdictChanged for how an
|
||||
// evaluation error counts.
|
||||
func AffectsPosture(ctx context.Context, diff *nbpeer.MetaDiff, checks []*Checks) bool {
|
||||
if diff == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
oldPeer := nbpeer.Peer{Meta: diff.OldMeta, Location: diff.OldLocation}
|
||||
newPeer := nbpeer.Peer{Meta: diff.NewMeta, Location: diff.NewLocation}
|
||||
|
||||
for _, c := range checks {
|
||||
for _, check := range c.GetChecks() {
|
||||
if verdictChanged(ctx, check, oldPeer, newPeer) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// verdictChanged replays check against old and new state and reports whether the verdict
|
||||
// differs. Like callers, it treats an evaluation error as deny: two errors are the same
|
||||
// verdict (no change), an error on one side only is a flip.
|
||||
func verdictChanged(ctx context.Context, check Check, oldPeer, newPeer nbpeer.Peer) bool {
|
||||
oldPass, oldErr := check.Check(ctx, oldPeer)
|
||||
newPass, newErr := check.Check(ctx, newPeer)
|
||||
|
||||
oldVerdict := oldPass && (oldErr == nil)
|
||||
newVerdict := newPass && (newErr == nil)
|
||||
changed := oldVerdict != newVerdict
|
||||
|
||||
log.WithContext(ctx).Tracef("posture check %s replay: verdict %t -> %t (changed=%t), errs: %v -> %v",
|
||||
check.Name(), oldVerdict, newVerdict, changed, oldErr, newErr)
|
||||
|
||||
return changed
|
||||
}
|
||||
|
||||
// ChecksDefinition contains definition of actual check
|
||||
type ChecksDefinition struct {
|
||||
NBVersionCheck *NBVersionCheck `json:",omitempty"`
|
||||
|
||||
@@ -489,6 +489,7 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
|
||||
|
||||
policy := &types.Policy{
|
||||
AccountID: account.Id,
|
||||
Enabled: true,
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
|
||||
@@ -581,28 +581,6 @@ func (s *SqlStore) MarkPeerDisconnectedIfSameSession(ctx context.Context, accoun
|
||||
return result.RowsAffected > 0, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) SavePeerLocation(ctx context.Context, accountID string, peerWithLocation *nbpeer.Peer) error {
|
||||
// To maintain data integrity, we create a copy of the peer's location to prevent unintended updates to other fields.
|
||||
var peerCopy nbpeer.Peer
|
||||
// Since the location field has been migrated to JSON serialization,
|
||||
// updating the struct ensures the correct data format is inserted into the database.
|
||||
peerCopy.Location = peerWithLocation.Location
|
||||
|
||||
result := s.db.Model(&nbpeer.Peer{}).
|
||||
Where(accountAndIDQueryCondition, accountID, peerWithLocation.ID).
|
||||
Updates(peerCopy)
|
||||
|
||||
if result.Error != nil {
|
||||
return status.Errorf(status.Internal, "failed to save peer locations to store: %v", result.Error)
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.Errorf(status.NotFound, peerNotFoundFMT, peerWithLocation.ID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ApproveAccountPeers marks all peers that currently require approval in the given account as approved.
|
||||
func (s *SqlStore) ApproveAccountPeers(ctx context.Context, accountID string) (int, error) {
|
||||
result := s.db.Model(&nbpeer.Peer{}).
|
||||
|
||||
@@ -618,56 +618,6 @@ func TestSqlStore_SavePeerStatus(t *testing.T) {
|
||||
assert.WithinDurationf(t, newStatus.LastSeen, actual.LastSeen.UTC(), time.Millisecond, "LastSeen should be equal")
|
||||
}
|
||||
|
||||
func TestSqlStore_SavePeerLocation(t *testing.T) {
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b")
|
||||
require.NoError(t, err)
|
||||
|
||||
peer := &nbpeer.Peer{
|
||||
AccountID: account.Id,
|
||||
ID: "testpeer",
|
||||
Location: nbpeer.Location{
|
||||
ConnectionIP: net.ParseIP("0.0.0.0"),
|
||||
CountryCode: "YY",
|
||||
CityName: "City",
|
||||
GeoNameID: 1,
|
||||
},
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Meta: nbpeer.PeerSystemMeta{},
|
||||
}
|
||||
// error is expected as peer is not in store yet
|
||||
err = store.SavePeerLocation(context.Background(), account.Id, peer)
|
||||
assert.Error(t, err)
|
||||
|
||||
account.Peers[peer.ID] = peer
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
|
||||
peer.Location.ConnectionIP = net.ParseIP("35.1.1.1")
|
||||
peer.Location.CountryCode = "DE"
|
||||
peer.Location.CityName = "Berlin"
|
||||
peer.Location.GeoNameID = 2950159
|
||||
|
||||
err = store.SavePeerLocation(context.Background(), account.Id, account.Peers[peer.ID])
|
||||
assert.NoError(t, err)
|
||||
|
||||
account, err = store.GetAccount(context.Background(), account.Id)
|
||||
require.NoError(t, err)
|
||||
|
||||
actual := account.Peers[peer.ID].Location
|
||||
assert.Equal(t, peer.Location, actual)
|
||||
|
||||
peer.ID = "non-existing-peer"
|
||||
err = store.SavePeerLocation(context.Background(), account.Id, peer)
|
||||
assert.Error(t, err)
|
||||
parsedErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")
|
||||
}
|
||||
|
||||
func Test_TestGetAccountByPrivateDomain(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
|
||||
@@ -185,7 +185,6 @@ type Store interface {
|
||||
// recorded by the database. Returns true when the update happened,
|
||||
// false when a newer session has taken over.
|
||||
MarkPeerDisconnectedIfSameSession(ctx context.Context, accountID, peerID string, sessionStartedAt int64) (bool, error)
|
||||
SavePeerLocation(ctx context.Context, accountID string, peer *nbpeer.Peer) error
|
||||
ApproveAccountPeers(ctx context.Context, accountID string) (int, error)
|
||||
DeletePeer(ctx context.Context, accountID string, peerID string) error
|
||||
|
||||
|
||||
@@ -2968,20 +2968,6 @@ func (mr *MockStoreMockRecorder) SavePeer(ctx, accountID, peer interface{}) *gom
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SavePeer", reflect.TypeOf((*MockStore)(nil).SavePeer), ctx, accountID, peer)
|
||||
}
|
||||
|
||||
// SavePeerLocation mocks base method.
|
||||
func (m *MockStore) SavePeerLocation(ctx context.Context, accountID string, peer *peer.Peer) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SavePeerLocation", ctx, accountID, peer)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SavePeerLocation indicates an expected call of SavePeerLocation.
|
||||
func (mr *MockStoreMockRecorder) SavePeerLocation(ctx, accountID, peer interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SavePeerLocation", reflect.TypeOf((*MockStore)(nil).SavePeerLocation), ctx, accountID, peer)
|
||||
}
|
||||
|
||||
// SavePeerStatus mocks base method.
|
||||
func (m *MockStore) SavePeerStatus(ctx context.Context, accountID, peerID string, status peer.PeerStatus) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -12,6 +12,9 @@ type PeerSync struct {
|
||||
WireGuardPubKey string
|
||||
// Meta is the system information passed by peer, must be always present
|
||||
Meta nbpeer.PeerSystemMeta
|
||||
// RealIP is the peer's connection IP, used to refresh its geo location.
|
||||
// May be nil when the request has no associated connection IP.
|
||||
RealIP net.IP
|
||||
// UpdateAccountPeers indicate updating account peers,
|
||||
// which occurs when the peer's metadata is updated
|
||||
UpdateAccountPeers bool
|
||||
|
||||
@@ -1059,8 +1059,8 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.WithContext(ctx).Debugf("Got %d users from ExternalCache for account %s", len(usersFromIntegration), accountID)
|
||||
log.WithContext(ctx).Debugf("Got %d users from InternalCache for account %s", len(queriedUsers), accountID)
|
||||
log.WithContext(ctx).Tracef("Got %d users from ExternalCache for account %s", len(usersFromIntegration), accountID)
|
||||
log.WithContext(ctx).Tracef("Got %d users from InternalCache for account %s", len(queriedUsers), accountID)
|
||||
queriedUsers = append(queriedUsers, usersFromIntegration...)
|
||||
}
|
||||
|
||||
|
||||
@@ -48,6 +48,10 @@ type Type int32
|
||||
var (
|
||||
ErrExtraSettingsNotFound = errors.New("extra settings not found")
|
||||
ErrPeerAlreadyLoggedIn = errors.New("peer with the same public key is already logged in")
|
||||
|
||||
// ErrNoAuthMethodProvided is returned when a peer login attempt carries neither a
|
||||
// setup key nor an SSO token. Match it with errors.Is.
|
||||
ErrNoAuthMethodProvided = Errorf(Unauthenticated, "no peer auth method provided, please use a setup key or interactive SSO login")
|
||||
)
|
||||
|
||||
// Error is an internal error
|
||||
@@ -66,6 +70,16 @@ func (e *Error) Error() string {
|
||||
return e.Message
|
||||
}
|
||||
|
||||
// Is reports whether target is an *Error with the same type and message,
|
||||
// enabling matching with errors.Is against sentinel errors.
|
||||
func (e *Error) Is(target error) bool {
|
||||
var t *Error
|
||||
if !errors.As(target, &t) {
|
||||
return false
|
||||
}
|
||||
return e.ErrorType == t.ErrorType && e.Message == t.Message
|
||||
}
|
||||
|
||||
// Errorf returns Error(ErrorType, fmt.Sprintf(format, a...)).
|
||||
func Errorf(errorType Type, format string, a ...interface{}) error {
|
||||
return &Error{
|
||||
|
||||
Reference in New Issue
Block a user