Compare commits

..

11 Commits

Author SHA1 Message Date
riccardom
4988b6726e Aligns new tests to signature 2026-06-28 17:25:17 +02:00
riccardom
2552830184 Prevents skipping of intermediate map updates potentially not applied
by moving the persistence from applySync to the map state manager
2026-06-28 17:23:34 +02:00
riccardom
3b8fc688f4 Do the wholesale (firewall/routes/dns) once only 2026-06-28 17:23:34 +02:00
riccardom
d82d62e818 Adds explicit merge call for future map updates 2026-06-28 17:20:00 +02:00
riccardom
0bf964dad7 Do not process intermediate one if new ones are fresher just use the freshest 2026-06-28 17:20:00 +02:00
riccardom
297dcb3e24 Always run onConverged for every map that is processed 2026-06-28 17:20:00 +02:00
riccardom
bc22926fe0 Drop in case of error, will reconcile with next update 2026-06-28 17:20:00 +02:00
riccardom
d3f2ef9adb Comment why not serial 2026-06-28 17:20:00 +02:00
riccardom
5bec1e8f03 Adds map state manager 2026-06-28 17:20:00 +02:00
riccardom
74bb5c613e Allows to specify max batch for tests 2026-06-28 17:20:00 +02:00
riccardom
29dde908ae Modifies handleSync to support progressive peers conns convergence 2026-06-28 17:19:27 +02:00
25 changed files with 777 additions and 1644 deletions

View File

@@ -33,7 +33,7 @@
<br/>
<br/>
<strong>
🚀 <a href="https://netbird.io/careers">We are hiring! Join us at https://netbird.io/careers</a>
🚀 <a href="https://careers.netbird.io">We are hiring! Join us at careers.netbird.io</a>
</strong>
</p>

View File

@@ -11,7 +11,6 @@ import (
"time"
"github.com/hashicorp/go-multierror"
"github.com/mitchellh/hashstructure/v2"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
@@ -31,13 +30,11 @@ type Manager interface {
// DefaultManager uses firewall manager to handle
type DefaultManager struct {
firewall firewall.Manager
ipsetCounter int
peerRulesPairs map[id.RuleID][]firewall.Rule
routeRules map[id.RuleID]struct{}
previousConfigHash uint64
hasAppliedConfig bool
mutex sync.Mutex
firewall firewall.Manager
ipsetCounter int
peerRulesPairs map[id.RuleID][]firewall.Rule
routeRules map[id.RuleID]struct{}
mutex sync.Mutex
}
func NewDefaultManager(fm firewall.Manager) *DefaultManager {
@@ -60,23 +57,6 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
return
}
// Skip the full rebuild + flush when the inputs that drive the firewall
// state are byte-for-byte identical to the last successfully applied
// update. Management re-sends the same network map far more often than it
// actually changes (account-wide updates, peer meta churn), and rebuilding
// every peer/route ACL and flushing the firewall on every such sync is the
// dominant client-side cost when nothing changed. Mirrors the same guard the
// DNS server already uses (previousConfigHash). Only the fields ApplyFiltering
// consumes participate in the hash, so an unrelated map change cannot mask a
// real ACL change.
hash, err := d.firewallConfigHash(networkMap, dnsRouteFeatureFlag)
if err != nil {
log.Errorf("unable to hash firewall configuration, applying unconditionally: %v", err)
} else if d.hasAppliedConfig && d.previousConfigHash == hash {
log.Debugf("not applying the firewall configuration update as there is nothing new (hash: %d)", hash)
return
}
start := time.Now()
defer func() {
total := 0
@@ -90,49 +70,13 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
d.applyPeerACLs(networkMap)
routeErr := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag)
if routeErr != nil {
log.Errorf("Failed to apply route ACLs: %v", routeErr)
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag); err != nil {
log.Errorf("Failed to apply route ACLs: %v", err)
}
flushErr := d.firewall.Flush()
if flushErr != nil {
log.Error("failed to flush firewall rules: ", flushErr)
if err := d.firewall.Flush(); err != nil {
log.Error("failed to flush firewall rules: ", err)
}
// Only remember the hash once the firewall actually reflects this config.
// If applying or flushing failed, leave the previous hash untouched so the
// next (possibly identical) update is not skipped and gets a chance to
// reconcile the firewall state.
if err == nil && routeErr == nil && flushErr == nil {
d.previousConfigHash = hash
d.hasAppliedConfig = true
} else {
d.hasAppliedConfig = false
}
}
// firewallConfigHash hashes exactly the inputs ApplyFiltering uses to build the
// firewall state, so an identical hash means an identical resulting ruleset.
func (d *DefaultManager) firewallConfigHash(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool) (uint64, error) {
return hashstructure.Hash(struct {
PeerRules []*mgmProto.FirewallRule
PeerRulesIsEmpty bool
RouteRules []*mgmProto.RouteFirewallRule
RouteRulesIsEmpty bool
DNSRouteFeatureFlag bool
}{
PeerRules: networkMap.GetFirewallRules(),
PeerRulesIsEmpty: networkMap.GetFirewallRulesIsEmpty(),
RouteRules: networkMap.GetRoutesFirewallRules(),
RouteRulesIsEmpty: networkMap.GetRoutesFirewallRulesIsEmpty(),
DNSRouteFeatureFlag: dnsRouteFeatureFlag,
}, hashstructure.FormatV2, &hashstructure.HashOptions{
ZeroNil: true,
IgnoreZeroValue: true,
SlicesAsSets: true,
UseStringer: true,
})
}
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {

View File

@@ -1,7 +1,6 @@
package acl
import (
"fmt"
"net/netip"
"testing"
@@ -486,149 +485,3 @@ func TestPortInfoEmpty(t *testing.T) {
})
}
}
// TestApplyFilteringSkipsUnchangedConfig verifies that an identical network map
// re-applied is recognized as a no-op (hash unchanged), while a real change to
// any firewall-relevant input forces a re-apply (hash changes). This is the
// guard that prevents a full ruleset rebuild + flush on every redundant sync.
func TestApplyFilteringSkipsUnchangedConfig(t *testing.T) {
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any())
network := netip.MustParsePrefix("172.0.0.1/32")
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: network.Addr(),
Network: network,
}).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, fw.Close(nil))
}()
acl := NewDefaultManager(fw)
networkMap := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "22",
},
},
FirewallRulesIsEmpty: false,
}
acl.ApplyFiltering(networkMap, false)
require.True(t, acl.hasAppliedConfig, "config should be marked applied after first apply")
firstHash := acl.previousConfigHash
require.NotZero(t, firstHash)
// Re-applying the identical map must not change the recorded hash: the
// expensive rebuild path was skipped.
acl.ApplyFiltering(networkMap, false)
assert.Equal(t, firstHash, acl.previousConfigHash,
"identical re-apply must be a no-op (hash unchanged)")
// A real change must produce a different hash and re-apply.
networkMap.FirewallRules[0].Action = mgmProto.RuleAction_DROP
acl.ApplyFiltering(networkMap, false)
assert.NotEqual(t, firstHash, acl.previousConfigHash,
"changing a rule's action must force a re-apply (hash changed)")
// The dnsRouteFeatureFlag also participates in the hash.
changedHash := acl.previousConfigHash
acl.ApplyFiltering(networkMap, true)
assert.NotEqual(t, changedHash, acl.previousConfigHash,
"flipping dnsRouteFeatureFlag must force a re-apply (hash changed)")
}
func buildNetworkMap(peerRules, routeRules int) *mgmProto.NetworkMap {
nm := &mgmProto.NetworkMap{
FirewallRulesIsEmpty: peerRules == 0,
RoutesFirewallRulesIsEmpty: routeRules == 0,
}
for i := range peerRules {
nm.FirewallRules = append(nm.FirewallRules, &mgmProto.FirewallRule{
PeerIP: fmt.Sprintf("10.%d.%d.%d", i>>16&0xff, i>>8&0xff, i&0xff),
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: fmt.Sprintf("%d", 1024+i%64511),
})
}
for i := range routeRules {
nm.RoutesFirewallRules = append(nm.RoutesFirewallRules, &mgmProto.RouteFirewallRule{
Destination: fmt.Sprintf("192.168.%d.0/24", i%256),
SourceRanges: []string{fmt.Sprintf("10.0.%d.0/24", i%256)},
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
})
}
return nm
}
func BenchmarkFirewallConfigHash_Small(b *testing.B) {
d := &DefaultManager{}
nm := buildNetworkMap(10, 5)
b.ResetTimer()
for b.Loop() {
_, _ = d.firewallConfigHash(nm, false)
}
}
func BenchmarkFirewallConfigHash_Medium(b *testing.B) {
d := &DefaultManager{}
nm := buildNetworkMap(100, 50)
b.ResetTimer()
for b.Loop() {
_, _ = d.firewallConfigHash(nm, false)
}
}
func BenchmarkFirewallConfigHash_Large(b *testing.B) {
d := &DefaultManager{}
nm := buildNetworkMap(1000, 200)
b.ResetTimer()
for b.Loop() {
_, _ = d.firewallConfigHash(nm, false)
}
}
// TestFirewallConfigHashDeterministic verifies the hash is stable for equal
// inputs and order-independent for the rule slices (management does not
// guarantee rule order).
func TestFirewallConfigHashDeterministic(t *testing.T) {
d := &DefaultManager{}
nm1 := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{PeerIP: "10.0.0.1", Direction: mgmProto.RuleDirection_IN, Action: mgmProto.RuleAction_ACCEPT, Protocol: mgmProto.RuleProtocol_TCP, Port: "22"},
{PeerIP: "10.0.0.2", Direction: mgmProto.RuleDirection_IN, Action: mgmProto.RuleAction_DROP, Protocol: mgmProto.RuleProtocol_TCP, Port: "80"},
},
}
// Same rules, reversed order.
nm2 := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
nm1.FirewallRules[1],
nm1.FirewallRules[0],
},
}
h1, err := d.firewallConfigHash(nm1, false)
require.NoError(t, err)
h2, err := d.firewallConfigHash(nm2, false)
require.NoError(t, err)
assert.Equal(t, h1, h2, "hash must be order-independent for rule slices")
}

View File

