mirror of
https://github.com/netbirdio/netbird.git
synced 2026-07-01 12:19:56 +00:00
Compare commits
2 Commits
fix/profil
...
refactor/r
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
50a29c07ce | ||
|
|
7d8e20030b |
@@ -33,7 +33,7 @@
|
||||
<br/>
|
||||
<br/>
|
||||
<strong>
|
||||
🚀 <a href="https://netbird.io/careers">We are hiring! Join us at https://netbird.io/careers</a>
|
||||
🚀 <a href="https://careers.netbird.io">We are hiring! Join us at careers.netbird.io</a>
|
||||
</strong>
|
||||
</p>
|
||||
|
||||
|
||||
@@ -136,11 +136,6 @@ func (p *ProxyBind) CloseConn() error {
|
||||
return p.close()
|
||||
}
|
||||
|
||||
// InjectPacket is a no-op for the userspace proxy: first-packet reinjection is kernel-only.
|
||||
func (p *ProxyBind) InjectPacket(_ []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *ProxyBind) close() error {
|
||||
if p.remoteConn == nil {
|
||||
return nil
|
||||
|
||||
@@ -219,17 +219,6 @@ func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) {
|
||||
p.pausedCond.L.Unlock()
|
||||
}
|
||||
|
||||
// InjectPacket writes b to the remote peer over the underlying transport.
|
||||
func (p *ProxyWrapper) InjectPacket(b []byte) error {
|
||||
if p.remoteConn == nil {
|
||||
return errors.New("proxy not started")
|
||||
}
|
||||
if _, err := p.remoteConn.Write(b); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CloseConn close the remoteConn and automatically remove the conn instance from the map
|
||||
func (p *ProxyWrapper) CloseConn() error {
|
||||
if p.cancel == nil {
|
||||
|
||||
@@ -18,9 +18,4 @@ type Proxy interface {
|
||||
RedirectAs(endpoint *net.UDPAddr)
|
||||
CloseConn() error
|
||||
SetDisconnectListener(disconnected func())
|
||||
|
||||
// InjectPacket writes a raw packet directly to the remote peer over the underlying transport,
|
||||
// bypassing WireGuard. Used to replay the captured lazyconn handshake initiation. Only the
|
||||
// kernel-mode proxies act on it; the userspace proxy is a no-op since reinjection is kernel-only.
|
||||
InjectPacket(b []byte) error
|
||||
}
|
||||
|
||||
@@ -147,17 +147,6 @@ func (p *WGUDPProxy) RedirectAs(endpoint *net.UDPAddr) {
|
||||
p.sendPkg = p.srcFakerConn.SendPkg
|
||||
}
|
||||
|
||||
// InjectPacket writes b to the remote peer over the underlying transport.
|
||||
func (p *WGUDPProxy) InjectPacket(b []byte) error {
|
||||
if p.remoteConn == nil {
|
||||
return errors.New("proxy not started")
|
||||
}
|
||||
if _, err := p.remoteConn.Write(b); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CloseConn close the localConn
|
||||
func (p *WGUDPProxy) CloseConn() error {
|
||||
if p.cancel == nil {
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/mitchellh/hashstructure/v2"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
@@ -31,13 +30,11 @@ type Manager interface {
|
||||
|
||||
// DefaultManager uses firewall manager to handle
|
||||
type DefaultManager struct {
|
||||
firewall firewall.Manager
|
||||
ipsetCounter int
|
||||
peerRulesPairs map[id.RuleID][]firewall.Rule
|
||||
routeRules map[id.RuleID]struct{}
|
||||
previousConfigHash uint64
|
||||
hasAppliedConfig bool
|
||||
mutex sync.Mutex
|
||||
firewall firewall.Manager
|
||||
ipsetCounter int
|
||||
peerRulesPairs map[id.RuleID][]firewall.Rule
|
||||
routeRules map[id.RuleID]struct{}
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
func NewDefaultManager(fm firewall.Manager) *DefaultManager {
|
||||
@@ -60,23 +57,6 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
|
||||
return
|
||||
}
|
||||
|
||||
// Skip the full rebuild + flush when the inputs that drive the firewall
|
||||
// state are byte-for-byte identical to the last successfully applied
|
||||
// update. Management re-sends the same network map far more often than it
|
||||
// actually changes (account-wide updates, peer meta churn), and rebuilding
|
||||
// every peer/route ACL and flushing the firewall on every such sync is the
|
||||
// dominant client-side cost when nothing changed. Mirrors the same guard the
|
||||
// DNS server already uses (previousConfigHash). Only the fields ApplyFiltering
|
||||
// consumes participate in the hash, so an unrelated map change cannot mask a
|
||||
// real ACL change.
|
||||
hash, err := d.firewallConfigHash(networkMap, dnsRouteFeatureFlag)
|
||||
if err != nil {
|
||||
log.Errorf("unable to hash firewall configuration, applying unconditionally: %v", err)
|
||||
} else if d.hasAppliedConfig && d.previousConfigHash == hash {
|
||||
log.Debugf("not applying the firewall configuration update as there is nothing new (hash: %d)", hash)
|
||||
return
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
total := 0
|
||||
@@ -90,49 +70,13 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
|
||||
|
||||
d.applyPeerACLs(networkMap)
|
||||
|
||||
routeErr := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag)
|
||||
if routeErr != nil {
|
||||
log.Errorf("Failed to apply route ACLs: %v", routeErr)
|
||||
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag); err != nil {
|
||||
log.Errorf("Failed to apply route ACLs: %v", err)
|
||||
}
|
||||
|
||||
flushErr := d.firewall.Flush()
|
||||
if flushErr != nil {
|
||||
log.Error("failed to flush firewall rules: ", flushErr)
|
||||
if err := d.firewall.Flush(); err != nil {
|
||||
log.Error("failed to flush firewall rules: ", err)
|
||||
}
|
||||
|
||||
// Only remember the hash once the firewall actually reflects this config.
|
||||
// If applying or flushing failed, leave the previous hash untouched so the
|
||||
// next (possibly identical) update is not skipped and gets a chance to
|
||||
// reconcile the firewall state.
|
||||
if err == nil && routeErr == nil && flushErr == nil {
|
||||
d.previousConfigHash = hash
|
||||
d.hasAppliedConfig = true
|
||||
} else {
|
||||
d.hasAppliedConfig = false
|
||||
}
|
||||
}
|
||||
|
||||
// firewallConfigHash hashes exactly the inputs ApplyFiltering uses to build the
|
||||
// firewall state, so an identical hash means an identical resulting ruleset.
|
||||
func (d *DefaultManager) firewallConfigHash(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool) (uint64, error) {
|
||||
return hashstructure.Hash(struct {
|
||||
PeerRules []*mgmProto.FirewallRule
|
||||
PeerRulesIsEmpty bool
|
||||
RouteRules []*mgmProto.RouteFirewallRule
|
||||
RouteRulesIsEmpty bool
|
||||
DNSRouteFeatureFlag bool
|
||||
}{
|
||||
PeerRules: networkMap.GetFirewallRules(),
|
||||
PeerRulesIsEmpty: networkMap.GetFirewallRulesIsEmpty(),
|
||||
RouteRules: networkMap.GetRoutesFirewallRules(),
|
||||
RouteRulesIsEmpty: networkMap.GetRoutesFirewallRulesIsEmpty(),
|
||||
DNSRouteFeatureFlag: dnsRouteFeatureFlag,
|
||||
}, hashstructure.FormatV2, &hashstructure.HashOptions{
|
||||
ZeroNil: true,
|
||||
IgnoreZeroValue: true,
|
||||
SlicesAsSets: true,
|
||||
UseStringer: true,
|
||||
})
|
||||
}
|
||||
|
||||
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package acl
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
@@ -486,149 +485,3 @@ func TestPortInfoEmpty(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestApplyFilteringSkipsUnchangedConfig verifies that an identical network map
|
||||
// re-applied is recognized as a no-op (hash unchanged), while a real change to
|
||||
// any firewall-relevant input forces a re-apply (hash changes). This is the
|
||||
// guard that prevents a full ruleset rebuild + flush on every redundant sync.
|
||||
func TestApplyFilteringSkipsUnchangedConfig(t *testing.T) {
|
||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||
IP: network.Addr(),
|
||||
Network: network,
|
||||
}).AnyTimes()
|
||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, fw.Close(nil))
|
||||
}()
|
||||
|
||||
acl := NewDefaultManager(fw)
|
||||
|
||||
networkMap := &mgmProto.NetworkMap{
|
||||
FirewallRules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "22",
|
||||
},
|
||||
},
|
||||
FirewallRulesIsEmpty: false,
|
||||
}
|
||||
|
||||
acl.ApplyFiltering(networkMap, false)
|
||||
require.True(t, acl.hasAppliedConfig, "config should be marked applied after first apply")
|
||||
firstHash := acl.previousConfigHash
|
||||
require.NotZero(t, firstHash)
|
||||
|
||||
// Re-applying the identical map must not change the recorded hash: the
|
||||
// expensive rebuild path was skipped.
|
||||
acl.ApplyFiltering(networkMap, false)
|
||||
assert.Equal(t, firstHash, acl.previousConfigHash,
|
||||
"identical re-apply must be a no-op (hash unchanged)")
|
||||
|
||||
// A real change must produce a different hash and re-apply.
|
||||
networkMap.FirewallRules[0].Action = mgmProto.RuleAction_DROP
|
||||
acl.ApplyFiltering(networkMap, false)
|
||||
assert.NotEqual(t, firstHash, acl.previousConfigHash,
|
||||
"changing a rule's action must force a re-apply (hash changed)")
|
||||
|
||||
// The dnsRouteFeatureFlag also participates in the hash.
|
||||
changedHash := acl.previousConfigHash
|
||||
acl.ApplyFiltering(networkMap, true)
|
||||
assert.NotEqual(t, changedHash, acl.previousConfigHash,
|
||||
"flipping dnsRouteFeatureFlag must force a re-apply (hash changed)")
|
||||
}
|
||||
|
||||
func buildNetworkMap(peerRules, routeRules int) *mgmProto.NetworkMap {
|
||||
nm := &mgmProto.NetworkMap{
|
||||
FirewallRulesIsEmpty: peerRules == 0,
|
||||
RoutesFirewallRulesIsEmpty: routeRules == 0,
|
||||
}
|
||||
for i := range peerRules {
|
||||
nm.FirewallRules = append(nm.FirewallRules, &mgmProto.FirewallRule{
|
||||
PeerIP: fmt.Sprintf("10.%d.%d.%d", i>>16&0xff, i>>8&0xff, i&0xff),
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: fmt.Sprintf("%d", 1024+i%64511),
|
||||
})
|
||||
}
|
||||
for i := range routeRules {
|
||||
nm.RoutesFirewallRules = append(nm.RoutesFirewallRules, &mgmProto.RouteFirewallRule{
|
||||
Destination: fmt.Sprintf("192.168.%d.0/24", i%256),
|
||||
SourceRanges: []string{fmt.Sprintf("10.0.%d.0/24", i%256)},
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
})
|
||||
}
|
||||
return nm
|
||||
}
|
||||
|
||||
func BenchmarkFirewallConfigHash_Small(b *testing.B) {
|
||||
d := &DefaultManager{}
|
||||
nm := buildNetworkMap(10, 5)
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
_, _ = d.firewallConfigHash(nm, false)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkFirewallConfigHash_Medium(b *testing.B) {
|
||||
d := &DefaultManager{}
|
||||
nm := buildNetworkMap(100, 50)
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
_, _ = d.firewallConfigHash(nm, false)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkFirewallConfigHash_Large(b *testing.B) {
|
||||
d := &DefaultManager{}
|
||||
nm := buildNetworkMap(1000, 200)
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
_, _ = d.firewallConfigHash(nm, false)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFirewallConfigHashDeterministic verifies the hash is stable for equal
|
||||
// inputs and order-independent for the rule slices (management does not
|
||||
// guarantee rule order).
|
||||
func TestFirewallConfigHashDeterministic(t *testing.T) {
|
||||
d := &DefaultManager{}
|
||||
|
||||
nm1 := &mgmProto.NetworkMap{
|
||||
FirewallRules: []*mgmProto.FirewallRule{
|
||||
{PeerIP: "10.0.0.1", Direction: mgmProto.RuleDirection_IN, Action: mgmProto.RuleAction_ACCEPT, Protocol: mgmProto.RuleProtocol_TCP, Port: "22"},
|
||||
{PeerIP: "10.0.0.2", Direction: mgmProto.RuleDirection_IN, Action: mgmProto.RuleAction_DROP, Protocol: mgmProto.RuleProtocol_TCP, Port: "80"},
|
||||
},
|
||||
}
|
||||
// Same rules, reversed order.
|
||||
nm2 := &mgmProto.NetworkMap{
|
||||
FirewallRules: []*mgmProto.FirewallRule{
|
||||
nm1.FirewallRules[1],
|
||||
nm1.FirewallRules[0],
|
||||
},
|
||||
}
|
||||
|
||||
h1, err := d.firewallConfigHash(nm1, false)
|
||||
require.NoError(t, err)
|
||||
h2, err := d.firewallConfigHash(nm2, false)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, h1, h2, "hash must be order-independent for rule slices")
|
||||
}
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
@@ -168,10 +167,7 @@ func getRcodeForNotFound(ctx context.Context, r resolver, domain string, origina
|
||||
case dns.TypeA:
|
||||
alternativeNetwork = "ip6"
|
||||
default:
|
||||
// Non-address types reach LookupIP only unexpectedly; without an
|
||||
// address pair to probe we cannot prove the name is absent, so answer
|
||||
// NODATA rather than a poisoning NXDOMAIN.
|
||||
return dns.RcodeSuccess
|
||||
return dns.RcodeNameError
|
||||
}
|
||||
|
||||
if _, err := r.LookupNetIP(ctx, alternativeNetwork, domain); err != nil {
|
||||
@@ -188,230 +184,6 @@ func getRcodeForNotFound(ctx context.Context, r resolver, domain string, origina
|
||||
return dns.RcodeSuccess
|
||||
}
|
||||
|
||||
// RecordResolver is the host resolver surface used to forward non-address
|
||||
// record queries. net.DefaultResolver satisfies it.
|
||||
type RecordResolver interface {
|
||||
LookupMX(ctx context.Context, name string) ([]*net.MX, error)
|
||||
LookupTXT(ctx context.Context, name string) ([]string, error)
|
||||
LookupNS(ctx context.Context, name string) ([]*net.NS, error)
|
||||
LookupSRV(ctx context.Context, service, proto, name string) (string, []*net.SRV, error)
|
||||
LookupCNAME(ctx context.Context, host string) (string, error)
|
||||
LookupAddr(ctx context.Context, addr string) ([]string, error)
|
||||
}
|
||||
|
||||
// LookupRecords resolves a non-address DNS record type through the host
|
||||
// resolver and returns the resource records and the DNS rcode. Types the host
|
||||
// resolver cannot answer (anything not covered by the net.Resolver Lookup*
|
||||
// methods) yield NODATA so that a routed name is never poisoned with NXDOMAIN
|
||||
// for an unsupported type.
|
||||
func LookupRecords(ctx context.Context, r RecordResolver, name string, qtype uint16, ttl uint32) ([]dns.RR, int) {
|
||||
fqdn := dns.Fqdn(name)
|
||||
|
||||
switch qtype {
|
||||
case dns.TypeMX:
|
||||
return lookupMX(ctx, r, name, fqdn, ttl)
|
||||
case dns.TypeTXT:
|
||||
return lookupTXT(ctx, r, name, fqdn, ttl)
|
||||
case dns.TypeNS:
|
||||
return lookupNS(ctx, r, name, fqdn, ttl)
|
||||
case dns.TypeSRV:
|
||||
return lookupSRV(ctx, r, name, fqdn, ttl)
|
||||
case dns.TypeCNAME:
|
||||
return lookupCNAME(ctx, r, name, fqdn, ttl)
|
||||
case dns.TypePTR:
|
||||
return lookupPTR(ctx, r, name, fqdn, ttl)
|
||||
default:
|
||||
return nil, dns.RcodeSuccess
|
||||
}
|
||||
}
|
||||
|
||||
func recordHeader(fqdn string, rrtype uint16, ttl uint32) dns.RR_Header {
|
||||
return dns.RR_Header{Name: fqdn, Rrtype: rrtype, Class: dns.ClassINET, Ttl: ttl}
|
||||
}
|
||||
|
||||
func lookupMX(ctx context.Context, r RecordResolver, name, fqdn string, ttl uint32) ([]dns.RR, int) {
|
||||
recs, err := r.LookupMX(ctx, name)
|
||||
if err != nil {
|
||||
return nil, rcodeForRecordError(err)
|
||||
}
|
||||
rrs := make([]dns.RR, 0, len(recs))
|
||||
for _, mx := range recs {
|
||||
rrs = append(rrs, &dns.MX{
|
||||
Hdr: recordHeader(fqdn, dns.TypeMX, ttl),
|
||||
Preference: mx.Pref,
|
||||
Mx: dns.Fqdn(mx.Host),
|
||||
})
|
||||
}
|
||||
return rrs, dns.RcodeSuccess
|
||||
}
|
||||
|
||||
func lookupTXT(ctx context.Context, r RecordResolver, name, fqdn string, ttl uint32) ([]dns.RR, int) {
|
||||
recs, err := r.LookupTXT(ctx, name)
|
||||
if err != nil {
|
||||
return nil, rcodeForRecordError(err)
|
||||
}
|
||||
rrs := make([]dns.RR, 0, len(recs))
|
||||
for _, txt := range recs {
|
||||
rrs = append(rrs, &dns.TXT{
|
||||
Hdr: recordHeader(fqdn, dns.TypeTXT, ttl),
|
||||
Txt: chunkTXT(txt),
|
||||
})
|
||||
}
|
||||
return rrs, dns.RcodeSuccess
|
||||
}
|
||||
|
||||
func lookupNS(ctx context.Context, r RecordResolver, name, fqdn string, ttl uint32) ([]dns.RR, int) {
|
||||
recs, err := r.LookupNS(ctx, name)
|
||||
if err != nil {
|
||||
return nil, rcodeForRecordError(err)
|
||||
}
|
||||
rrs := make([]dns.RR, 0, len(recs))
|
||||
for _, ns := range recs {
|
||||
rrs = append(rrs, &dns.NS{
|
||||
Hdr: recordHeader(fqdn, dns.TypeNS, ttl),
|
||||
Ns: dns.Fqdn(ns.Host),
|
||||
})
|
||||
}
|
||||
return rrs, dns.RcodeSuccess
|
||||
}
|
||||
|
||||
func lookupSRV(ctx context.Context, r RecordResolver, name, fqdn string, ttl uint32) ([]dns.RR, int) {
|
||||
_, recs, err := r.LookupSRV(ctx, "", "", name)
|
||||
if err != nil {
|
||||
return nil, rcodeForRecordError(err)
|
||||
}
|
||||
rrs := make([]dns.RR, 0, len(recs))
|
||||
for _, srv := range recs {
|
||||
rrs = append(rrs, &dns.SRV{
|
||||
Hdr: recordHeader(fqdn, dns.TypeSRV, ttl),
|
||||
Priority: srv.Priority,
|
||||
Weight: srv.Weight,
|
||||
Port: srv.Port,
|
||||
Target: dns.Fqdn(srv.Target),
|
||||
})
|
||||
}
|
||||
return rrs, dns.RcodeSuccess
|
||||
}
|
||||
|
||||
func lookupCNAME(ctx context.Context, r RecordResolver, name, fqdn string, ttl uint32) ([]dns.RR, int) {
|
||||
cname, err := r.LookupCNAME(ctx, name)
|
||||
if err != nil {
|
||||
return nil, rcodeForRecordError(err)
|
||||
}
|
||||
// LookupCNAME returns the queried name itself when the name resolves but
|
||||
// has no CNAME record; that is a NODATA result, not a CNAME.
|
||||
if strings.EqualFold(dns.Fqdn(cname), fqdn) {
|
||||
return nil, dns.RcodeSuccess
|
||||
}
|
||||
return []dns.RR{&dns.CNAME{
|
||||
Hdr: recordHeader(fqdn, dns.TypeCNAME, ttl),
|
||||
Target: dns.Fqdn(cname),
|
||||
}}, dns.RcodeSuccess
|
||||
}
|
||||
|
||||
func lookupPTR(ctx context.Context, r RecordResolver, name, fqdn string, ttl uint32) ([]dns.RR, int) {
|
||||
addr, ok := ptrQueryAddr(name)
|
||||
if !ok {
|
||||
return nil, dns.RcodeSuccess
|
||||
}
|
||||
names, err := r.LookupAddr(ctx, addr)
|
||||
if err != nil {
|
||||
return nil, rcodeForRecordError(err)
|
||||
}
|
||||
rrs := make([]dns.RR, 0, len(names))
|
||||
for _, n := range names {
|
||||
rrs = append(rrs, &dns.PTR{
|
||||
Hdr: recordHeader(fqdn, dns.TypePTR, ttl),
|
||||
Ptr: dns.Fqdn(n),
|
||||
})
|
||||
}
|
||||
return rrs, dns.RcodeSuccess
|
||||
}
|
||||
|
||||
// ptrQueryAddr converts a reverse-DNS query name (in-addr.arpa or ip6.arpa)
|
||||
// into the address string expected by net.Resolver.LookupAddr. It reports false
|
||||
// when the name is not a well-formed reverse name.
|
||||
func ptrQueryAddr(qname string) (string, bool) {
|
||||
name := strings.TrimSuffix(strings.ToLower(dns.Fqdn(qname)), ".")
|
||||
|
||||
switch {
|
||||
case strings.HasSuffix(name, ".in-addr.arpa"):
|
||||
return parseInAddrArpa(strings.TrimSuffix(name, ".in-addr.arpa"))
|
||||
case strings.HasSuffix(name, ".ip6.arpa"):
|
||||
return parseIP6Arpa(strings.TrimSuffix(name, ".ip6.arpa"))
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
// parseInAddrArpa turns the label portion of an in-addr.arpa name into an IPv4
|
||||
// address string, reporting false when it is not a well-formed reverse name.
|
||||
func parseInAddrArpa(labelPart string) (string, bool) {
|
||||
labels := strings.Split(labelPart, ".")
|
||||
if len(labels) != 4 {
|
||||
return "", false
|
||||
}
|
||||
slices.Reverse(labels)
|
||||
addr, err := netip.ParseAddr(strings.Join(labels, "."))
|
||||
if err != nil || !addr.Is4() {
|
||||
return "", false
|
||||
}
|
||||
return addr.String(), true
|
||||
}
|
||||
|
||||
// parseIP6Arpa turns the nibble portion of an ip6.arpa name into an IPv6
|
||||
// address string, reporting false when it is not a well-formed reverse name.
|
||||
func parseIP6Arpa(nibblePart string) (string, bool) {
|
||||
nibbles := strings.Split(nibblePart, ".")
|
||||
if len(nibbles) != 32 {
|
||||
return "", false
|
||||
}
|
||||
slices.Reverse(nibbles)
|
||||
var sb strings.Builder
|
||||
for i, n := range nibbles {
|
||||
if i > 0 && i%4 == 0 {
|
||||
sb.WriteByte(':')
|
||||
}
|
||||
sb.WriteString(n)
|
||||
}
|
||||
addr, err := netip.ParseAddr(sb.String())
|
||||
if err != nil || !addr.Is6() {
|
||||
return "", false
|
||||
}
|
||||
return addr.String(), true
|
||||
}
|
||||
|
||||
// rcodeForRecordError maps a non-address lookup error to a DNS rcode. A
|
||||
// not-found result becomes NODATA rather than NXDOMAIN: net.DNSError.IsNotFound
|
||||
// does not distinguish a missing name from a name that exists only with records
|
||||
// of other types, so the name cannot be proven absent and must not be poisoned.
|
||||
func rcodeForRecordError(err error) int {
|
||||
var dnsErr *net.DNSError
|
||||
if errors.As(err, &dnsErr) && dnsErr.IsNotFound {
|
||||
return dns.RcodeSuccess
|
||||
}
|
||||
return dns.RcodeServerFailure
|
||||
}
|
||||
|
||||
// chunkTXT splits a TXT string into character-strings no longer than 255 bytes
|
||||
// so the record can be packed. The chunks form one TXT resource record.
|
||||
func chunkTXT(s string) []string {
|
||||
const maxLen = 255
|
||||
if len(s) <= maxLen {
|
||||
return []string{s}
|
||||
}
|
||||
|
||||
var chunks []string
|
||||
for len(s) > maxLen {
|
||||
chunks = append(chunks, s[:maxLen])
|
||||
s = s[maxLen:]
|
||||
}
|
||||
if len(s) > 0 {
|
||||
chunks = append(chunks, s)
|
||||
}
|
||||
return chunks
|
||||
}
|
||||
|
||||
// FormatAnswers formats DNS resource records for logging.
|
||||
func FormatAnswers(answers []dns.RR) string {
|
||||
if len(answers) == 0 {
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
@@ -122,164 +121,6 @@ func TestLookupIP_DNSErrorNotIsNotFound(t *testing.T) {
|
||||
assert.Equal(t, dns.RcodeServerFailure, result.Rcode, "upstream failure should map to SERVFAIL")
|
||||
}
|
||||
|
||||
func TestPtrQueryAddr(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
qname string
|
||||
want string
|
||||
wantOK bool
|
||||
}{
|
||||
{name: "ipv4", qname: "4.3.2.1.in-addr.arpa.", want: "1.2.3.4", wantOK: true},
|
||||
{name: "ipv4 no trailing dot", qname: "1.0.0.127.in-addr.arpa", want: "127.0.0.1", wantOK: true},
|
||||
{
|
||||
name: "ipv6",
|
||||
qname: "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.",
|
||||
want: "2001:db8::1",
|
||||
wantOK: true,
|
||||
},
|
||||
{name: "ipv4 wrong label count", qname: "2.1.in-addr.arpa.", wantOK: false},
|
||||
{name: "ipv6 wrong nibble count", qname: "1.0.ip6.arpa.", wantOK: false},
|
||||
{name: "not a reverse name", qname: "example.com.", wantOK: false},
|
||||
{name: "ipv4 bad octet", qname: "4.3.2.999.in-addr.arpa.", wantOK: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, ok := ptrQueryAddr(tt.qname)
|
||||
assert.Equal(t, tt.wantOK, ok, "parse success mismatch")
|
||||
if tt.wantOK {
|
||||
assert.Equal(t, tt.want, got, "parsed address mismatch")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type mockRecordResolver struct {
|
||||
mx []*net.MX
|
||||
txt []string
|
||||
ns []*net.NS
|
||||
srv []*net.SRV
|
||||
cname string
|
||||
ptr []string
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockRecordResolver) LookupMX(context.Context, string) ([]*net.MX, error) {
|
||||
return m.mx, m.err
|
||||
}
|
||||
func (m *mockRecordResolver) LookupTXT(context.Context, string) ([]string, error) {
|
||||
return m.txt, m.err
|
||||
}
|
||||
func (m *mockRecordResolver) LookupNS(context.Context, string) ([]*net.NS, error) {
|
||||
return m.ns, m.err
|
||||
}
|
||||
func (m *mockRecordResolver) LookupSRV(context.Context, string, string, string) (string, []*net.SRV, error) {
|
||||
return "", m.srv, m.err
|
||||
}
|
||||
func (m *mockRecordResolver) LookupCNAME(context.Context, string) (string, error) {
|
||||
return m.cname, m.err
|
||||
}
|
||||
func (m *mockRecordResolver) LookupAddr(context.Context, string) ([]string, error) {
|
||||
return m.ptr, m.err
|
||||
}
|
||||
|
||||
func TestLookupRecords(t *testing.T) {
|
||||
notFound := &net.DNSError{IsNotFound: true, Name: "example.com."}
|
||||
|
||||
t.Run("MX success", func(t *testing.T) {
|
||||
r := &mockRecordResolver{mx: []*net.MX{{Host: "mail.example.com.", Pref: 10}}}
|
||||
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeMX, 300)
|
||||
assert.Equal(t, dns.RcodeSuccess, rcode)
|
||||
require.Len(t, rrs, 1)
|
||||
assert.Equal(t, "mail.example.com.", rrs[0].(*dns.MX).Mx)
|
||||
})
|
||||
|
||||
t.Run("TXT short string is one character-string", func(t *testing.T) {
|
||||
r := &mockRecordResolver{txt: []string{"v=spf1 -all"}}
|
||||
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeTXT, 300)
|
||||
assert.Equal(t, dns.RcodeSuccess, rcode)
|
||||
require.Len(t, rrs, 1)
|
||||
assert.Equal(t, []string{"v=spf1 -all"}, rrs[0].(*dns.TXT).Txt)
|
||||
})
|
||||
|
||||
t.Run("TXT chunks long strings", func(t *testing.T) {
|
||||
long := strings.Repeat("a", 300)
|
||||
r := &mockRecordResolver{txt: []string{long}}
|
||||
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeTXT, 300)
|
||||
assert.Equal(t, dns.RcodeSuccess, rcode)
|
||||
require.Len(t, rrs, 1)
|
||||
txt := rrs[0].(*dns.TXT).Txt
|
||||
require.Len(t, txt, 2, "300-byte string should split into two character-strings")
|
||||
assert.Equal(t, 255, len(txt[0]))
|
||||
assert.Equal(t, 45, len(txt[1]))
|
||||
})
|
||||
|
||||
t.Run("NS success", func(t *testing.T) {
|
||||
r := &mockRecordResolver{ns: []*net.NS{{Host: "ns1.example.com."}}}
|
||||
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeNS, 300)
|
||||
assert.Equal(t, dns.RcodeSuccess, rcode)
|
||||
require.Len(t, rrs, 1)
|
||||
assert.Equal(t, "ns1.example.com.", rrs[0].(*dns.NS).Ns)
|
||||
})
|
||||
|
||||
t.Run("SRV success", func(t *testing.T) {
|
||||
r := &mockRecordResolver{srv: []*net.SRV{{Target: "sip.example.com.", Port: 5060}}}
|
||||
rrs, rcode := LookupRecords(context.Background(), r, "_sip._tcp.example.com.", dns.TypeSRV, 300)
|
||||
assert.Equal(t, dns.RcodeSuccess, rcode)
|
||||
require.Len(t, rrs, 1)
|
||||
assert.Equal(t, uint16(5060), rrs[0].(*dns.SRV).Port)
|
||||
})
|
||||
|
||||
t.Run("CNAME success", func(t *testing.T) {
|
||||
r := &mockRecordResolver{cname: "target.example.com."}
|
||||
rrs, rcode := LookupRecords(context.Background(), r, "www.example.com.", dns.TypeCNAME, 300)
|
||||
assert.Equal(t, dns.RcodeSuccess, rcode)
|
||||
require.Len(t, rrs, 1)
|
||||
assert.Equal(t, "target.example.com.", rrs[0].(*dns.CNAME).Target)
|
||||
})
|
||||
|
||||
t.Run("CNAME equal to name is NODATA", func(t *testing.T) {
|
||||
r := &mockRecordResolver{cname: "example.com."}
|
||||
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeCNAME, 300)
|
||||
assert.Equal(t, dns.RcodeSuccess, rcode)
|
||||
assert.Empty(t, rrs, "self-referential CNAME is NODATA")
|
||||
})
|
||||
|
||||
t.Run("PTR success", func(t *testing.T) {
|
||||
r := &mockRecordResolver{ptr: []string{"host.example.com."}}
|
||||
rrs, rcode := LookupRecords(context.Background(), r, "4.3.2.1.in-addr.arpa.", dns.TypePTR, 300)
|
||||
assert.Equal(t, dns.RcodeSuccess, rcode)
|
||||
require.Len(t, rrs, 1)
|
||||
assert.Equal(t, "host.example.com.", rrs[0].(*dns.PTR).Ptr)
|
||||
})
|
||||
|
||||
t.Run("PTR malformed name is NODATA", func(t *testing.T) {
|
||||
r := &mockRecordResolver{}
|
||||
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypePTR, 300)
|
||||
assert.Equal(t, dns.RcodeSuccess, rcode)
|
||||
assert.Empty(t, rrs)
|
||||
})
|
||||
|
||||
t.Run("not found is NODATA never NXDOMAIN", func(t *testing.T) {
|
||||
r := &mockRecordResolver{err: notFound}
|
||||
_, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeMX, 300)
|
||||
assert.Equal(t, dns.RcodeSuccess, rcode, "missing record must not poison the name")
|
||||
})
|
||||
|
||||
t.Run("server failure maps to SERVFAIL", func(t *testing.T) {
|
||||
r := &mockRecordResolver{err: &net.DNSError{Err: "server misbehaving", IsTemporary: true}}
|
||||
_, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeMX, 300)
|
||||
assert.Equal(t, dns.RcodeServerFailure, rcode)
|
||||
})
|
||||
|
||||
t.Run("unsupported type is NODATA", func(t *testing.T) {
|
||||
r := &mockRecordResolver{}
|
||||
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeCAA, 300)
|
||||
assert.Equal(t, dns.RcodeSuccess, rcode)
|
||||
assert.Empty(t, rrs)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStripOPT(t *testing.T) {
|
||||
rm := &dns.Msg{
|
||||
Extra: []dns.RR{
|
||||
|
||||
@@ -37,12 +37,6 @@ const (
|
||||
|
||||
type resolver interface {
|
||||
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
|
||||
LookupMX(ctx context.Context, name string) ([]*net.MX, error)
|
||||
LookupTXT(ctx context.Context, name string) ([]string, error)
|
||||
LookupNS(ctx context.Context, name string) ([]*net.NS, error)
|
||||
LookupSRV(ctx context.Context, service, proto, name string) (string, []*net.SRV, error)
|
||||
LookupCNAME(ctx context.Context, host string) (string, error)
|
||||
LookupAddr(ctx context.Context, addr string) ([]string, error)
|
||||
}
|
||||
|
||||
type firewaller interface {
|
||||
@@ -216,6 +210,12 @@ func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, q
|
||||
qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
|
||||
|
||||
resp := query.SetReply(query)
|
||||
network := resutil.NetworkForQtype(question.Qtype)
|
||||
if network == "" {
|
||||
resp.Rcode = dns.RcodeNotImplemented
|
||||
f.writeResponse(logger, w, resp, qname, startTime)
|
||||
return
|
||||
}
|
||||
|
||||
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(qname, "."))
|
||||
if mostSpecificResId == "" {
|
||||
@@ -227,46 +227,9 @@ func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, q
|
||||
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
|
||||
defer cancel()
|
||||
|
||||
reqHasEdns := query.IsEdns0() != nil
|
||||
|
||||
switch question.Qtype {
|
||||
case dns.TypeA, dns.TypeAAAA:
|
||||
f.handleAddressQuery(ctx, logger, w, resp, mostSpecificResId, matchingEntries, reqHasEdns, startTime)
|
||||
case dns.TypeMX, dns.TypeTXT, dns.TypeNS, dns.TypeSRV, dns.TypeCNAME, dns.TypePTR:
|
||||
f.handleRecordQuery(ctx, logger, w, resp, startTime)
|
||||
default:
|
||||
// The domain is routed here, so any other type is answered NODATA
|
||||
// (NOERROR, empty answer) rather than falling back to a resolver that
|
||||
// would poison the name with NXDOMAIN. The Extended DNS Error lets a
|
||||
// client tell this capability-driven NODATA apart from an
|
||||
// authoritative one. The OPT pseudo-record must not appear unless the
|
||||
// query advertised EDNS0.
|
||||
if reqHasEdns {
|
||||
attachEDE(resp, dns.ExtendedErrorCodeNotSupported, "netbird forwarder: unsupported query type")
|
||||
}
|
||||
f.writeResponse(logger, w, resp, qname, startTime)
|
||||
}
|
||||
}
|
||||
|
||||
// handleAddressQuery resolves A/AAAA queries, programs the firewall sets and
|
||||
// resolved-IP state, and caches the answer for resilience on upstream failure.
|
||||
func (f *DNSForwarder) handleAddressQuery(
|
||||
ctx context.Context,
|
||||
logger *log.Entry,
|
||||
w dns.ResponseWriter,
|
||||
resp *dns.Msg,
|
||||
mostSpecificResId route.ResID,
|
||||
matchingEntries []*ForwarderEntry,
|
||||
reqHasEdns bool,
|
||||
startTime time.Time,
|
||||
) {
|
||||
question := resp.Question[0]
|
||||
qname := strings.ToLower(question.Name)
|
||||
|
||||
network := resutil.NetworkForQtype(question.Qtype)
|
||||
result := resutil.LookupIP(ctx, f.resolver, network, qname, question.Qtype)
|
||||
if result.Err != nil {
|
||||
f.handleDNSError(ctx, logger, w, question, resp, qname, result, reqHasEdns, startTime)
|
||||
f.handleDNSError(ctx, logger, w, question, resp, qname, result, query.IsEdns0() != nil, startTime)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -277,25 +240,6 @@ func (f *DNSForwarder) handleAddressQuery(
|
||||
f.writeResponse(logger, w, resp, qname, startTime)
|
||||
}
|
||||
|
||||
// handleRecordQuery resolves non-address record types (MX, TXT, NS, SRV,
|
||||
// CNAME, PTR) through the host resolver. Missing records are answered NODATA so
|
||||
// the routed name is never poisoned with NXDOMAIN.
|
||||
func (f *DNSForwarder) handleRecordQuery(
|
||||
ctx context.Context,
|
||||
logger *log.Entry,
|
||||
w dns.ResponseWriter,
|
||||
resp *dns.Msg,
|
||||
startTime time.Time,
|
||||
) {
|
||||
question := resp.Question[0]
|
||||
qname := strings.ToLower(question.Name)
|
||||
|
||||
records, rcode := resutil.LookupRecords(ctx, f.resolver, qname, question.Qtype, f.ttl)
|
||||
resp.Rcode = rcode
|
||||
resp.Answer = append(resp.Answer, records...)
|
||||
f.writeResponse(logger, w, resp, qname, startTime)
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) writeResponse(logger *log.Entry, w dns.ResponseWriter, resp *dns.Msg, qname string, startTime time.Time) {
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
logger.Errorf("failed to write DNS response: %v", err)
|
||||
|
||||
@@ -133,41 +133,6 @@ func (m *MockResolver) LookupNetIP(ctx context.Context, network, host string) ([
|
||||
return args.Get(0).([]netip.Addr), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockResolver) LookupMX(ctx context.Context, name string) ([]*net.MX, error) {
|
||||
args := m.Called(ctx, name)
|
||||
recs, _ := args.Get(0).([]*net.MX)
|
||||
return recs, args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockResolver) LookupTXT(ctx context.Context, name string) ([]string, error) {
|
||||
args := m.Called(ctx, name)
|
||||
recs, _ := args.Get(0).([]string)
|
||||
return recs, args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockResolver) LookupNS(ctx context.Context, name string) ([]*net.NS, error) {
|
||||
args := m.Called(ctx, name)
|
||||
recs, _ := args.Get(0).([]*net.NS)
|
||||
return recs, args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockResolver) LookupSRV(ctx context.Context, service, proto, name string) (string, []*net.SRV, error) {
|
||||
args := m.Called(ctx, service, proto, name)
|
||||
recs, _ := args.Get(1).([]*net.SRV)
|
||||
return args.String(0), recs, args.Error(2)
|
||||
}
|
||||
|
||||
func (m *MockResolver) LookupCNAME(ctx context.Context, host string) (string, error) {
|
||||
args := m.Called(ctx, host)
|
||||
return args.String(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockResolver) LookupAddr(ctx context.Context, addr string) ([]string, error) {
|
||||
args := m.Called(ctx, addr)
|
||||
recs, _ := args.Get(0).([]string)
|
||||
return recs, args.Error(1)
|
||||
}
|
||||
|
||||
func TestDNSForwarder_SubdomainAccessLogic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -580,15 +545,12 @@ func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDNSForwarder_ResponseCodes(t *testing.T) {
|
||||
// A type with no net.Resolver Lookup method (CAA) must answer NODATA
|
||||
// (NOERROR, empty) rather than NXDOMAIN/NOTIMP to avoid poisoning the name.
|
||||
tests := []struct {
|
||||
name string
|
||||
queryType uint16
|
||||
queryDomain string
|
||||
configured string
|
||||
expectedCode int
|
||||
expectEDE bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
@@ -600,13 +562,28 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
|
||||
description: "RFC compliant REFUSED for unauthorized queries",
|
||||
},
|
||||
{
|
||||
name: "unsupported query type returns NODATA",
|
||||
queryType: dns.TypeCAA,
|
||||
name: "unsupported query type returns NOTIMP",
|
||||
queryType: dns.TypeMX,
|
||||
queryDomain: "example.com",
|
||||
configured: "example.com",
|
||||
expectedCode: dns.RcodeSuccess,
|
||||
expectEDE: true,
|
||||
description: "Unsupported types answer NODATA, not NXDOMAIN/NOTIMP",
|
||||
expectedCode: dns.RcodeNotImplemented,
|
||||
description: "RFC compliant NOTIMP for unsupported types",
|
||||
},
|
||||
{
|
||||
name: "CNAME query returns NOTIMP",
|
||||
queryType: dns.TypeCNAME,
|
||||
queryDomain: "example.com",
|
||||
configured: "example.com",
|
||||
expectedCode: dns.RcodeNotImplemented,
|
||||
description: "CNAME queries not supported",
|
||||
},
|
||||
{
|
||||
name: "TXT query returns NOTIMP",
|
||||
queryType: dns.TypeTXT,
|
||||
queryDomain: "example.com",
|
||||
configured: "example.com",
|
||||
expectedCode: dns.RcodeNotImplemented,
|
||||
description: "TXT queries not supported",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -622,7 +599,6 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
|
||||
|
||||
query := &dns.Msg{}
|
||||
query.SetQuestion(dns.Fqdn(tt.queryDomain), tt.queryType)
|
||||
query.SetEdns0(dns.DefaultMsgSize, false)
|
||||
|
||||
// Capture the written response
|
||||
var writtenResp *dns.Msg
|
||||
@@ -638,213 +614,10 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
|
||||
// Check the response written to the writer
|
||||
require.NotNil(t, writtenResp, "Expected response to be written")
|
||||
assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description)
|
||||
assert.Empty(t, writtenResp.Answer, "Non-address response should carry no answers")
|
||||
|
||||
if tt.expectEDE {
|
||||
require.NotNil(t, writtenResp.IsEdns0(), "EDNS0 client should get an OPT in the reply")
|
||||
assert.True(t, hasEDE(writtenResp, dns.ExtendedErrorCodeNotSupported),
|
||||
"unsupported type NODATA should carry EDE Not Supported")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func hasEDE(m *dns.Msg, code uint16) bool {
|
||||
opt := m.IsEdns0()
|
||||
if opt == nil {
|
||||
return false
|
||||
}
|
||||
for _, o := range opt.Option {
|
||||
if ede, ok := o.(*dns.EDNS0_EDE); ok && ede.InfoCode == code {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func TestDNSForwarder_RecordQueries(t *testing.T) {
|
||||
notFound := &net.DNSError{IsNotFound: true, Name: "example.com"}
|
||||
|
||||
t.Run("MX records are forwarded", func(t *testing.T) {
|
||||
mockResolver := &MockResolver{}
|
||||
forwarder := newRecordTestForwarder(t, mockResolver, "example.com")
|
||||
|
||||
mockResolver.On("LookupMX", mock.Anything, "example.com.").
|
||||
Return([]*net.MX{{Host: "mail.example.com.", Pref: 10}}, nil).Once()
|
||||
|
||||
resp := runRecordQuery(t, forwarder, "example.com", dns.TypeMX)
|
||||
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||
require.Len(t, resp.Answer, 1)
|
||||
mx, ok := resp.Answer[0].(*dns.MX)
|
||||
require.True(t, ok, "answer should be an MX record")
|
||||
assert.Equal(t, uint16(10), mx.Preference)
|
||||
assert.Equal(t, "mail.example.com.", mx.Mx)
|
||||
mockResolver.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("missing MX is NODATA not NXDOMAIN", func(t *testing.T) {
|
||||
mockResolver := &MockResolver{}
|
||||
forwarder := newRecordTestForwarder(t, mockResolver, "example.com")
|
||||
|
||||
// A not-found cannot prove the name is absent (it may exist with only
|
||||
// other record types), so it must answer NODATA, never NXDOMAIN.
|
||||
mockResolver.On("LookupMX", mock.Anything, "example.com.").
|
||||
Return(nil, notFound).Once()
|
||||
|
||||
resp := runRecordQuery(t, forwarder, "example.com", dns.TypeMX)
|
||||
assert.Equal(t, dns.RcodeSuccess, resp.Rcode, "missing record must be NODATA")
|
||||
assert.Empty(t, resp.Answer)
|
||||
mockResolver.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("NS records are forwarded", func(t *testing.T) {
|
||||
mockResolver := &MockResolver{}
|
||||
forwarder := newRecordTestForwarder(t, mockResolver, "example.com")
|
||||
|
||||
mockResolver.On("LookupNS", mock.Anything, "example.com.").
|
||||
Return([]*net.NS{{Host: "ns1.example.com."}}, nil).Once()
|
||||
|
||||
resp := runRecordQuery(t, forwarder, "example.com", dns.TypeNS)
|
||||
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||
require.Len(t, resp.Answer, 1)
|
||||
ns, ok := resp.Answer[0].(*dns.NS)
|
||||
require.True(t, ok, "answer should be an NS record")
|
||||
assert.Equal(t, "ns1.example.com.", ns.Ns)
|
||||
mockResolver.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("missing NS is NODATA", func(t *testing.T) {
|
||||
mockResolver := &MockResolver{}
|
||||
forwarder := newRecordTestForwarder(t, mockResolver, "example.com")
|
||||
|
||||
mockResolver.On("LookupNS", mock.Anything, "example.com.").
|
||||
Return(nil, notFound).Once()
|
||||
|
||||
resp := runRecordQuery(t, forwarder, "example.com", dns.TypeNS)
|
||||
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||
assert.Empty(t, resp.Answer)
|
||||
mockResolver.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("SRV records are forwarded", func(t *testing.T) {
|
||||
mockResolver := &MockResolver{}
|
||||
forwarder := newRecordTestForwarder(t, mockResolver, "_sip._tcp.example.com")
|
||||
|
||||
mockResolver.On("LookupSRV", mock.Anything, "", "", "_sip._tcp.example.com.").
|
||||
Return("", []*net.SRV{{Target: "sip.example.com.", Port: 5060, Priority: 10, Weight: 5}}, nil).Once()
|
||||
|
||||
resp := runRecordQuery(t, forwarder, "_sip._tcp.example.com", dns.TypeSRV)
|
||||
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||
require.Len(t, resp.Answer, 1)
|
||||
srv, ok := resp.Answer[0].(*dns.SRV)
|
||||
require.True(t, ok, "answer should be an SRV record")
|
||||
assert.Equal(t, "sip.example.com.", srv.Target)
|
||||
assert.Equal(t, uint16(5060), srv.Port)
|
||||
assert.Equal(t, uint16(10), srv.Priority)
|
||||
mockResolver.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("missing SRV is NODATA", func(t *testing.T) {
|
||||
mockResolver := &MockResolver{}
|
||||
forwarder := newRecordTestForwarder(t, mockResolver, "_sip._tcp.example.com")
|
||||
|
||||
mockResolver.On("LookupSRV", mock.Anything, "", "", "_sip._tcp.example.com.").
|
||||
Return("", nil, notFound).Once()
|
||||
|
||||
resp := runRecordQuery(t, forwarder, "_sip._tcp.example.com", dns.TypeSRV)
|
||||
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||
assert.Empty(t, resp.Answer)
|
||||
mockResolver.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("TXT records are forwarded", func(t *testing.T) {
|
||||
mockResolver := &MockResolver{}
|
||||
forwarder := newRecordTestForwarder(t, mockResolver, "example.com")
|
||||
|
||||
mockResolver.On("LookupTXT", mock.Anything, "example.com.").
|
||||
Return([]string{"v=spf1 -all"}, nil).Once()
|
||||
|
||||
resp := runRecordQuery(t, forwarder, "example.com", dns.TypeTXT)
|
||||
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||
require.Len(t, resp.Answer, 1)
|
||||
txt, ok := resp.Answer[0].(*dns.TXT)
|
||||
require.True(t, ok, "answer should be a TXT record")
|
||||
assert.Equal(t, []string{"v=spf1 -all"}, txt.Txt)
|
||||
mockResolver.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("CNAME record is forwarded", func(t *testing.T) {
|
||||
mockResolver := &MockResolver{}
|
||||
forwarder := newRecordTestForwarder(t, mockResolver, "www.example.com")
|
||||
|
||||
mockResolver.On("LookupCNAME", mock.Anything, "www.example.com.").
|
||||
Return("target.example.com.", nil).Once()
|
||||
|
||||
resp := runRecordQuery(t, forwarder, "www.example.com", dns.TypeCNAME)
|
||||
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||
require.Len(t, resp.Answer, 1)
|
||||
cname, ok := resp.Answer[0].(*dns.CNAME)
|
||||
require.True(t, ok, "answer should be a CNAME record")
|
||||
assert.Equal(t, "target.example.com.", cname.Target)
|
||||
mockResolver.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("CNAME equal to the name is NODATA", func(t *testing.T) {
|
||||
mockResolver := &MockResolver{}
|
||||
forwarder := newRecordTestForwarder(t, mockResolver, "example.com")
|
||||
|
||||
// No CNAME exists: LookupCNAME echoes the queried name back.
|
||||
mockResolver.On("LookupCNAME", mock.Anything, "example.com.").
|
||||
Return("example.com.", nil).Once()
|
||||
|
||||
resp := runRecordQuery(t, forwarder, "example.com", dns.TypeCNAME)
|
||||
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||
assert.Empty(t, resp.Answer, "self-referential CNAME means no CNAME record")
|
||||
mockResolver.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("PTR record is forwarded", func(t *testing.T) {
|
||||
mockResolver := &MockResolver{}
|
||||
forwarder := newRecordTestForwarder(t, mockResolver, "*.in-addr.arpa")
|
||||
|
||||
// The reverse name is parsed back to the address LookupAddr expects.
|
||||
mockResolver.On("LookupAddr", mock.Anything, "1.2.3.4").
|
||||
Return([]string{"host.example.com."}, nil).Once()
|
||||
|
||||
resp := runRecordQuery(t, forwarder, "4.3.2.1.in-addr.arpa", dns.TypePTR)
|
||||
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||
require.Len(t, resp.Answer, 1)
|
||||
ptr, ok := resp.Answer[0].(*dns.PTR)
|
||||
require.True(t, ok, "answer should be a PTR record")
|
||||
assert.Equal(t, "host.example.com.", ptr.Ptr)
|
||||
mockResolver.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
|
||||
func newRecordTestForwarder(t *testing.T, r resolver, configured string) *DNSForwarder {
|
||||
t.Helper()
|
||||
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
|
||||
forwarder.resolver = r
|
||||
|
||||
d, err := domain.FromString(configured)
|
||||
require.NoError(t, err)
|
||||
forwarder.UpdateDomains([]*ForwarderEntry{{Domain: d, ResID: "test-res"}})
|
||||
return forwarder
|
||||
}
|
||||
|
||||
func runRecordQuery(t *testing.T, forwarder *DNSForwarder, qname string, qtype uint16) *dns.Msg {
|
||||
t.Helper()
|
||||
query := &dns.Msg{}
|
||||
query.SetQuestion(dns.Fqdn(qname), qtype)
|
||||
|
||||
mockWriter := &test.MockResponseWriter{}
|
||||
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||
|
||||
resp := mockWriter.GetLastResponse()
|
||||
require.NotNil(t, resp, "expected response to be written")
|
||||
return resp
|
||||
}
|
||||
|
||||
func TestDNSForwarder_UpstreamFailureEDE(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -82,12 +82,6 @@ const (
|
||||
PeerConnectionTimeoutMax = 45000 // ms
|
||||
PeerConnectionTimeoutMin = 30000 // ms
|
||||
disableAutoUpdate = "disabled"
|
||||
|
||||
// systemInfoTimeout bounds how long the sync loop waits for system info / posture
|
||||
// check gathering. The gathering runs uncancellable system calls (process scan,
|
||||
// exec, os.Stat); without this bound a single stuck call freezes handleSync, and
|
||||
// thus syncMsgMux, for as long as the call hangs (observed multi-minute freezes).
|
||||
systemInfoTimeout = 15 * time.Second
|
||||
)
|
||||
|
||||
var ErrResetConnection = fmt.Errorf("reset connection")
|
||||
@@ -901,16 +895,6 @@ func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdate
|
||||
e.updateManager.SetVersion(autoUpdateSettings.Version, autoUpdateSettings.AlwaysUpdate)
|
||||
}
|
||||
|
||||
// phase times a sync sub-phase: it returns a function that records the elapsed
|
||||
// duration when called. Starting the timer at the call site keeps inter-phase
|
||||
// glue code out of the measurement.
|
||||
func (e *Engine) phase(name string) func() {
|
||||
start := time.Now()
|
||||
return func() {
|
||||
e.clientMetrics.RecordSyncPhase(e.ctx, name, time.Since(start))
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
started := time.Now()
|
||||
defer func() {
|
||||
@@ -930,10 +914,7 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate)
|
||||
}
|
||||
|
||||
done := e.phase("netbird_config")
|
||||
err := e.updateNetbirdConfig(update.GetNetbirdConfig())
|
||||
done()
|
||||
if err != nil {
|
||||
if err := e.updateNetbirdConfig(update.GetNetbirdConfig()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -947,16 +928,11 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
done = e.phase("checks")
|
||||
err = e.updateChecksIfNew(update.Checks)
|
||||
done()
|
||||
if err != nil {
|
||||
if err := e.updateChecksIfNew(update.Checks); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
done = e.phase("persist")
|
||||
e.persistSyncResponse(update)
|
||||
done()
|
||||
|
||||
// only apply new changes and ignore old ones
|
||||
if err := e.updateNetworkMap(nm); err != nil {
|
||||
@@ -1090,22 +1066,11 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
||||
}
|
||||
e.checks = checks
|
||||
|
||||
info, ok := system.GetInfoWithChecksTimeout(e.ctx, systemInfoTimeout, checks, e.overlayAddresses()...)
|
||||
if !ok {
|
||||
// Gathering timed out; skip the meta sync this cycle rather than blocking the
|
||||
// sync loop (and syncMsgMux) on a stuck system call. A later sync will retry.
|
||||
return nil
|
||||
info, err := system.GetInfoWithChecks(e.ctx, checks, e.overlayAddresses()...)
|
||||
if err != nil {
|
||||
log.Warnf("failed to get system info with checks: %v", err)
|
||||
info = system.GetInfo(e.ctx)
|
||||
}
|
||||
e.applyInfoFlags(info)
|
||||
|
||||
if err := e.mgmClient.SyncMeta(info); err != nil {
|
||||
return fmt.Errorf("could not sync meta: error %s", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyInfoFlags sets the engine's config-derived feature flags on the gathered system info.
|
||||
func (e *Engine) applyInfoFlags(info *system.Info) {
|
||||
info.SetFlags(
|
||||
e.config.RosenpassEnabled,
|
||||
e.config.RosenpassPermissive,
|
||||
@@ -1124,6 +1089,12 @@ func (e *Engine) applyInfoFlags(info *system.Info) {
|
||||
e.config.EnableSSHRemotePortForwarding,
|
||||
e.config.DisableSSHAuth,
|
||||
)
|
||||
|
||||
if err := e.mgmClient.SyncMeta(info); err != nil {
|
||||
log.Errorf("could not sync meta: error %s", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// overlayAddresses returns our own WireGuard overlay address (v4 and v6) so it
|
||||
@@ -1283,15 +1254,31 @@ func (e *Engine) receiveManagementEvents() {
|
||||
e.shutdownWg.Add(1)
|
||||
go func() {
|
||||
defer e.shutdownWg.Done()
|
||||
info, ok := system.GetInfoWithChecksTimeout(e.ctx, systemInfoTimeout, e.checks, e.overlayAddresses()...)
|
||||
if !ok {
|
||||
// Gathering timed out; connect the stream with base info so management
|
||||
// connectivity still comes up rather than blocking here.
|
||||
info, err := system.GetInfoWithChecks(e.ctx, e.checks, e.overlayAddresses()...)
|
||||
if err != nil {
|
||||
log.Warnf("failed to get system info with checks: %v", err)
|
||||
info = system.GetInfo(e.ctx)
|
||||
}
|
||||
e.applyInfoFlags(info)
|
||||
info.SetFlags(
|
||||
e.config.RosenpassEnabled,
|
||||
e.config.RosenpassPermissive,
|
||||
&e.config.ServerSSHAllowed,
|
||||
e.config.DisableClientRoutes,
|
||||
e.config.DisableServerRoutes,
|
||||
e.config.DisableDNS,
|
||||
e.config.DisableFirewall,
|
||||
e.config.BlockLANAccess,
|
||||
e.config.BlockInbound,
|
||||
e.config.DisableIPv6,
|
||||
e.config.LazyConnectionEnabled,
|
||||
e.config.EnableSSHRoot,
|
||||
e.config.EnableSSHSFTP,
|
||||
e.config.EnableSSHLocalPortForwarding,
|
||||
e.config.EnableSSHRemotePortForwarding,
|
||||
e.config.DisableSSHAuth,
|
||||
)
|
||||
|
||||
err := e.mgmClient.Sync(e.ctx, info, e.handleSync)
|
||||
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
|
||||
if err != nil {
|
||||
// happens if management is unavailable for a long time.
|
||||
// We want to cancel the operation of the whole client
|
||||
@@ -1384,16 +1371,13 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
|
||||
dnsConfig := toDNSConfig(protoDNSConfig, e.wgInterface.Address())
|
||||
|
||||
done := e.phase("dns_server")
|
||||
if err := e.dnsServer.UpdateDNSServer(serial, dnsConfig); err != nil {
|
||||
log.Errorf("failed to update dns server, err: %v", err)
|
||||
}
|
||||
done()
|
||||
|
||||
e.routeManager.SetDNSForwarderPort(dnsConfig.ForwarderPort)
|
||||
|
||||
// apply routes first, route related actions might depend on routing being enabled
|
||||
done = e.phase("routes_classify")
|
||||
routes := toRoutes(networkMap.GetRoutes())
|
||||
serverRoutes, clientRoutes := e.routeManager.ClassifyRoutes(routes)
|
||||
|
||||
@@ -1402,60 +1386,29 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
e.connMgr.UpdateRouteHAMap(clientRoutes)
|
||||
log.Debugf("updated lazy connection manager with %d HA groups", len(clientRoutes))
|
||||
}
|
||||
done()
|
||||
|
||||
done = e.phase("routes_apply")
|
||||
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
|
||||
if err := e.routeManager.UpdateRoutes(serial, serverRoutes, clientRoutes, dnsRouteFeatureFlag); err != nil {
|
||||
log.Errorf("failed to update routes: %v", err)
|
||||
}
|
||||
done()
|
||||
|
||||
done = e.phase("filtering")
|
||||
if e.acl != nil {
|
||||
e.acl.ApplyFiltering(networkMap, dnsRouteFeatureFlag)
|
||||
}
|
||||
done()
|
||||
|
||||
done = e.phase("dns_forwarder")
|
||||
fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes)
|
||||
e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries)
|
||||
done()
|
||||
|
||||
// Ingress forward rules
|
||||
done = e.phase("forward_rules")
|
||||
forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules())
|
||||
if err != nil {
|
||||
log.Errorf("failed to update forward rules, err: %v", err)
|
||||
}
|
||||
done()
|
||||
|
||||
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
|
||||
|
||||
done = e.phase("offline_peers")
|
||||
e.updateOfflinePeers(networkMap.GetOfflinePeers())
|
||||
done()
|
||||
|
||||
remotePeers, err := e.reconcilePeers(networkMap)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
|
||||
done = e.phase("lazy_exclude")
|
||||
excludedLazyPeers := e.toExcludedLazyPeers(forwardingRules, remotePeers)
|
||||
e.connMgr.SetExcludeList(e.ctx, excludedLazyPeers)
|
||||
done()
|
||||
|
||||
e.networkSerial = serial
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// reconcilePeers applies the remote peer list from the network map (removing,
|
||||
// modifying and adding peers, then updating SSH config) and returns the remote
|
||||
// peers with our own peer filtered out, for use by later sync steps.
|
||||
func (e *Engine) reconcilePeers(networkMap *mgmProto.NetworkMap) ([]*mgmProto.RemotePeerConfig, error) {
|
||||
// Filter out own peer from the remote peers list
|
||||
localPubKey := e.config.WgPrivateKey.PublicKey().String()
|
||||
remotePeers := make([]*mgmProto.RemotePeerConfig, 0, len(networkMap.GetRemotePeers()))
|
||||
@@ -1470,43 +1423,42 @@ func (e *Engine) reconcilePeers(networkMap *mgmProto.NetworkMap) ([]*mgmProto.Re
|
||||
err := e.removeAllPeers()
|
||||
e.statusRecorder.FinishPeerListModifications()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
return remotePeers, nil
|
||||
} else {
|
||||
err := e.removePeers(remotePeers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = e.modifyPeers(remotePeers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = e.addNewPeers(remotePeers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
e.statusRecorder.FinishPeerListModifications()
|
||||
|
||||
e.updatePeerSSHHostKeys(remotePeers)
|
||||
|
||||
if err := e.updateSSHClientConfig(remotePeers); err != nil {
|
||||
log.Warnf("failed to update SSH client config: %v", err)
|
||||
}
|
||||
|
||||
e.updateSSHServerAuth(networkMap.GetSshAuth())
|
||||
}
|
||||
|
||||
done := e.phase("removed_peers")
|
||||
err := e.removePeers(remotePeers)
|
||||
done()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
|
||||
excludedLazyPeers := e.toExcludedLazyPeers(forwardingRules, remotePeers)
|
||||
e.connMgr.SetExcludeList(e.ctx, excludedLazyPeers)
|
||||
|
||||
done = e.phase("modified_peers")
|
||||
err = e.modifyPeers(remotePeers)
|
||||
done()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
e.networkSerial = serial
|
||||
|
||||
done = e.phase("added_peers")
|
||||
err = e.addNewPeers(remotePeers)
|
||||
done()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
e.statusRecorder.FinishPeerListModifications()
|
||||
|
||||
e.updatePeerSSHHostKeys(remotePeers)
|
||||
|
||||
if err := e.updateSSHClientConfig(remotePeers); err != nil {
|
||||
log.Warnf("failed to update SSH client config: %v", err)
|
||||
}
|
||||
|
||||
e.updateSSHServerAuth(networkMap.GetSshAuth())
|
||||
|
||||
return remotePeers, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func toDNSFeatureFlag(networkMap *mgmProto.NetworkMap) bool {
|
||||
|
||||
@@ -178,10 +178,6 @@ func (m *MockWGIface) LastActivities() map[string]monotime.Time {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockWGIface) MTU() uint16 {
|
||||
return 1280
|
||||
}
|
||||
|
||||
func (m *MockWGIface) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -44,5 +44,4 @@ type wgIfaceBase interface {
|
||||
FullStats() (*configurer.Stats, error)
|
||||
LastActivities() map[string]monotime.Time
|
||||
SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error
|
||||
MTU() uint16
|
||||
}
|
||||
|
||||
@@ -124,11 +124,6 @@ func (d *BindListener) ReadPackets() {
|
||||
d.done.Done()
|
||||
}
|
||||
|
||||
// CapturedPacket is unused in userspace bind mode: first-packet reinjection is kernel-only.
|
||||
func (d *BindListener) CapturedPacket() []byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close stops the listener and cleans up resources.
|
||||
func (d *BindListener) Close() {
|
||||
d.peerCfg.Log.Infof("closing activity listener (LazyConn)")
|
||||
|
||||
@@ -45,6 +45,10 @@ type MockWGIfaceBind struct {
|
||||
endpointMgr *mockEndpointManager
|
||||
}
|
||||
|
||||
func (m *MockWGIfaceBind) RemovePeer(string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockWGIfaceBind) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAddr, *wgtypes.Key) error {
|
||||
return nil
|
||||
}
|
||||
@@ -64,10 +68,6 @@ func (m *MockWGIfaceBind) GetBind() device.EndpointManager {
|
||||
return m.endpointMgr
|
||||
}
|
||||
|
||||
func (m *MockWGIfaceBind) MTU() uint16 {
|
||||
return 1280
|
||||
}
|
||||
|
||||
func TestBindListener_Creation(t *testing.T) {
|
||||
mockEndpointMgr := newMockEndpointManager()
|
||||
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
|
||||
@@ -207,9 +207,8 @@ func TestManager_BindMode(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case ev := <-mgr.OnActivityChan:
|
||||
assert.Equal(t, cfg.PeerConnID, ev.PeerConnID, "Received peer connection ID should match")
|
||||
assert.Nil(t, ev.FirstPacket, "Bind mode does not capture packets: reinjection is kernel-only")
|
||||
case peerConnID := <-mgr.OnActivityChan:
|
||||
assert.Equal(t, cfg.PeerConnID, peerConnID, "Received peer connection ID should match")
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for activity notification")
|
||||
}
|
||||
@@ -267,8 +266,8 @@ func TestManager_BindMode_MultiplePeers(t *testing.T) {
|
||||
receivedPeers := make(map[peerid.ConnID]bool)
|
||||
for i := 0; i < 2; i++ {
|
||||
select {
|
||||
case ev := <-mgr.OnActivityChan:
|
||||
receivedPeers[ev.PeerConnID] = true
|
||||
case peerConnID := <-mgr.OnActivityChan:
|
||||
receivedPeers[peerConnID] = true
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for activity notifications")
|
||||
}
|
||||
|
||||
@@ -3,13 +3,11 @@ package activity
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"slices"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bufsize"
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
)
|
||||
|
||||
@@ -22,8 +20,6 @@ type UDPListener struct {
|
||||
done sync.Mutex
|
||||
|
||||
isClosed atomic.Bool
|
||||
|
||||
capturedPacket []byte
|
||||
}
|
||||
|
||||
// NewUDPListener creates a listener that detects activity via UDP socket reads.
|
||||
@@ -50,13 +46,9 @@ func NewUDPListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*UDPListener,
|
||||
}
|
||||
|
||||
// ReadPackets blocks reading from the UDP socket until activity is detected or the listener is closed.
|
||||
// The first packet that triggers activity is captured so it can be reinjected through the real
|
||||
// transport once it is established. Without this, kernel WireGuard's handshake initiation would be
|
||||
// dropped and WG would only retry after REKEY_TIMEOUT.
|
||||
func (d *UDPListener) ReadPackets() {
|
||||
for {
|
||||
buf := make([]byte, int(d.wgIface.MTU())+bufsize.WGBufferOverhead)
|
||||
n, remoteAddr, err := d.conn.ReadFromUDP(buf)
|
||||
n, remoteAddr, err := d.conn.ReadFromUDP(make([]byte, 1))
|
||||
if err != nil {
|
||||
if d.isClosed.Load() {
|
||||
d.peerCfg.Log.Infof("exit from activity listener")
|
||||
@@ -70,24 +62,20 @@ func (d *UDPListener) ReadPackets() {
|
||||
d.peerCfg.Log.Warnf("received %d bytes from %s, too short", n, remoteAddr)
|
||||
continue
|
||||
}
|
||||
d.capturedPacket = slices.Clone(buf[:n])
|
||||
d.peerCfg.Log.Infof("activity detected, captured %d bytes for reinjection", n)
|
||||
d.peerCfg.Log.Infof("activity detected")
|
||||
break
|
||||
}
|
||||
|
||||
// Leave the peer in place. ConfigureWGEndpoint will UpdatePeer with the real endpoint;
|
||||
// removing the peer here wipes kernel WG's staged queue and drops the user packet that
|
||||
// triggered activation.
|
||||
d.peerCfg.Log.Debugf("removing lazy endpoint: %s", d.endpoint.String())
|
||||
if err := d.wgIface.RemovePeer(d.peerCfg.PublicKey); err != nil {
|
||||
d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err)
|
||||
}
|
||||
|
||||
// Ignore close error as it may return "use of closed network connection" if already closed.
|
||||
_ = d.conn.Close()
|
||||
d.done.Unlock()
|
||||
}
|
||||
|
||||
// CapturedPacket returns the first packet that triggered activity, or nil if none was captured.
|
||||
// Safe to call after ReadPackets returns.
|
||||
func (d *UDPListener) CapturedPacket() []byte {
|
||||
return d.capturedPacket
|
||||
}
|
||||
|
||||
// Close stops the listener and cleans up resources.
|
||||
func (d *UDPListener) Close() {
|
||||
d.peerCfg.Log.Infof("closing activity listener: %s", d.conn.LocalAddr().String())
|
||||
|
||||
@@ -19,25 +19,17 @@ import (
|
||||
type listener interface {
|
||||
ReadPackets()
|
||||
Close()
|
||||
CapturedPacket() []byte
|
||||
}
|
||||
|
||||
// Event reports activity on a managed peer. FirstPacket is the bytes that triggered activation,
|
||||
// captured for reinjection through the real transport.
|
||||
type Event struct {
|
||||
PeerConnID peerid.ConnID
|
||||
FirstPacket []byte
|
||||
}
|
||||
|
||||
type WgInterface interface {
|
||||
RemovePeer(peerKey string) error
|
||||
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||
IsUserspaceBind() bool
|
||||
Address() wgaddr.Address
|
||||
MTU() uint16
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
OnActivityChan chan Event
|
||||
OnActivityChan chan peerid.ConnID
|
||||
|
||||
wgIface WgInterface
|
||||
|
||||
@@ -49,7 +41,7 @@ type Manager struct {
|
||||
|
||||
func NewManager(wgIface WgInterface) *Manager {
|
||||
m := &Manager{
|
||||
OnActivityChan: make(chan Event, 1),
|
||||
OnActivityChan: make(chan peerid.ConnID, 1),
|
||||
wgIface: wgIface,
|
||||
peers: make(map[peerid.ConnID]listener),
|
||||
done: make(chan struct{}),
|
||||
@@ -124,12 +116,12 @@ func (m *Manager) waitForTraffic(l listener, peerConnID peerid.ConnID) {
|
||||
delete(m.peers, peerConnID)
|
||||
m.mu.Unlock()
|
||||
|
||||
m.notify(Event{PeerConnID: peerConnID, FirstPacket: l.CapturedPacket()})
|
||||
m.notify(peerConnID)
|
||||
}
|
||||
|
||||
func (m *Manager) notify(ev Event) {
|
||||
func (m *Manager) notify(peerConnID peerid.ConnID) {
|
||||
select {
|
||||
case <-m.done:
|
||||
case m.OnActivityChan <- ev:
|
||||
case m.OnActivityChan <- peerConnID:
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package activity
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
@@ -26,6 +25,10 @@ func (m *MocPeer) ConnID() peerid.ConnID {
|
||||
type MocWGIface struct {
|
||||
}
|
||||
|
||||
func (m MocWGIface) RemovePeer(string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m MocWGIface) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAddr, *wgtypes.Key) error {
|
||||
return nil
|
||||
}
|
||||
@@ -41,10 +44,6 @@ func (m MocWGIface) Address() wgaddr.Address {
|
||||
}
|
||||
}
|
||||
|
||||
func (m MocWGIface) MTU() uint16 {
|
||||
return 1280
|
||||
}
|
||||
|
||||
// GetPeerListener is a test helper to access listeners
|
||||
func (m *Manager) GetPeerListener(peerConnID peerid.ConnID) (listener, bool) {
|
||||
m.mu.Lock()
|
||||
@@ -87,15 +86,11 @@ func TestManager_MonitorPeerActivity(t *testing.T) {
|
||||
}
|
||||
|
||||
select {
|
||||
case ev := <-mgr.OnActivityChan:
|
||||
if ev.PeerConnID != peerCfg1.PeerConnID {
|
||||
t.Fatalf("unexpected peerConnID: %v", ev.PeerConnID)
|
||||
}
|
||||
if !bytes.Equal(ev.FirstPacket, []byte{0x01, 0x02, 0x03, 0x04, 0x05}) {
|
||||
t.Fatalf("unexpected first packet: %v", ev.FirstPacket)
|
||||
case peerConnID := <-mgr.OnActivityChan:
|
||||
if peerConnID != peerCfg1.PeerConnID {
|
||||
t.Fatalf("unexpected peerConnID: %v", peerConnID)
|
||||
}
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("timed out waiting for activity")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -130,8 +130,8 @@ func (m *Manager) Start(ctx context.Context) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case ev := <-m.activityManager.OnActivityChan:
|
||||
m.onPeerActivity(ev)
|
||||
case peerConnID := <-m.activityManager.OnActivityChan:
|
||||
m.onPeerActivity(peerConnID)
|
||||
case peerIDs := <-m.inactivityManager.InactivePeersChan():
|
||||
m.onPeerInactivityTimedOut(peerIDs)
|
||||
}
|
||||
@@ -513,13 +513,13 @@ func (m *Manager) checkHaGroupActivity(haGroup route.HAUniqueID, peerID string,
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) onPeerActivity(ev activity.Event) {
|
||||
func (m *Manager) onPeerActivity(peerConnID peerid.ConnID) {
|
||||
m.managedPeersMu.Lock()
|
||||
defer m.managedPeersMu.Unlock()
|
||||
|
||||
mp, ok := m.managedPeersByConnID[ev.PeerConnID]
|
||||
mp, ok := m.managedPeersByConnID[peerConnID]
|
||||
if !ok {
|
||||
log.Errorf("peer not found by conn id: %v", ev.PeerConnID)
|
||||
log.Errorf("peer not found by conn id: %v", peerConnID)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -536,7 +536,7 @@ func (m *Manager) onPeerActivity(ev activity.Event) {
|
||||
|
||||
m.activateHAGroupPeers(mp.peerCfg)
|
||||
|
||||
m.peerStore.PeerConnOpenWithFirstPacket(m.engineCtx, mp.peerCfg.PublicKey, ev.FirstPacket)
|
||||
m.peerStore.PeerConnOpen(m.engineCtx, mp.peerCfg.PublicKey)
|
||||
}
|
||||
|
||||
func (m *Manager) onPeerInactivityTimedOut(peerIDs map[string]struct{}) {
|
||||
|
||||
@@ -17,5 +17,4 @@ type WGIface interface {
|
||||
IsUserspaceBind() bool
|
||||
Address() wgaddr.Address
|
||||
LastActivities() map[string]monotime.Time
|
||||
MTU() uint16
|
||||
}
|
||||
|
||||
@@ -120,30 +120,6 @@ func (m *influxDBMetrics) RecordSyncDuration(_ context.Context, agentInfo AgentI
|
||||
m.trimLocked()
|
||||
}
|
||||
|
||||
func (m *influxDBMetrics) RecordSyncPhase(_ context.Context, agentInfo AgentInfo, phase string, duration time.Duration) {
|
||||
tags := fmt.Sprintf("deployment_type=%s,version=%s,os=%s,arch=%s,peer_id=%s,phase=%s",
|
||||
agentInfo.DeploymentType.String(),
|
||||
agentInfo.Version,
|
||||
agentInfo.OS,
|
||||
agentInfo.Arch,
|
||||
agentInfo.peerID,
|
||||
phase,
|
||||
)
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.samples = append(m.samples, influxSample{
|
||||
measurement: "netbird_sync_phase",
|
||||
tags: tags,
|
||||
fields: map[string]float64{
|
||||
"duration_seconds": duration.Seconds(),
|
||||
},
|
||||
timestamp: time.Now(),
|
||||
})
|
||||
m.trimLocked()
|
||||
}
|
||||
|
||||
func (m *influxDBMetrics) RecordLoginDuration(_ context.Context, agentInfo AgentInfo, duration time.Duration, success bool) {
|
||||
result := "success"
|
||||
if !success {
|
||||
|
||||
@@ -78,25 +78,6 @@ Tags:
|
||||
- `os`: Operating system (linux, darwin, windows, android, ios, etc.)
|
||||
- `arch`: CPU architecture (amd64, arm64, etc.)
|
||||
|
||||
### Sync Phase Timing
|
||||
|
||||
Measurement: `netbird_sync_phase`
|
||||
|
||||
Breaks down where time goes inside a single sync, so the total `netbird_sync` duration can be attributed to the sub-step that dominates.
|
||||
|
||||
| Field | Description |
|
||||
|-------|-------------|
|
||||
| `duration_seconds` | Time spent in one sub-phase of sync processing |
|
||||
|
||||
Tags:
|
||||
- `phase`: the sub-phase — `netbird_config`, `checks`, `persist`, `dns_server`, `routes_classify`, `routes_apply`, `filtering`, `dns_forwarder`, `forward_rules`, `offline_peers`, `removed_peers`, `modified_peers`, `added_peers`, `lazy_exclude`
|
||||
- `deployment_type`: "cloud" | "selfhosted" | "unknown"
|
||||
- `version`: NetBird version string
|
||||
- `os`: Operating system (linux, darwin, windows, android, ios, etc.)
|
||||
- `arch`: CPU architecture (amd64, arm64, etc.)
|
||||
|
||||
**Note:** this is wall-time per phase — it includes both CPU work and time spent waiting on locks. A slow phase points to *where* the time goes, not *why*; pair it with lock-wait metrics to tell contention apart from real work.
|
||||
|
||||
### Login Duration
|
||||
|
||||
Measurement: `netbird_login`
|
||||
@@ -210,52 +191,4 @@ docker compose exec influxdb influx query \
|
||||
|
||||
# Check ingest server health
|
||||
curl http://localhost:8087/health
|
||||
```
|
||||
|
||||
## Analyzing a Debug Bundle
|
||||
|
||||
Metrics collection is always on, so every debug bundle ships a `metrics.txt` in InfluxDB line protocol — a timestamped time series of all recorded events (sync durations, sync phases, connection stages, login). You can replay it into the local stack and graph it, without a running client.
|
||||
|
||||
The bundle's `metrics.txt` is a rolling window (capped at 5 days / ~20k samples, see [Buffer Limits](#buffer-limits)). For a connection incident the relevant window is short (connection setup is seconds), so a bundle captured during the issue is enough.
|
||||
|
||||
### 1. Start the stack
|
||||
|
||||
```bash
|
||||
# From this directory (client/internal/metrics/infra)
|
||||
INFLUXDB_ADMIN_TOKEN=admin123 INFLUXDB_ADMIN_PASSWORD=admin123 GRAFANA_ADMIN_PASSWORD=admin123 \
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
(`admin123` are throwaway local credentials — fine for offline analysis.)
|
||||
|
||||
### 2. Clear any previous data
|
||||
|
||||
So you only see this bundle:
|
||||
|
||||
```bash
|
||||
docker exec influxdb influx delete --org netbird --bucket metrics --token admin123 \
|
||||
--start 1970-01-01T00:00:00Z --stop 2100-01-01T00:00:00Z
|
||||
```
|
||||
|
||||
### 3. Import the bundle's metrics.txt
|
||||
|
||||
InfluxDB is not exposed on the host, so import inside the container:
|
||||
|
||||
```bash
|
||||
docker cp /path/to/bundle/metrics.txt influxdb:/tmp/m.txt
|
||||
docker exec influxdb influx write --org netbird --bucket metrics --precision ns \
|
||||
--token admin123 --file /tmp/m.txt
|
||||
```
|
||||
|
||||
Re-importing the same file is idempotent (same measurement+tags+timestamp overwrites).
|
||||
|
||||
### 4. View the dashboards
|
||||
|
||||
Grafana on http://localhost:3001 (login `admin` / `admin123`), datasource pre-provisioned:
|
||||
|
||||
- **Where sync time goes:** http://localhost:3001/d/netbird-sync-phases/netbird-sync-phases-where-time-goes
|
||||
- **General client metrics:** http://localhost:3001/d/netbird-influxdb-metrics
|
||||
|
||||
**Set the time range** to cover the bundle's timestamps (e.g. "Last 7 days" or an absolute range matching when the bundle was taken) — with the default short range the panels look empty.
|
||||
|
||||
Bundles are distinguishable by the `version` tag; add a tag at import time (e.g. `sed 's/^netbird_\([a-z_]*\),/netbird_\1,bundle=mycase,/' metrics.txt`) if you want to compare several side by side.
|
||||
```
|
||||
@@ -1,259 +0,0 @@
|
||||
{
|
||||
"annotations": {
|
||||
"list": []
|
||||
},
|
||||
"editable": true,
|
||||
"fiscalYearStartMonth": 0,
|
||||
"graphTooltip": 1,
|
||||
"links": [],
|
||||
"refresh": "",
|
||||
"schemaVersion": 39,
|
||||
"tags": [
|
||||
"netbird",
|
||||
"sync"
|
||||
],
|
||||
"templating": {
|
||||
"list": [
|
||||
{
|
||||
"current": {
|
||||
"text": "All",
|
||||
"value": "$__all"
|
||||
},
|
||||
"datasource": {
|
||||
"type": "influxdb",
|
||||
"uid": "influxdb"
|
||||
},
|
||||
"definition": "import \"influxdata/influxdb/schema\"\nschema.tagValues(bucket: \"metrics\", tag: \"version\")",
|
||||
"includeAll": true,
|
||||
"label": "version",
|
||||
"multi": true,
|
||||
"name": "version",
|
||||
"query": "import \"influxdata/influxdb/schema\"\nschema.tagValues(bucket: \"metrics\", tag: \"version\")",
|
||||
"refresh": 2,
|
||||
"type": "query",
|
||||
"allValue": ".*"
|
||||
}
|
||||
]
|
||||
},
|
||||
"time": {
|
||||
"from": "now-2d",
|
||||
"to": "now"
|
||||
},
|
||||
"timepicker": {},
|
||||
"timezone": "",
|
||||
"title": "NetBird Sync Phases (where time goes)",
|
||||
"uid": "netbird-sync-phases",
|
||||
"version": 1,
|
||||
"panels": [
|
||||
{
|
||||
"id": 1,
|
||||
"title": "Time per phase over time (stacked, ms)",
|
||||
"type": "timeseries",
|
||||
"datasource": {
|
||||
"type": "influxdb",
|
||||
"uid": "influxdb"
|
||||
},
|
||||
"gridPos": {
|
||||
"h": 10,
|
||||
"w": 24,
|
||||
"x": 0,
|
||||
"y": 0
|
||||
},
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"unit": "ms",
|
||||
"custom": {
|
||||
"drawStyle": "bars",
|
||||
"stacking": {
|
||||
"mode": "normal",
|
||||
"group": "A"
|
||||
},
|
||||
"fillOpacity": 80,
|
||||
"lineWidth": 0
|
||||
}
|
||||
},
|
||||
"overrides": []
|
||||
},
|
||||
"options": {
|
||||
"legend": {
|
||||
"displayMode": "table",
|
||||
"placement": "right",
|
||||
"calcs": [
|
||||
"max",
|
||||
"mean"
|
||||
]
|
||||
},
|
||||
"tooltip": {
|
||||
"mode": "multi",
|
||||
"sort": "desc"
|
||||
}
|
||||
},
|
||||
"targets": [
|
||||
{
|
||||
"refId": "A",
|
||||
"datasource": {
|
||||
"type": "influxdb",
|
||||
"uid": "influxdb"
|
||||
},
|
||||
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync_phase\" and r._field == \"duration_seconds\")\n |> filter(fn: (r) => r.version =~ /${version:regex}/)\n |> map(fn: (r) => ({ r with _value: r._value * 1000.0 }))\n |> keep(columns: [\"_time\", \"_value\", \"phase\"])\n |> group(columns: [\"phase\"])"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"title": "p95 per phase (ms)",
|
||||
"type": "bargauge",
|
||||
"datasource": {
|
||||
"type": "influxdb",
|
||||
"uid": "influxdb"
|
||||
},
|
||||
"gridPos": {
|
||||
"h": 11,
|
||||
"w": 12,
|
||||
"x": 0,
|
||||
"y": 10
|
||||
},
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"unit": "ms",
|
||||
"color": {
|
||||
"mode": "continuous-GrYlRd"
|
||||
}
|
||||
},
|
||||
"overrides": []
|
||||
},
|
||||
"options": {
|
||||
"displayMode": "gradient",
|
||||
"orientation": "horizontal",
|
||||
"reduceOptions": {
|
||||
"calcs": [
|
||||
"lastNotNull"
|
||||
],
|
||||
"fields": "",
|
||||
"values": false
|
||||
},
|
||||
"showUnfilled": true
|
||||
},
|
||||
"targets": [
|
||||
{
|
||||
"refId": "A",
|
||||
"datasource": {
|
||||
"type": "influxdb",
|
||||
"uid": "influxdb"
|
||||
},
|
||||
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync_phase\" and r._field == \"duration_seconds\")\n |> filter(fn: (r) => r.version =~ /${version:regex}/)\n |> map(fn: (r) => ({ r with _value: r._value * 1000.0 }))\n |> group(columns: [\"phase\"])\n |> quantile(q: 0.95)\n |> group()\n |> sort(columns: [\"_value\"], desc: true)"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 3,
|
||||
"title": "Per-phase stats (ms): mean / p95 / max",
|
||||
"type": "table",
|
||||
"datasource": {
|
||||
"type": "influxdb",
|
||||
"uid": "influxdb"
|
||||
},
|
||||
"gridPos": {
|
||||
"h": 11,
|
||||
"w": 12,
|
||||
"x": 12,
|
||||
"y": 10
|
||||
},
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"unit": "ms"
|
||||
},
|
||||
"overrides": []
|
||||
},
|
||||
"options": {
|
||||
"showHeader": true,
|
||||
"sortBy": [
|
||||
{
|
||||
"displayName": "max",
|
||||
"desc": true
|
||||
}
|
||||
]
|
||||
},
|
||||
"transformations": [
|
||||
{
|
||||
"id": "merge",
|
||||
"options": {}
|
||||
}
|
||||
],
|
||||
"targets": [
|
||||
{
|
||||
"refId": "mean",
|
||||
"datasource": {
|
||||
"type": "influxdb",
|
||||
"uid": "influxdb"
|
||||
},
|
||||
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync_phase\" and r._field == \"duration_seconds\")\n |> filter(fn: (r) => r.version =~ /${version:regex}/)\n |> map(fn: (r) => ({ r with _value: r._value * 1000.0 }))\n |> group(columns: [\"phase\"])\n |> mean()\n |> group()\n |> keep(columns: [\"phase\", \"_value\"])\n |> rename(columns: {_value: \"mean\"})"
|
||||
},
|
||||
{
|
||||
"refId": "p95",
|
||||
"datasource": {
|
||||
"type": "influxdb",
|
||||
"uid": "influxdb"
|
||||
},
|
||||
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync_phase\" and r._field == \"duration_seconds\")\n |> filter(fn: (r) => r.version =~ /${version:regex}/)\n |> map(fn: (r) => ({ r with _value: r._value * 1000.0 }))\n |> group(columns: [\"phase\"])\n |> quantile(q: 0.95)\n |> group()\n |> keep(columns: [\"phase\", \"_value\"])\n |> rename(columns: {_value: \"p95\"})"
|
||||
},
|
||||
{
|
||||
"refId": "max",
|
||||
"datasource": {
|
||||
"type": "influxdb",
|
||||
"uid": "influxdb"
|
||||
},
|
||||
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync_phase\" and r._field == \"duration_seconds\")\n |> filter(fn: (r) => r.version =~ /${version:regex}/)\n |> map(fn: (r) => ({ r with _value: r._value * 1000.0 }))\n |> group(columns: [\"phase\"])\n |> max()\n |> group()\n |> keep(columns: [\"phase\", \"_value\"])\n |> rename(columns: {_value: \"max\"})"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 4,
|
||||
"title": "Total sync duration (netbird_sync, ms) \u2014 reference",
|
||||
"type": "timeseries",
|
||||
"datasource": {
|
||||
"type": "influxdb",
|
||||
"uid": "influxdb"
|
||||
},
|
||||
"gridPos": {
|
||||
"h": 8,
|
||||
"w": 24,
|
||||
"x": 0,
|
||||
"y": 21
|
||||
},
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"unit": "ms",
|
||||
"custom": {
|
||||
"drawStyle": "points",
|
||||
"pointSize": 5
|
||||
}
|
||||
},
|
||||
"overrides": []
|
||||
},
|
||||
"options": {
|
||||
"legend": {
|
||||
"displayMode": "table",
|
||||
"placement": "right",
|
||||
"calcs": [
|
||||
"max",
|
||||
"mean"
|
||||
]
|
||||
},
|
||||
"tooltip": {
|
||||
"mode": "single"
|
||||
}
|
||||
},
|
||||
"targets": [
|
||||
{
|
||||
"refId": "A",
|
||||
"datasource": {
|
||||
"type": "influxdb",
|
||||
"uid": "influxdb"
|
||||
},
|
||||
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync\" and r._field == \"duration_seconds\")\n |> filter(fn: (r) => r.version =~ /${version:regex}/)\n |> map(fn: (r) => ({ r with _value: r._value * 1000.0 }))\n |> keep(columns: [\"_time\", \"_value\", \"version\"])\n |> group(columns: [\"version\"])"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -19,7 +19,7 @@ const (
|
||||
defaultListenAddr = ":8087"
|
||||
defaultInfluxDBURL = "http://influxdb:8086/api/v2/write?org=netbird&bucket=metrics&precision=ns"
|
||||
maxBodySize = 50 * 1024 * 1024 // 50 MB max request body
|
||||
maxDurationSeconds = 86400.0 // reject any duration field > 24 hours
|
||||
maxDurationSeconds = 300.0 // reject any duration field > 5 minutes
|
||||
peerIDLength = 16 // truncated SHA-256: 8 bytes = 16 hex chars
|
||||
maxTagValueLength = 64 // reject tag values longer than this
|
||||
)
|
||||
@@ -59,19 +59,6 @@ var allowedMeasurements = map[string]measurementSpec{
|
||||
"peer_id": true,
|
||||
},
|
||||
},
|
||||
"netbird_sync_phase": {
|
||||
allowedFields: map[string]bool{
|
||||
"duration_seconds": true,
|
||||
},
|
||||
allowedTags: map[string]bool{
|
||||
"deployment_type": true,
|
||||
"version": true,
|
||||
"os": true,
|
||||
"arch": true,
|
||||
"peer_id": true,
|
||||
"phase": true,
|
||||
},
|
||||
},
|
||||
"netbird_login": {
|
||||
allowedFields: map[string]bool{
|
||||
"duration_seconds": true,
|
||||
|
||||
@@ -53,14 +53,14 @@ func TestValidateLine_NegativeValue(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestValidateLine_DurationTooLarge(t *testing.T) {
|
||||
line := `netbird_sync,deployment_type=cloud,version=1.0.0,os=linux,arch=amd64,peer_id=abc duration_seconds=100000 1234567890`
|
||||
line := `netbird_sync,deployment_type=cloud,version=1.0.0,os=linux,arch=amd64,peer_id=abc duration_seconds=999 1234567890`
|
||||
err := validateLine(line)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "too large")
|
||||
}
|
||||
|
||||
func TestValidateLine_TotalSecondsTooLarge(t *testing.T) {
|
||||
line := `netbird_peer_connection,deployment_type=cloud,connection_type=ice,attempt_type=initial,version=1.0.0,os=linux,arch=amd64,peer_id=abc,connection_pair_id=pair total_seconds=100000 1234567890`
|
||||
line := `netbird_peer_connection,deployment_type=cloud,connection_type=ice,attempt_type=initial,version=1.0.0,os=linux,arch=amd64,peer_id=abc,connection_pair_id=pair total_seconds=500 1234567890`
|
||||
err := validateLine(line)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "too large")
|
||||
|
||||
@@ -56,9 +56,6 @@ type metricsImplementation interface {
|
||||
// RecordSyncDuration records how long it took to process a sync message
|
||||
RecordSyncDuration(ctx context.Context, agentInfo AgentInfo, duration time.Duration)
|
||||
|
||||
// RecordSyncPhase records how long a single sub-phase of sync processing took
|
||||
RecordSyncPhase(ctx context.Context, agentInfo AgentInfo, phase string, duration time.Duration)
|
||||
|
||||
// RecordLoginDuration records how long the login to management took
|
||||
RecordLoginDuration(ctx context.Context, agentInfo AgentInfo, duration time.Duration, success bool)
|
||||
|
||||
@@ -130,18 +127,6 @@ func (c *ClientMetrics) RecordSyncDuration(ctx context.Context, duration time.Du
|
||||
c.impl.RecordSyncDuration(ctx, agentInfo, duration)
|
||||
}
|
||||
|
||||
// RecordSyncPhase records the duration of a single sub-phase of sync processing
|
||||
func (c *ClientMetrics) RecordSyncPhase(ctx context.Context, phase string, duration time.Duration) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.mu.RLock()
|
||||
agentInfo := c.agentInfo
|
||||
c.mu.RUnlock()
|
||||
|
||||
c.impl.RecordSyncPhase(ctx, agentInfo, phase, duration)
|
||||
}
|
||||
|
||||
// RecordLoginDuration records how long the login to management server took
|
||||
func (c *ClientMetrics) RecordLoginDuration(ctx context.Context, duration time.Duration, success bool) {
|
||||
if c == nil {
|
||||
|
||||
@@ -70,9 +70,6 @@ func (m *mockMetrics) RecordConnectionStages(_ context.Context, _ AgentInfo, _ s
|
||||
func (m *mockMetrics) RecordSyncDuration(_ context.Context, _ AgentInfo, _ time.Duration) {
|
||||
}
|
||||
|
||||
func (m *mockMetrics) RecordSyncPhase(_ context.Context, _ AgentInfo, _ string, _ time.Duration) {
|
||||
}
|
||||
|
||||
func (m *mockMetrics) RecordLoginDuration(_ context.Context, _ AgentInfo, _ time.Duration, _ bool) {
|
||||
}
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -137,39 +136,6 @@ type Conn struct {
|
||||
// Connection stage timestamps for metrics
|
||||
metricsRecorder MetricsRecorder
|
||||
metricsStages *MetricsStages
|
||||
|
||||
// pendingFirstPacket is the lazyconn-captured handshake init, replayed once the real
|
||||
// transport is up.
|
||||
pendingFirstPacket []byte
|
||||
}
|
||||
|
||||
// injectPendingFirstPacket replays the captured handshake through the proxy if present, else
|
||||
// directly through the ICE conn. The packet is cleared only after a successful write, so a failed
|
||||
// or transport-less attempt leaves it available for a later reinjection. Caller must hold conn.mu.
|
||||
func (conn *Conn) injectPendingFirstPacket(proxy wgproxy.Proxy, directConn net.Conn) {
|
||||
pkt := conn.pendingFirstPacket
|
||||
if len(pkt) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
switch {
|
||||
case proxy != nil:
|
||||
if err := proxy.InjectPacket(pkt); err != nil {
|
||||
conn.Log.Debugf("failed to reinject captured first packet via proxy: %v", err)
|
||||
return
|
||||
}
|
||||
case directConn != nil:
|
||||
if _, err := directConn.Write(pkt); err != nil {
|
||||
conn.Log.Debugf("failed to reinject captured first packet via direct conn: %v", err)
|
||||
return
|
||||
}
|
||||
default:
|
||||
conn.Log.Debugf("no transport available to reinject captured first packet")
|
||||
return
|
||||
}
|
||||
|
||||
conn.pendingFirstPacket = nil
|
||||
conn.Log.Debugf("reinjected captured first packet (%d bytes)", len(pkt))
|
||||
}
|
||||
|
||||
// NewConn creates a new not opened Conn to the remote peer.
|
||||
@@ -206,16 +172,6 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
|
||||
// It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will
|
||||
// be used.
|
||||
func (conn *Conn) Open(engineCtx context.Context) error {
|
||||
return conn.open(engineCtx, nil)
|
||||
}
|
||||
|
||||
// OpenWithFirstPacket opens the connection like Open and stashes firstPacket to be replayed once
|
||||
// the real transport is established. The packet is retained only on a successful open.
|
||||
func (conn *Conn) OpenWithFirstPacket(engineCtx context.Context, firstPacket []byte) error {
|
||||
return conn.open(engineCtx, firstPacket)
|
||||
}
|
||||
|
||||
func (conn *Conn) open(engineCtx context.Context, firstPacket []byte) error {
|
||||
conn.mu.Lock()
|
||||
defer conn.mu.Unlock()
|
||||
|
||||
@@ -271,9 +227,6 @@ func (conn *Conn) open(engineCtx context.Context, firstPacket []byte) error {
|
||||
defer conn.wg.Done()
|
||||
conn.guard.Start(conn.ctx, conn.onGuardEvent)
|
||||
}()
|
||||
if len(firstPacket) > 0 {
|
||||
conn.pendingFirstPacket = slices.Clone(firstPacket)
|
||||
}
|
||||
conn.opened = true
|
||||
return nil
|
||||
}
|
||||
@@ -470,8 +423,6 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
|
||||
conn.wgProxyRelay.RedirectAs(ep)
|
||||
}
|
||||
|
||||
conn.injectPendingFirstPacket(wgProxy, iceConnInfo.RemoteConn)
|
||||
|
||||
conn.currentConnPriority = priority
|
||||
conn.statusICE.SetConnected()
|
||||
conn.updateIceState(iceConnInfo, updateTime)
|
||||
@@ -595,8 +546,6 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
||||
|
||||
wgConfigWorkaround()
|
||||
|
||||
conn.injectPendingFirstPacket(wgProxy, nil)
|
||||
|
||||
conn.rosenpassRemoteKey = rci.rosenpassPubKey
|
||||
conn.currentConnPriority = conntype.Relay
|
||||
conn.statusRelay.SetConnected()
|
||||
|
||||
@@ -54,19 +54,15 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
||||
w.relaySupportedOnRemotePeer.Store(true)
|
||||
|
||||
// the relayManager will return with error in case if the connection has lost with relay server
|
||||
currentRelayAddress, _, err := w.relayManager.RelayInstanceAddress()
|
||||
_, _, err := w.relayManager.RelayInstanceAddress()
|
||||
if err != nil {
|
||||
w.log.Errorf("failed to handle new offer: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
srv := w.preferredRelayServer(currentRelayAddress, remoteOfferAnswer.RelaySrvAddress)
|
||||
var serverIP netip.Addr
|
||||
if srv == remoteOfferAnswer.RelaySrvAddress {
|
||||
serverIP = remoteOfferAnswer.RelaySrvIP
|
||||
}
|
||||
|
||||
relayedConn, err := w.relayManager.OpenConn(w.peerCtx, srv, w.config.Key, serverIP)
|
||||
preferForeign := !w.isController
|
||||
remoteRelayServer := relayClient.RelayServer{Addr: remoteOfferAnswer.RelaySrvAddress, IP: remoteOfferAnswer.RelaySrvIP}
|
||||
relayedConn, err := w.relayManager.OpenConn(w.peerCtx, remoteRelayServer, w.config.Key, preferForeign)
|
||||
if err != nil {
|
||||
if errors.Is(err, relayClient.ErrConnAlreadyExists) {
|
||||
w.log.Debugf("handled offer by reusing existing relay connection")
|
||||
@@ -80,14 +76,13 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
||||
w.relayedConn = relayedConn
|
||||
w.relayLock.Unlock()
|
||||
|
||||
err = w.relayManager.AddCloseListener(srv, w.onRelayClientDisconnected)
|
||||
if err != nil {
|
||||
log.Errorf("failed to add close listener: %s", err)
|
||||
if err := w.relayManager.AddCloseListener(relayedConn.RemoteAddr().String(), w.onRelayClientDisconnected); err != nil {
|
||||
w.log.Errorf("failed to add close listener: %s", err)
|
||||
_ = relayedConn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
w.log.Debugf("peer conn opened via Relay: %s", srv)
|
||||
w.log.Debugf("peer conn opened via Relay: %s", relayedConn.RemoteAddr())
|
||||
go w.conn.onRelayConnectionIsReady(RelayConnInfo{
|
||||
relayedConn: relayedConn,
|
||||
rosenpassPubKey: remoteOfferAnswer.RosenpassPubKey,
|
||||
@@ -126,13 +121,6 @@ func (w *WorkerRelay) isRelaySupported(answer *OfferAnswer) bool {
|
||||
return answer.RelaySrvAddress != ""
|
||||
}
|
||||
|
||||
func (w *WorkerRelay) preferredRelayServer(myRelayAddress, remoteRelayAddress string) string {
|
||||
if w.isController {
|
||||
return myRelayAddress
|
||||
}
|
||||
return remoteRelayAddress
|
||||
}
|
||||
|
||||
func (w *WorkerRelay) onRelayClientDisconnected() {
|
||||
go w.conn.onRelayDisconnected()
|
||||
}
|
||||
|
||||
@@ -88,24 +88,11 @@ func (s *Store) PeerConnOpen(ctx context.Context, pubKey string) {
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
// this can be blocked because of the connect open limiter semaphore
|
||||
if err := p.Open(ctx); err != nil {
|
||||
p.Log.Errorf("failed to open peer connection: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// PeerConnOpenWithFirstPacket opens the peer connection and stashes a first packet to be
|
||||
// reinjected once the real transport is established.
|
||||
func (s *Store) PeerConnOpenWithFirstPacket(ctx context.Context, pubKey string, firstPacket []byte) {
|
||||
s.peerConnsMu.RLock()
|
||||
defer s.peerConnsMu.RUnlock()
|
||||
|
||||
p, ok := s.peerConns[pubKey]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if err := p.OpenWithFirstPacket(ctx, firstPacket); err != nil {
|
||||
p.Log.Errorf("failed to open peer connection: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Store) PeerConnIdle(pubKey string) {
|
||||
|
||||
@@ -386,7 +386,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.NetworkMonitor != nil && (config.NetworkMonitor == nil || *input.NetworkMonitor != *config.NetworkMonitor) {
|
||||
if input.NetworkMonitor != nil && input.NetworkMonitor != config.NetworkMonitor {
|
||||
log.Infof("switching Network Monitor to %t", *input.NetworkMonitor)
|
||||
config.NetworkMonitor = input.NetworkMonitor
|
||||
updated = true
|
||||
@@ -454,7 +454,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.EnableSSHRoot != nil && (config.EnableSSHRoot == nil || *input.EnableSSHRoot != *config.EnableSSHRoot) {
|
||||
if input.EnableSSHRoot != nil && input.EnableSSHRoot != config.EnableSSHRoot {
|
||||
if *input.EnableSSHRoot {
|
||||
log.Infof("enabling SSH root login")
|
||||
} else {
|
||||
@@ -464,7 +464,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.EnableSSHSFTP != nil && (config.EnableSSHSFTP == nil || *input.EnableSSHSFTP != *config.EnableSSHSFTP) {
|
||||
if input.EnableSSHSFTP != nil && input.EnableSSHSFTP != config.EnableSSHSFTP {
|
||||
if *input.EnableSSHSFTP {
|
||||
log.Infof("enabling SSH SFTP subsystem")
|
||||
} else {
|
||||
@@ -474,7 +474,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.EnableSSHLocalPortForwarding != nil && (config.EnableSSHLocalPortForwarding == nil || *input.EnableSSHLocalPortForwarding != *config.EnableSSHLocalPortForwarding) {
|
||||
if input.EnableSSHLocalPortForwarding != nil && input.EnableSSHLocalPortForwarding != config.EnableSSHLocalPortForwarding {
|
||||
if *input.EnableSSHLocalPortForwarding {
|
||||
log.Infof("enabling SSH local port forwarding")
|
||||
} else {
|
||||
@@ -484,7 +484,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.EnableSSHRemotePortForwarding != nil && (config.EnableSSHRemotePortForwarding == nil || *input.EnableSSHRemotePortForwarding != *config.EnableSSHRemotePortForwarding) {
|
||||
if input.EnableSSHRemotePortForwarding != nil && input.EnableSSHRemotePortForwarding != config.EnableSSHRemotePortForwarding {
|
||||
if *input.EnableSSHRemotePortForwarding {
|
||||
log.Infof("enabling SSH remote port forwarding")
|
||||
} else {
|
||||
@@ -494,7 +494,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.DisableSSHAuth != nil && (config.DisableSSHAuth == nil || *input.DisableSSHAuth != *config.DisableSSHAuth) {
|
||||
if input.DisableSSHAuth != nil && input.DisableSSHAuth != config.DisableSSHAuth {
|
||||
if *input.DisableSSHAuth {
|
||||
log.Infof("disabling SSH authentication")
|
||||
} else {
|
||||
@@ -504,7 +504,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.SSHJWTCacheTTL != nil && (config.SSHJWTCacheTTL == nil || *input.SSHJWTCacheTTL != *config.SSHJWTCacheTTL) {
|
||||
if input.SSHJWTCacheTTL != nil && input.SSHJWTCacheTTL != config.SSHJWTCacheTTL {
|
||||
log.Infof("updating SSH JWT cache TTL to %d seconds", *input.SSHJWTCacheTTL)
|
||||
config.SSHJWTCacheTTL = input.SSHJWTCacheTTL
|
||||
updated = true
|
||||
@@ -587,7 +587,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.DisableNotifications != nil && (config.DisableNotifications == nil || *input.DisableNotifications != *config.DisableNotifications) {
|
||||
if input.DisableNotifications != nil && input.DisableNotifications != config.DisableNotifications {
|
||||
if *input.DisableNotifications {
|
||||
log.Infof("disabling notifications")
|
||||
} else {
|
||||
|
||||
@@ -226,11 +226,12 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
return
|
||||
}
|
||||
|
||||
// All query types for an intercepted domain are forwarded to the peer's
|
||||
// DNS forwarder, which owns the name. Falling through to the system
|
||||
// resolver would let it answer NXDOMAIN for a name it isn't authoritative
|
||||
// for, poisoning the whole name (including the A/AAAA records the route
|
||||
// does serve). The forwarder answers NODATA for types it cannot resolve.
|
||||
// pass if non A/AAAA query
|
||||
if r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA {
|
||||
d.continueToNextHandler(w, r, logger, "non A/AAAA query")
|
||||
return
|
||||
}
|
||||
|
||||
d.mu.RLock()
|
||||
peerKey := d.currentPeerKey
|
||||
d.mu.RUnlock()
|
||||
@@ -292,6 +293,19 @@ func (d *DnsInterceptor) writeDNSError(w dns.ResponseWriter, r *dns.Msg, logger
|
||||
}
|
||||
}
|
||||
|
||||
// continueToNextHandler signals the handler chain to try the next handler
|
||||
func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry, reason string) {
|
||||
logger.Tracef("continuing to next handler for domain=%s reason=%s", r.Question[0].Name, reason)
|
||||
|
||||
resp := new(dns.Msg)
|
||||
resp.SetRcode(r, dns.RcodeNameError)
|
||||
// Set Zero bit to signal handler chain to continue
|
||||
resp.MsgHdr.Zero = true
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
logger.Errorf("failed writing DNS continue response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DnsInterceptor) getUpstreamIP(peerKey string) (netip.Addr, error) {
|
||||
peerAllowedIP, exists := d.peerStore.AllowedIP(peerKey)
|
||||
if !exists {
|
||||
|
||||
@@ -2,11 +2,9 @@ package system
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc/metadata"
|
||||
@@ -176,7 +174,7 @@ func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks, excludeIPs .
|
||||
processCheckPaths = append(processCheckPaths, check.GetFiles()...)
|
||||
}
|
||||
|
||||
files, err := checkFileAndProcess(ctx, processCheckPaths)
|
||||
files, err := checkFileAndProcess(processCheckPaths)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -189,43 +187,3 @@ func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks, excludeIPs .
|
||||
log.Debugf("all system information gathered successfully")
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// GetInfoWithChecksTimeout is GetInfoWithChecks bounded by timeout. Posture-check gathering
|
||||
// runs uncancellable system calls (process enumeration, os.Stat), so calling it inline can
|
||||
// block the caller for as long as such a call hangs. It runs in a goroutine instead: if it
|
||||
// does not return within timeout the caller gets (nil, false) and should proceed with
|
||||
// degraded behavior rather than block. On a gathering error it falls back to base GetInfo.
|
||||
//
|
||||
// The buffered channel lets the abandoned goroutine finish and exit once its blocking call
|
||||
// returns, so it does not leak beyond the duration of that call.
|
||||
func GetInfoWithChecksTimeout(ctx context.Context, timeout time.Duration, checks []*proto.Checks, excludeIPs ...netip.Addr) (*Info, bool) {
|
||||
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
infoCh := make(chan *Info, 1)
|
||||
go func() {
|
||||
info, err := GetInfoWithChecks(ctx, checks, excludeIPs...)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
log.Warnf("failed to get system info with checks: %v", err)
|
||||
info = GetInfo(ctx)
|
||||
info.removeAddresses(excludeIPs...)
|
||||
}
|
||||
infoCh <- info
|
||||
}()
|
||||
|
||||
select {
|
||||
case info := <-infoCh:
|
||||
return info, true
|
||||
case <-ctx.Done():
|
||||
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
|
||||
log.Warnf("gathering system info with checks timed out after %s", timeout)
|
||||
} else {
|
||||
// Parent context canceled (e.g. shutdown), not a timeout.
|
||||
log.Warnf("gathering system info with checks canceled: %v", ctx.Err())
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -50,7 +50,7 @@ func GetInfo(ctx context.Context) *Info {
|
||||
}
|
||||
|
||||
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
|
||||
func checkFileAndProcess(_ context.Context, _ []string) ([]File, error) {
|
||||
func checkFileAndProcess(paths []string) ([]File, error) {
|
||||
return []File{}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ func GetInfo(ctx context.Context) *Info {
|
||||
sysName := string(bytes.Split(utsname.Sysname[:], []byte{0})[0])
|
||||
machine := string(bytes.Split(utsname.Machine[:], []byte{0})[0])
|
||||
release := string(bytes.Split(utsname.Release[:], []byte{0})[0])
|
||||
swVersion, err := exec.CommandContext(ctx, "sw_vers", "-productVersion").Output()
|
||||
swVersion, err := exec.Command("sw_vers", "-productVersion").Output()
|
||||
if err != nil {
|
||||
log.Warnf("got an error while retrieving macOS version with sw_vers, error: %s. Using darwin version instead.\n", err)
|
||||
swVersion = []byte(release)
|
||||
|
||||
@@ -105,7 +105,7 @@ func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
|
||||
}
|
||||
|
||||
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
|
||||
func checkFileAndProcess(_ context.Context, _ []string) ([]File, error) {
|
||||
func checkFileAndProcess(paths []string) ([]File, error) {
|
||||
return []File{}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -103,7 +103,7 @@ func collectLocationInfo(info *Info) {
|
||||
}
|
||||
}
|
||||
|
||||
func checkFileAndProcess(_ context.Context, _ []string) ([]File, error) {
|
||||
func checkFileAndProcess(_ []string) ([]File, error) {
|
||||
return []File{}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc/metadata"
|
||||
@@ -36,20 +35,6 @@ func Test_CustomHostname(t *testing.T) {
|
||||
assert.Equal(t, want, got.Hostname)
|
||||
}
|
||||
|
||||
func TestGetInfoWithChecksTimeout_Success(t *testing.T) {
|
||||
info, ok := GetInfoWithChecksTimeout(context.Background(), 30*time.Second, nil)
|
||||
assert.True(t, ok, "expected gathering to complete within the timeout")
|
||||
assert.NotNil(t, info)
|
||||
}
|
||||
|
||||
func TestGetInfoWithChecksTimeout_Timeout(t *testing.T) {
|
||||
// A 1ns budget expires before the (real) system-info gathering can finish, so the
|
||||
// caller must get (nil, false) instead of blocking on the in-flight goroutine.
|
||||
info, ok := GetInfoWithChecksTimeout(context.Background(), time.Nanosecond, nil)
|
||||
assert.False(t, ok, "expected timeout to be reported")
|
||||
assert.Nil(t, info)
|
||||
}
|
||||
|
||||
func Test_NetAddresses(t *testing.T) {
|
||||
addr, err := networkAddresses()
|
||||
if err != nil {
|
||||
|
||||
@@ -3,30 +3,24 @@
|
||||
package system
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"slices"
|
||||
|
||||
"github.com/shirou/gopsutil/v3/process"
|
||||
)
|
||||
|
||||
// getRunningProcesses returns a list of running process paths. The context bounds the work:
|
||||
// the per-PID loop bails as soon as ctx is done, and the gopsutil calls honor it where they
|
||||
// can, so a stuck enumeration cannot run unbounded.
|
||||
func getRunningProcesses(ctx context.Context) ([]string, error) {
|
||||
processIDs, err := process.PidsWithContext(ctx)
|
||||
// getRunningProcesses returns a list of running process paths.
|
||||
func getRunningProcesses() ([]string, error) {
|
||||
processIDs, err := process.Pids()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
processMap := make(map[string]bool)
|
||||
for _, pID := range processIDs {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p := &process.Process{Pid: pID}
|
||||
|
||||
path, _ := p.ExeWithContext(ctx)
|
||||
path, _ := p.Exe()
|
||||
if path != "" {
|
||||
processMap[path] = false
|
||||
}
|
||||
@@ -41,21 +35,18 @@ func getRunningProcesses(ctx context.Context) ([]string, error) {
|
||||
}
|
||||
|
||||
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
|
||||
func checkFileAndProcess(ctx context.Context, paths []string) ([]File, error) {
|
||||
func checkFileAndProcess(paths []string) ([]File, error) {
|
||||
files := make([]File, len(paths))
|
||||
if len(paths) == 0 {
|
||||
return files, nil
|
||||
}
|
||||
|
||||
runningProcesses, err := getRunningProcesses(ctx)
|
||||
runningProcesses, err := getRunningProcesses()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i, path := range paths {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
file := File{Path: path}
|
||||
|
||||
_, err := os.Stat(path)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package system
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/shirou/gopsutil/v3/process"
|
||||
@@ -10,7 +9,7 @@ import (
|
||||
func Benchmark_getRunningProcesses(b *testing.B) {
|
||||
b.Run("getRunningProcesses new", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
ps, err := getRunningProcesses(context.Background())
|
||||
ps, err := getRunningProcesses()
|
||||
if err != nil {
|
||||
b.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -30,38 +29,12 @@ func Benchmark_getRunningProcesses(b *testing.B) {
|
||||
}
|
||||
}
|
||||
})
|
||||
s, _ := getRunningProcesses(context.Background())
|
||||
s, _ := getRunningProcesses()
|
||||
b.Logf("getRunningProcesses returned %d processes", len(s))
|
||||
s, _ = getRunningProcessesOld()
|
||||
b.Logf("getRunningProcessesOld returned %d processes", len(s))
|
||||
}
|
||||
|
||||
func TestCheckFileAndProcess_ContextCanceled(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
// With a canceled context and non-empty paths the gathering must bail with an error
|
||||
// instead of running the (potentially blocking) process scan / stat loop.
|
||||
if _, err := checkFileAndProcess(ctx, []string{"/does/not/exist"}); err == nil {
|
||||
t.Fatal("expected error on canceled context, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckFileAndProcess_EmptyPaths(t *testing.T) {
|
||||
// No check paths means no work to do: it must return immediately with no error,
|
||||
// even on a canceled context (nothing to scan or stat).
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
files, err := checkFileAndProcess(ctx, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error for empty paths: %v", err)
|
||||
}
|
||||
if len(files) != 0 {
|
||||
t.Fatalf("expected no files, got %d", len(files))
|
||||
}
|
||||
}
|
||||
|
||||
func getRunningProcessesOld() ([]string, error) {
|
||||
processes, err := process.Processes()
|
||||
if err != nil {
|
||||
|
||||
@@ -9,8 +9,6 @@ set -o pipefail
|
||||
|
||||
SED_STRIP_PADDING='s/=//g'
|
||||
|
||||
NETBIRD_EULA_URL="https://netbird.io/self-hosted-EULA"
|
||||
|
||||
check_docker_compose() {
|
||||
if command -v docker-compose &> /dev/null; then
|
||||
echo "docker-compose"
|
||||
@@ -141,43 +139,6 @@ read_yes_no() {
|
||||
esac
|
||||
}
|
||||
|
||||
# Gate the install on explicit acceptance of the NetBird On-Premise EULA.
|
||||
require_eula_acceptance() {
|
||||
cat > /dev/stderr <<EOF
|
||||
|
||||
──────────────────────────────────────────────────────────────────────
|
||||
NetBird On-Premise End User License Agreement
|
||||
──────────────────────────────────────────────────────────────────────
|
||||
NetBird's on-premise software is commercial software, licensed and not
|
||||
sold. Your installation, deployment and use are governed by the NetBird
|
||||
On-Premise End User License Agreement (the "EULA"). Please read the EULA
|
||||
in full before continuing:
|
||||
|
||||
${NETBIRD_EULA_URL}
|
||||
|
||||
By typing "accept" and continuing the installation, you confirm that you
|
||||
have read and agree to the EULA, that you are authorized to accept it on
|
||||
behalf of your organization (the "Customer"), and that the Software is
|
||||
used for business purposes only.
|
||||
──────────────────────────────────────────────────────────────────────
|
||||
EOF
|
||||
|
||||
if [[ "${NB_ACCEPT_EULA:-}" == "yes" ]]; then
|
||||
echo "EULA accepted via NB_ACCEPT_EULA=yes." > /dev/stderr
|
||||
return 0
|
||||
fi
|
||||
|
||||
local ans=""
|
||||
echo -n 'Type "accept" to agree, or anything else to abort: ' > /dev/stderr
|
||||
read -r ans < /dev/tty
|
||||
if [[ "$ans" != "accept" ]]; then
|
||||
echo "" > /dev/stderr
|
||||
echo "EULA not accepted. Aborting installation." > /dev/stderr
|
||||
exit 1
|
||||
fi
|
||||
echo "" > /dev/stderr
|
||||
}
|
||||
|
||||
wait_postgres() {
|
||||
set +e
|
||||
echo -n "Waiting for postgres to become ready"
|
||||
@@ -213,9 +174,6 @@ init_environment() {
|
||||
exit 1
|
||||
fi
|
||||
|
||||
require_eula_acceptance
|
||||
NETBIRD_EULA_ACCEPTED_AT=$(date -u +%Y-%m-%dT%H:%M:%SZ)
|
||||
|
||||
echo "NetBird Enterprise bootstrap"
|
||||
echo ""
|
||||
echo "Traffic flow:"
|
||||
@@ -302,11 +260,6 @@ render_env() {
|
||||
# Generated by getting-started-enterprise.sh
|
||||
# Holds all configuration and secrets for the stack. Mode 600.
|
||||
|
||||
# NetBird On-Premise EULA acceptance
|
||||
NETBIRD_EULA_ACCEPTED=yes
|
||||
NETBIRD_EULA_ACCEPTED_AT=${NETBIRD_EULA_ACCEPTED_AT}
|
||||
NETBIRD_EULA_URL=${NETBIRD_EULA_URL}
|
||||
|
||||
# Features (set by the script; don't edit without re-running)
|
||||
NETBIRD_TRAFFIC_FLOW_ENABLED=${NETBIRD_TRAFFIC_FLOW}
|
||||
|
||||
|
||||
@@ -25,8 +25,6 @@ set -o pipefail
|
||||
OVERRIDE_FILE="docker-compose.override.yml"
|
||||
ENTERPRISE_CONFIG_FILE="config.yaml.enterprise"
|
||||
|
||||
NETBIRD_EULA_URL="https://netbird.io/self-hosted-EULA"
|
||||
|
||||
check_docker_compose() {
|
||||
if command -v docker-compose &> /dev/null; then
|
||||
echo "docker-compose"
|
||||
@@ -117,43 +115,6 @@ read_yes_no() {
|
||||
esac
|
||||
}
|
||||
|
||||
# Gate the migration on explicit acceptance of the NetBird On-Premise EULA.
|
||||
require_eula_acceptance() {
|
||||
cat > /dev/stderr <<EOF
|
||||
|
||||
──────────────────────────────────────────────────────────────────────
|
||||
NetBird On-Premise End User License Agreement
|
||||
──────────────────────────────────────────────────────────────────────
|
||||
NetBird's on-premise software is commercial software, licensed and not
|
||||
sold. Your installation, deployment and use are governed by the NetBird
|
||||
On-Premise End User License Agreement (the "EULA"). Please read the EULA
|
||||
in full before continuing:
|
||||
|
||||
${NETBIRD_EULA_URL}
|
||||
|
||||
By typing "accept" and continuing the installation, you confirm that you
|
||||
have read and agree to the EULA, that you are authorized to accept it on
|
||||
behalf of your organization (the "Customer"), and that the Software is
|
||||
used for business purposes only.
|
||||
──────────────────────────────────────────────────────────────────────
|
||||
EOF
|
||||
|
||||
if [[ "${NB_ACCEPT_EULA:-}" == "yes" ]]; then
|
||||
echo "EULA accepted via NB_ACCEPT_EULA=yes." > /dev/stderr
|
||||
return 0
|
||||
fi
|
||||
|
||||
local ans=""
|
||||
echo -n 'Type "accept" to agree, or anything else to abort: ' > /dev/stderr
|
||||
read -r ans < /dev/tty
|
||||
if [[ "$ans" != "accept" ]]; then
|
||||
echo "" > /dev/stderr
|
||||
echo "EULA not accepted. Aborting migration." > /dev/stderr
|
||||
exit 1
|
||||
fi
|
||||
echo "" > /dev/stderr
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Detection — read the operator's existing compose to find service names and
|
||||
# paths we need to override. Bail loudly if shape isn't recognised.
|
||||
@@ -475,9 +436,6 @@ init_migration() {
|
||||
echo " Network: $COMPOSE_NETWORK"
|
||||
echo ""
|
||||
|
||||
require_eula_acceptance
|
||||
NETBIRD_EULA_ACCEPTED_AT=$(date -u +%Y-%m-%dT%H:%M:%SZ)
|
||||
|
||||
local proceed
|
||||
proceed=$(read_yes_no "Proceed with migration?" "y")
|
||||
if [[ "$proceed" != "yes" ]]; then
|
||||
@@ -571,10 +529,6 @@ apply_changes() {
|
||||
{
|
||||
echo ""
|
||||
echo "# Added by migrate-to-enterprise.sh on $(date -u +%Y-%m-%dT%H:%M:%SZ)"
|
||||
echo "# NetBird On-Premise EULA accepted at install time"
|
||||
echo "NETBIRD_EULA_ACCEPTED=yes"
|
||||
echo "NETBIRD_EULA_ACCEPTED_AT=${NETBIRD_EULA_ACCEPTED_AT}"
|
||||
echo "NETBIRD_EULA_URL=${NETBIRD_EULA_URL}"
|
||||
echo "NB_LICENSE_KEY=${NB_LICENSE_KEY}"
|
||||
if [[ -n "${NETBIRD_LICENSE_SERVER_BASE_URL:-}" ]]; then
|
||||
echo "NETBIRD_LICENSE_SERVER_BASE_URL=${NETBIRD_LICENSE_SERVER_BASE_URL}"
|
||||
|
||||
@@ -55,14 +55,6 @@ type GrpcClient struct {
|
||||
connStateCallback ConnStateNotifier
|
||||
connStateCallbackLock sync.RWMutex
|
||||
serverURL string
|
||||
|
||||
// syncStreamErr holds the last Sync stream error, or nil while the stream
|
||||
// is established and healthy. GetServerKey succeeds even when the peer
|
||||
// cannot sync (e.g. the server returns "settings not found"), so the
|
||||
// health probe must consult this to avoid reporting a healthy management
|
||||
// connection while the Sync stream keeps failing.
|
||||
syncStreamMu sync.RWMutex
|
||||
syncStreamErr error
|
||||
}
|
||||
|
||||
type ExposeRequest struct {
|
||||
@@ -372,8 +364,6 @@ func (c *GrpcClient) handleSyncStream(ctx context.Context, serverPubKey wgtypes.
|
||||
stream, err := c.connectToSyncStream(ctx, serverPubKey, sysInfo)
|
||||
if err != nil {
|
||||
log.Debugf("failed to open Management Service stream: %s", err)
|
||||
c.notifyDisconnected(err)
|
||||
c.setSyncStreamDisconnected(err)
|
||||
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.PermissionDenied {
|
||||
return backoff.Permanent(err) // unrecoverable error, propagate to the upper layer
|
||||
}
|
||||
@@ -382,13 +372,11 @@ func (c *GrpcClient) handleSyncStream(ctx context.Context, serverPubKey wgtypes.
|
||||
|
||||
log.Infof("connected to the Management Service stream")
|
||||
c.notifyConnected()
|
||||
c.setSyncStreamConnected()
|
||||
|
||||
// blocking until error
|
||||
err = c.receiveUpdatesEvents(stream, serverPubKey, msgHandler)
|
||||
if err != nil {
|
||||
c.notifyDisconnected(err)
|
||||
c.setSyncStreamDisconnected(err)
|
||||
if ctx.Err() != nil {
|
||||
log.Debugf("management connection context has been canceled, this usually indicates shutdown")
|
||||
return nil
|
||||
@@ -542,13 +530,6 @@ func (c *GrpcClient) IsHealthy() bool {
|
||||
log.Warnf("health check returned: %s", err)
|
||||
return false
|
||||
}
|
||||
|
||||
if syncErr := c.syncStreamError(); syncErr != nil {
|
||||
c.notifyDisconnected(syncErr)
|
||||
log.Warnf("management transport is up but the Sync stream is unhealthy: %s", syncErr)
|
||||
return false
|
||||
}
|
||||
|
||||
c.notifyConnected()
|
||||
return true
|
||||
}
|
||||
@@ -790,24 +771,6 @@ func (c *GrpcClient) SyncMeta(sysInfo *system.Info) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *GrpcClient) setSyncStreamConnected() {
|
||||
c.syncStreamMu.Lock()
|
||||
defer c.syncStreamMu.Unlock()
|
||||
c.syncStreamErr = nil
|
||||
}
|
||||
|
||||
func (c *GrpcClient) setSyncStreamDisconnected(err error) {
|
||||
c.syncStreamMu.Lock()
|
||||
defer c.syncStreamMu.Unlock()
|
||||
c.syncStreamErr = err
|
||||
}
|
||||
|
||||
func (c *GrpcClient) syncStreamError() error {
|
||||
c.syncStreamMu.RLock()
|
||||
defer c.syncStreamMu.RUnlock()
|
||||
return c.syncStreamErr
|
||||
}
|
||||
|
||||
func (c *GrpcClient) notifyDisconnected(err error) {
|
||||
c.connStateCallbackLock.RLock()
|
||||
defer c.connStateCallbackLock.RUnlock()
|
||||
|
||||
158
shared/relay/client/fallback.go
Normal file
158
shared/relay/client/fallback.go
Normal file
@@ -0,0 +1,158 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
raceTotalTimeout = 60 * time.Second
|
||||
raceFallbackDelay = 10 * time.Second
|
||||
)
|
||||
|
||||
type raceAttempt struct {
|
||||
conn net.Conn
|
||||
err error
|
||||
}
|
||||
|
||||
type connRace struct {
|
||||
racer *ConnRacer
|
||||
peerKey string
|
||||
remoteRelayServer RelayServer
|
||||
preferForeign bool
|
||||
|
||||
raceCtx context.Context
|
||||
otherCtx context.Context
|
||||
cancelPreferred context.CancelFunc
|
||||
cancelOther context.CancelFunc
|
||||
results chan raceAttempt
|
||||
fallbackTimer *time.Timer
|
||||
|
||||
otherStarted bool
|
||||
settled int
|
||||
lastErr error
|
||||
}
|
||||
|
||||
type ConnRacer struct {
|
||||
home *Client
|
||||
foreignStore *ForeignRelaysStore
|
||||
}
|
||||
|
||||
func NewConnRacer(home *Client, foreignStore *ForeignRelaysStore) *ConnRacer {
|
||||
return &ConnRacer{
|
||||
home: home,
|
||||
foreignStore: foreignStore,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *ConnRacer) Run(ctx context.Context, peerKey string, remoteRelayServer RelayServer, preferForeign bool) (net.Conn, error) {
|
||||
raceCtx, cancel := context.WithTimeout(ctx, raceTotalTimeout)
|
||||
defer cancel()
|
||||
|
||||
preferredCtx, cancelPreferred := context.WithCancel(raceCtx)
|
||||
otherCtx, cancelOther := context.WithCancel(raceCtx)
|
||||
|
||||
race := &connRace{
|
||||
racer: r,
|
||||
peerKey: peerKey,
|
||||
remoteRelayServer: remoteRelayServer,
|
||||
preferForeign: preferForeign,
|
||||
raceCtx: raceCtx,
|
||||
otherCtx: otherCtx,
|
||||
cancelPreferred: cancelPreferred,
|
||||
cancelOther: cancelOther,
|
||||
results: make(chan raceAttempt, 2),
|
||||
fallbackTimer: time.NewTimer(raceFallbackDelay),
|
||||
}
|
||||
defer race.fallbackTimer.Stop()
|
||||
|
||||
go func() {
|
||||
race.results <- r.open(preferredCtx, peerKey, remoteRelayServer, preferForeign)
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-race.fallbackTimer.C:
|
||||
race.startOther()
|
||||
case res := <-race.results:
|
||||
if conn, err, done := race.handleResult(res); done {
|
||||
return conn, err
|
||||
}
|
||||
case <-raceCtx.Done():
|
||||
return race.onTimeout()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *connRace) startOther() {
|
||||
if c.otherStarted {
|
||||
return
|
||||
}
|
||||
c.otherStarted = true
|
||||
c.fallbackTimer.Stop()
|
||||
go func() {
|
||||
c.results <- c.racer.open(c.otherCtx, c.peerKey, c.remoteRelayServer, !c.preferForeign)
|
||||
}()
|
||||
}
|
||||
|
||||
func (c *connRace) handleResult(res raceAttempt) (net.Conn, error, bool) {
|
||||
if (res.err == nil && res.conn != nil) || errors.Is(res.err, ErrConnAlreadyExists) {
|
||||
c.stop()
|
||||
return res.conn, res.err, true
|
||||
}
|
||||
|
||||
c.lastErr = res.err
|
||||
c.settled++
|
||||
if !c.otherStarted {
|
||||
c.startOther()
|
||||
return nil, nil, false
|
||||
}
|
||||
if c.settled == 2 {
|
||||
c.cancelPreferred()
|
||||
c.cancelOther()
|
||||
return nil, c.lastErr, true
|
||||
}
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
func (c *connRace) onTimeout() (net.Conn, error) {
|
||||
c.stop()
|
||||
if c.lastErr != nil {
|
||||
return nil, c.lastErr
|
||||
}
|
||||
return nil, c.raceCtx.Err()
|
||||
}
|
||||
|
||||
func (c *connRace) stop() {
|
||||
c.cancelPreferred()
|
||||
c.cancelOther()
|
||||
go c.racer.drainLoser(c.results, c.settled, c.otherStarted)
|
||||
}
|
||||
|
||||
func (r *ConnRacer) open(ctx context.Context, peerKey string, remoteRelayServer RelayServer, foreign bool) raceAttempt {
|
||||
if foreign {
|
||||
conn, err := r.foreignStore.OpenConn(ctx, peerKey, remoteRelayServer)
|
||||
return raceAttempt{conn: conn, err: err}
|
||||
}
|
||||
conn, err := r.home.OpenConn(ctx, peerKey)
|
||||
return raceAttempt{conn: conn, err: err}
|
||||
}
|
||||
|
||||
func (r *ConnRacer) drainLoser(results chan raceAttempt, settled int, otherStarted bool) {
|
||||
started := 1
|
||||
if otherStarted {
|
||||
started = 2
|
||||
}
|
||||
for i := settled; i < started; i++ {
|
||||
res := <-results
|
||||
if res.conn != nil {
|
||||
if err := res.conn.Close(); err != nil {
|
||||
log.Debugf("failed to close losing relay connection: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
155
shared/relay/client/foreign_relays.go
Normal file
155
shared/relay/client/foreign_relays.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sync/singleflight"
|
||||
|
||||
relayAuth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
|
||||
)
|
||||
|
||||
type foreignRelay struct {
|
||||
client *Client
|
||||
created time.Time
|
||||
inUse int
|
||||
}
|
||||
|
||||
type ForeignRelaysStore struct {
|
||||
mu sync.RWMutex
|
||||
clients map[string]*foreignRelay
|
||||
|
||||
group singleflight.Group
|
||||
|
||||
ctx context.Context
|
||||
tokenStore *relayAuth.TokenStore
|
||||
peerID string
|
||||
mtu uint16
|
||||
transportFallback *transportFallback
|
||||
onDisconnect func(string)
|
||||
keepUnusedServerTime time.Duration
|
||||
}
|
||||
|
||||
func NewForeignRelaysStore(ctx context.Context, tokenStore *relayAuth.TokenStore, peerID string, mtu uint16, transportFallback *transportFallback, onDisconnect func(string), keepUnusedServerTime time.Duration) *ForeignRelaysStore {
|
||||
return &ForeignRelaysStore{
|
||||
clients: make(map[string]*foreignRelay),
|
||||
ctx: ctx,
|
||||
tokenStore: tokenStore,
|
||||
peerID: peerID,
|
||||
mtu: mtu,
|
||||
transportFallback: transportFallback,
|
||||
onDisconnect: onDisconnect,
|
||||
keepUnusedServerTime: keepUnusedServerTime,
|
||||
}
|
||||
}
|
||||
|
||||
func (f *ForeignRelaysStore) OpenConn(ctx context.Context, peerKey string, remoteRelayServer RelayServer) (net.Conn, error) {
|
||||
fr, err := f.acquire(remoteRelayServer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.release(fr)
|
||||
|
||||
return fr.client.OpenConn(ctx, peerKey)
|
||||
}
|
||||
|
||||
func (f *ForeignRelaysStore) acquire(remoteRelayServer RelayServer) (*foreignRelay, error) {
|
||||
f.mu.Lock()
|
||||
if fr, ok := f.clients[remoteRelayServer.Addr]; ok {
|
||||
fr.inUse++
|
||||
f.mu.Unlock()
|
||||
return fr, nil
|
||||
}
|
||||
f.mu.Unlock()
|
||||
|
||||
v, err, _ := f.group.Do(remoteRelayServer.Addr, func() (any, error) {
|
||||
f.mu.RLock()
|
||||
fr, ok := f.clients[remoteRelayServer.Addr]
|
||||
f.mu.RUnlock()
|
||||
if ok {
|
||||
return fr, nil
|
||||
}
|
||||
|
||||
relayClient := NewClientWithServerIP(remoteRelayServer.Addr, remoteRelayServer.IP, f.tokenStore, f.peerID, f.mtu)
|
||||
relayClient.SetTransportFallback(f.transportFallback)
|
||||
if err := relayClient.Connect(f.ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
relayClient.SetOnDisconnectListener(f.onDisconnect)
|
||||
|
||||
f.mu.Lock()
|
||||
fr = &foreignRelay{client: relayClient, created: time.Now()}
|
||||
f.clients[remoteRelayServer.Addr] = fr
|
||||
f.mu.Unlock()
|
||||
return fr, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fr := v.(*foreignRelay)
|
||||
f.mu.Lock()
|
||||
if cur, ok := f.clients[remoteRelayServer.Addr]; !ok || cur != fr {
|
||||
f.mu.Unlock()
|
||||
return f.acquire(remoteRelayServer)
|
||||
}
|
||||
fr.inUse++
|
||||
f.mu.Unlock()
|
||||
return fr, nil
|
||||
}
|
||||
|
||||
func (f *ForeignRelaysStore) release(fr *foreignRelay) {
|
||||
f.mu.Lock()
|
||||
fr.inUse--
|
||||
f.mu.Unlock()
|
||||
}
|
||||
|
||||
func (f *ForeignRelaysStore) evict(serverAddress string) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
if _, ok := f.clients[serverAddress]; ok {
|
||||
delete(f.clients, serverAddress)
|
||||
log.Debugf("evicted disconnected foreign relay client: %s", serverAddress)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *ForeignRelaysStore) cleanupUnused() {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
for addr, fr := range f.clients {
|
||||
if time.Since(fr.created) <= f.keepUnusedServerTime {
|
||||
continue
|
||||
}
|
||||
if fr.inUse > 0 {
|
||||
continue
|
||||
}
|
||||
if fr.client.HasConns() {
|
||||
continue
|
||||
}
|
||||
fr.client.SetOnDisconnectListener(nil)
|
||||
go func() {
|
||||
_ = fr.client.Close()
|
||||
}()
|
||||
log.Debugf("clean up unused relay server connection: %s", addr)
|
||||
delete(f.clients, addr)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *ForeignRelaysStore) states() []RelayConnState {
|
||||
f.mu.RLock()
|
||||
clients := make([]*Client, 0, len(f.clients))
|
||||
for _, fr := range f.clients {
|
||||
clients = append(clients, fr.client)
|
||||
}
|
||||
f.mu.RUnlock()
|
||||
|
||||
states := make([]RelayConnState, 0, len(clients))
|
||||
for _, c := range clients {
|
||||
states = append(states, relayConnState(c))
|
||||
}
|
||||
return states
|
||||
}
|
||||
@@ -22,22 +22,6 @@ var (
|
||||
ErrRelayClientNotConnected = fmt.Errorf("relay client not connected")
|
||||
)
|
||||
|
||||
// RelayTrack hold the relay clients for the foreign relay servers.
|
||||
// With the mutex can ensure we can open new connection in case the relay connection has been established with
|
||||
// the relay server.
|
||||
type RelayTrack struct {
|
||||
sync.RWMutex
|
||||
relayClient *Client
|
||||
err error
|
||||
created time.Time
|
||||
}
|
||||
|
||||
func NewRelayTrack() *RelayTrack {
|
||||
return &RelayTrack{
|
||||
created: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
type OnServerCloseListener func()
|
||||
|
||||
// ManagerOption configures a Manager at construction time.
|
||||
@@ -54,6 +38,11 @@ type RelayConnState struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
type RelayServer struct {
|
||||
Addr string
|
||||
IP netip.Addr
|
||||
}
|
||||
|
||||
// WithMaxBackoffInterval caps the exponential backoff between reconnect
|
||||
// attempts to the home relay. A non-positive value keeps the default.
|
||||
func WithMaxBackoffInterval(d time.Duration) ManagerOption {
|
||||
@@ -78,8 +67,7 @@ type Manager struct {
|
||||
relayClientMu sync.RWMutex
|
||||
reconnectGuard *Guard
|
||||
|
||||
relayClients map[string]*RelayTrack
|
||||
relayClientsMutex sync.RWMutex
|
||||
foreign *ForeignRelaysStore
|
||||
|
||||
onDisconnectedListeners map[string]*list.List
|
||||
onReconnectedListenerFn func()
|
||||
@@ -115,7 +103,6 @@ func NewManager(ctx context.Context, serverURLs []string, peerID string, mtu uin
|
||||
ConnectionTimeout: defaultConnectionTimeout,
|
||||
TransportFallback: tf,
|
||||
},
|
||||
relayClients: make(map[string]*RelayTrack),
|
||||
onDisconnectedListeners: make(map[string]*list.List),
|
||||
cleanupInterval: relayCleanupInterval,
|
||||
keepUnusedServerTime: keepUnusedServerTime,
|
||||
@@ -123,6 +110,7 @@ func NewManager(ctx context.Context, serverURLs []string, peerID string, mtu uin
|
||||
for _, opt := range opts {
|
||||
opt(m)
|
||||
}
|
||||
m.foreign = NewForeignRelaysStore(ctx, tokenStore, peerID, mtu, tf, m.onServerDisconnected, m.keepUnusedServerTime)
|
||||
m.serverPicker.ServerURLs.Store(serverURLs)
|
||||
m.reconnectGuard = NewGuard(m.serverPicker, m.maxBackoffInterval)
|
||||
return m
|
||||
@@ -154,13 +142,7 @@ func (m *Manager) Serve() error {
|
||||
return err
|
||||
}
|
||||
|
||||
// OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be
|
||||
// established via the relay server. If the peer is on a different relay server, the manager will establish a new
|
||||
// connection to the relay server. It returns back with a net.Conn what represent the remote peer connection.
|
||||
//
|
||||
// serverIP, when valid and serverAddress is foreign, is used as a dial target if the FQDN-based dial fails.
|
||||
// Ignored for the local home-server path. TLS verification still uses the FQDN via SNI.
|
||||
func (m *Manager) OpenConn(ctx context.Context, serverAddress, peerKey string, serverIP netip.Addr) (net.Conn, error) {
|
||||
func (m *Manager) OpenConn(ctx context.Context, remoteRelayServer RelayServer, peerKey string, preferForeign bool) (net.Conn, error) {
|
||||
m.relayClientMu.RLock()
|
||||
defer m.relayClientMu.RUnlock()
|
||||
|
||||
@@ -168,26 +150,17 @@ func (m *Manager) OpenConn(ctx context.Context, serverAddress, peerKey string, s
|
||||
return nil, ErrRelayClientNotConnected
|
||||
}
|
||||
|
||||
foreign, err := m.isForeignServer(serverAddress)
|
||||
foreign, err := m.isForeignServer(remoteRelayServer.Addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var (
|
||||
netConn net.Conn
|
||||
)
|
||||
if !foreign {
|
||||
log.Debugf("open peer connection via permanent server: %s", peerKey)
|
||||
netConn, err = m.relayClient.OpenConn(ctx, peerKey)
|
||||
} else {
|
||||
log.Debugf("open peer connection via foreign server: %s", serverAddress)
|
||||
netConn, err = m.openConnVia(ctx, serverAddress, peerKey, serverIP)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return m.relayClient.OpenConn(ctx, peerKey)
|
||||
}
|
||||
|
||||
return netConn, err
|
||||
racer := NewConnRacer(m.relayClient, m.foreign)
|
||||
return racer.Run(ctx, peerKey, remoteRelayServer, preferForeign)
|
||||
}
|
||||
|
||||
// Ready returns true if the home Relay client is connected to the relay server.
|
||||
@@ -282,26 +255,7 @@ func (m *Manager) RelayStates() []RelayConnState {
|
||||
states = append(states, st)
|
||||
}
|
||||
|
||||
// Snapshot the tracks, then query each outside the map lock: a track can be
|
||||
// held by an in-progress Connect, and blocking on it must not stall other
|
||||
// relay operations.
|
||||
m.relayClientsMutex.RLock()
|
||||
tracks := make([]*RelayTrack, 0, len(m.relayClients))
|
||||
for _, rt := range m.relayClients {
|
||||
tracks = append(tracks, rt)
|
||||
}
|
||||
m.relayClientsMutex.RUnlock()
|
||||
|
||||
// Only connected foreign relays carry state; a failed connect is evicted
|
||||
// immediately (openConnVia), so there is no error state to surface.
|
||||
for _, rt := range tracks {
|
||||
rt.RLock()
|
||||
rc := rt.relayClient
|
||||
rt.RUnlock()
|
||||
if rc != nil {
|
||||
states = append(states, relayConnState(rc))
|
||||
}
|
||||
}
|
||||
states = append(states, m.foreign.states()...)
|
||||
|
||||
return states
|
||||
}
|
||||
@@ -322,64 +276,6 @@ func (m *Manager) UpdateToken(token *relayAuth.Token) error {
|
||||
return m.tokenStore.UpdateToken(token)
|
||||
}
|
||||
|
||||
func (m *Manager) openConnVia(ctx context.Context, serverAddress, peerKey string, serverIP netip.Addr) (net.Conn, error) {
|
||||
// check if already has a connection to the desired relay server
|
||||
m.relayClientsMutex.RLock()
|
||||
rt, ok := m.relayClients[serverAddress]
|
||||
if ok {
|
||||
rt.RLock()
|
||||
m.relayClientsMutex.RUnlock()
|
||||
defer rt.RUnlock()
|
||||
if rt.err != nil {
|
||||
return nil, rt.err
|
||||
}
|
||||
return rt.relayClient.OpenConn(ctx, peerKey)
|
||||
}
|
||||
m.relayClientsMutex.RUnlock()
|
||||
|
||||
// if not, establish a new connection but check it again (because changed the lock type) before starting the
|
||||
// connection
|
||||
m.relayClientsMutex.Lock()
|
||||
rt, ok = m.relayClients[serverAddress]
|
||||
if ok {
|
||||
rt.RLock()
|
||||
m.relayClientsMutex.Unlock()
|
||||
defer rt.RUnlock()
|
||||
if rt.err != nil {
|
||||
return nil, rt.err
|
||||
}
|
||||
return rt.relayClient.OpenConn(ctx, peerKey)
|
||||
}
|
||||
|
||||
// create a new relay client and store it in the relayClients map
|
||||
rt = NewRelayTrack()
|
||||
rt.Lock()
|
||||
m.relayClients[serverAddress] = rt
|
||||
m.relayClientsMutex.Unlock()
|
||||
|
||||
relayClient := NewClientWithServerIP(serverAddress, serverIP, m.tokenStore, m.peerID, m.mtu)
|
||||
relayClient.SetTransportFallback(m.transportFallback)
|
||||
err := relayClient.Connect(m.ctx)
|
||||
if err != nil {
|
||||
rt.err = err
|
||||
rt.Unlock()
|
||||
m.relayClientsMutex.Lock()
|
||||
delete(m.relayClients, serverAddress)
|
||||
m.relayClientsMutex.Unlock()
|
||||
return nil, err
|
||||
}
|
||||
// if connection closed then delete the relay client from the list
|
||||
relayClient.SetOnDisconnectListener(m.onServerDisconnected)
|
||||
rt.relayClient = relayClient
|
||||
rt.Unlock()
|
||||
|
||||
conn, err := relayClient.OpenConn(ctx, peerKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (m *Manager) onServerConnected() {
|
||||
m.listenerLock.Lock()
|
||||
defer m.listenerLock.Unlock()
|
||||
@@ -405,21 +301,12 @@ func (m *Manager) onServerDisconnected(serverAddress string) {
|
||||
m.relayClientMu.Unlock()
|
||||
|
||||
if !isHome {
|
||||
m.evictForeignRelay(serverAddress)
|
||||
m.foreign.evict(serverAddress)
|
||||
}
|
||||
|
||||
m.notifyOnDisconnectListeners(serverAddress)
|
||||
}
|
||||
|
||||
func (m *Manager) evictForeignRelay(serverAddress string) {
|
||||
m.relayClientsMutex.Lock()
|
||||
defer m.relayClientsMutex.Unlock()
|
||||
if _, ok := m.relayClients[serverAddress]; ok {
|
||||
delete(m.relayClients, serverAddress)
|
||||
log.Debugf("evicted disconnected foreign relay client: %s", serverAddress)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) listenGuardEvent(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
@@ -458,43 +345,11 @@ func (m *Manager) startCleanupLoop() {
|
||||
case <-m.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
m.cleanUpUnusedRelays()
|
||||
m.foreign.cleanupUnused()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) cleanUpUnusedRelays() {
|
||||
m.relayClientsMutex.Lock()
|
||||
defer m.relayClientsMutex.Unlock()
|
||||
|
||||
for addr, rt := range m.relayClients {
|
||||
rt.Lock()
|
||||
// if the connection failed to the server the relay client will be nil
|
||||
// but the instance will be kept in the relayClients until the next locking
|
||||
if rt.err != nil {
|
||||
rt.Unlock()
|
||||
continue
|
||||
}
|
||||
|
||||
if time.Since(rt.created) <= m.keepUnusedServerTime {
|
||||
rt.Unlock()
|
||||
continue
|
||||
}
|
||||
|
||||
if rt.relayClient.HasConns() {
|
||||
rt.Unlock()
|
||||
continue
|
||||
}
|
||||
rt.relayClient.SetOnDisconnectListener(nil)
|
||||
go func() {
|
||||
_ = rt.relayClient.Close()
|
||||
}()
|
||||
log.Debugf("clean up unused relay server connection: %s", addr)
|
||||
delete(m.relayClients, addr)
|
||||
rt.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) addListener(serverAddress string, onClosedListener OnServerCloseListener) {
|
||||
m.listenerLock.Lock()
|
||||
defer m.listenerLock.Unlock()
|
||||
|
||||
@@ -85,7 +85,6 @@ type GrpcClient struct {
|
||||
// receive backpressure as a dead stream: reconnecting cannot help, since the
|
||||
// new stream feeds the same worker, and only triggers a reconnect storm.
|
||||
receiveHandoffBlocked atomic.Bool
|
||||
watchdogWg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewClient creates a new Signal client
|
||||
@@ -201,18 +200,10 @@ func (c *GrpcClient) Receive(ctx context.Context, msgHandler func(msg *proto.Mes
|
||||
// Guard the receive direction: the transport can stay healthy while the
|
||||
// server stops delivering messages. The watchdog reconnects via cancelStream.
|
||||
c.markReceived()
|
||||
c.watchdogWg.Add(1)
|
||||
go func() {
|
||||
defer c.watchdogWg.Done()
|
||||
c.watchReceiveStream(streamCtx, cancelStream)
|
||||
}()
|
||||
go c.watchReceiveStream(streamCtx, cancelStream)
|
||||
|
||||
// start receiving messages from the Signal stream (from other peers through signal)
|
||||
err = c.receive(stream)
|
||||
|
||||
cancelStream()
|
||||
c.watchdogWg.Wait()
|
||||
|
||||
if err != nil {
|
||||
// Check the parent context, not streamCtx: a watchdog-triggered
|
||||
// cancelStream must reconnect, only a parent cancel is shutdown.
|
||||
@@ -409,12 +400,7 @@ func (c *GrpcClient) encryptMessage(msg *proto.Message) (*proto.EncryptedMessage
|
||||
|
||||
// Send sends a message to the remote Peer through the Signal Exchange.
|
||||
func (c *GrpcClient) Send(msg *proto.Message) error {
|
||||
return c.send(c.ctx, msg)
|
||||
}
|
||||
|
||||
// send delivers a message deriving per-attempt timeouts from parentCtx, so a
|
||||
// caller can abort an in-flight send by cancelling that context.
|
||||
func (c *GrpcClient) send(parentCtx context.Context, msg *proto.Message) error {
|
||||
if !c.Ready() {
|
||||
return fmt.Errorf("no connection to signal")
|
||||
}
|
||||
@@ -430,7 +416,7 @@ func (c *GrpcClient) send(parentCtx context.Context, msg *proto.Message) error {
|
||||
if attempt > 1 {
|
||||
attemptTimeout = time.Duration(attempt) * 5 * time.Second
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(parentCtx, attemptTimeout)
|
||||
ctx, cancel := context.WithTimeout(c.ctx, attemptTimeout)
|
||||
|
||||
_, err = c.realClient.Send(ctx, encryptedMessage)
|
||||
|
||||
@@ -500,7 +486,7 @@ func (c *GrpcClient) watchReceiveStream(ctx context.Context, cancelStream contex
|
||||
}
|
||||
|
||||
if probeSentAt.IsZero() {
|
||||
if err := c.sendReceiveProbe(ctx); err != nil {
|
||||
if err := c.sendReceiveProbe(); err != nil {
|
||||
log.Debugf("failed to send signal receive probe: %v", err)
|
||||
}
|
||||
probeSentAt = time.Now()
|
||||
@@ -509,13 +495,11 @@ func (c *GrpcClient) watchReceiveStream(ctx context.Context, cancelStream contex
|
||||
}
|
||||
}
|
||||
|
||||
// sendReceiveProbe sends a self-addressed heartbeat bound to ctx, so cancelStream
|
||||
// aborts an in-flight probe instead of leaving the watchdog blocked on send timeouts.
|
||||
// The Signal server routes it back to this client, exercising the exact receive
|
||||
// path the watchdog guards.
|
||||
func (c *GrpcClient) sendReceiveProbe(ctx context.Context) error {
|
||||
// sendReceiveProbe sends a self-addressed heartbeat. The Signal server routes it
|
||||
// back to this client, exercising the exact receive path the watchdog guards.
|
||||
func (c *GrpcClient) sendReceiveProbe() error {
|
||||
self := c.key.PublicKey().String()
|
||||
return c.send(ctx, &proto.Message{
|
||||
return c.Send(&proto.Message{
|
||||
Key: self,
|
||||
RemoteKey: self,
|
||||
Body: &proto.Body{Type: proto.Body_HEARTBEAT},
|
||||
@@ -557,9 +541,6 @@ func (c *GrpcClient) receive(stream proto.SignalExchange_ConnectStreamClient) er
|
||||
if err := c.decryptionWorker.AddMsg(c.ctx, msg); err != nil {
|
||||
log.Errorf("failed to add message to decryption worker: %v", err)
|
||||
}
|
||||
// Refresh liveness before clearing the flag so the window between here and
|
||||
// the next Recv does not read a stale timestamp as a dead stream.
|
||||
c.markReceived()
|
||||
c.receiveHandoffBlocked.Store(false)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -75,7 +74,7 @@ func TestReceiveProbeRoundTrips(t *testing.T) {
|
||||
t.Fatal("signal stream did not connect within timeout")
|
||||
}
|
||||
|
||||
require.NoError(t, client.sendReceiveProbe(ctx))
|
||||
require.NoError(t, client.sendReceiveProbe())
|
||||
|
||||
select {
|
||||
case <-received:
|
||||
@@ -107,72 +106,3 @@ func TestReceiveAliveTreatsHandoffBlockAsLiveness(t *testing.T) {
|
||||
c.markReceived()
|
||||
require.True(t, c.receiveAlive(), "a freshly received frame must keep the stream alive")
|
||||
}
|
||||
|
||||
// fakeRecvStream feeds the receive loop frames from a channel and reports EOF
|
||||
// once the channel is closed. Only Recv is exercised by the loop.
|
||||
type fakeRecvStream struct {
|
||||
sigProto.SignalExchange_ConnectStreamClient
|
||||
frames chan *sigProto.EncryptedMessage
|
||||
}
|
||||
|
||||
func (s *fakeRecvStream) Recv() (*sigProto.EncryptedMessage, error) {
|
||||
msg, ok := <-s.frames
|
||||
if !ok {
|
||||
return nil, io.EOF
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// TestReceiveLoopRefreshesLivenessAfterBlockedHandoff drives the real receive
|
||||
// loop into a handoff that blocks past the inactivity threshold, then checks the
|
||||
// window after the handoff drains but before the next Recv. The loop must have
|
||||
// refreshed the timestamp on unblocking, otherwise that window reads the stale
|
||||
// pre-handoff timestamp as a dead stream and the watchdog tears down a healthy
|
||||
// connection.
|
||||
func TestReceiveLoopRefreshesLivenessAfterBlockedHandoff(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
t.Cleanup(cancel)
|
||||
c := &GrpcClient{ctx: ctx}
|
||||
|
||||
handling := make(chan struct{}, 8)
|
||||
gate := make(chan struct{})
|
||||
decrypt := func(*sigProto.EncryptedMessage) (*sigProto.Message, error) { return &sigProto.Message{}, nil }
|
||||
handler := func(*sigProto.Message) error {
|
||||
handling <- struct{}{}
|
||||
<-gate
|
||||
return nil
|
||||
}
|
||||
c.decryptionWorker = NewWorker(decrypt, handler)
|
||||
workerCtx, workerCancel := context.WithCancel(context.Background())
|
||||
go c.decryptionWorker.Work(workerCtx)
|
||||
t.Cleanup(workerCancel)
|
||||
|
||||
frames := make(chan *sigProto.EncryptedMessage)
|
||||
t.Cleanup(func() { close(frames) })
|
||||
go func() { _ = c.receive(&fakeRecvStream{frames: frames}) }()
|
||||
|
||||
// First frame: the worker drains it and parks in the blocking handler.
|
||||
frames <- &sigProto.EncryptedMessage{}
|
||||
<-handling
|
||||
// Second frame fills the worker's single-slot pool.
|
||||
frames <- &sigProto.EncryptedMessage{}
|
||||
// Third frame: the pool is full, so the loop parks on the handoff.
|
||||
frames <- &sigProto.EncryptedMessage{}
|
||||
|
||||
require.Eventually(t, c.receiveHandoffBlocked.Load, time.Second, time.Millisecond,
|
||||
"receive loop should park on the worker handoff")
|
||||
|
||||
// Simulate the handoff having blocked past the inactivity threshold.
|
||||
c.lastReceived.Store(time.Now().Add(-2 * receiveInactivityThreshold).UnixNano())
|
||||
require.True(t, c.receiveAlive(), "a loop parked on the handoff must stay alive")
|
||||
|
||||
// Drain the worker so the handoff returns and the loop resumes reading.
|
||||
close(gate)
|
||||
|
||||
// Once the handoff clears, the loop is parked on the next Recv with no frame
|
||||
// pending. The stream must still read as alive in that window.
|
||||
require.Eventually(t, func() bool { return !c.receiveHandoffBlocked.Load() }, time.Second, time.Millisecond,
|
||||
"handoff should drain once the worker is released")
|
||||
require.True(t, c.receiveAlive(),
|
||||
"the loop must refresh liveness when the handoff drains, before the next Recv")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user