@@ -8,7 +8,6 @@ import (
"errors"
"net"
"net/netip"
"slices"
"strings"
"github.com/miekg/dns"
@@ -168,10 +167,7 @@ func getRcodeForNotFound(ctx context.Context, r resolver, domain string, origina
case dns.TypeA:
alternativeNetwork = "ip6"
default:
// Non-address types reach LookupIP only unexpectedly; without an
// address pair to probe we cannot prove the name is absent, so answer
// NODATA rather than a poisoning NXDOMAIN.
return dns.RcodeSuccess
return dns.RcodeNameError
}
if _, err := r.LookupNetIP(ctx, alternativeNetwork, domain); err != nil {
@@ -188,230 +184,6 @@ func getRcodeForNotFound(ctx context.Context, r resolver, domain string, origina
return dns.RcodeSuccess
}
// RecordResolver is the host resolver surface used to forward non-address
// record queries. net.DefaultResolver satisfies it.
type RecordResolver interface {
LookupMX(ctx context.Context, name string) ([]*net.MX, error)
LookupTXT(ctx context.Context, name string) ([]string, error)
LookupNS(ctx context.Context, name string) ([]*net.NS, error)
LookupSRV(ctx context.Context, service, proto, name string) (string, []*net.SRV, error)
LookupCNAME(ctx context.Context, host string) (string, error)
LookupAddr(ctx context.Context, addr string) ([]string, error)
}
// LookupRecords resolves a non-address DNS record type through the host
// resolver and returns the resource records and the DNS rcode. Types the host
// resolver cannot answer (anything not covered by the net.Resolver Lookup*
// methods) yield NODATA so that a routed name is never poisoned with NXDOMAIN
// for an unsupported type.
func LookupRecords(ctx context.Context, r RecordResolver, name string, qtype uint16, ttl uint32) ([]dns.RR, int) {
fqdn := dns.Fqdn(name)
switch qtype {
case dns.TypeMX:
return lookupMX(ctx, r, name, fqdn, ttl)
case dns.TypeTXT:
return lookupTXT(ctx, r, name, fqdn, ttl)
case dns.TypeNS:
return lookupNS(ctx, r, name, fqdn, ttl)
case dns.TypeSRV:
return lookupSRV(ctx, r, name, fqdn, ttl)
case dns.TypeCNAME:
return lookupCNAME(ctx, r, name, fqdn, ttl)
case dns.TypePTR:
return lookupPTR(ctx, r, name, fqdn, ttl)
default:
return nil, dns.RcodeSuccess
}
}
func recordHeader(fqdn string, rrtype uint16, ttl uint32) dns.RR_Header {
return dns.RR_Header{Name: fqdn, Rrtype: rrtype, Class: dns.ClassINET, Ttl: ttl}
}
func lookupMX(ctx context.Context, r RecordResolver, name, fqdn string, ttl uint32) ([]dns.RR, int) {
recs, err := r.LookupMX(ctx, name)
if err != nil {
return nil, rcodeForRecordError(err)
}
rrs := make([]dns.RR, 0, len(recs))
for _, mx := range recs {
rrs = append(rrs, &dns.MX{
Hdr: recordHeader(fqdn, dns.TypeMX, ttl),
Preference: mx.Pref,
Mx: dns.Fqdn(mx.Host),
})
}
return rrs, dns.RcodeSuccess
}
func lookupTXT(ctx context.Context, r RecordResolver, name, fqdn string, ttl uint32) ([]dns.RR, int) {
recs, err := r.LookupTXT(ctx, name)
if err != nil {
return nil, rcodeForRecordError(err)
}
rrs := make([]dns.RR, 0, len(recs))
for _, txt := range recs {
rrs = append(rrs, &dns.TXT{
Hdr: recordHeader(fqdn, dns.TypeTXT, ttl),
Txt: chunkTXT(txt),
})
}
return rrs, dns.RcodeSuccess
}
func lookupNS(ctx context.Context, r RecordResolver, name, fqdn string, ttl uint32) ([]dns.RR, int) {
recs, err := r.LookupNS(ctx, name)
if err != nil {
return nil, rcodeForRecordError(err)
}
rrs := make([]dns.RR, 0, len(recs))
for _, ns := range recs {
rrs = append(rrs, &dns.NS{
Hdr: recordHeader(fqdn, dns.TypeNS, ttl),
Ns: dns.Fqdn(ns.Host),
})
}
return rrs, dns.RcodeSuccess
}
func lookupSRV(ctx context.Context, r RecordResolver, name, fqdn string, ttl uint32) ([]dns.RR, int) {
_, recs, err := r.LookupSRV(ctx, "", "", name)
if err != nil {
return nil, rcodeForRecordError(err)
}
rrs := make([]dns.RR, 0, len(recs))
for _, srv := range recs {
rrs = append(rrs, &dns.SRV{
Hdr: recordHeader(fqdn, dns.TypeSRV, ttl),
Priority: srv.Priority,
Weight: srv.Weight,
Port: srv.Port,
Target: dns.Fqdn(srv.Target),
})
}
return rrs, dns.RcodeSuccess
}
func lookupCNAME(ctx context.Context, r RecordResolver, name, fqdn string, ttl uint32) ([]dns.RR, int) {
cname, err := r.LookupCNAME(ctx, name)
if err != nil {
return nil, rcodeForRecordError(err)
}
// LookupCNAME returns the queried name itself when the name resolves but
// has no CNAME record; that is a NODATA result, not a CNAME.
if strings.EqualFold(dns.Fqdn(cname), fqdn) {
return nil, dns.RcodeSuccess
}
return []dns.RR{&dns.CNAME{
Hdr: recordHeader(fqdn, dns.TypeCNAME, ttl),
Target: dns.Fqdn(cname),
}}, dns.RcodeSuccess
}
func lookupPTR(ctx context.Context, r RecordResolver, name, fqdn string, ttl uint32) ([]dns.RR, int) {
addr, ok := ptrQueryAddr(name)
if !ok {
return nil, dns.RcodeSuccess
}
names, err := r.LookupAddr(ctx, addr)
if err != nil {
return nil, rcodeForRecordError(err)
}
rrs := make([]dns.RR, 0, len(names))
for _, n := range names {
rrs = append(rrs, &dns.PTR{
Hdr: recordHeader(fqdn, dns.TypePTR, ttl),
Ptr: dns.Fqdn(n),
})
}
return rrs, dns.RcodeSuccess
}
// ptrQueryAddr converts a reverse-DNS query name (in-addr.arpa or ip6.arpa)
// into the address string expected by net.Resolver.LookupAddr. It reports false
// when the name is not a well-formed reverse name.
func ptrQueryAddr(qname string) (string, bool) {
name := strings.TrimSuffix(strings.ToLower(dns.Fqdn(qname)), ".")
switch {
case strings.HasSuffix(name, ".in-addr.arpa"):
return parseInAddrArpa(strings.TrimSuffix(name, ".in-addr.arpa"))
case strings.HasSuffix(name, ".ip6.arpa"):
return parseIP6Arpa(strings.TrimSuffix(name, ".ip6.arpa"))
default:
return "", false
}
}
// parseInAddrArpa turns the label portion of an in-addr.arpa name into an IPv4
// address string, reporting false when it is not a well-formed reverse name.
func parseInAddrArpa(labelPart string) (string, bool) {
labels := strings.Split(labelPart, ".")
if len(labels) != 4 {
return "", false
}
slices.Reverse(labels)
addr, err := netip.ParseAddr(strings.Join(labels, "."))
if err != nil || !addr.Is4() {
return "", false
}
return addr.String(), true
}
// parseIP6Arpa turns the nibble portion of an ip6.arpa name into an IPv6
// address string, reporting false when it is not a well-formed reverse name.
func parseIP6Arpa(nibblePart string) (string, bool) {
nibbles := strings.Split(nibblePart, ".")
if len(nibbles) != 32 {
return "", false
}
slices.Reverse(nibbles)
var sb strings.Builder
for i, n := range nibbles {
if i > 0 && i%4 == 0 {
sb.WriteByte(':')
}
sb.WriteString(n)
}
addr, err := netip.ParseAddr(sb.String())
if err != nil || !addr.Is6() {
return "", false
}
return addr.String(), true
}
// rcodeForRecordError maps a non-address lookup error to a DNS rcode. A
// not-found result becomes NODATA rather than NXDOMAIN: net.DNSError.IsNotFound
// does not distinguish a missing name from a name that exists only with records
// of other types, so the name cannot be proven absent and must not be poisoned.
func rcodeForRecordError(err error) int {
var dnsErr *net.DNSError
if errors.As(err, &dnsErr) && dnsErr.IsNotFound {
return dns.RcodeSuccess
}
return dns.RcodeServerFailure
}
// chunkTXT splits a TXT string into character-strings no longer than 255 bytes
// so the record can be packed. The chunks form one TXT resource record.
func chunkTXT(s string) []string {
const maxLen = 255
if len(s) <= maxLen {
return []string{s}
}
var chunks []string
for len(s) > maxLen {
chunks = append(chunks, s[:maxLen])
s = s[maxLen:]
}
if len(s) > 0 {
chunks = append(chunks, s)
}
return chunks
}
// FormatAnswers formats DNS resource records for logging.
func FormatAnswers(answers []dns.RR) string {
if len(answers) == 0 {

View File

@@ -5,7 +5,6 @@ import (
"errors"
"net"
"net/netip"
"strings"
"testing"
"github.com/miekg/dns"
@@ -122,164 +121,6 @@ func TestLookupIP_DNSErrorNotIsNotFound(t *testing.T) {
assert.Equal(t, dns.RcodeServerFailure, result.Rcode, "upstream failure should map to SERVFAIL")
}
func TestPtrQueryAddr(t *testing.T) {
tests := []struct {
name string
qname string
want string
wantOK bool
}{
{name: "ipv4", qname: "4.3.2.1.in-addr.arpa.", want: "1.2.3.4", wantOK: true},
{name: "ipv4 no trailing dot", qname: "1.0.0.127.in-addr.arpa", want: "127.0.0.1", wantOK: true},
{
name: "ipv6",
qname: "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.",
want: "2001:db8::1",
wantOK: true,
},
{name: "ipv4 wrong label count", qname: "2.1.in-addr.arpa.", wantOK: false},
{name: "ipv6 wrong nibble count", qname: "1.0.ip6.arpa.", wantOK: false},
{name: "not a reverse name", qname: "example.com.", wantOK: false},
{name: "ipv4 bad octet", qname: "4.3.2.999.in-addr.arpa.", wantOK: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, ok := ptrQueryAddr(tt.qname)
assert.Equal(t, tt.wantOK, ok, "parse success mismatch")
if tt.wantOK {
assert.Equal(t, tt.want, got, "parsed address mismatch")
}
})
}
}
type mockRecordResolver struct {
mx []*net.MX
txt []string
ns []*net.NS
srv []*net.SRV
cname string
ptr []string
err error
}
func (m *mockRecordResolver) LookupMX(context.Context, string) ([]*net.MX, error) {
return m.mx, m.err
}
func (m *mockRecordResolver) LookupTXT(context.Context, string) ([]string, error) {
return m.txt, m.err
}
func (m *mockRecordResolver) LookupNS(context.Context, string) ([]*net.NS, error) {
return m.ns, m.err
}
func (m *mockRecordResolver) LookupSRV(context.Context, string, string, string) (string, []*net.SRV, error) {
return "", m.srv, m.err
}
func (m *mockRecordResolver) LookupCNAME(context.Context, string) (string, error) {
return m.cname, m.err
}
func (m *mockRecordResolver) LookupAddr(context.Context, string) ([]string, error) {
return m.ptr, m.err
}
func TestLookupRecords(t *testing.T) {
notFound := &net.DNSError{IsNotFound: true, Name: "example.com."}
t.Run("MX success", func(t *testing.T) {
r := &mockRecordResolver{mx: []*net.MX{{Host: "mail.example.com.", Pref: 10}}}
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeMX, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
require.Len(t, rrs, 1)
assert.Equal(t, "mail.example.com.", rrs[0].(*dns.MX).Mx)
})
t.Run("TXT short string is one character-string", func(t *testing.T) {
r := &mockRecordResolver{txt: []string{"v=spf1 -all"}}
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeTXT, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
require.Len(t, rrs, 1)
assert.Equal(t, []string{"v=spf1 -all"}, rrs[0].(*dns.TXT).Txt)
})
t.Run("TXT chunks long strings", func(t *testing.T) {
long := strings.Repeat("a", 300)
r := &mockRecordResolver{txt: []string{long}}
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeTXT, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
require.Len(t, rrs, 1)
txt := rrs[0].(*dns.TXT).Txt
require.Len(t, txt, 2, "300-byte string should split into two character-strings")
assert.Equal(t, 255, len(txt[0]))
assert.Equal(t, 45, len(txt[1]))
})
t.Run("NS success", func(t *testing.T) {
r := &mockRecordResolver{ns: []*net.NS{{Host: "ns1.example.com."}}}
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeNS, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
require.Len(t, rrs, 1)
assert.Equal(t, "ns1.example.com.", rrs[0].(*dns.NS).Ns)
})
t.Run("SRV success", func(t *testing.T) {
r := &mockRecordResolver{srv: []*net.SRV{{Target: "sip.example.com.", Port: 5060}}}
rrs, rcode := LookupRecords(context.Background(), r, "_sip._tcp.example.com.", dns.TypeSRV, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
require.Len(t, rrs, 1)
assert.Equal(t, uint16(5060), rrs[0].(*dns.SRV).Port)
})
t.Run("CNAME success", func(t *testing.T) {
r := &mockRecordResolver{cname: "target.example.com."}
rrs, rcode := LookupRecords(context.Background(), r, "www.example.com.", dns.TypeCNAME, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
require.Len(t, rrs, 1)
assert.Equal(t, "target.example.com.", rrs[0].(*dns.CNAME).Target)
})
t.Run("CNAME equal to name is NODATA", func(t *testing.T) {
r := &mockRecordResolver{cname: "example.com."}
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeCNAME, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
assert.Empty(t, rrs, "self-referential CNAME is NODATA")
})
t.Run("PTR success", func(t *testing.T) {
r := &mockRecordResolver{ptr: []string{"host.example.com."}}
rrs, rcode := LookupRecords(context.Background(), r, "4.3.2.1.in-addr.arpa.", dns.TypePTR, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
require.Len(t, rrs, 1)
assert.Equal(t, "host.example.com.", rrs[0].(*dns.PTR).Ptr)
})
t.Run("PTR malformed name is NODATA", func(t *testing.T) {
r := &mockRecordResolver{}
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypePTR, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
assert.Empty(t, rrs)
})
t.Run("not found is NODATA never NXDOMAIN", func(t *testing.T) {
r := &mockRecordResolver{err: notFound}
_, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeMX, 300)
assert.Equal(t, dns.RcodeSuccess, rcode, "missing record must not poison the name")
})
t.Run("server failure maps to SERVFAIL", func(t *testing.T) {
r := &mockRecordResolver{err: &net.DNSError{Err: "server misbehaving", IsTemporary: true}}
_, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeMX, 300)
assert.Equal(t, dns.RcodeServerFailure, rcode)
})
t.Run("unsupported type is NODATA", func(t *testing.T) {
r := &mockRecordResolver{}
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeCAA, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
assert.Empty(t, rrs)
})
}
func TestStripOPT(t *testing.T) {
rm := &dns.Msg{
Extra: []dns.RR{

View File

@@ -37,12 +37,6 @@ const (
type resolver interface {
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
LookupMX(ctx context.Context, name string) ([]*net.MX, error)
LookupTXT(ctx context.Context, name string) ([]string, error)
LookupNS(ctx context.Context, name string) ([]*net.NS, error)
LookupSRV(ctx context.Context, service, proto, name string) (string, []*net.SRV, error)
LookupCNAME(ctx context.Context, host string) (string, error)
LookupAddr(ctx context.Context, addr string) ([]string, error)
}
type firewaller interface {
@@ -216,6 +210,12 @@ func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, q
qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
resp := query.SetReply(query)
network := resutil.NetworkForQtype(question.Qtype)
if network == "" {
resp.Rcode = dns.RcodeNotImplemented
f.writeResponse(logger, w, resp, qname, startTime)
return
}
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(qname, "."))
if mostSpecificResId == "" {
@@ -227,46 +227,9 @@ func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, q
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
defer cancel()
reqHasEdns := query.IsEdns0() != nil
switch question.Qtype {
case dns.TypeA, dns.TypeAAAA:
f.handleAddressQuery(ctx, logger, w, resp, mostSpecificResId, matchingEntries, reqHasEdns, startTime)
case dns.TypeMX, dns.TypeTXT, dns.TypeNS, dns.TypeSRV, dns.TypeCNAME, dns.TypePTR:
f.handleRecordQuery(ctx, logger, w, resp, startTime)
default:
// The domain is routed here, so any other type is answered NODATA
// (NOERROR, empty answer) rather than falling back to a resolver that
// would poison the name with NXDOMAIN. The Extended DNS Error lets a
// client tell this capability-driven NODATA apart from an
// authoritative one. The OPT pseudo-record must not appear unless the
// query advertised EDNS0.
if reqHasEdns {
attachEDE(resp, dns.ExtendedErrorCodeNotSupported, "netbird forwarder: unsupported query type")
}
f.writeResponse(logger, w, resp, qname, startTime)
}
}
// handleAddressQuery resolves A/AAAA queries, programs the firewall sets and
// resolved-IP state, and caches the answer for resilience on upstream failure.
func (f *DNSForwarder) handleAddressQuery(
ctx context.Context,
logger *log.Entry,
w dns.ResponseWriter,
resp *dns.Msg,
mostSpecificResId route.ResID,
matchingEntries []*ForwarderEntry,
reqHasEdns bool,
startTime time.Time,
) {
question := resp.Question[0]
qname := strings.ToLower(question.Name)
network := resutil.NetworkForQtype(question.Qtype)
result := resutil.LookupIP(ctx, f.resolver, network, qname, question.Qtype)
if result.Err != nil {
f.handleDNSError(ctx, logger, w, question, resp, qname, result, reqHasEdns, startTime)
f.handleDNSError(ctx, logger, w, question, resp, qname, result, query.IsEdns0() != nil, startTime)
return
}
@@ -277,25 +240,6 @@ func (f *DNSForwarder) handleAddressQuery(
f.writeResponse(logger, w, resp, qname, startTime)
}
// handleRecordQuery resolves non-address record types (MX, TXT, NS, SRV,
// CNAME, PTR) through the host resolver. Missing records are answered NODATA so
// the routed name is never poisoned with NXDOMAIN.
func (f *DNSForwarder) handleRecordQuery(
ctx context.Context,
logger *log.Entry,
w dns.ResponseWriter,
resp *dns.Msg,
startTime time.Time,
) {
question := resp.Question[0]
qname := strings.ToLower(question.Name)
records, rcode := resutil.LookupRecords(ctx, f.resolver, qname, question.Qtype, f.ttl)
resp.Rcode = rcode
resp.Answer = append(resp.Answer, records...)
f.writeResponse(logger, w, resp, qname, startTime)
}
func (f *DNSForwarder) writeResponse(logger *log.Entry, w dns.ResponseWriter, resp *dns.Msg, qname string, startTime time.Time) {
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed to write DNS response: %v", err)

View File

@@ -133,41 +133,6 @@ func (m *MockResolver) LookupNetIP(ctx context.Context, network, host string) ([
return args.Get(0).([]netip.Addr), args.Error(1)
}
func (m *MockResolver) LookupMX(ctx context.Context, name string) ([]*net.MX, error) {
args := m.Called(ctx, name)
recs, _ := args.Get(0).([]*net.MX)
return recs, args.Error(1)
}
func (m *MockResolver) LookupTXT(ctx context.Context, name string) ([]string, error) {
args := m.Called(ctx, name)
recs, _ := args.Get(0).([]string)
return recs, args.Error(1)
}
func (m *MockResolver) LookupNS(ctx context.Context, name string) ([]*net.NS, error) {
args := m.Called(ctx, name)
recs, _ := args.Get(0).([]*net.NS)
return recs, args.Error(1)
}
func (m *MockResolver) LookupSRV(ctx context.Context, service, proto, name string) (string, []*net.SRV, error) {
args := m.Called(ctx, service, proto, name)
recs, _ := args.Get(1).([]*net.SRV)
return args.String(0), recs, args.Error(2)
}
func (m *MockResolver) LookupCNAME(ctx context.Context, host string) (string, error) {
args := m.Called(ctx, host)
return args.String(0), args.Error(1)
}
func (m *MockResolver) LookupAddr(ctx context.Context, addr string) ([]string, error) {
args := m.Called(ctx, addr)
recs, _ := args.Get(0).([]string)
return recs, args.Error(1)
}
func TestDNSForwarder_SubdomainAccessLogic(t *testing.T) {
tests := []struct {
name string
@@ -580,15 +545,12 @@ func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
}
func TestDNSForwarder_ResponseCodes(t *testing.T) {
// A type with no net.Resolver Lookup method (CAA) must answer NODATA
// (NOERROR, empty) rather than NXDOMAIN/NOTIMP to avoid poisoning the name.
tests := []struct {
name string
queryType uint16
queryDomain string
configured string
expectedCode int
expectEDE bool
description string
}{
{
@@ -600,13 +562,28 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
description: "RFC compliant REFUSED for unauthorized queries",
},
{
name: "unsupported query type returns NODATA",
queryType: dns.TypeCAA,
name: "unsupported query type returns NOTIMP",
queryType: dns.TypeMX,
queryDomain: "example.com",
configured: "example.com",
expectedCode: dns.RcodeSuccess,
expectEDE: true,
description: "Unsupported types answer NODATA, not NXDOMAIN/NOTIMP",
expectedCode: dns.RcodeNotImplemented,
description: "RFC compliant NOTIMP for unsupported types",
},
{
name: "CNAME query returns NOTIMP",
queryType: dns.TypeCNAME,
queryDomain: "example.com",
configured: "example.com",
expectedCode: dns.RcodeNotImplemented,
description: "CNAME queries not supported",
},
{
name: "TXT query returns NOTIMP",
queryType: dns.TypeTXT,
queryDomain: "example.com",
configured: "example.com",
expectedCode: dns.RcodeNotImplemented,
description: "TXT queries not supported",
},
}
@@ -622,7 +599,6 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
query := &dns.Msg{}
query.SetQuestion(dns.Fqdn(tt.queryDomain), tt.queryType)
query.SetEdns0(dns.DefaultMsgSize, false)
// Capture the written response
var writtenResp *dns.Msg
@@ -638,213 +614,10 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
// Check the response written to the writer
require.NotNil(t, writtenResp, "Expected response to be written")
assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description)
assert.Empty(t, writtenResp.Answer, "Non-address response should carry no answers")
if tt.expectEDE {
require.NotNil(t, writtenResp.IsEdns0(), "EDNS0 client should get an OPT in the reply")
assert.True(t, hasEDE(writtenResp, dns.ExtendedErrorCodeNotSupported),
"unsupported type NODATA should carry EDE Not Supported")
}
})
}
}
func hasEDE(m *dns.Msg, code uint16) bool {
opt := m.IsEdns0()
if opt == nil {
return false
}
for _, o := range opt.Option {
if ede, ok := o.(*dns.EDNS0_EDE); ok && ede.InfoCode == code {
return true
}
}
return false
}
func TestDNSForwarder_RecordQueries(t *testing.T) {
notFound := &net.DNSError{IsNotFound: true, Name: "example.com"}
t.Run("MX records are forwarded", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "example.com")
mockResolver.On("LookupMX", mock.Anything, "example.com.").
Return([]*net.MX{{Host: "mail.example.com.", Pref: 10}}, nil).Once()
resp := runRecordQuery(t, forwarder, "example.com", dns.TypeMX)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 1)
mx, ok := resp.Answer[0].(*dns.MX)
require.True(t, ok, "answer should be an MX record")
assert.Equal(t, uint16(10), mx.Preference)
assert.Equal(t, "mail.example.com.", mx.Mx)
mockResolver.AssertExpectations(t)
})
t.Run("missing MX is NODATA not NXDOMAIN", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "example.com")
// A not-found cannot prove the name is absent (it may exist with only
// other record types), so it must answer NODATA, never NXDOMAIN.
mockResolver.On("LookupMX", mock.Anything, "example.com.").
Return(nil, notFound).Once()
resp := runRecordQuery(t, forwarder, "example.com", dns.TypeMX)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode, "missing record must be NODATA")
assert.Empty(t, resp.Answer)
mockResolver.AssertExpectations(t)
})
t.Run("NS records are forwarded", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "example.com")
mockResolver.On("LookupNS", mock.Anything, "example.com.").
Return([]*net.NS{{Host: "ns1.example.com."}}, nil).Once()
resp := runRecordQuery(t, forwarder, "example.com", dns.TypeNS)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 1)
ns, ok := resp.Answer[0].(*dns.NS)
require.True(t, ok, "answer should be an NS record")
assert.Equal(t, "ns1.example.com.", ns.Ns)
mockResolver.AssertExpectations(t)
})
t.Run("missing NS is NODATA", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "example.com")
mockResolver.On("LookupNS", mock.Anything, "example.com.").
Return(nil, notFound).Once()
resp := runRecordQuery(t, forwarder, "example.com", dns.TypeNS)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
assert.Empty(t, resp.Answer)
mockResolver.AssertExpectations(t)
})
t.Run("SRV records are forwarded", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "_sip._tcp.example.com")
mockResolver.On("LookupSRV", mock.Anything, "", "", "_sip._tcp.example.com.").
Return("", []*net.SRV{{Target: "sip.example.com.", Port: 5060, Priority: 10, Weight: 5}}, nil).Once()
resp := runRecordQuery(t, forwarder, "_sip._tcp.example.com", dns.TypeSRV)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 1)
srv, ok := resp.Answer[0].(*dns.SRV)
require.True(t, ok, "answer should be an SRV record")
assert.Equal(t, "sip.example.com.", srv.Target)
assert.Equal(t, uint16(5060), srv.Port)
assert.Equal(t, uint16(10), srv.Priority)
mockResolver.AssertExpectations(t)
})
t.Run("missing SRV is NODATA", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "_sip._tcp.example.com")
mockResolver.On("LookupSRV", mock.Anything, "", "", "_sip._tcp.example.com.").
Return("", nil, notFound).Once()
resp := runRecordQuery(t, forwarder, "_sip._tcp.example.com", dns.TypeSRV)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
assert.Empty(t, resp.Answer)
mockResolver.AssertExpectations(t)
})
t.Run("TXT records are forwarded", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "example.com")
mockResolver.On("LookupTXT", mock.Anything, "example.com.").
Return([]string{"v=spf1 -all"}, nil).Once()
resp := runRecordQuery(t, forwarder, "example.com", dns.TypeTXT)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 1)
txt, ok := resp.Answer[0].(*dns.TXT)
require.True(t, ok, "answer should be a TXT record")
assert.Equal(t, []string{"v=spf1 -all"}, txt.Txt)
mockResolver.AssertExpectations(t)
})
t.Run("CNAME record is forwarded", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "www.example.com")
mockResolver.On("LookupCNAME", mock.Anything, "www.example.com.").
Return("target.example.com.", nil).Once()
resp := runRecordQuery(t, forwarder, "www.example.com", dns.TypeCNAME)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 1)
cname, ok := resp.Answer[0].(*dns.CNAME)
require.True(t, ok, "answer should be a CNAME record")
assert.Equal(t, "target.example.com.", cname.Target)
mockResolver.AssertExpectations(t)
})
t.Run("CNAME equal to the name is NODATA", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "example.com")
// No CNAME exists: LookupCNAME echoes the queried name back.
mockResolver.On("LookupCNAME", mock.Anything, "example.com.").
Return("example.com.", nil).Once()
resp := runRecordQuery(t, forwarder, "example.com", dns.TypeCNAME)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
assert.Empty(t, resp.Answer, "self-referential CNAME means no CNAME record")
mockResolver.AssertExpectations(t)
})
t.Run("PTR record is forwarded", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "*.in-addr.arpa")
// The reverse name is parsed back to the address LookupAddr expects.
mockResolver.On("LookupAddr", mock.Anything, "1.2.3.4").
Return([]string{"host.example.com."}, nil).Once()
resp := runRecordQuery(t, forwarder, "4.3.2.1.in-addr.arpa", dns.TypePTR)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 1)
ptr, ok := resp.Answer[0].(*dns.PTR)
require.True(t, ok, "answer should be a PTR record")
assert.Equal(t, "host.example.com.", ptr.Ptr)
mockResolver.AssertExpectations(t)
})
}
func newRecordTestForwarder(t *testing.T, r resolver, configured string) *DNSForwarder {
t.Helper()
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
forwarder.resolver = r
d, err := domain.FromString(configured)
require.NoError(t, err)
forwarder.UpdateDomains([]*ForwarderEntry{{Domain: d, ResID: "test-res"}})
return forwarder
}
func runRecordQuery(t *testing.T, forwarder *DNSForwarder, qname string, qtype uint16) *dns.Msg {
t.Helper()
query := &dns.Msg{}
query.SetQuestion(dns.Fqdn(qname), qtype)
mockWriter := &test.MockResponseWriter{}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
resp := mockWriter.GetLastResponse()
require.NotNil(t, resp, "expected response to be written")
return resp
}
func TestDNSForwarder_UpstreamFailureEDE(t *testing.T) {
tests := []struct {
name string

View File

@@ -210,6 +210,12 @@ type Engine struct {
// networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service
networkSerial uint64
// forwardingRules holds the ingress forward rules applied for the current target.
// Wholesale sections (incl. forward rules) run only on the first pass of a target;
// it is stashed here so the final, peer-converged pass can build the lazy-connection
// exclude list without recomputing them on every bounded peer pass.
forwardingRules []firewallManager.ForwardRule
networkMonitor *networkmonitor.NetworkMonitor
sshServer sshServer
@@ -762,7 +768,15 @@ func (e *Engine) blockLanAccess() {
// modifyPeers updates peers that have been modified (e.g. IP address has been changed).
// It closes the existing connection, removes it from the peerConns map, and creates a new one.
func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
// maxPeersPerSyncPass is the default per-pass cap on how many peers each of
// removePeers/modifyPeers/addNewPeers applies, so syncMsgMux is held only for a
// batch at a time and other subsystems can interleave between passes. It is
// passed in (not read globally) so tests can exercise the multi-pass path.
const maxPeersPerSyncPass = 300
// modifyPeers re-applies up to maxBatch changed peers per call. It returns true
// when more changed peers remained than the cap, so the caller re-runs.
func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig, maxBatch int) (bool, error) {
// first, check if peers have been modified
var modified []*mgmProto.RemotePeerConfig
@@ -792,26 +806,32 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
}
}
more := false
if len(modified) > maxBatch {
modified = modified[:maxBatch]
more = true
}
// second, close all modified connections and remove them from the state map
for _, p := range modified {
err := e.removePeer(p.GetWgPubKey())
if err != nil {
return err
if err := e.removePeer(p.GetWgPubKey()); err != nil {
return false, err
}
}
// third, add the peer connections again
for _, p := range modified {
err := e.addNewPeer(p)
if err != nil {
return err
if err := e.addNewPeer(p); err != nil {
return false, err
}
}
return nil
return more, nil
}
// removePeers finds and removes peers that do not exist anymore in the network map received from the Management Service.
// It also removes peers that have been modified (e.g. change of IP address). They will be added again in addPeers method.
func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
// removePeers removes up to maxBatch peers per call. It returns true when more
// peers remained to remove than the cap, so the caller re-runs.
func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig, maxBatch int) (bool, error) {
newPeers := make([]string, 0, len(peersUpdate))
for _, p := range peersUpdate {
newPeers = append(newPeers, p.GetWgPubKey())
@@ -819,14 +839,19 @@ func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
toRemove := util.SliceDiff(e.peerStore.PeersPubKey(), newPeers)
more := false
if len(toRemove) > maxBatch {
toRemove = toRemove[:maxBatch]
more = true
}
for _, p := range toRemove {
err := e.removePeer(p)
if err != nil {
return err
if err := e.removePeer(p); err != nil {
return false, err
}
log.Infof("removed peer %s", p)
}
return nil
return more, nil
}
func (e *Engine) removeAllPeers() error {
@@ -895,40 +920,25 @@ func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdate
e.updateManager.SetVersion(autoUpdateSettings.Version, autoUpdateSettings.AlwaysUpdate)
}
// phase times a sync sub-phase: it returns a function that records the elapsed
// duration when called. Starting the timer at the call site keeps inter-phase
// glue code out of the measurement.
func (e *Engine) phase(name string) func() {
start := time.Now()
return func() {
e.clientMetrics.RecordSyncPhase(e.ctx, name, time.Since(start))
}
}
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
started := time.Now()
defer func() {
duration := time.Since(started)
log.Infof("sync finished in %s", duration)
e.clientMetrics.RecordSyncDuration(e.ctx, duration)
}()
// applySyncPass applies one bounded pass of the sync update under syncMsgMux and
// returns true if more peers remained than the per-pass cap. It is driven by the
// mapStateManager, which re-invokes it (releasing the lock between passes) until
// the update is fully applied.
func (e *Engine) applySyncPass(update *mgmProto.SyncResponse, firstPass bool) (bool, error) {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
// Check context INSIDE lock to ensure atomicity with shutdown
if e.ctx.Err() != nil {
return e.ctx.Err()
return false, e.ctx.Err()
}
if update.NetworkMap != nil && update.NetworkMap.PeerConfig != nil {
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate)
}
done := e.phase("netbird_config")
err := e.updateNetbirdConfig(update.GetNetbirdConfig())
done()
if err != nil {
return err
if err := e.updateNetbirdConfig(update.GetNetbirdConfig()); err != nil {
return false, err
}
// Posture checks are bound to the network map presence:
@@ -938,28 +948,22 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
// leave the previously applied checks untouched
nm := update.GetNetworkMap()
if nm == nil {
return nil
return false, nil
}
done = e.phase("checks")
err = e.updateChecksIfNew(update.Checks)
done()
if err != nil {
return err
if err := e.updateChecksIfNew(update.Checks); err != nil {
return false, err
}
done = e.phase("persist")
e.persistSyncResponse(update)
done()
// only apply new changes and ignore old ones
if err := e.updateNetworkMap(nm); err != nil {
return err
more, err := e.updateNetworkMap(nm, maxPeersPerSyncPass, firstPass)
if err != nil {
return false, err
}
e.statusRecorder.PublishEvent(cProto.SystemEvent_INFO, cProto.SystemEvent_SYSTEM, "Network map updated", "", nil)
return nil
return more, nil
}
// updateNetbirdConfig applies the management-provided NetBird configuration:
@@ -1005,6 +1009,13 @@ func (e *Engine) updateNetbirdConfig(wCfg *mgmProto.NetbirdConfig) error {
// (not syncMsgMux) is held for the whole Set so the store cannot be cleared (disabled /
// engine close) mid-call and have this write resurrect a file that was just removed.
func (e *Engine) persistSyncResponse(update *mgmProto.SyncResponse) {
// Only persist updates that carry a network map. Config-only updates (e.g. relay
// token rotation, STUN/TURN) have a nil NetworkMap; persisting them would overwrite
// the last full map on disk and break restore-on-restart.
if update.GetNetworkMap() == nil {
return
}
e.syncRespMux.RLock()
defer e.syncRespMux.RUnlock()
@@ -1296,7 +1307,19 @@ func (e *Engine) receiveManagementEvents() {
e.config.DisableSSHAuth,
)
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
// The map-state manager converges the latest update in the background in
// bounded passes; the stream callback only hands it the newest target.
manager := newMapStateManager(e.applySyncPass, e.persistSyncResponse, func(d time.Duration) {
log.Infof("sync finished in %s", d)
e.clientMetrics.RecordSyncDuration(e.ctx, d)
})
e.shutdownWg.Add(1)
go func() {
defer e.shutdownWg.Done()
manager.run(e.ctx)
}()
err = e.mgmClient.Sync(e.ctx, info, manager.SetTarget)
if err != nil {
// happens if management is unavailable for a long time.
// We want to cancel the operation of the whole client
@@ -1347,21 +1370,104 @@ func (e *Engine) updateTURNs(turns []*mgmProto.ProtectedHostConfig) error {
return nil
}
func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
// updateNetworkMap applies the wholesale parts (config, routes, ACL, DNS) in full
// and up to maxBatch peers per phase. It returns true when more peers remained
// than the cap, so the caller re-runs until convergence.
func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap, maxBatch int, firstPass bool) (bool, error) {
// intentionally leave it before checking serial because for now it can happen that peer IP changed but serial didn't
if networkMap.GetPeerConfig() != nil {
err := e.updateConfig(networkMap.GetPeerConfig())
if err != nil {
return err
return false, err
}
}
serial := networkMap.GetSerial()
if e.networkSerial > serial {
log.Debugf("received outdated NetworkMap with serial %d, ignoring", serial)
return nil
return false, nil
}
// Wholesale sections (firewall/ACL, DNS, routes, forward rules) are applied
// up-front and only once per target: they are cheap, local, idempotent and must
// be in place before peers come up (fail-closed). On the bounded re-runs that only
// drain the remaining peer batches they are skipped — the applied forward rules are
// reused from e.forwardingRules for the lazy-exclude finalize.
if firstPass {
e.applyWholesale(networkMap, serial)
}
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
e.updateOfflinePeers(networkMap.GetOfflinePeers())
// Filter out own peer from the remote peers list
localPubKey := e.config.WgPrivateKey.PublicKey().String()
remotePeers := make([]*mgmProto.RemotePeerConfig, 0, len(networkMap.GetRemotePeers()))
for _, p := range networkMap.GetRemotePeers() {
if p.GetWgPubKey() != localPubKey {
remotePeers = append(remotePeers, p)
}
}
// needMore signals the caller to re-run when a peer phase hit its per-pass cap.
needMore := false
// cleanup request, most likely our peer has been deleted
if networkMap.GetRemotePeersIsEmpty() {
err := e.removeAllPeers()
e.statusRecorder.FinishPeerListModifications()
if err != nil {
return false, err
}
} else {
removeMore, err := e.removePeers(remotePeers, maxBatch)
if err != nil {
return false, err
}
modifyMore, err := e.modifyPeers(remotePeers, maxBatch)
if err != nil {
return false, err
}
addMore, err := e.addNewPeers(remotePeers, maxBatch)
if err != nil {
return false, err
}
needMore = removeMore || modifyMore || addMore
e.statusRecorder.FinishPeerListModifications()
e.updatePeerSSHHostKeys(remotePeers)
if err := e.updateSSHClientConfig(remotePeers); err != nil {
log.Warnf("failed to update SSH client config: %v", err)
}
e.updateSSHServerAuth(networkMap.GetSshAuth())
}
// Set the exclude list only once peers have fully converged (this pass added
// the last batch). It needs all target peers present in the store, and
// ExcludePeer has replace-semantics — a partial set mid-convergence would be wrong.
if !needMore {
excludedLazyPeers := e.toExcludedLazyPeers(e.forwardingRules, remotePeers)
e.connMgr.SetExcludeList(e.ctx, excludedLazyPeers)
}
e.networkSerial = serial
return needMore, nil
}
// applyWholesale applies the cheap, local, idempotent map sections — lazy feature
// flag, firewall/legacy management, DNS, routes, ACL filtering, DNS forwarder and
// ingress forward rules — that must be in place before peers come up. It runs once
// per target (first pass only); the resulting forward rules are stashed in
// e.forwardingRules for the lazy-exclude finalize on the peer-converged pass.
func (e *Engine) applyWholesale(networkMap *mgmProto.NetworkMap, serial uint64) {
if err := e.connMgr.UpdatedRemoteFeatureFlag(e.ctx, networkMap.GetPeerConfig().GetLazyConnectionEnabled()); err != nil {
log.Errorf("failed to update lazy connection feature flag: %v", err)
}
@@ -1389,16 +1495,13 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
dnsConfig := toDNSConfig(protoDNSConfig, e.wgInterface.Address())
done := e.phase("dns_server")
if err := e.dnsServer.UpdateDNSServer(serial, dnsConfig); err != nil {
log.Errorf("failed to update dns server, err: %v", err)
}
done()
e.routeManager.SetDNSForwarderPort(dnsConfig.ForwarderPort)
// apply routes first, route related actions might depend on routing being enabled
done = e.phase("routes_classify")
routes := toRoutes(networkMap.GetRoutes())
serverRoutes, clientRoutes := e.routeManager.ClassifyRoutes(routes)
@@ -1407,111 +1510,25 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
e.connMgr.UpdateRouteHAMap(clientRoutes)
log.Debugf("updated lazy connection manager with %d HA groups", len(clientRoutes))
}
done()
done = e.phase("routes_apply")
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
if err := e.routeManager.UpdateRoutes(serial, serverRoutes, clientRoutes, dnsRouteFeatureFlag); err != nil {
log.Errorf("failed to update routes: %v", err)
}
done()
done = e.phase("filtering")
if e.acl != nil {
e.acl.ApplyFiltering(networkMap, dnsRouteFeatureFlag)
}
done()
done = e.phase("dns_forwarder")
fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes)
e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries)
done()
// Ingress forward rules
done = e.phase("forward_rules")
forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules())
if err != nil {
log.Errorf("failed to update forward rules, err: %v", err)
}
done()
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
done = e.phase("offline_peers")
e.updateOfflinePeers(networkMap.GetOfflinePeers())
done()
remotePeers, err := e.reconcilePeers(networkMap)
if err != nil {
return err
}
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
done = e.phase("lazy_exclude")
excludedLazyPeers := e.toExcludedLazyPeers(forwardingRules, remotePeers)
e.connMgr.SetExcludeList(e.ctx, excludedLazyPeers)
done()
e.networkSerial = serial
return nil
}
// reconcilePeers applies the remote peer list from the network map (removing,
// modifying and adding peers, then updating SSH config) and returns the remote
// peers with our own peer filtered out, for use by later sync steps.
func (e *Engine) reconcilePeers(networkMap *mgmProto.NetworkMap) ([]*mgmProto.RemotePeerConfig, error) {
// Filter out own peer from the remote peers list
localPubKey := e.config.WgPrivateKey.PublicKey().String()
remotePeers := make([]*mgmProto.RemotePeerConfig, 0, len(networkMap.GetRemotePeers()))
for _, p := range networkMap.GetRemotePeers() {
if p.GetWgPubKey() != localPubKey {
remotePeers = append(remotePeers, p)
}
}
// cleanup request, most likely our peer has been deleted
if networkMap.GetRemotePeersIsEmpty() {
err := e.removeAllPeers()
e.statusRecorder.FinishPeerListModifications()
if err != nil {
return nil, err
}
return remotePeers, nil
}
done := e.phase("removed_peers")
err := e.removePeers(remotePeers)
done()
if err != nil {
return nil, err
}
done = e.phase("modified_peers")
err = e.modifyPeers(remotePeers)
done()
if err != nil {
return nil, err
}
done = e.phase("added_peers")
err = e.addNewPeers(remotePeers)
done()
if err != nil {
return nil, err
}
e.statusRecorder.FinishPeerListModifications()
e.updatePeerSSHHostKeys(remotePeers)
if err := e.updateSSHClientConfig(remotePeers); err != nil {
log.Warnf("failed to update SSH client config: %v", err)
}
e.updateSSHServerAuth(networkMap.GetSshAuth())
return remotePeers, nil
e.forwardingRules = forwardingRules
}
func toDNSFeatureFlag(networkMap *mgmProto.NetworkMap) bool {
@@ -1691,14 +1708,23 @@ func addrToString(addr netip.Addr) string {
}
// addNewPeers adds peers that were not know before but arrived from the Management service with the update
func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
// addNewPeers adds up to maxBatch not-yet-present peers per call. It returns true
// when more new peers remained than the cap, so the caller re-runs.
func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig, maxBatch int) (bool, error) {
added := 0
for _, p := range peersUpdate {
err := e.addNewPeer(p)
if err != nil {
return err
if _, ok := e.peerStore.PeerConn(p.GetWgPubKey()); ok {
continue // already present (cheap skip), does not count toward the cap
}
if added >= maxBatch {
return true, nil // at least one more new peer remains
}
if err := e.addNewPeer(p); err != nil {
return false, err
}
added++
}
return nil
return false, nil
}
// addNewPeer add peer if connection doesn't exist

View File

@@ -124,7 +124,7 @@ func TestEngine_SSH(t *testing.T) {
RemotePeersIsEmpty: false,
}
err = engine.updateNetworkMap(networkMap)
_, err = engine.updateNetworkMap(networkMap, maxPeersPerSyncPass, true)
require.NoError(t, err)
assert.Nil(t, engine.sshServer)
@@ -146,7 +146,7 @@ func TestEngine_SSH(t *testing.T) {
RemotePeersIsEmpty: false,
}
err = engine.updateNetworkMap(networkMap)
_, err = engine.updateNetworkMap(networkMap, maxPeersPerSyncPass, true)
require.NoError(t, err)
time.Sleep(250 * time.Millisecond)
@@ -159,7 +159,7 @@ func TestEngine_SSH(t *testing.T) {
RemotePeersIsEmpty: false,
}
err = engine.updateNetworkMap(networkMap)
_, err = engine.updateNetworkMap(networkMap, maxPeersPerSyncPass, true)
require.NoError(t, err)
// time.Sleep(250 * time.Millisecond)
@@ -174,7 +174,7 @@ func TestEngine_SSH(t *testing.T) {
RemotePeersIsEmpty: false,
}
err = engine.updateNetworkMap(networkMap)
_, err = engine.updateNetworkMap(networkMap, maxPeersPerSyncPass, true)
require.NoError(t, err)
assert.Nil(t, engine.sshServer)

View File

@@ -433,7 +433,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
for _, c := range []testCase{case1, case2, case3, case4, case5, case6} {
t.Run(c.name, func(t *testing.T) {
err = engine.updateNetworkMap(c.networkMap)
_, err = engine.updateNetworkMap(c.networkMap, maxPeersPerSyncPass, true)
if err != nil {
t.Fatal(err)
return
@@ -460,6 +460,47 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
}
})
}
// chunked apply: with a per-pass cap smaller than the number of peers, a
// single updateNetworkMap applies one batch and reports more==true; the
// caller re-runs until convergence. (engine currently holds 0 peers.)
t.Run("chunked add converges over multiple passes", func(t *testing.T) {
nm := &mgmtProto.NetworkMap{
Serial: 6,
RemotePeers: []*mgmtProto.RemotePeerConfig{peer1, peer2, peer3},
}
more, err := engine.updateNetworkMap(nm, 1, true)
require.NoError(t, err)
require.True(t, more, "pass 1 should signal more")
require.Len(t, engine.peerStore.PeersPubKey(), 1)
more, err = engine.updateNetworkMap(nm, 1, false)
require.NoError(t, err)
require.True(t, more, "pass 2 should signal more")
require.Len(t, engine.peerStore.PeersPubKey(), 2)
more, err = engine.updateNetworkMap(nm, 1, false)
require.NoError(t, err)
require.False(t, more, "pass 3 should converge")
require.Len(t, engine.peerStore.PeersPubKey(), 3)
})
t.Run("chunked remove converges over multiple passes", func(t *testing.T) {
nm := &mgmtProto.NetworkMap{
Serial: 7,
RemotePeers: []*mgmtProto.RemotePeerConfig{peer1}, // remove peer2, peer3
}
more, err := engine.updateNetworkMap(nm, 1, true)
require.NoError(t, err)
require.True(t, more, "pass 1 should signal more (2 to remove, cap 1)")
more, err = engine.updateNetworkMap(nm, 1, false)
require.NoError(t, err)
require.False(t, more, "pass 2 should converge")
require.Len(t, engine.peerStore.PeersPubKey(), 1)
})
}
func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
@@ -630,7 +671,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
}
}()
err = engine.updateNetworkMap(testCase.networkMap)
_, err = engine.updateNetworkMap(testCase.networkMap, maxPeersPerSyncPass, true)
assert.NoError(t, err, "shouldn't return error")
assert.Equal(t, testCase.expectedSerial, input.inputSerial, "serial should match")
assert.Len(t, input.clientRoutes, testCase.expectedLen, "clientRoutes len should match")
@@ -834,7 +875,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
}
}()
err = engine.updateNetworkMap(testCase.networkMap)
_, err = engine.updateNetworkMap(testCase.networkMap, maxPeersPerSyncPass, true)
assert.NoError(t, err, "shouldn't return error")
assert.Equal(t, testCase.expectedSerial, input.inputSerial, "serial should match")
assert.Len(t, input.inputNSGroups, testCase.expectedZonesLen, "zones len should match")

190
client/internal/mapsync.go Normal file
View File

@@ -0,0 +1,190 @@
package internal
import (
"context"
"sync"
"time"
log "github.com/sirupsen/logrus"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
// mapStateManager is the single read/write point between the management stream
// (writes) and the convergence loop (reads/applies).
//
// The stream calls SetTarget with the latest full SyncResponse — the complete
// desired state. A single background goroutine (run) applies it to the engine in
// bounded passes via apply() until converged, releasing syncMsgMux between passes
// so other subsystems interleave. If a newer update arrives mid-flight, the loop
// coalesces: it keeps converging toward the latest target and the intermediate one
// is SKIPPED — never applied on its own (logged, no onConverged).
//
// Convergence is a single comparison: appliedGen == targetGen. targetGen
// increments on every SetTarget (an internal generation counter, so it also covers
// config-only updates that carry no network-map serial).
//
// onConverged fires once for each — and only each — map that is actually processed
// (i.e. converged as the target). Skipped/superseded maps and dropped-on-error maps
// do NOT fire it. So "sync finished in X" / RecordSyncDuration always corresponds
// to a real, completed alignment.
type mapStateManager struct {
// apply performs one bounded apply pass and reports whether more passes are needed.
// firstPass is true on the first pass of a given target, so the caller can run
// wholesale (firewall/routes/DNS/forward-rules) once per target and skip it on the
// re-runs that only drain the bounded peer batches. The manager owns this signal
// because it owns the convergence boundary; the engine need not track serials for it.
apply func(update *mgmProto.SyncResponse, firstPass bool) (bool, error)
// onConverged is called once per processed map, with the elapsed time since that
// map was received (for the sync-duration metric / "sync finished" log).
onConverged func(time.Duration)
// persist snapshots an update to disk for restore-on-restart. Called once per
// update received from management (in SetTarget), including ones later coalesced
// or skipped from apply, so the on-disk state mirrors what management last sent.
// The impl skips config-only updates (nil NetworkMap). May be nil.
persist func(*mgmProto.SyncResponse)
mu sync.Mutex
target *mgmProto.SyncResponse
targetGen uint64
appliedGen uint64
targetSetAt time.Time
wake chan struct{}
}
func newMapStateManager(apply func(update *mgmProto.SyncResponse, firstPass bool) (bool, error), persist func(*mgmProto.SyncResponse), onConverged func(time.Duration)) *mapStateManager {
return &mapStateManager{
apply: apply,
persist: persist,
onConverged: onConverged,
wake: make(chan struct{}, 1),
}
}
// SetTarget records the latest update as the desired state and wakes the loop.
// It returns immediately; convergence happens in the background. Serial-based
// staleness of the network map is still enforced inside apply (updateNetworkMap).
func (m *mapStateManager) SetTarget(update *mgmProto.SyncResponse) error {
m.mu.Lock()
// A target that has not settled yet (targetGen > appliedGen) is being superseded
// before it converged: we coalesce to the latest map and never apply this one on
// its own. It is SKIPPED — logged here, and it will not fire onConverged.
if m.target != nil && m.targetGen > m.appliedGen {
log.Debugf("sync map (gen %d) superseded before convergence, skipping", m.targetGen)
}
m.target = m.mergeTarget(m.target, update)
// Bump an internal generation counter, NOT the map serial: config-only updates
// (relay token rotation, STUN/TURN) arrive with NetworkMap == nil and carry no
// serial, yet must still be applied. Every SetTarget is therefore a distinct
// target regardless of payload. Map-serial staleness is enforced separately
// inside apply (updateNetworkMap).
m.targetGen++
m.targetSetAt = time.Now()
m.mu.Unlock()
select {
case m.wake <- struct{}{}:
default:
}
// Persist every update received from management — once per update (not per apply
// pass), and including ones that get coalesced/skipped from apply, so the on-disk
// state always reflects the latest map management sent. Done after waking the loop
// so convergence can start in parallel with the disk write. The persist impl skips
// config-only updates (nil NetworkMap).
if m.persist != nil {
m.persist(update)
}
return nil
}
// mergeTarget combines the currently pending target with a freshly received update
// and returns the new desired state. It is called under m.mu from SetTarget and is
// the single seam where the replace-vs-squash decision lives.
//
// Today management always sends a FULL map (the complete desired state), so the
// update simply replaces whatever was pending — prev is ignored. When management
// starts sending incremental/delta updates, squash `update` onto `prev` here; the
// rest of the manager (generation tracking, convergence, signaling) is unaffected
// because it already treats target as "the complete desired state, whatever it is".
func (m *mapStateManager) mergeTarget(prev, update *mgmProto.SyncResponse) *mgmProto.SyncResponse {
return update
}
// run drives convergence until ctx is done. It is meant to run in its own goroutine.
func (m *mapStateManager) run(ctx context.Context) {
// passGen is the generation of the most recent apply() call (0 = none). A pass is
// the first for its target when its generation differs from the previous one —
// true on a fresh target and on a coalesced switch to a newer target mid-flight.
var passGen uint64
for {
m.mu.Lock()
target, tg, ag := m.target, m.targetGen, m.appliedGen
m.mu.Unlock()
// Fully converged (or nothing yet): block until a new target arrives.
if target == nil || ag == tg {
select {
case <-ctx.Done():
return
case <-m.wake:
continue
}
}
firstPass := tg != passGen
passGen = tg
more, err := m.apply(target, firstPass)
if err != nil {
if ctx.Err() != nil {
return
}
// Log and DROP this target — do not retry it. A deterministic failure
// (e.g. a malformed peer in the map) would otherwise spin every pass
// making no progress. Management is the source of truth and re-delivers
// the full map on the next sync, so dropping is safe; peers already
// applied this convergence stay (idempotent diffs) and the remainder is
// reconciled by the next target. Mirrors the legacy handleSync path,
// where the apply error was logged by the gRPC client and the update
// dropped. No onConverged: this target did not converge.
log.Errorf("apply sync pass, dropping update: %v", err)
m.settle(tg, false)
continue
}
if more {
// keep converging the current target; syncMsgMux was released by apply
// between passes so other subsystems interleave.
continue
}
// This pass converged. Mark applied and signal this one map.
m.settle(tg, true)
// if a newer target arrived mid-pass, settle is a no-op (targetGen != tg) and
// ag<tg next iteration -> apply it; this generation was skipped (logged in
// SetTarget) and is not signaled.
}
}
// settle marks generation tg as processed so the loop goes idle instead of
// re-applying the same target. It is a no-op when a newer target arrived during the
// pass (targetGen != tg), leaving appliedGen behind so that target re-applies — the
// just-finished generation was already counted as skipped.
//
// When signal is true (the pass converged) it fires onConverged once for this map;
// when false (the target was dropped on error) it does not — the map did not converge.
func (m *mapStateManager) settle(tg uint64, signal bool) {
m.mu.Lock()
if m.targetGen != tg {
m.mu.Unlock()
return
}
m.appliedGen = tg
setAt := m.targetSetAt
m.mu.Unlock()
if signal && m.onConverged != nil {
m.onConverged(time.Since(setAt))
}
}

View File

@@ -0,0 +1,242 @@
package internal
import (
"context"
"errors"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
// converges over the bounded passes (apply returns more until the 3rd pass),
// fires onConverged exactly once, then blocks (no further apply) until a new target.
func TestMapStateManager_ConvergesThenStops(t *testing.T) {
var passes int32
var firstPasses int32
converged := make(chan struct{}, 1)
apply := func(_ *mgmProto.SyncResponse, firstPass bool) (bool, error) {
n := atomic.AddInt32(&passes, 1)
if firstPass {
atomic.AddInt32(&firstPasses, 1)
}
return n < 3, nil // more on pass 1 and 2, converge on pass 3
}
m := newMapStateManager(apply, nil, func(time.Duration) { converged <- struct{}{} })
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go m.run(ctx)
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
select {
case <-converged:
case <-time.After(2 * time.Second):
t.Fatal("manager did not converge")
}
require.EqualValues(t, 3, atomic.LoadInt32(&passes))
require.EqualValues(t, 1, atomic.LoadInt32(&firstPasses), "firstPass true only on pass 1, false on re-runs of the same target")
// once converged the loop blocks: no further apply calls
time.Sleep(100 * time.Millisecond)
require.EqualValues(t, 3, atomic.LoadInt32(&passes), "apply must not run after convergence")
}
// persist runs once per received update (not per apply pass), regardless of how many
// bounded passes that target takes to converge.
func TestMapStateManager_PersistsOncePerUpdate(t *testing.T) {
var passes, persists int32
converged := make(chan struct{}, 1)
apply := func(_ *mgmProto.SyncResponse, _ bool) (bool, error) {
n := atomic.AddInt32(&passes, 1)
return n < 3, nil // 3 passes for one target
}
persist := func(*mgmProto.SyncResponse) { atomic.AddInt32(&persists, 1) }
m := newMapStateManager(apply, persist, func(time.Duration) { converged <- struct{}{} })
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go m.run(ctx)
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
select {
case <-converged:
case <-time.After(2 * time.Second):
t.Fatal("did not converge")
}
require.EqualValues(t, 3, atomic.LoadInt32(&passes))
require.EqualValues(t, 1, atomic.LoadInt32(&persists), "persist once per update, not per pass")
}
// every update received from management is persisted — even one that is coalesced /
// skipped from apply before it ever converges.
func TestMapStateManager_PersistsEveryUpdateIncludingSkipped(t *testing.T) {
release := make(chan struct{})
var persists int32
apply := func(_ *mgmProto.SyncResponse, _ bool) (bool, error) {
<-release // hold the first apply so the second update coalesces/skips
return false, nil
}
persist := func(*mgmProto.SyncResponse) { atomic.AddInt32(&persists, 1) }
m := newMapStateManager(apply, persist, nil)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go m.run(ctx)
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{})) // map1 -> apply blocks
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{})) // map2 supersedes map1 (skipped from apply)
close(release)
// both updates persisted even though map1 is skipped from apply
require.Eventually(t, func() bool { return atomic.LoadInt32(&persists) == 2 }, 2*time.Second, 10*time.Millisecond)
}
// each map that is actually processed (converged before the next arrives) fires
// onConverged exactly once — mirroring the legacy per-message handleSync timing.
func TestMapStateManager_SignalsEachProcessedMap(t *testing.T) {
converged := make(chan struct{}, 8)
apply := func(_ *mgmProto.SyncResponse, _ bool) (bool, error) {
return false, nil // converge in one pass
}
m := newMapStateManager(apply, nil, func(time.Duration) { converged <- struct{}{} })
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go m.run(ctx)
const maps = 3
for i := 0; i < maps; i++ {
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
select { // wait for this map to converge before sending the next (no coalescing)
case <-converged:
case <-time.After(2 * time.Second):
t.Fatalf("map %d not signaled", i)
}
}
// no extra signals once the stream goes quiet
select {
case <-converged:
t.Fatal("unexpected extra onConverged")
case <-time.After(100 * time.Millisecond):
}
}
// a map superseded before it converges is skipped: only the latest (processed) map
// fires onConverged, not the skipped one.
func TestMapStateManager_SkippedMapNotSignaled(t *testing.T) {
release := make(chan struct{})
var applies, converged atomic.Int32
apply := func(_ *mgmProto.SyncResponse, _ bool) (bool, error) {
applies.Add(1)
<-release // hold the first apply in-flight so we can queue a newer target
return false, nil
}
m := newMapStateManager(apply, nil, func(time.Duration) { converged.Add(1) })
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go m.run(ctx)
// map1 is picked up; its apply blocks on release
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
require.Eventually(t, func() bool { return applies.Load() >= 1 }, 2*time.Second, 5*time.Millisecond)
// map2 supersedes map1 before it settled -> map1 is skipped
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
close(release) // let both applies proceed
// only the processed (latest) map signals; the skipped one does not
require.Eventually(t, func() bool { return converged.Load() == 1 }, 2*time.Second, 10*time.Millisecond)
time.Sleep(150 * time.Millisecond)
require.EqualValues(t, 1, converged.Load(), "skipped map must not fire onConverged")
require.EqualValues(t, 2, applies.Load(), "both targets entered apply (map1 once, map2 once)")
}
// an apply error drops the target: no retry of the same target, no onConverged,
// the loop goes idle — and a fresh target is still applied afterwards.
func TestMapStateManager_DropsTargetOnError(t *testing.T) {
applied := make(chan struct{}, 8)
var failNext atomic.Bool
failNext.Store(true)
apply := func(_ *mgmProto.SyncResponse, _ bool) (bool, error) {
applied <- struct{}{}
if failNext.Load() {
return false, errors.New("boom")
}
return false, nil // converge in one pass
}
var converged atomic.Int32
m := newMapStateManager(apply, nil, func(time.Duration) { converged.Add(1) })
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go m.run(ctx)
// first target errors -> applied once, then dropped (no retry, no onConverged)
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
select {
case <-applied:
case <-time.After(2 * time.Second):
t.Fatal("errored target not applied")
}
select {
case <-applied:
t.Fatal("errored target must not be retried")
case <-time.After(150 * time.Millisecond):
}
require.EqualValues(t, 0, converged.Load(), "onConverged must not fire on error")
// a new target is still processed normally and converges
failNext.Store(false)
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
select {
case <-applied:
case <-time.After(2 * time.Second):
t.Fatal("new target after error not applied")
}
require.Eventually(t, func() bool { return converged.Load() == 1 }, 2*time.Second, 10*time.Millisecond)
}
// a new target after convergence triggers a fresh apply; an idle (converged)
// manager does not apply on its own.
func TestMapStateManager_ReappliesOnNewTarget(t *testing.T) {
applied := make(chan struct{}, 8)
apply := func(_ *mgmProto.SyncResponse, _ bool) (bool, error) {
applied <- struct{}{}
return false, nil // converge in one pass
}
m := newMapStateManager(apply, nil, nil)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go m.run(ctx)
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
select {
case <-applied:
case <-time.After(2 * time.Second):
t.Fatal("first target not applied")
}
// converged → must stay idle (no spurious apply)
select {
case <-applied:
t.Fatal("unexpected apply while idle/converged")
case <-time.After(150 * time.Millisecond):
}
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
select {
case <-applied:
case <-time.After(2 * time.Second):
t.Fatal("new target not applied")
}
}

View File

@@ -120,30 +120,6 @@ func (m *influxDBMetrics) RecordSyncDuration(_ context.Context, agentInfo AgentI
m.trimLocked()
}
func (m *influxDBMetrics) RecordSyncPhase(_ context.Context, agentInfo AgentInfo, phase string, duration time.Duration) {
tags := fmt.Sprintf("deployment_type=%s,version=%s,os=%s,arch=%s,peer_id=%s,phase=%s",
agentInfo.DeploymentType.String(),
agentInfo.Version,
agentInfo.OS,
agentInfo.Arch,
agentInfo.peerID,
phase,
)
m.mu.Lock()
defer m.mu.Unlock()
m.samples = append(m.samples, influxSample{
measurement: "netbird_sync_phase",
tags: tags,
fields: map[string]float64{
"duration_seconds": duration.Seconds(),
},
timestamp: time.Now(),
})
m.trimLocked()
}
func (m *influxDBMetrics) RecordLoginDuration(_ context.Context, agentInfo AgentInfo, duration time.Duration, success bool) {
result := "success"
if !success {

View File

@@ -78,25 +78,6 @@ Tags:
- `os`: Operating system (linux, darwin, windows, android, ios, etc.)
- `arch`: CPU architecture (amd64, arm64, etc.)
### Sync Phase Timing
Measurement: `netbird_sync_phase`
Breaks down where time goes inside a single sync, so the total `netbird_sync` duration can be attributed to the sub-step that dominates.
| Field | Description |
|-------|-------------|
| `duration_seconds` | Time spent in one sub-phase of sync processing |
Tags:
- `phase`: the sub-phase — `netbird_config`, `checks`, `persist`, `dns_server`, `routes_classify`, `routes_apply`, `filtering`, `dns_forwarder`, `forward_rules`, `offline_peers`, `removed_peers`, `modified_peers`, `added_peers`, `lazy_exclude`
- `deployment_type`: "cloud" | "selfhosted" | "unknown"
- `version`: NetBird version string
- `os`: Operating system (linux, darwin, windows, android, ios, etc.)
- `arch`: CPU architecture (amd64, arm64, etc.)
**Note:** this is wall-time per phase — it includes both CPU work and time spent waiting on locks. A slow phase points to *where* the time goes, not *why*; pair it with lock-wait metrics to tell contention apart from real work.
### Login Duration
Measurement: `netbird_login`
@@ -210,52 +191,4 @@ docker compose exec influxdb influx query \
# Check ingest server health
curl http://localhost:8087/health
```
## Analyzing a Debug Bundle
Metrics collection is always on, so every debug bundle ships a `metrics.txt` in InfluxDB line protocol — a timestamped time series of all recorded events (sync durations, sync phases, connection stages, login). You can replay it into the local stack and graph it, without a running client.
The bundle's `metrics.txt` is a rolling window (capped at 5 days / ~20k samples, see [Buffer Limits](#buffer-limits)). For a connection incident the relevant window is short (connection setup is seconds), so a bundle captured during the issue is enough.
### 1. Start the stack
```bash
# From this directory (client/internal/metrics/infra)
INFLUXDB_ADMIN_TOKEN=admin123 INFLUXDB_ADMIN_PASSWORD=admin123 GRAFANA_ADMIN_PASSWORD=admin123 \
docker compose up -d
```
(`admin123` are throwaway local credentials — fine for offline analysis.)
### 2. Clear any previous data
So you only see this bundle:
```bash
docker exec influxdb influx delete --org netbird --bucket metrics --token admin123 \
--start 1970-01-01T00:00:00Z --stop 2100-01-01T00:00:00Z
```
### 3. Import the bundle's metrics.txt
InfluxDB is not exposed on the host, so import inside the container:
```bash
docker cp /path/to/bundle/metrics.txt influxdb:/tmp/m.txt
docker exec influxdb influx write --org netbird --bucket metrics --precision ns \
--token admin123 --file /tmp/m.txt
```
Re-importing the same file is idempotent (same measurement+tags+timestamp overwrites).
### 4. View the dashboards
Grafana on http://localhost:3001 (login `admin` / `admin123`), datasource pre-provisioned:
- **Where sync time goes:** http://localhost:3001/d/netbird-sync-phases/netbird-sync-phases-where-time-goes
- **General client metrics:** http://localhost:3001/d/netbird-influxdb-metrics
**Set the time range** to cover the bundle's timestamps (e.g. "Last 7 days" or an absolute range matching when the bundle was taken) — with the default short range the panels look empty.
Bundles are distinguishable by the `version` tag; add a tag at import time (e.g. `sed 's/^netbird_\([a-z_]*\),/netbird_\1,bundle=mycase,/' metrics.txt`) if you want to compare several side by side.
```

View File

@@ -1,259 +0,0 @@
{
"annotations": {
"list": []
},
"editable": true,
"fiscalYearStartMonth": 0,
"graphTooltip": 1,
"links": [],
"refresh": "",
"schemaVersion": 39,
"tags": [
"netbird",
"sync"
],
"templating": {
"list": [
{
"current": {
"text": "All",
"value": "$__all"
},
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"definition": "import \"influxdata/influxdb/schema\"\nschema.tagValues(bucket: \"metrics\", tag: \"version\")",
"includeAll": true,
"label": "version",
"multi": true,
"name": "version",
"query": "import \"influxdata/influxdb/schema\"\nschema.tagValues(bucket: \"metrics\", tag: \"version\")",
"refresh": 2,
"type": "query",
"allValue": ".*"
}
]
},
"time": {
"from": "now-2d",
"to": "now"
},
"timepicker": {},
"timezone": "",
"title": "NetBird Sync Phases (where time goes)",
"uid": "netbird-sync-phases",
"version": 1,
"panels": [
{
"id": 1,
"title": "Time per phase over time (stacked, ms)",
"type": "timeseries",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"gridPos": {
"h": 10,
"w": 24,
"x": 0,
"y": 0
},
"fieldConfig": {
"defaults": {
"unit": "ms",
"custom": {
"drawStyle": "bars",
"stacking": {
"mode": "normal",
"group": "A"
},
"fillOpacity": 80,
"lineWidth": 0
}
},
"overrides": []
},
"options": {
"legend": {
"displayMode": "table",
"placement": "right",
"calcs": [
"max",
"mean"
]
},
"tooltip": {
"mode": "multi",
"sort": "desc"
}
},
"targets": [
{
"refId": "A",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync_phase\" and r._field == \"duration_seconds\")\n |> filter(fn: (r) => r.version =~ /${version:regex}/)\n |> map(fn: (r) => ({ r with _value: r._value * 1000.0 }))\n |> keep(columns: [\"_time\", \"_value\", \"phase\"])\n |> group(columns: [\"phase\"])"
}
]
},
{
"id": 2,
"title": "p95 per phase (ms)",
"type": "bargauge",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"gridPos": {
"h": 11,
"w": 12,
"x": 0,
"y": 10
},
"fieldConfig": {
"defaults": {
"unit": "ms",
"color": {
"mode": "continuous-GrYlRd"
}
},
"overrides": []
},
"options": {
"displayMode": "gradient",
"orientation": "horizontal",
"reduceOptions": {
"calcs": [
"lastNotNull"
],
"fields": "",
"values": false
},
"showUnfilled": true
},
"targets": [
{
"refId": "A",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync_phase\" and r._field == \"duration_seconds\")\n |> filter(fn: (r) => r.version =~ /${version:regex}/)\n |> map(fn: (r) => ({ r with _value: r._value * 1000.0 }))\n |> group(columns: [\"phase\"])\n |> quantile(q: 0.95)\n |> group()\n |> sort(columns: [\"_value\"], desc: true)"
}
]
},
{
"id": 3,
"title": "Per-phase stats (ms): mean / p95 / max",
"type": "table",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"gridPos": {
"h": 11,
"w": 12,
"x": 12,
"y": 10
},
"fieldConfig": {
"defaults": {
"unit": "ms"
},
"overrides": []
},
"options": {
"showHeader": true,
"sortBy": [
{
"displayName": "max",
"desc": true
}
]
},
"transformations": [
{
"id": "merge",
"options": {}
}
],
"targets": [
{
"refId": "mean",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync_phase\" and r._field == \"duration_seconds\")\n |> filter(fn: (r) => r.version =~ /${version:regex}/)\n |> map(fn: (r) => ({ r with _value: r._value * 1000.0 }))\n |> group(columns: [\"phase\"])\n |> mean()\n |> group()\n |> keep(columns: [\"phase\", \"_value\"])\n |> rename(columns: {_value: \"mean\"})"
},
{
"refId": "p95",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync_phase\" and r._field == \"duration_seconds\")\n |> filter(fn: (r) => r.version =~ /${version:regex}/)\n |> map(fn: (r) => ({ r with _value: r._value * 1000.0 }))\n |> group(columns: [\"phase\"])\n |> quantile(q: 0.95)\n |> group()\n |> keep(columns: [\"phase\", \"_value\"])\n |> rename(columns: {_value: \"p95\"})"
},
{
"refId": "max",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync_phase\" and r._field == \"duration_seconds\")\n |> filter(fn: (r) => r.version =~ /${version:regex}/)\n |> map(fn: (r) => ({ r with _value: r._value * 1000.0 }))\n |> group(columns: [\"phase\"])\n |> max()\n |> group()\n |> keep(columns: [\"phase\", \"_value\"])\n |> rename(columns: {_value: \"max\"})"
}
]
},
{
"id": 4,
"title": "Total sync duration (netbird_sync, ms) \u2014 reference",
"type": "timeseries",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"gridPos": {
"h": 8,
"w": 24,
"x": 0,
"y": 21
},
"fieldConfig": {
"defaults": {
"unit": "ms",
"custom": {
"drawStyle": "points",
"pointSize": 5
}
},
"overrides": []
},
"options": {
"legend": {
"displayMode": "table",
"placement": "right",
"calcs": [
"max",
"mean"
]
},
"tooltip": {
"mode": "single"
}
},
"targets": [
{
"refId": "A",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync\" and r._field == \"duration_seconds\")\n |> filter(fn: (r) => r.version =~ /${version:regex}/)\n |> map(fn: (r) => ({ r with _value: r._value * 1000.0 }))\n |> keep(columns: [\"_time\", \"_value\", \"version\"])\n |> group(columns: [\"version\"])"
}
]
}
]
}

View File

@@ -19,7 +19,7 @@ const (
defaultListenAddr = ":8087"
defaultInfluxDBURL = "http://influxdb:8086/api/v2/write?org=netbird&bucket=metrics&precision=ns"
maxBodySize = 50 * 1024 * 1024 // 50 MB max request body
maxDurationSeconds = 86400.0 // reject any duration field > 24 hours
maxDurationSeconds = 300.0 // reject any duration field > 5 minutes
peerIDLength = 16 // truncated SHA-256: 8 bytes = 16 hex chars
maxTagValueLength = 64 // reject tag values longer than this
)
@@ -59,19 +59,6 @@ var allowedMeasurements = map[string]measurementSpec{
"peer_id": true,
},
},
"netbird_sync_phase": {
allowedFields: map[string]bool{
"duration_seconds": true,
},
allowedTags: map[string]bool{
"deployment_type": true,
"version": true,
"os": true,
"arch": true,
"peer_id": true,
"phase": true,
},
},
"netbird_login": {
allowedFields: map[string]bool{
"duration_seconds": true,

View File

@@ -53,14 +53,14 @@ func TestValidateLine_NegativeValue(t *testing.T) {
}
func TestValidateLine_DurationTooLarge(t *testing.T) {
line := `netbird_sync,deployment_type=cloud,version=1.0.0,os=linux,arch=amd64,peer_id=abc duration_seconds=100000 1234567890`
line := `netbird_sync,deployment_type=cloud,version=1.0.0,os=linux,arch=amd64,peer_id=abc duration_seconds=999 1234567890`
err := validateLine(line)
require.Error(t, err)
assert.Contains(t, err.Error(), "too large")
}
func TestValidateLine_TotalSecondsTooLarge(t *testing.T) {
line := `netbird_peer_connection,deployment_type=cloud,connection_type=ice,attempt_type=initial,version=1.0.0,os=linux,arch=amd64,peer_id=abc,connection_pair_id=pair total_seconds=100000 1234567890`
line := `netbird_peer_connection,deployment_type=cloud,connection_type=ice,attempt_type=initial,version=1.0.0,os=linux,arch=amd64,peer_id=abc,connection_pair_id=pair total_seconds=500 1234567890`
err := validateLine(line)
require.Error(t, err)
assert.Contains(t, err.Error(), "too large")

View File

@@ -56,9 +56,6 @@ type metricsImplementation interface {
// RecordSyncDuration records how long it took to process a sync message
RecordSyncDuration(ctx context.Context, agentInfo AgentInfo, duration time.Duration)
// RecordSyncPhase records how long a single sub-phase of sync processing took
RecordSyncPhase(ctx context.Context, agentInfo AgentInfo, phase string, duration time.Duration)
// RecordLoginDuration records how long the login to management took
RecordLoginDuration(ctx context.Context, agentInfo AgentInfo, duration time.Duration, success bool)
@@ -130,18 +127,6 @@ func (c *ClientMetrics) RecordSyncDuration(ctx context.Context, duration time.Du
c.impl.RecordSyncDuration(ctx, agentInfo, duration)
}
// RecordSyncPhase records the duration of a single sub-phase of sync processing
func (c *ClientMetrics) RecordSyncPhase(ctx context.Context, phase string, duration time.Duration) {
if c == nil {
return
}
c.mu.RLock()
agentInfo := c.agentInfo
c.mu.RUnlock()
c.impl.RecordSyncPhase(ctx, agentInfo, phase, duration)
}
// RecordLoginDuration records how long the login to management server took
func (c *ClientMetrics) RecordLoginDuration(ctx context.Context, duration time.Duration, success bool) {
if c == nil {

View File

@@ -70,9 +70,6 @@ func (m *mockMetrics) RecordConnectionStages(_ context.Context, _ AgentInfo, _ s
func (m *mockMetrics) RecordSyncDuration(_ context.Context, _ AgentInfo, _ time.Duration) {
}
func (m *mockMetrics) RecordSyncPhase(_ context.Context, _ AgentInfo, _ string, _ time.Duration) {
}
func (m *mockMetrics) RecordLoginDuration(_ context.Context, _ AgentInfo, _ time.Duration, _ bool) {
}

View File

@@ -226,11 +226,12 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return
}
// All query types for an intercepted domain are forwarded to the peer's
// DNS forwarder, which owns the name. Falling through to the system
// resolver would let it answer NXDOMAIN for a name it isn't authoritative
// for, poisoning the whole name (including the A/AAAA records the route
// does serve). The forwarder answers NODATA for types it cannot resolve.
// pass if non A/AAAA query
if r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA {
d.continueToNextHandler(w, r, logger, "non A/AAAA query")
return
}
d.mu.RLock()
peerKey := d.currentPeerKey
d.mu.RUnlock()
@@ -292,6 +293,19 @@ func (d *DnsInterceptor) writeDNSError(w dns.ResponseWriter, r *dns.Msg, logger
}
}
// continueToNextHandler signals the handler chain to try the next handler
func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry, reason string) {
logger.Tracef("continuing to next handler for domain=%s reason=%s", r.Question[0].Name, reason)
resp := new(dns.Msg)
resp.SetRcode(r, dns.RcodeNameError)
// Set Zero bit to signal handler chain to continue
resp.MsgHdr.Zero = true
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed writing DNS continue response: %v", err)
}
}
func (d *DnsInterceptor) getUpstreamIP(peerKey string) (netip.Addr, error) {
peerAllowedIP, exists := d.peerStore.AllowedIP(peerKey)
if !exists {

36
go.mod
View File

@@ -34,10 +34,10 @@ require (
git.sr.ht/~jackmordaunt/go-toast/v2 v2.0.3
github.com/DeRuina/timberjack v1.4.2
github.com/awnumar/memguard v0.23.0
github.com/aws/aws-sdk-go-v2 v1.42.0
github.com/aws/aws-sdk-go-v2/config v1.32.26
github.com/aws/aws-sdk-go-v2/credentials v1.19.25
github.com/aws/aws-sdk-go-v2/service/s3 v1.104.1
github.com/aws/aws-sdk-go-v2 v1.38.3
github.com/aws/aws-sdk-go-v2/config v1.31.6
github.com/aws/aws-sdk-go-v2/credentials v1.18.10
github.com/aws/aws-sdk-go-v2/service/s3 v1.87.3
github.com/c-robinson/iplib v1.0.3
github.com/caddyserver/certmagic v0.21.3
github.com/cilium/ebpf v0.19.0
@@ -158,21 +158,21 @@ require (
github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect
github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 // indirect
github.com/awnumar/memcall v0.4.0 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.13 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.29 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.29 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.29 // indirect
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.30 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.12 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.22 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.29 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.30 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.1 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.6 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.6 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.6 // indirect
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 // indirect
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.6 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.1 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.8.6 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.6 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.6 // indirect
github.com/aws/aws-sdk-go-v2/service/route53 v1.42.3 // indirect
github.com/aws/aws-sdk-go-v2/service/signin v1.2.1 // indirect
github.com/aws/aws-sdk-go-v2/service/sso v1.31.4 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.36.7 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.43.4 // indirect
github.com/aws/smithy-go v1.27.1 // indirect
github.com/aws/aws-sdk-go-v2/service/sso v1.29.1 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.34.2 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.38.2 // indirect
github.com/aws/smithy-go v1.23.0 // indirect
github.com/beevik/etree v1.6.0 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect

72
go.sum
View File

@@ -52,44 +52,44 @@ github.com/awnumar/memcall v0.4.0 h1:B7hgZYdfH6Ot1Goaz8jGne/7i8xD4taZie/PNSFZ29g
github.com/awnumar/memcall v0.4.0/go.mod h1:8xOx1YbfyuCg3Fy6TO8DK0kZUua3V42/goA5Ru47E8w=
github.com/awnumar/memguard v0.23.0 h1:sJ3a1/SWlcuKIQ7MV+R9p0Pvo9CWsMbGZvcZQtmc68A=
github.com/awnumar/memguard v0.23.0/go.mod h1:olVofBrsPdITtJ2HgxQKrEYEMyIBAIciVG4wNnZhW9M=
github.com/aws/aws-sdk-go-v2 v1.42.0 h1:XvXMJTkFQtpBKIWZnmr9ZEOc2InWM2yldjXEJ/bymhA=
github.com/aws/aws-sdk-go-v2 v1.42.0/go.mod h1:27+ACypSLljLAEKsCYOmrjKh83vuTRkuAe9Uv/3A4bg=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.13 h1:p1BBrg/Hhp6uK7zpejeI8QFXHJeC/mynzi04Sl03k9g=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.13/go.mod h1:8cIfkE9MDhkRZGpQ22aV6/lkYeYSozpz16Smrs5x4Ls=
github.com/aws/aws-sdk-go-v2/config v1.32.26 h1:JI+W5B3jUA8UBz2ggbICGd9UCR6/+SB21G8EFl0SFTQ=
github.com/aws/aws-sdk-go-v2/config v1.32.26/go.mod h1:RLE2Ls/wRstvdSz1GPrIWNnXcKZ/znDdWyMuiQxdBoY=
github.com/aws/aws-sdk-go-v2/credentials v1.19.25 h1:TzPVjfUZ1hsKafvYE+DIzKXIik2KufQxsPHanlkttbo=
github.com/aws/aws-sdk-go-v2/credentials v1.19.25/go.mod h1:K4hw0buguVvtC74HnVfTRr0LzQQHAWPqJbBU9QGk2Pg=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.29 h1:r6qZHbT+wxgWO/e9vYNUEtg7lv5+UN3pRqKhLXvnArg=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.29/go.mod h1:QRnaRcTVGKPGRy8w78HMQtKUGRYcnMZAANATkeVA6Mo=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.29 h1:f3vKqSo13fhTYb+JEcXwXefZQE26I1FB5eTSniU67ko=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.29/go.mod h1:MzoLFUArKGpGD+ukmPiTPG1X5x4o6M2kq4v2dr1FiEc=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.29 h1:RdwIf/CuUsvJX3RgJagbOyotl/cxoLY4xviKuE7p2GY=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.29/go.mod h1:71wt8W2EgswdZy9Mf9KNnzxZ3TiZlv4caKghPktDOkA=
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.30 h1:VTGy885W5DKBxWRUJbym9hytNaYzsyaPkCHGRRMAOhU=
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.30/go.mod h1:AS0HycUvJRFvTt613AYDOgO2jzw+00cVSMny8XB3yMY=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.12 h1:ZD2+BSw9vFsNlKYIasSNt3uDbjqqXIBcM13UJv/Lx2k=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.12/go.mod h1:Ms4zlcVBbXbiP7EVLhl+lgjvA/a7YphqQ3Ih3174EmI=
github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.22 h1:V51LGlOq/1VsDsHUdoklAQi7rMmx4qQubvFYAlP2254=
github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.22/go.mod h1:4Pzhyz8hJOm2bepgl+NjvRx8vlUFAIIvJnZ/MkcNPpU=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.29 h1:DRebniUGZ2MqiiIVmQJ04vIXr918hubdHMnarSLEWyU=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.29/go.mod h1:LfRkPCD8YHDM2E5eTkos2UpwYeZnBcVarTa8L59bJHA=
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.30 h1:4HbXxyipSYxexU0juMIpdS05dilL6dbB2VQHxxN2vGU=
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.30/go.mod h1:G7RP+uhagpKtKhd1BM9N6JQqjCcGEU47K5lBVZQyRQw=
github.com/aws/aws-sdk-go-v2 v1.38.3 h1:B6cV4oxnMs45fql4yRH+/Po/YU+597zgWqvDpYMturk=
github.com/aws/aws-sdk-go-v2 v1.38.3/go.mod h1:sDioUELIUO9Znk23YVmIk86/9DOpkbyyVb1i/gUNFXY=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.1 h1:i8p8P4diljCr60PpJp6qZXNlgX4m2yQFpYk+9ZT+J4E=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.1/go.mod h1:ddqbooRZYNoJ2dsTwOty16rM+/Aqmk/GOXrK8cg7V00=
github.com/aws/aws-sdk-go-v2/config v1.31.6 h1:a1t8fXY4GT4xjyJExz4knbuoxSCacB5hT/WgtfPyLjo=
github.com/aws/aws-sdk-go-v2/config v1.31.6/go.mod h1:5ByscNi7R+ztvOGzeUaIu49vkMk2soq5NaH5PYe33MQ=
github.com/aws/aws-sdk-go-v2/credentials v1.18.10 h1:xdJnXCouCx8Y0NncgoptztUocIYLKeQxrCgN6x9sdhg=
github.com/aws/aws-sdk-go-v2/credentials v1.18.10/go.mod h1:7tQk08ntj914F/5i9jC4+2HQTAuJirq7m1vZVIhEkWs=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.6 h1:wbjnrrMnKew78/juW7I2BtKQwa1qlf6EjQgS69uYY14=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.6/go.mod h1:AtiqqNrDioJXuUgz3+3T0mBWN7Hro2n9wll2zRUc0ww=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.6 h1:uF68eJA6+S9iVr9WgX1NaRGyQ/6MdIyc4JNUo6TN1FA=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.6/go.mod h1:qlPeVZCGPiobx8wb1ft0GHT5l+dc6ldnwInDFaMvC7Y=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.6 h1:pa1DEC6JoI0zduhZePp3zmhWvk/xxm4NB8Hy/Tlsgos=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.6/go.mod h1:gxEjPebnhWGJoaDdtDkA0JX46VRg1wcTHYe63OfX5pE=
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 h1:bIqFDwgGXXN1Kpp99pDOdKMTTb5d2KyU5X/BZxjOkRo=
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3/go.mod h1:H5O/EsxDWyU+LP/V8i5sm8cxoZgc2fdNR9bxlOFrQTo=
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.6 h1:R0tNFJqfjHL3900cqhXuwQ+1K4G0xc9Yf8EDbFXCKEw=
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.6/go.mod h1:y/7sDdu+aJvPtGXr4xYosdpq9a6T9Z0jkXfugmti0rI=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.1 h1:oegbebPEMA/1Jny7kvwejowCaHz1FWZAQ94WXFNCyTM=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.1/go.mod h1:kemo5Myr9ac0U9JfSjMo9yHLtw+pECEHsFtJ9tqCEI8=
github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.8.6 h1:hncKj/4gR+TPauZgTAsxOxNcvBayhUlYZ6LO/BYiQ30=
github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.8.6/go.mod h1:OiIh45tp6HdJDDJGnja0mw8ihQGz3VGrUflLqSL0SmM=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.6 h1:LHS1YAIJXJ4K9zS+1d/xa9JAA9sL2QyXIQCQFQW/X08=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.6/go.mod h1:c9PCiTEuh0wQID5/KqA32J+HAgZxN9tOGXKCiYJjTZI=
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.6 h1:nEXUSAwyUfLTgnc9cxlDWy637qsq4UWwp3sNAfl0Z3Y=
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.6/go.mod h1:HGzIULx4Ge3Do2V0FaiYKcyKzOqwrhUZgCI77NisswQ=
github.com/aws/aws-sdk-go-v2/service/route53 v1.42.3 h1:MmLCRqP4U4Cw9gJ4bNrCG0mWqEtBlmAVleyelcHARMU=
github.com/aws/aws-sdk-go-v2/service/route53 v1.42.3/go.mod h1:AMPjK2YnRh0YgOID3PqhJA1BRNfXDfGOnSsKHtAe8yA=
github.com/aws/aws-sdk-go-v2/service/s3 v1.104.1 h1:yb03KevaOAG5e8suo79Af74vjIQvoeKmjl79WQchLrs=
github.com/aws/aws-sdk-go-v2/service/s3 v1.104.1/go.mod h1:mreYODw0Y4yv7xeczvqC6vciwFao8lPE9k1l1ulfY6E=
github.com/aws/aws-sdk-go-v2/service/signin v1.2.1 h1:BeJmkm5YOZs6lGRGcNoIuLSoTTtGLLCEqlSiRKYodfM=
github.com/aws/aws-sdk-go-v2/service/signin v1.2.1/go.mod h1:LxYujSTLPRlp2vTtcUO/+1ilrew8ytt6SvQyOgejzFQ=
github.com/aws/aws-sdk-go-v2/service/sso v1.31.4 h1:i465b/3c7xJd++pobNIDOggouekCuiWOnB0goQJy+94=
github.com/aws/aws-sdk-go-v2/service/sso v1.31.4/go.mod h1:Lk7PlmoTYryQmyBG0EXqj5BcUbj3whXdU2s3yGI3EAc=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.36.7 h1:xbmJAnBbyYPkTzoCNCF/bpJ6ymQHRdXX1vquYfDIGYk=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.36.7/go.mod h1:Q5N6icH+KJZDLh+ESNwzdv6cZ6vLFF/egy3IOxWhmz4=
github.com/aws/aws-sdk-go-v2/service/sts v1.43.4 h1:Np0vmL7op0Zs5xGacYMMX3v5O5pvZ46xhb5LwDgPj8M=
github.com/aws/aws-sdk-go-v2/service/sts v1.43.4/go.mod h1:r8wkDOuLaaMFqFiYAb8dGY2A3gJCOujMc6CFOVC4Zhc=
github.com/aws/smithy-go v1.27.1 h1:4T340VFndXtADGF52gYa1POyL7s9E4Z1OeZ1hCscIw8=
github.com/aws/smithy-go v1.27.1/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
github.com/aws/aws-sdk-go-v2/service/s3 v1.87.3 h1:ETkfWcXP2KNPLecaDa++5bsQhCRa5M5sLUJa5DWYIIg=
github.com/aws/aws-sdk-go-v2/service/s3 v1.87.3/go.mod h1:+/3ZTqoYb3Ur7DObD00tarKMLMuKg8iqz5CHEanqTnw=
github.com/aws/aws-sdk-go-v2/service/sso v1.29.1 h1:8OLZnVJPvjnrxEwHFg9hVUof/P4sibH+Ea4KKuqAGSg=
github.com/aws/aws-sdk-go-v2/service/sso v1.29.1/go.mod h1:27M3BpVi0C02UiQh1w9nsBEit6pLhlaH3NHna6WUbDE=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.34.2 h1:gKWSTnqudpo8dAxqBqZnDoDWCiEh/40FziUjr/mo6uA=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.34.2/go.mod h1:x7+rkNmRoEN1U13A6JE2fXne9EWyJy54o3n6d4mGaXQ=
github.com/aws/aws-sdk-go-v2/service/sts v1.38.2 h1:YZPjhyaGzhDQEvsffDEcpycq49nl7fiGcfJTIo8BszI=
github.com/aws/aws-sdk-go-v2/service/sts v1.38.2/go.mod h1:2dIN8qhQfv37BdUYGgEC8Q3tteM3zFxTI1MLO2O3J3c=
github.com/aws/smithy-go v1.23.0 h1:8n6I3gXzWJB2DxBDnfxgBaSX6oe0d/t10qGz7OKqMCE=
github.com/aws/smithy-go v1.23.0/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI=
github.com/beevik/etree v1.6.0 h1:u8Kwy8pp9D9XeITj2Z0XtA5qqZEmtJtuXZRQi+j03eE=
github.com/beevik/etree v1.6.0/go.mod h1:bh4zJxiIr62SOf9pRzN7UUYaEDa9HEKafK25+sLc0Gc=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=

View File

@@ -55,14 +55,6 @@ type GrpcClient struct {
connStateCallback ConnStateNotifier
connStateCallbackLock sync.RWMutex
serverURL string
// syncStreamErr holds the last Sync stream error, or nil while the stream
// is established and healthy. GetServerKey succeeds even when the peer
// cannot sync (e.g. the server returns "settings not found"), so the
// health probe must consult this to avoid reporting a healthy management
// connection while the Sync stream keeps failing.
syncStreamMu sync.RWMutex
syncStreamErr error
}
type ExposeRequest struct {
@@ -372,8 +364,6 @@ func (c *GrpcClient) handleSyncStream(ctx context.Context, serverPubKey wgtypes.
stream, err := c.connectToSyncStream(ctx, serverPubKey, sysInfo)
if err != nil {
log.Debugf("failed to open Management Service stream: %s", err)
c.notifyDisconnected(err)
c.setSyncStreamDisconnected(err)
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.PermissionDenied {
return backoff.Permanent(err) // unrecoverable error, propagate to the upper layer
}
@@ -382,13 +372,11 @@ func (c *GrpcClient) handleSyncStream(ctx context.Context, serverPubKey wgtypes.
log.Infof("connected to the Management Service stream")
c.notifyConnected()
c.setSyncStreamConnected()
// blocking until error
err = c.receiveUpdatesEvents(stream, serverPubKey, msgHandler)
if err != nil {
c.notifyDisconnected(err)
c.setSyncStreamDisconnected(err)
if ctx.Err() != nil {
log.Debugf("management connection context has been canceled, this usually indicates shutdown")
return nil
@@ -542,13 +530,6 @@ func (c *GrpcClient) IsHealthy() bool {
log.Warnf("health check returned: %s", err)
return false
}
if syncErr := c.syncStreamError(); syncErr != nil {
c.notifyDisconnected(syncErr)
log.Warnf("management transport is up but the Sync stream is unhealthy: %s", syncErr)
return false
}
c.notifyConnected()
return true
}
@@ -790,24 +771,6 @@ func (c *GrpcClient) SyncMeta(sysInfo *system.Info) error {
return err
}
func (c *GrpcClient) setSyncStreamConnected() {
c.syncStreamMu.Lock()
defer c.syncStreamMu.Unlock()
c.syncStreamErr = nil
}
func (c *GrpcClient) setSyncStreamDisconnected(err error) {
c.syncStreamMu.Lock()
defer c.syncStreamMu.Unlock()
c.syncStreamErr = err
}
func (c *GrpcClient) syncStreamError() error {
c.syncStreamMu.RLock()
defer c.syncStreamMu.RUnlock()
return c.syncStreamErr
}
func (c *GrpcClient) notifyDisconnected(err error) {
c.connStateCallbackLock.RLock()
defer c.connStateCallbackLock.RUnlock()

View File

@@ -85,7 +85,6 @@ type GrpcClient struct {
// receive backpressure as a dead stream: reconnecting cannot help, since the
// new stream feeds the same worker, and only triggers a reconnect storm.
receiveHandoffBlocked atomic.Bool
watchdogWg sync.WaitGroup
}
// NewClient creates a new Signal client
@@ -201,18 +200,10 @@ func (c *GrpcClient) Receive(ctx context.Context, msgHandler func(msg *proto.Mes
// Guard the receive direction: the transport can stay healthy while the
// server stops delivering messages. The watchdog reconnects via cancelStream.
c.markReceived()
c.watchdogWg.Add(1)
go func() {
defer c.watchdogWg.Done()
c.watchReceiveStream(streamCtx, cancelStream)
}()
go c.watchReceiveStream(streamCtx, cancelStream)
// start receiving messages from the Signal stream (from other peers through signal)
err = c.receive(stream)
cancelStream()
c.watchdogWg.Wait()
if err != nil {
// Check the parent context, not streamCtx: a watchdog-triggered
// cancelStream must reconnect, only a parent cancel is shutdown.
@@ -409,12 +400,7 @@ func (c *GrpcClient) encryptMessage(msg *proto.Message) (*proto.EncryptedMessage
// Send sends a message to the remote Peer through the Signal Exchange.
func (c *GrpcClient) Send(msg *proto.Message) error {
return c.send(c.ctx, msg)
}
// send delivers a message deriving per-attempt timeouts from parentCtx, so a
// caller can abort an in-flight send by cancelling that context.
func (c *GrpcClient) send(parentCtx context.Context, msg *proto.Message) error {
if !c.Ready() {
return fmt.Errorf("no connection to signal")
}
@@ -430,7 +416,7 @@ func (c *GrpcClient) send(parentCtx context.Context, msg *proto.Message) error {
if attempt > 1 {
attemptTimeout = time.Duration(attempt) * 5 * time.Second
}
ctx, cancel := context.WithTimeout(parentCtx, attemptTimeout)
ctx, cancel := context.WithTimeout(c.ctx, attemptTimeout)
_, err = c.realClient.Send(ctx, encryptedMessage)
@@ -500,7 +486,7 @@ func (c *GrpcClient) watchReceiveStream(ctx context.Context, cancelStream contex
}
if probeSentAt.IsZero() {
if err := c.sendReceiveProbe(ctx); err != nil {
if err := c.sendReceiveProbe(); err != nil {
log.Debugf("failed to send signal receive probe: %v", err)
}
probeSentAt = time.Now()
@@ -509,13 +495,11 @@ func (c *GrpcClient) watchReceiveStream(ctx context.Context, cancelStream contex
}
}
// sendReceiveProbe sends a self-addressed heartbeat bound to ctx, so cancelStream
// aborts an in-flight probe instead of leaving the watchdog blocked on send timeouts.
// The Signal server routes it back to this client, exercising the exact receive
// path the watchdog guards.
func (c *GrpcClient) sendReceiveProbe(ctx context.Context) error {
// sendReceiveProbe sends a self-addressed heartbeat. The Signal server routes it
// back to this client, exercising the exact receive path the watchdog guards.
func (c *GrpcClient) sendReceiveProbe() error {
self := c.key.PublicKey().String()
return c.send(ctx, &proto.Message{
return c.Send(&proto.Message{
Key: self,
RemoteKey: self,
Body: &proto.Body{Type: proto.Body_HEARTBEAT},
@@ -557,9 +541,6 @@ func (c *GrpcClient) receive(stream proto.SignalExchange_ConnectStreamClient) er
if err := c.decryptionWorker.AddMsg(c.ctx, msg); err != nil {
log.Errorf("failed to add message to decryption worker: %v", err)
}
// Refresh liveness before clearing the flag so the window between here and
// the next Recv does not read a stale timestamp as a dead stream.
c.markReceived()
c.receiveHandoffBlocked.Store(false)
}
}

View File

@@ -2,7 +2,6 @@ package client
import (
"context"
"io"
"net"
"testing"
"time"
@@ -75,7 +74,7 @@ func TestReceiveProbeRoundTrips(t *testing.T) {
t.Fatal("signal stream did not connect within timeout")
}
require.NoError(t, client.sendReceiveProbe(ctx))
require.NoError(t, client.sendReceiveProbe())
select {
case <-received:
@@ -107,72 +106,3 @@ func TestReceiveAliveTreatsHandoffBlockAsLiveness(t *testing.T) {
c.markReceived()
require.True(t, c.receiveAlive(), "a freshly received frame must keep the stream alive")
}
// fakeRecvStream feeds the receive loop frames from a channel and reports EOF
// once the channel is closed. Only Recv is exercised by the loop.
type fakeRecvStream struct {
sigProto.SignalExchange_ConnectStreamClient
frames chan *sigProto.EncryptedMessage
}
func (s *fakeRecvStream) Recv() (*sigProto.EncryptedMessage, error) {
msg, ok := <-s.frames
if !ok {
return nil, io.EOF
}
return msg, nil
}
// TestReceiveLoopRefreshesLivenessAfterBlockedHandoff drives the real receive
// loop into a handoff that blocks past the inactivity threshold, then checks the
// window after the handoff drains but before the next Recv. The loop must have
// refreshed the timestamp on unblocking, otherwise that window reads the stale
// pre-handoff timestamp as a dead stream and the watchdog tears down a healthy
// connection.
func TestReceiveLoopRefreshesLivenessAfterBlockedHandoff(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
c := &GrpcClient{ctx: ctx}
handling := make(chan struct{}, 8)
gate := make(chan struct{})
decrypt := func(*sigProto.EncryptedMessage) (*sigProto.Message, error) { return &sigProto.Message{}, nil }
handler := func(*sigProto.Message) error {
handling <- struct{}{}
<-gate
return nil
}
c.decryptionWorker = NewWorker(decrypt, handler)
workerCtx, workerCancel := context.WithCancel(context.Background())
go c.decryptionWorker.Work(workerCtx)
t.Cleanup(workerCancel)
frames := make(chan *sigProto.EncryptedMessage)
t.Cleanup(func() { close(frames) })
go func() { _ = c.receive(&fakeRecvStream{frames: frames}) }()
// First frame: the worker drains it and parks in the blocking handler.
frames <- &sigProto.EncryptedMessage{}
<-handling
// Second frame fills the worker's single-slot pool.
frames <- &sigProto.EncryptedMessage{}
// Third frame: the pool is full, so the loop parks on the handoff.
frames <- &sigProto.EncryptedMessage{}
require.Eventually(t, c.receiveHandoffBlocked.Load, time.Second, time.Millisecond,
"receive loop should park on the worker handoff")
// Simulate the handoff having blocked past the inactivity threshold.
c.lastReceived.Store(time.Now().Add(-2 * receiveInactivityThreshold).UnixNano())
require.True(t, c.receiveAlive(), "a loop parked on the handoff must stay alive")
// Drain the worker so the handoff returns and the loop resumes reading.
close(gate)
// Once the handoff clears, the loop is parked on the next Recv with no frame
// pending. The stream must still read as alive in that window.
require.Eventually(t, func() bool { return !c.receiveHandoffBlocked.Load() }, time.Second, time.Millisecond,
"handoff should drain once the worker is released")
require.True(t, c.receiveAlive(),
"the loop must refresh liveness when the handoff drains, before the next Recv")
}