Compare commits

...

6 Commits

Author SHA1 Message Date
Zoltán Papp
faf3559ea7 Revert "[client] Prefer systemd-resolved stub over file mode regardless of resolv.conf header (#5935)"
This reverts commit 75e408f51c.
2026-04-23 21:29:52 +02:00
Zoltán Papp
dc8c2edf50 Revert "[client] Add TTL-based refresh to mgmt DNS cache via handler chain (#5945)"
This reverts commit 801de8c68d.
2026-04-23 21:29:46 +02:00
Zoltan Papp
f732b01a05 [management] unify peer-update test timeout via constant (#5952)
peerShouldReceiveUpdate waited 500ms for the expected update message,
and every outer wrapper across the management/server test suite paired
it with a 1s goroutine-drain timeout. Both were too tight for slower
CI runners (MySQL, FreeBSD, loaded sqlite), producing intermittent
"Timed out waiting for update message" failures in tests like
TestDNSAccountPeersUpdate, TestPeerAccountPeersUpdate, and
TestNameServerAccountPeersUpdate.

Introduce peerUpdateTimeout (5s) next to the helper and use it both in
the helper and in every outer wrapper so the two timeouts stay in sync.
Only runs down on failure; passing tests return as soon as the channel
delivers, so there is no slowdown on green runs.
2026-04-23 21:19:21 +02:00
alsruf36
c07c726ea7 [proxy] Set session cookie path to root (#5915) 2026-04-23 18:20:54 +02:00
Pascal Fischer
fa0d58d093 [management] exclude peers for expiration job that have already been marked expired (#5970) 2026-04-23 16:01:54 +02:00
Vlad
b6038e8acd [management] refactor: changeable pat rate limiting (#5946) 2026-04-23 15:13:22 +02:00
32 changed files with 552 additions and 1406 deletions

View File

@@ -13,7 +13,6 @@ import (
const (
defaultResolvConfPath = "/etc/resolv.conf"
nsswitchConfPath = "/etc/nsswitch.conf"
)
type resolvConf struct {

View File

@@ -1,10 +1,7 @@
package dns
import (
"context"
"fmt"
"math"
"net"
"slices"
"strconv"
"strings"
@@ -195,12 +192,6 @@ func (c *HandlerChain) logHandlers() {
}
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
c.dispatch(w, r, math.MaxInt)
}
// dispatch routes a DNS request through the chain, skipping handlers with
// priority > maxPriority. Shared by ServeDNS and ResolveInternal.
func (c *HandlerChain) dispatch(w dns.ResponseWriter, r *dns.Msg, maxPriority int) {
if len(r.Question) == 0 {
return
}
@@ -225,9 +216,6 @@ func (c *HandlerChain) dispatch(w dns.ResponseWriter, r *dns.Msg, maxPriority in
// Try handlers in priority order
for _, entry := range handlers {
if entry.Priority > maxPriority {
continue
}
if !c.isHandlerMatch(qname, entry) {
continue
}
@@ -285,55 +273,6 @@ func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, q
cw.response.Len(), meta, time.Since(startTime))
}
// ResolveInternal runs an in-process DNS query against the chain, skipping any
// handler with priority > maxPriority. Used by internal callers (e.g. the mgmt
// cache refresher) that must bypass themselves to avoid loops. Honors ctx
// cancellation; on ctx.Done the dispatch goroutine is left to drain on its own
// (bounded by the invoked handler's internal timeout).
func (c *HandlerChain) ResolveInternal(ctx context.Context, r *dns.Msg, maxPriority int) (*dns.Msg, error) {
if len(r.Question) == 0 {
return nil, fmt.Errorf("empty question")
}
base := &internalResponseWriter{}
done := make(chan struct{})
go func() {
c.dispatch(base, r, maxPriority)
close(done)
}()
select {
case <-done:
case <-ctx.Done():
// Prefer a completed response if dispatch finished concurrently with cancellation.
select {
case <-done:
default:
return nil, fmt.Errorf("resolve %s: %w", strings.ToLower(r.Question[0].Name), ctx.Err())
}
}
if base.response == nil || base.response.Rcode == dns.RcodeRefused {
return nil, fmt.Errorf("no handler resolved %s at priority ≤ %d",
strings.ToLower(r.Question[0].Name), maxPriority)
}
return base.response, nil
}
// HasRootHandlerAtOrBelow reports whether any "." handler is registered at
// priority ≤ maxPriority.
func (c *HandlerChain) HasRootHandlerAtOrBelow(maxPriority int) bool {
c.mu.RLock()
defer c.mu.RUnlock()
for _, h := range c.handlers {
if h.Pattern == "." && h.Priority <= maxPriority {
return true
}
}
return false
}
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
switch {
case entry.Pattern == ".":
@@ -352,36 +291,3 @@ func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
}
}
}
// internalResponseWriter captures a dns.Msg for in-process chain queries.
type internalResponseWriter struct {
response *dns.Msg
}
func (w *internalResponseWriter) WriteMsg(m *dns.Msg) error { w.response = m; return nil }
func (w *internalResponseWriter) LocalAddr() net.Addr { return nil }
func (w *internalResponseWriter) RemoteAddr() net.Addr { return nil }
// Write unpacks raw DNS bytes so handlers that call Write instead of WriteMsg
// still surface their answer to ResolveInternal.
func (w *internalResponseWriter) Write(p []byte) (int, error) {
msg := new(dns.Msg)
if err := msg.Unpack(p); err != nil {
return 0, err
}
w.response = msg
return len(p), nil
}
func (w *internalResponseWriter) Close() error { return nil }
func (w *internalResponseWriter) TsigStatus() error { return nil }
// TsigTimersOnly is part of dns.ResponseWriter.
func (w *internalResponseWriter) TsigTimersOnly(bool) {
// no-op: in-process queries carry no TSIG state.
}
// Hijack is part of dns.ResponseWriter.
func (w *internalResponseWriter) Hijack() {
// no-op: in-process queries have no underlying connection to hand off.
}

View File

@@ -1,15 +1,11 @@
package dns_test
import (
"context"
"net"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dns/test"
@@ -1046,163 +1042,3 @@ func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) {
})
}
}
// answeringHandler writes a fixed A record to ack the query. Used to verify
// which handler ResolveInternal dispatches to.
type answeringHandler struct {
name string
ip string
}
func (h *answeringHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
resp := &dns.Msg{}
resp.SetReply(r)
resp.Answer = []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP(h.ip).To4(),
}}
_ = w.WriteMsg(resp)
}
func (h *answeringHandler) String() string { return h.name }
func TestHandlerChain_ResolveInternal_SkipsAboveMaxPriority(t *testing.T) {
chain := nbdns.NewHandlerChain()
high := &answeringHandler{name: "high", ip: "10.0.0.1"}
low := &answeringHandler{name: "low", ip: "10.0.0.2"}
chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache)
chain.AddHandler("example.com.", low, nbdns.PriorityUpstream)
r := new(dns.Msg)
r.SetQuestion("example.com.", dns.TypeA)
resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, 1, len(resp.Answer))
a, ok := resp.Answer[0].(*dns.A)
assert.True(t, ok)
assert.Equal(t, "10.0.0.2", a.A.String(), "should skip mgmtCache handler and resolve via upstream")
}
func TestHandlerChain_ResolveInternal_ErrorWhenNoMatch(t *testing.T) {
chain := nbdns.NewHandlerChain()
high := &answeringHandler{name: "high", ip: "10.0.0.1"}
chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache)
r := new(dns.Msg)
r.SetQuestion("example.com.", dns.TypeA)
_, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
assert.Error(t, err, "no handler at or below maxPriority should error")
}
// rawWriteHandler packs a response and calls ResponseWriter.Write directly
// (instead of WriteMsg), exercising the internalResponseWriter.Write path.
type rawWriteHandler struct {
ip string
}
func (h *rawWriteHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
resp := &dns.Msg{}
resp.SetReply(r)
resp.Answer = []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP(h.ip).To4(),
}}
packed, err := resp.Pack()
if err != nil {
return
}
_, _ = w.Write(packed)
}
func TestHandlerChain_ResolveInternal_CapturesRawWrite(t *testing.T) {
chain := nbdns.NewHandlerChain()
chain.AddHandler("example.com.", &rawWriteHandler{ip: "10.0.0.3"}, nbdns.PriorityUpstream)
r := new(dns.Msg)
r.SetQuestion("example.com.", dns.TypeA)
resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
assert.NoError(t, err)
require.NotNil(t, resp)
require.Len(t, resp.Answer, 1)
a, ok := resp.Answer[0].(*dns.A)
require.True(t, ok)
assert.Equal(t, "10.0.0.3", a.A.String(), "handlers calling Write(packed) must still surface their answer")
}
func TestHandlerChain_ResolveInternal_EmptyQuestion(t *testing.T) {
chain := nbdns.NewHandlerChain()
_, err := chain.ResolveInternal(context.Background(), new(dns.Msg), nbdns.PriorityUpstream)
assert.Error(t, err)
}
// hangingHandler blocks indefinitely until closed, simulating a wedged upstream.
type hangingHandler struct {
block chan struct{}
}
func (h *hangingHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
<-h.block
resp := &dns.Msg{}
resp.SetReply(r)
_ = w.WriteMsg(resp)
}
func (h *hangingHandler) String() string { return "hangingHandler" }
func TestHandlerChain_ResolveInternal_HonorsContextTimeout(t *testing.T) {
chain := nbdns.NewHandlerChain()
h := &hangingHandler{block: make(chan struct{})}
defer close(h.block)
chain.AddHandler("example.com.", h, nbdns.PriorityUpstream)
r := new(dns.Msg)
r.SetQuestion("example.com.", dns.TypeA)
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
start := time.Now()
_, err := chain.ResolveInternal(ctx, r, nbdns.PriorityUpstream)
elapsed := time.Since(start)
assert.Error(t, err)
assert.ErrorIs(t, err, context.DeadlineExceeded)
assert.Less(t, elapsed, 500*time.Millisecond, "ResolveInternal must return shortly after ctx deadline")
}
func TestHandlerChain_HasRootHandlerAtOrBelow(t *testing.T) {
chain := nbdns.NewHandlerChain()
h := &answeringHandler{name: "h", ip: "10.0.0.1"}
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "empty chain")
chain.AddHandler("example.com.", h, nbdns.PriorityUpstream)
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "non-root handler does not count")
chain.AddHandler(".", h, nbdns.PriorityMgmtCache)
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler above threshold excluded")
chain.AddHandler(".", h, nbdns.PriorityDefault)
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler at PriorityDefault included")
chain.RemoveHandler(".", nbdns.PriorityDefault)
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream))
// Primary nsgroup case: root handler lands at PriorityUpstream.
chain.AddHandler(".", h, nbdns.PriorityUpstream)
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityUpstream included")
chain.RemoveHandler(".", nbdns.PriorityUpstream)
// Fallback case: original /etc/resolv.conf entries land at PriorityFallback.
chain.AddHandler(".", h, nbdns.PriorityFallback)
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityFallback included")
chain.RemoveHandler(".", nbdns.PriorityFallback)
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream))
}

View File

@@ -46,12 +46,12 @@ type restoreHostManager interface {
}
func newHostManager(wgInterface string) (hostManager, error) {
osManager, reason, err := getOSDNSManagerType()
osManager, err := getOSDNSManagerType()
if err != nil {
return nil, fmt.Errorf("get os dns manager type: %w", err)
}
log.Infof("System DNS manager discovered: %s (%s)", osManager, reason)
log.Infof("System DNS manager discovered: %s", osManager)
mgr, err := newHostManagerFromType(wgInterface, osManager)
// need to explicitly return nil mgr on error to avoid returning a non-nil interface containing a nil value
if err != nil {
@@ -74,49 +74,17 @@ func newHostManagerFromType(wgInterface string, osManager osManagerType) (restor
}
}
func getOSDNSManagerType() (osManagerType, string, error) {
resolved := isSystemdResolvedRunning()
nss := isLibnssResolveUsed()
stub := checkStub()
// Prefer systemd-resolved whenever it owns libc resolution, regardless of
// who wrote /etc/resolv.conf. File-mode rewrites do not affect lookups
// that go through nss-resolve, and in foreign mode they can loop back
// through resolved as an upstream.
if resolved && (nss || stub) {
return systemdManager, fmt.Sprintf("systemd-resolved active (nss-resolve=%t, stub=%t)", nss, stub), nil
}
mgr, reason, rejected, err := scanResolvConfHeader()
if err != nil {
return 0, "", err
}
if reason != "" {
return mgr, reason, nil
}
fallback := fmt.Sprintf("no manager matched (resolved=%t, nss-resolve=%t, stub=%t)", resolved, nss, stub)
if len(rejected) > 0 {
fallback += "; rejected: " + strings.Join(rejected, ", ")
}
return fileManager, fallback, nil
}
// scanResolvConfHeader walks /etc/resolv.conf header comments and returns the
// matching manager. If reason is empty the caller should pick file mode and
// use rejected for diagnostics.
func scanResolvConfHeader() (osManagerType, string, []string, error) {
func getOSDNSManagerType() (osManagerType, error) {
file, err := os.Open(defaultResolvConfPath)
if err != nil {
return 0, "", nil, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err)
return 0, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err)
}
defer func() {
if cerr := file.Close(); cerr != nil {
log.Errorf("close file %s: %s", defaultResolvConfPath, cerr)
if err := file.Close(); err != nil {
log.Errorf("close file %s: %s", defaultResolvConfPath, err)
}
}()
var rejected []string
scanner := bufio.NewScanner(file)
for scanner.Scan() {
text := scanner.Text()
@@ -124,48 +92,41 @@ func scanResolvConfHeader() (osManagerType, string, []string, error) {
continue
}
if text[0] != '#' {
break
return fileManager, nil
}
if mgr, reason, rej := matchResolvConfHeader(text); reason != "" {
return mgr, reason, nil, nil
} else if rej != "" {
rejected = append(rejected, rej)
if strings.Contains(text, fileGeneratedResolvConfContentHeader) {
return netbirdManager, nil
}
if strings.Contains(text, "NetworkManager") && isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
return networkManager, nil
}
if strings.Contains(text, "systemd-resolved") && isSystemdResolvedRunning() {
if checkStub() {
return systemdManager, nil
} else {
return fileManager, nil
}
}
if strings.Contains(text, "resolvconf") {
if isSystemdResolveConfMode() {
return systemdManager, nil
}
return resolvConfManager, nil
}
}
if err := scanner.Err(); err != nil && err != io.EOF {
return 0, "", nil, fmt.Errorf("scan: %w", err)
return 0, fmt.Errorf("scan: %w", err)
}
return 0, "", rejected, nil
return fileManager, nil
}
// matchResolvConfHeader inspects a single comment line. Returns either a
// definitive (manager, reason) or a non-empty rejected diagnostic.
func matchResolvConfHeader(text string) (osManagerType, string, string) {
if strings.Contains(text, fileGeneratedResolvConfContentHeader) {
return netbirdManager, "netbird-managed resolv.conf header detected", ""
}
if strings.Contains(text, "NetworkManager") {
if isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
return networkManager, "NetworkManager header + supported version on dbus", ""
}
return 0, "", "NetworkManager header (no dbus or unsupported version)"
}
if strings.Contains(text, "resolvconf") {
if isSystemdResolveConfMode() {
return systemdManager, "resolvconf header in systemd-resolved compatibility mode", ""
}
return resolvConfManager, "resolvconf header detected", ""
}
return 0, "", ""
}
// checkStub reports whether systemd-resolved's stub (127.0.0.53) is listed
// in /etc/resolv.conf. On parse failure we assume it is, to avoid dropping
// into file mode while resolved is active.
// checkStub checks if the stub resolver is disabled in systemd-resolved. If it is disabled, we fall back to file manager.
func checkStub() bool {
rConf, err := parseDefaultResolvConf()
if err != nil {
log.Warnf("failed to parse resolv conf, assuming stub is active: %s", err)
log.Warnf("failed to parse resolv conf: %s", err)
return true
}
@@ -178,36 +139,3 @@ func checkStub() bool {
return false
}
// isLibnssResolveUsed reports whether nss-resolve is listed before dns on
// the hosts: line of /etc/nsswitch.conf. When it is, libc lookups are
// delegated to systemd-resolved regardless of /etc/resolv.conf.
func isLibnssResolveUsed() bool {
bs, err := os.ReadFile(nsswitchConfPath)
if err != nil {
log.Debugf("read %s: %v", nsswitchConfPath, err)
return false
}
return parseNsswitchResolveAhead(bs)
}
func parseNsswitchResolveAhead(data []byte) bool {
for _, line := range strings.Split(string(data), "\n") {
if i := strings.IndexByte(line, '#'); i >= 0 {
line = line[:i]
}
fields := strings.Fields(line)
if len(fields) < 2 || fields[0] != "hosts:" {
continue
}
for _, module := range fields[1:] {
switch module {
case "dns":
return false
case "resolve":
return true
}
}
}
return false
}

View File

@@ -1,76 +0,0 @@
//go:build (linux && !android) || freebsd
package dns
import "testing"
func TestParseNsswitchResolveAhead(t *testing.T) {
tests := []struct {
name string
in string
want bool
}{
{
name: "resolve before dns with action token",
in: "hosts: mymachines resolve [!UNAVAIL=return] files myhostname dns\n",
want: true,
},
{
name: "dns before resolve",
in: "hosts: files mdns4_minimal [NOTFOUND=return] dns resolve\n",
want: false,
},
{
name: "debian default with only dns",
in: "hosts: files mdns4_minimal [NOTFOUND=return] dns mymachines\n",
want: false,
},
{
name: "neither resolve nor dns",
in: "hosts: files myhostname\n",
want: false,
},
{
name: "no hosts line",
in: "passwd: files systemd\ngroup: files systemd\n",
want: false,
},
{
name: "empty",
in: "",
want: false,
},
{
name: "comments and blank lines ignored",
in: "# comment\n\n# another\nhosts: resolve dns\n",
want: true,
},
{
name: "trailing inline comment",
in: "hosts: resolve [!UNAVAIL=return] dns # fallback\n",
want: true,
},
{
name: "hosts token must be the first field",
in: " hosts: resolve dns\n",
want: true,
},
{
name: "other db line mentioning resolve is ignored",
in: "networks: resolve\nhosts: dns\n",
want: false,
},
{
name: "only resolve, no dns",
in: "hosts: files resolve\n",
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := parseNsswitchResolveAhead([]byte(tt.in)); got != tt.want {
t.Errorf("parseNsswitchResolveAhead() = %v, want %v", got, tt.want)
}
})
}
}

View File

@@ -2,83 +2,40 @@ package mgmt
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"net/url"
"os"
"slices"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"golang.org/x/sync/singleflight"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/client/internal/dns/resutil"
"github.com/netbirdio/netbird/shared/management/domain"
)
const (
dnsTimeout = 5 * time.Second
defaultTTL = 300 * time.Second
refreshBackoff = 30 * time.Second
const dnsTimeout = 5 * time.Second
// envMgmtCacheTTL overrides defaultTTL for integration/dev testing.
envMgmtCacheTTL = "NB_MGMT_CACHE_TTL"
)
// ChainResolver lets the cache refresh stale entries through the DNS handler
// chain instead of net.DefaultResolver, avoiding loopback when NetBird is the
// system resolver.
type ChainResolver interface {
ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error)
HasRootHandlerAtOrBelow(maxPriority int) bool
}
// cachedRecord holds DNS records plus timestamps used for TTL refresh.
// records and cachedAt are set at construction and treated as immutable;
// lastFailedRefresh and consecFailures are mutable and must be accessed under
// Resolver.mutex.
type cachedRecord struct {
records []dns.RR
cachedAt time.Time
lastFailedRefresh time.Time
consecFailures int
}
// Resolver caches critical NetBird infrastructure domains.
// records, refreshing, mgmtDomain and serverDomains are all guarded by mutex.
// Resolver caches critical NetBird infrastructure domains
type Resolver struct {
records map[dns.Question]*cachedRecord
records map[dns.Question][]dns.RR
mgmtDomain *domain.Domain
serverDomains *dnsconfig.ServerDomains
mutex sync.RWMutex
}
chain ChainResolver
chainMaxPriority int
refreshGroup singleflight.Group
// refreshing tracks questions whose refresh is running via the OS
// fallback path. A ServeDNS hit for a question in this map indicates
// the OS resolver routed the recursive query back to us (loop). Only
// the OS path arms this so chain-path refreshes don't produce false
// positives. The atomic bool is CAS-flipped once per refresh to
// throttle the warning log.
refreshing map[dns.Question]*atomic.Bool
cacheTTL time.Duration
type ipsResponse struct {
ips []netip.Addr
err error
}
// NewResolver creates a new management domains cache resolver.
func NewResolver() *Resolver {
return &Resolver{
records: make(map[dns.Question]*cachedRecord),
refreshing: make(map[dns.Question]*atomic.Bool),
cacheTTL: resolveCacheTTL(),
records: make(map[dns.Question][]dns.RR),
}
}
@@ -87,19 +44,7 @@ func (m *Resolver) String() string {
return "MgmtCacheResolver"
}
// SetChainResolver wires the handler chain used to refresh stale cache entries.
// maxPriority caps which handlers may answer refresh queries (typically
// PriorityUpstream, so upstream/default/fallback handlers are consulted and
// mgmt/route/local handlers are skipped).
func (m *Resolver) SetChainResolver(chain ChainResolver, maxPriority int) {
m.mutex.Lock()
m.chain = chain
m.chainMaxPriority = maxPriority
m.mutex.Unlock()
}
// ServeDNS serves cached A/AAAA records. Stale entries are returned
// immediately and refreshed asynchronously (stale-while-revalidate).
// ServeDNS implements dns.Handler interface.
func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 0 {
m.continueToNext(w, r)
@@ -115,14 +60,7 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
}
m.mutex.RLock()
cached, found := m.records[question]
inflight := m.refreshing[question]
var shouldRefresh bool
if found {
stale := time.Since(cached.cachedAt) > m.cacheTTL
inBackoff := !cached.lastFailedRefresh.IsZero() && time.Since(cached.lastFailedRefresh) < refreshBackoff
shouldRefresh = stale && !inBackoff
}
records, found := m.records[question]
m.mutex.RUnlock()
if !found {
@@ -130,23 +68,12 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return
}
if inflight != nil && inflight.CompareAndSwap(false, true) {
log.Warnf("mgmt cache: possible resolver loop for domain=%s: served stale while an OS-fallback refresh was inflight (if NetBird is the system resolver, the OS-path predicate is wrong)",
question.Name)
}
// Skip scheduling a refresh goroutine if one is already inflight for
// this question; singleflight would dedup anyway but skipping avoids
// a parked goroutine per stale hit under bursty load.
if shouldRefresh && inflight == nil {
m.scheduleRefresh(question, cached)
}
resp := &dns.Msg{}
resp.SetReply(r)
resp.Authoritative = false
resp.RecursionAvailable = true
resp.Answer = cloneRecordsWithTTL(cached.records, m.responseTTL(cached.cachedAt))
resp.Answer = append(resp.Answer, records...)
log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name)
@@ -171,260 +98,101 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
}
}
// AddDomain resolves a domain and stores its A/AAAA records in the cache.
// A family that resolves NODATA (nil err, zero records) evicts any stale
// entry for that qtype.
// AddDomain manually adds a domain to cache by resolving it.
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
ctx, cancel := context.WithTimeout(ctx, dnsTimeout)
defer cancel()
aRecords, aaaaRecords, errA, errAAAA := m.lookupBoth(ctx, d, dnsName)
if errA != nil && errAAAA != nil {
return fmt.Errorf("resolve %s: %w", d.SafeString(), errors.Join(errA, errAAAA))
ips, err := lookupIPWithExtraTimeout(ctx, d)
if err != nil {
return err
}
if len(aRecords) == 0 && len(aaaaRecords) == 0 {
if err := errors.Join(errA, errAAAA); err != nil {
return fmt.Errorf("resolve %s: no A/AAAA records: %w", d.SafeString(), err)
var aRecords, aaaaRecords []dns.RR
for _, ip := range ips {
if ip.Is4() {
rr := &dns.A{
Hdr: dns.RR_Header{
Name: dnsName,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300,
},
A: ip.AsSlice(),
}
aRecords = append(aRecords, rr)
} else if ip.Is6() {
rr := &dns.AAAA{
Hdr: dns.RR_Header{
Name: dnsName,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: 300,
},
AAAA: ip.AsSlice(),
}
aaaaRecords = append(aaaaRecords, rr)
}
return fmt.Errorf("resolve %s: no A/AAAA records", d.SafeString())
}
now := time.Now()
m.mutex.Lock()
defer m.mutex.Unlock()
m.applyFamilyRecords(dnsName, dns.TypeA, aRecords, errA, now)
m.applyFamilyRecords(dnsName, dns.TypeAAAA, aaaaRecords, errAAAA, now)
if len(aRecords) > 0 {
aQuestion := dns.Question{
Name: dnsName,
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}
m.records[aQuestion] = aRecords
}
log.Debugf("added/updated domain=%s with %d A records and %d AAAA records",
if len(aaaaRecords) > 0 {
aaaaQuestion := dns.Question{
Name: dnsName,
Qtype: dns.TypeAAAA,
Qclass: dns.ClassINET,
}
m.records[aaaaQuestion] = aaaaRecords
}
m.mutex.Unlock()
log.Debugf("added domain=%s with %d A records and %d AAAA records",
d.SafeString(), len(aRecords), len(aaaaRecords))
return nil
}
// applyFamilyRecords writes records, evicts on NODATA, leaves the cache
// untouched on error. Caller holds m.mutex.
func (m *Resolver) applyFamilyRecords(dnsName string, qtype uint16, records []dns.RR, err error, now time.Time) {
q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET}
switch {
case len(records) > 0:
m.records[q] = &cachedRecord{records: records, cachedAt: now}
case err == nil:
delete(m.records, q)
}
}
func lookupIPWithExtraTimeout(ctx context.Context, d domain.Domain) ([]netip.Addr, error) {
log.Infof("looking up IP for mgmt domain=%s", d.SafeString())
defer log.Infof("done looking up IP for mgmt domain=%s", d.SafeString())
resultChan := make(chan *ipsResponse, 1)
// scheduleRefresh kicks off an async refresh. DoChan spawns one goroutine per
// unique in-flight key; bursty stale hits share its channel. expected is the
// cachedRecord pointer observed by the caller; the refresh only mutates the
// cache if that pointer is still the one stored, so a stale in-flight refresh
// can't clobber a newer entry written by AddDomain or a competing refresh.
func (m *Resolver) scheduleRefresh(question dns.Question, expected *cachedRecord) {
key := question.Name + "|" + dns.TypeToString[question.Qtype]
_ = m.refreshGroup.DoChan(key, func() (any, error) {
return nil, m.refreshQuestion(question, expected)
})
}
// refreshQuestion replaces the cached records on success, or marks the entry
// failed (arming the backoff) on failure. While this runs, ServeDNS can detect
// a resolver loop by spotting a query for this same question arriving on us.
// expected pins the cache entry observed at schedule time; mutations only apply
// if m.records[question] still points at it.
func (m *Resolver) refreshQuestion(question dns.Question, expected *cachedRecord) error {
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
defer cancel()
d, err := domain.FromString(strings.TrimSuffix(question.Name, "."))
if err != nil {
m.markRefreshFailed(question, expected)
return fmt.Errorf("parse domain: %w", err)
}
records, err := m.lookupRecords(ctx, d, question)
if err != nil {
fails := m.markRefreshFailed(question, expected)
logf := log.Warnf
if fails == 0 || fails > 1 {
logf = log.Debugf
go func() {
ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString())
resultChan <- &ipsResponse{
err: err,
ips: ips,
}
logf("refresh mgmt cache domain=%s type=%s: %v (consecutive failures=%d)",
d.SafeString(), dns.TypeToString[question.Qtype], err, fails)
return err
}()
var resp *ipsResponse
select {
case <-time.After(dnsTimeout + time.Millisecond*500):
log.Warnf("timed out waiting for IP for mgmt domain=%s", d.SafeString())
return nil, fmt.Errorf("timed out waiting for ips to be available for domain %s", d.SafeString())
case <-ctx.Done():
return nil, ctx.Err()
case resp = <-resultChan:
}
// NOERROR/NODATA: family gone upstream, evict so we stop serving stale.
if len(records) == 0 {
m.mutex.Lock()
if m.records[question] == expected {
delete(m.records, question)
m.mutex.Unlock()
log.Infof("removed mgmt cache domain=%s type=%s: no records returned",
d.SafeString(), dns.TypeToString[question.Qtype])
return nil
}
m.mutex.Unlock()
log.Debugf("skipping refresh evict for domain=%s type=%s: entry changed during refresh",
d.SafeString(), dns.TypeToString[question.Qtype])
return nil
if resp.err != nil {
return nil, fmt.Errorf("resolve domain %s: %w", d.SafeString(), resp.err)
}
now := time.Now()
m.mutex.Lock()
if m.records[question] != expected {
m.mutex.Unlock()
log.Debugf("skipping refresh write for domain=%s type=%s: entry changed during refresh",
d.SafeString(), dns.TypeToString[question.Qtype])
return nil
}
m.records[question] = &cachedRecord{records: records, cachedAt: now}
m.mutex.Unlock()
log.Infof("refreshed mgmt cache domain=%s type=%s",
d.SafeString(), dns.TypeToString[question.Qtype])
return nil
}
func (m *Resolver) markRefreshing(question dns.Question) {
m.mutex.Lock()
m.refreshing[question] = &atomic.Bool{}
m.mutex.Unlock()
}
func (m *Resolver) clearRefreshing(question dns.Question) {
m.mutex.Lock()
delete(m.refreshing, question)
m.mutex.Unlock()
}
// markRefreshFailed arms the backoff and returns the new consecutive-failure
// count so callers can downgrade subsequent failure logs to debug.
func (m *Resolver) markRefreshFailed(question dns.Question, expected *cachedRecord) int {
m.mutex.Lock()
defer m.mutex.Unlock()
c, ok := m.records[question]
if !ok || c != expected {
return 0
}
c.lastFailedRefresh = time.Now()
c.consecFailures++
return c.consecFailures
}
// lookupBoth resolves A and AAAA via chain or OS. Per-family errors let
// callers tell records, NODATA (nil err, no records), and failure apart.
func (m *Resolver) lookupBoth(ctx context.Context, d domain.Domain, dnsName string) (aRecords, aaaaRecords []dns.RR, errA, errAAAA error) {
m.mutex.RLock()
chain := m.chain
maxPriority := m.chainMaxPriority
m.mutex.RUnlock()
if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) {
aRecords, errA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeA)
aaaaRecords, errAAAA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeAAAA)
return
}
// TODO: drop once every supported OS registers a fallback resolver. Safe
// today: no root handler at priority ≤ PriorityUpstream means NetBird is
// not the system resolver, so net.DefaultResolver will not loop back.
aRecords, errA = m.osLookup(ctx, d, dnsName, dns.TypeA)
aaaaRecords, errAAAA = m.osLookup(ctx, d, dnsName, dns.TypeAAAA)
return
}
// lookupRecords resolves a single record type via chain or OS. The OS branch
// arms the loop detector for the duration of its call so that ServeDNS can
// spot the OS resolver routing the recursive query back to us.
func (m *Resolver) lookupRecords(ctx context.Context, d domain.Domain, q dns.Question) ([]dns.RR, error) {
m.mutex.RLock()
chain := m.chain
maxPriority := m.chainMaxPriority
m.mutex.RUnlock()
if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) {
return m.lookupViaChain(ctx, chain, maxPriority, q.Name, q.Qtype)
}
// TODO: drop once every supported OS registers a fallback resolver.
m.markRefreshing(q)
defer m.clearRefreshing(q)
return m.osLookup(ctx, d, q.Name, q.Qtype)
}
// lookupViaChain resolves via the handler chain and rewrites each RR to use
// dnsName as owner and m.cacheTTL as TTL, so CNAME-backed domains don't cache
// target-owned records or upstream TTLs. NODATA returns (nil, nil).
func (m *Resolver) lookupViaChain(ctx context.Context, chain ChainResolver, maxPriority int, dnsName string, qtype uint16) ([]dns.RR, error) {
msg := &dns.Msg{}
msg.SetQuestion(dnsName, qtype)
msg.RecursionDesired = true
resp, err := chain.ResolveInternal(ctx, msg, maxPriority)
if err != nil {
return nil, fmt.Errorf("chain resolve: %w", err)
}
if resp == nil {
return nil, fmt.Errorf("chain resolve returned nil response")
}
if resp.Rcode != dns.RcodeSuccess {
return nil, fmt.Errorf("chain resolve rcode=%s", dns.RcodeToString[resp.Rcode])
}
ttl := uint32(m.cacheTTL.Seconds())
owners := cnameOwners(dnsName, resp.Answer)
var filtered []dns.RR
for _, rr := range resp.Answer {
h := rr.Header()
if h.Class != dns.ClassINET || h.Rrtype != qtype {
continue
}
if !owners[strings.ToLower(dns.Fqdn(h.Name))] {
continue
}
if cp := cloneIPRecord(rr, dnsName, ttl); cp != nil {
filtered = append(filtered, cp)
}
}
return filtered, nil
}
// osLookup resolves a single family via net.DefaultResolver using resutil,
// which disambiguates NODATA from NXDOMAIN and Unmaps v4-mapped-v6. NODATA
// returns (nil, nil).
func (m *Resolver) osLookup(ctx context.Context, d domain.Domain, dnsName string, qtype uint16) ([]dns.RR, error) {
network := resutil.NetworkForQtype(qtype)
if network == "" {
return nil, fmt.Errorf("unsupported qtype %s", dns.TypeToString[qtype])
}
log.Infof("looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype])
defer log.Infof("done looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype])
result := resutil.LookupIP(ctx, net.DefaultResolver, network, d.PunycodeString(), qtype)
if result.Rcode == dns.RcodeSuccess {
return resutil.IPsToRRs(dnsName, result.IPs, uint32(m.cacheTTL.Seconds())), nil
}
if result.Err != nil {
return nil, fmt.Errorf("resolve %s type=%s: %w", d.SafeString(), dns.TypeToString[qtype], result.Err)
}
return nil, fmt.Errorf("resolve %s type=%s: rcode=%s", d.SafeString(), dns.TypeToString[qtype], dns.RcodeToString[result.Rcode])
}
// responseTTL returns the remaining cache lifetime in seconds (rounded up),
// so downstream resolvers don't cache an answer for longer than we will.
func (m *Resolver) responseTTL(cachedAt time.Time) uint32 {
remaining := m.cacheTTL - time.Since(cachedAt)
if remaining <= 0 {
return 0
}
return uint32((remaining + time.Second - 1) / time.Second)
return resp.ips, nil
}
// PopulateFromConfig extracts and caches domains from the client configuration.
@@ -456,12 +224,19 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error {
m.mutex.Lock()
defer m.mutex.Unlock()
qA := dns.Question{Name: dnsName, Qtype: dns.TypeA, Qclass: dns.ClassINET}
qAAAA := dns.Question{Name: dnsName, Qtype: dns.TypeAAAA, Qclass: dns.ClassINET}
delete(m.records, qA)
delete(m.records, qAAAA)
delete(m.refreshing, qA)
delete(m.refreshing, qAAAA)
aQuestion := dns.Question{
Name: dnsName,
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}
delete(m.records, aQuestion)
aaaaQuestion := dns.Question{
Name: dnsName,
Qtype: dns.TypeAAAA,
Qclass: dns.ClassINET,
}
delete(m.records, aaaaQuestion)
log.Debugf("removed domain=%s from cache", d.SafeString())
return nil
@@ -619,73 +394,3 @@ func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.Serve
return domains
}
// cloneIPRecord returns a deep copy of rr retargeted to owner with ttl. Non
// A/AAAA records return nil.
func cloneIPRecord(rr dns.RR, owner string, ttl uint32) dns.RR {
switch r := rr.(type) {
case *dns.A:
cp := *r
cp.Hdr.Name = owner
cp.Hdr.Ttl = ttl
cp.A = slices.Clone(r.A)
return &cp
case *dns.AAAA:
cp := *r
cp.Hdr.Name = owner
cp.Hdr.Ttl = ttl
cp.AAAA = slices.Clone(r.AAAA)
return &cp
}
return nil
}
// cloneRecordsWithTTL clones A/AAAA records preserving their owner and
// stamping ttl so the response shares no memory with the cached slice.
func cloneRecordsWithTTL(records []dns.RR, ttl uint32) []dns.RR {
out := make([]dns.RR, 0, len(records))
for _, rr := range records {
if cp := cloneIPRecord(rr, rr.Header().Name, ttl); cp != nil {
out = append(out, cp)
}
}
return out
}
// cnameOwners returns dnsName plus every target reachable by following CNAMEs
// in answer, iterating until fixed point so out-of-order chains resolve.
func cnameOwners(dnsName string, answer []dns.RR) map[string]bool {
owners := map[string]bool{dnsName: true}
for {
added := false
for _, rr := range answer {
cname, ok := rr.(*dns.CNAME)
if !ok {
continue
}
name := strings.ToLower(dns.Fqdn(cname.Hdr.Name))
if !owners[name] {
continue
}
target := strings.ToLower(dns.Fqdn(cname.Target))
if !owners[target] {
owners[target] = true
added = true
}
}
if !added {
return owners
}
}
}
// resolveCacheTTL reads the cache TTL override env var; invalid or empty
// values fall back to defaultTTL. Called once per Resolver from NewResolver.
func resolveCacheTTL() time.Duration {
if v := os.Getenv(envMgmtCacheTTL); v != "" {
if d, err := time.ParseDuration(v); err == nil && d > 0 {
return d
}
}
return defaultTTL
}

View File

@@ -1,408 +0,0 @@
package mgmt
import (
"context"
"errors"
"net"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/dns/test"
"github.com/netbirdio/netbird/shared/management/domain"
)
type fakeChain struct {
mu sync.Mutex
calls map[string]int
answers map[string][]dns.RR
err error
hasRoot bool
onLookup func()
}
func newFakeChain() *fakeChain {
return &fakeChain{
calls: map[string]int{},
answers: map[string][]dns.RR{},
hasRoot: true,
}
}
func (f *fakeChain) HasRootHandlerAtOrBelow(maxPriority int) bool {
f.mu.Lock()
defer f.mu.Unlock()
return f.hasRoot
}
func (f *fakeChain) ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error) {
f.mu.Lock()
q := msg.Question[0]
key := q.Name + "|" + dns.TypeToString[q.Qtype]
f.calls[key]++
answers := f.answers[key]
err := f.err
onLookup := f.onLookup
f.mu.Unlock()
if onLookup != nil {
onLookup()
}
if err != nil {
return nil, err
}
resp := &dns.Msg{}
resp.SetReply(msg)
resp.Answer = answers
return resp, nil
}
func (f *fakeChain) setAnswer(name string, qtype uint16, ip string) {
f.mu.Lock()
defer f.mu.Unlock()
key := name + "|" + dns.TypeToString[qtype]
hdr := dns.RR_Header{Name: name, Rrtype: qtype, Class: dns.ClassINET, Ttl: 60}
switch qtype {
case dns.TypeA:
f.answers[key] = []dns.RR{&dns.A{Hdr: hdr, A: net.ParseIP(ip).To4()}}
case dns.TypeAAAA:
f.answers[key] = []dns.RR{&dns.AAAA{Hdr: hdr, AAAA: net.ParseIP(ip).To16()}}
}
}
func (f *fakeChain) callCount(name string, qtype uint16) int {
f.mu.Lock()
defer f.mu.Unlock()
return f.calls[name+"|"+dns.TypeToString[qtype]]
}
// waitFor polls the predicate until it returns true or the deadline passes.
func waitFor(t *testing.T, d time.Duration, fn func() bool) {
t.Helper()
deadline := time.Now().Add(d)
for time.Now().Before(deadline) {
if fn() {
return
}
time.Sleep(5 * time.Millisecond)
}
t.Fatalf("condition not met within %s", d)
}
func queryA(t *testing.T, r *Resolver, name string) *dns.Msg {
t.Helper()
msg := new(dns.Msg)
msg.SetQuestion(name, dns.TypeA)
w := &test.MockResponseWriter{}
r.ServeDNS(w, msg)
return w.GetLastResponse()
}
func firstA(t *testing.T, resp *dns.Msg) string {
t.Helper()
require.NotNil(t, resp)
require.Greater(t, len(resp.Answer), 0, "expected at least one answer")
a, ok := resp.Answer[0].(*dns.A)
require.True(t, ok, "expected A record")
return a.A.String()
}
func TestResolver_CacheTTLGatesRefresh(t *testing.T) {
// Same cached entry age, different cacheTTL values: the shorter TTL must
// trigger a background refresh, the longer one must not. Proves that the
// per-Resolver cacheTTL field actually drives the stale decision.
cachedAt := time.Now().Add(-100 * time.Millisecond)
newRec := func() *cachedRecord {
return &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: "mgmt.example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: cachedAt,
}
}
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
t.Run("short TTL treats entry as stale and refreshes", func(t *testing.T) {
r := NewResolver()
r.cacheTTL = 10 * time.Millisecond
chain := newFakeChain()
chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
r.records[q] = newRec()
resp := queryA(t, r, q.Name)
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must be served while refresh runs")
waitFor(t, time.Second, func() bool {
return chain.callCount(q.Name, dns.TypeA) >= 1
})
})
t.Run("long TTL keeps entry fresh and skips refresh", func(t *testing.T) {
r := NewResolver()
r.cacheTTL = time.Hour
chain := newFakeChain()
chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
r.records[q] = newRec()
resp := queryA(t, r, q.Name)
assert.Equal(t, "10.0.0.1", firstA(t, resp))
time.Sleep(50 * time.Millisecond)
assert.Equal(t, 0, chain.callCount(q.Name, dns.TypeA), "fresh entry must not trigger refresh")
})
}
func TestResolver_ServeFresh_NoRefresh(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
r.records[dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: "mgmt.example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now(), // fresh
}
resp := queryA(t, r, "mgmt.example.com.")
assert.Equal(t, "10.0.0.1", firstA(t, resp))
time.Sleep(20 * time.Millisecond)
assert.Equal(t, 0, chain.callCount("mgmt.example.com.", dns.TypeA), "fresh entry must not trigger refresh")
}
func TestResolver_StaleTriggersAsyncRefresh(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now().Add(-2 * defaultTTL), // stale
}
// First query: serves stale immediately.
resp := queryA(t, r, "mgmt.example.com.")
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must be served while refresh runs")
waitFor(t, time.Second, func() bool {
return chain.callCount("mgmt.example.com.", dns.TypeA) >= 1
})
// Next query should now return the refreshed IP.
waitFor(t, time.Second, func() bool {
resp := queryA(t, r, "mgmt.example.com.")
return resp != nil && len(resp.Answer) > 0 && firstA(t, resp) == "10.0.0.2"
})
}
func TestResolver_ConcurrentStaleHitsCollapseRefresh(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
var inflight atomic.Int32
var maxInflight atomic.Int32
chain.onLookup = func() {
cur := inflight.Add(1)
defer inflight.Add(-1)
for {
prev := maxInflight.Load()
if cur <= prev || maxInflight.CompareAndSwap(prev, cur) {
break
}
}
time.Sleep(50 * time.Millisecond) // hold inflight long enough to collide
}
r.SetChainResolver(chain, 50)
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now().Add(-2 * defaultTTL),
}
var wg sync.WaitGroup
for i := 0; i < 50; i++ {
wg.Add(1)
go func() {
defer wg.Done()
queryA(t, r, "mgmt.example.com.")
}()
}
wg.Wait()
waitFor(t, 2*time.Second, func() bool {
return inflight.Load() == 0
})
calls := chain.callCount("mgmt.example.com.", dns.TypeA)
assert.LessOrEqual(t, calls, 2, "singleflight must collapse concurrent refreshes (got %d)", calls)
assert.Equal(t, int32(1), maxInflight.Load(), "only one refresh should run concurrently")
}
func TestResolver_RefreshFailureArmsBackoff(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.err = errors.New("boom")
r.SetChainResolver(chain, 50)
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now().Add(-2 * defaultTTL),
}
// First stale hit triggers a refresh attempt that fails.
resp := queryA(t, r, "mgmt.example.com.")
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry served while refresh fails")
waitFor(t, time.Second, func() bool {
return chain.callCount("mgmt.example.com.", dns.TypeA) == 1
})
waitFor(t, time.Second, func() bool {
r.mutex.RLock()
defer r.mutex.RUnlock()
c, ok := r.records[q]
return ok && !c.lastFailedRefresh.IsZero()
})
// Subsequent stale hits within backoff window should not schedule more refreshes.
for i := 0; i < 10; i++ {
queryA(t, r, "mgmt.example.com.")
}
time.Sleep(50 * time.Millisecond)
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeA), "backoff must suppress further refreshes")
}
func TestResolver_NoRootHandler_SkipsChain(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.hasRoot = false
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
// With hasRoot=false the chain must not be consulted. Use a short
// deadline so the OS fallback returns quickly without waiting on a
// real network call in CI.
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
_, _, _, _ = r.lookupBoth(ctx, domain.Domain("mgmt.example.com"), "mgmt.example.com.")
assert.Equal(t, 0, chain.callCount("mgmt.example.com.", dns.TypeA),
"chain must not be used when no root handler is registered at the bound priority")
}
func TestResolver_ServeDuringRefreshSetsLoopFlag(t *testing.T) {
// ServeDNS being invoked for a question while a refresh for that question
// is inflight indicates a resolver loop (OS resolver sent the recursive
// query back to us). The inflightRefresh.loopLoggedOnce flag must be set.
r := NewResolver()
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now(),
}
// Simulate an inflight refresh.
r.markRefreshing(q)
defer r.clearRefreshing(q)
resp := queryA(t, r, "mgmt.example.com.")
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must still be served to avoid breaking external queries")
r.mutex.RLock()
inflight := r.refreshing[q]
r.mutex.RUnlock()
require.NotNil(t, inflight)
assert.True(t, inflight.Load(), "loop flag must be set once a ServeDNS during refresh was observed")
}
func TestResolver_LoopFlagOnlyTrippedOncePerRefresh(t *testing.T) {
r := NewResolver()
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now(),
}
r.markRefreshing(q)
defer r.clearRefreshing(q)
// Multiple ServeDNS calls during the same refresh must not re-set the flag
// (CompareAndSwap from false -> true returns true only on the first call).
for range 5 {
queryA(t, r, "mgmt.example.com.")
}
r.mutex.RLock()
inflight := r.refreshing[q]
r.mutex.RUnlock()
assert.True(t, inflight.Load())
}
func TestResolver_NoLoopFlagWhenNotRefreshing(t *testing.T) {
r := NewResolver()
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now(),
}
queryA(t, r, "mgmt.example.com.")
r.mutex.RLock()
_, ok := r.refreshing[q]
r.mutex.RUnlock()
assert.False(t, ok, "no refresh inflight means no loop tracking")
}
func TestResolver_AddDomain_UsesChainWhenRootRegistered(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
chain.setAnswer("mgmt.example.com.", dns.TypeAAAA, "fd00::2")
r.SetChainResolver(chain, 50)
require.NoError(t, r.AddDomain(context.Background(), domain.Domain("mgmt.example.com")))
resp := queryA(t, r, "mgmt.example.com.")
assert.Equal(t, "10.0.0.2", firstA(t, resp))
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeA))
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeAAAA))
}

View File

@@ -6,7 +6,6 @@ import (
"net/url"
"strings"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
@@ -24,60 +23,6 @@ func TestResolver_NewResolver(t *testing.T) {
assert.False(t, resolver.MatchSubdomains())
}
func TestResolveCacheTTL(t *testing.T) {
tests := []struct {
name string
value string
want time.Duration
}{
{"unset falls back to default", "", defaultTTL},
{"valid duration", "45s", 45 * time.Second},
{"valid minutes", "2m", 2 * time.Minute},
{"malformed falls back to default", "not-a-duration", defaultTTL},
{"zero falls back to default", "0s", defaultTTL},
{"negative falls back to default", "-5s", defaultTTL},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Setenv(envMgmtCacheTTL, tc.value)
got := resolveCacheTTL()
assert.Equal(t, tc.want, got, "parsed TTL should match")
})
}
}
func TestNewResolver_CacheTTLFromEnv(t *testing.T) {
t.Setenv(envMgmtCacheTTL, "7s")
r := NewResolver()
assert.Equal(t, 7*time.Second, r.cacheTTL, "NewResolver should evaluate cacheTTL once from env")
}
func TestResolver_ResponseTTL(t *testing.T) {
now := time.Now()
tests := []struct {
name string
cacheTTL time.Duration
cachedAt time.Time
wantMin uint32
wantMax uint32
}{
{"fresh entry returns full TTL", 60 * time.Second, now, 59, 60},
{"half-aged entry returns half TTL", 60 * time.Second, now.Add(-30 * time.Second), 29, 31},
{"expired entry returns zero", 60 * time.Second, now.Add(-61 * time.Second), 0, 0},
{"exactly expired returns zero", 10 * time.Second, now.Add(-10 * time.Second), 0, 0},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
r := &Resolver{cacheTTL: tc.cacheTTL}
got := r.responseTTL(tc.cachedAt)
assert.GreaterOrEqual(t, got, tc.wantMin, "remaining TTL should be >= wantMin")
assert.LessOrEqual(t, got, tc.wantMax, "remaining TTL should be <= wantMax")
})
}
}
func TestResolver_ExtractDomainFromURL(t *testing.T) {
tests := []struct {
name string

View File

@@ -212,7 +212,6 @@ func newDefaultServer(
ctx, stop := context.WithCancel(ctx)
mgmtCacheResolver := mgmt.NewResolver()
mgmtCacheResolver.SetChainResolver(handlerChain, PriorityUpstream)
defaultServer := &DefaultServer{
ctx: ctx,

View File

@@ -30,6 +30,7 @@ import (
nbcache "github.com/netbirdio/netbird/management/server/cache"
nbContext "github.com/netbirdio/netbird/management/server/context"
nbhttp "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
@@ -109,7 +110,7 @@ func (s *BaseServer) EventStore() activity.Store {
func (s *BaseServer) APIHandler() http.Handler {
return Create(s, func() http.Handler {
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies)
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.RateLimiter())
if err != nil {
log.Fatalf("failed to create API handler: %v", err)
}
@@ -117,6 +118,15 @@ func (s *BaseServer) APIHandler() http.Handler {
})
}
func (s *BaseServer) RateLimiter() *middleware.APIRateLimiter {
return Create(s, func() *middleware.APIRateLimiter {
cfg, enabled := middleware.RateLimiterConfigFromEnv()
limiter := middleware.NewAPIRateLimiter(cfg)
limiter.SetEnabled(enabled)
return limiter
})
}
func (s *BaseServer) GRPCServer() *grpc.Server {
return Create(s, func() *grpc.Server {
trustedPeers := s.Config.ReverseProxy.TrustedPeers

View File

@@ -2311,6 +2311,29 @@ func TestAccount_GetExpiredPeers(t *testing.T) {
}
}
func TestGetExpiredPeers_SkipsAlreadyExpired(t *testing.T) {
ctx := context.Background()
testStore, cleanUp, err := store.NewTestStoreFromSQL(ctx, "testdata/store_with_expired_peers.sql", t.TempDir())
t.Cleanup(cleanUp)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
// Verify the already-expired peer is excluded at the store level
peers, err := testStore.GetAccountPeersWithExpiration(ctx, store.LockingStrengthNone, accountID)
require.NoError(t, err)
for _, peer := range peers {
assert.NotEqual(t, "cg05lnblo1hkg2j514p0", peer.ID, "already expired peer should be excluded by the store query")
assert.False(t, peer.Status.LoginExpired, "returned peers should not already be marked as login expired")
}
// Only the non-expired peer with expiration enabled should be returned
require.Len(t, peers, 1)
assert.Equal(t, "notexpired01", peers[0].ID)
}
func TestAccount_GetInactivePeers(t *testing.T) {
type test struct {
name string
@@ -3230,6 +3253,13 @@ func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *update_channel.
return manager, updateManager, account, peer1, peer2, peer3
}
// peerUpdateTimeout bounds how long peerShouldReceiveUpdate and its outer
// wrappers wait for an expected update message. Sized for slow CI runners
// (MySQL, FreeBSD, loaded sqlite) where the channel publish can take
// seconds. Only runs down on failure; passing tests return immediately
// when the channel delivers.
const peerUpdateTimeout = 5 * time.Second
func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.UpdateMessage) {
t.Helper()
select {
@@ -3248,7 +3278,7 @@ func peerShouldReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.Upd
if msg == nil {
t.Errorf("Received nil update message, expected valid message")
}
case <-time.After(500 * time.Millisecond):
case <-time.After(peerUpdateTimeout):
t.Error("Timed out waiting for update message")
}
}

View File

@@ -458,7 +458,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -478,7 +478,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -518,7 +518,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})

View File

@@ -620,7 +620,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -638,7 +638,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -656,7 +656,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -689,7 +689,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -730,7 +730,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -757,7 +757,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -804,7 +804,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})

View File

@@ -5,9 +5,6 @@ import (
"fmt"
"net/http"
"net/netip"
"os"
"strconv"
"time"
"github.com/gorilla/mux"
"github.com/rs/cors"
@@ -66,14 +63,11 @@ import (
)
const (
apiPrefix = "/api"
rateLimitingEnabledKey = "NB_API_RATE_LIMITING_ENABLED"
rateLimitingBurstKey = "NB_API_RATE_LIMITING_BURST"
rateLimitingRPMKey = "NB_API_RATE_LIMITING_RPM"
apiPrefix = "/api"
)
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix) (http.Handler, error) {
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, rateLimiter *middleware.APIRateLimiter) (http.Handler, error) {
// Register bypass paths for unauthenticated endpoints
if err := bypass.AddBypassPath("/api/instance"); err != nil {
@@ -94,34 +88,10 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
return nil, fmt.Errorf("failed to add bypass path: %w", err)
}
var rateLimitingConfig *middleware.RateLimiterConfig
if os.Getenv(rateLimitingEnabledKey) == "true" {
rpm := 6
if v := os.Getenv(rateLimitingRPMKey); v != "" {
value, err := strconv.Atoi(v)
if err != nil {
log.Warnf("parsing %s env var: %v, using default %d", rateLimitingRPMKey, err, rpm)
} else {
rpm = value
}
}
burst := 500
if v := os.Getenv(rateLimitingBurstKey); v != "" {
value, err := strconv.Atoi(v)
if err != nil {
log.Warnf("parsing %s env var: %v, using default %d", rateLimitingBurstKey, err, burst)
} else {
burst = value
}
}
rateLimitingConfig = &middleware.RateLimiterConfig{
RequestsPerMinute: float64(rpm),
Burst: burst,
CleanupInterval: 6 * time.Hour,
LimiterTTL: 24 * time.Hour,
}
if rateLimiter == nil {
log.Warn("NewAPIHandler: nil rate limiter, rate limiting disabled")
rateLimiter = middleware.NewAPIRateLimiter(nil)
rateLimiter.SetEnabled(false)
}
authMiddleware := middleware.NewAuthMiddleware(
@@ -129,7 +99,7 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
accountManager.GetAccountIDFromUserAuth,
accountManager.SyncUserJWTGroups,
accountManager.GetUserFromUserAuth,
rateLimitingConfig,
rateLimiter,
appMetrics.GetMeter(),
)

View File

@@ -43,14 +43,9 @@ func NewAuthMiddleware(
ensureAccount EnsureAccountFunc,
syncUserJWTGroups SyncUserJWTGroupsFunc,
getUserFromUserAuth GetUserFromUserAuthFunc,
rateLimiterConfig *RateLimiterConfig,
rateLimiter *APIRateLimiter,
meter metric.Meter,
) *AuthMiddleware {
var rateLimiter *APIRateLimiter
if rateLimiterConfig != nil {
rateLimiter = NewAPIRateLimiter(rateLimiterConfig)
}
var patUsageTracker *PATUsageTracker
if meter != nil {
var err error
@@ -181,10 +176,8 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
m.patUsageTracker.IncrementUsage(token)
}
if m.rateLimiter != nil && !isTerraformRequest(r) {
if !m.rateLimiter.Allow(token) {
return status.Errorf(status.TooManyRequests, "too many requests")
}
if !isTerraformRequest(r) && !m.rateLimiter.Allow(token) {
return status.Errorf(status.TooManyRequests, "too many requests")
}
ctx := r.Context()

View File

@@ -196,6 +196,8 @@ func TestAuthMiddleware_Handler(t *testing.T) {
GetPATInfoFunc: mockGetAccountInfoFromPAT,
}
disabledLimiter := NewAPIRateLimiter(nil)
disabledLimiter.SetEnabled(false)
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
@@ -207,7 +209,7 @@ func TestAuthMiddleware_Handler(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
nil,
disabledLimiter,
nil,
)
@@ -266,7 +268,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
NewAPIRateLimiter(rateLimitConfig),
nil,
)
@@ -318,7 +320,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
NewAPIRateLimiter(rateLimitConfig),
nil,
)
@@ -361,7 +363,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
NewAPIRateLimiter(rateLimitConfig),
nil,
)
@@ -405,7 +407,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
NewAPIRateLimiter(rateLimitConfig),
nil,
)
@@ -469,7 +471,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
NewAPIRateLimiter(rateLimitConfig),
nil,
)
@@ -528,7 +530,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
NewAPIRateLimiter(rateLimitConfig),
nil,
)
@@ -583,7 +585,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
NewAPIRateLimiter(rateLimitConfig),
nil,
)
@@ -670,6 +672,8 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
GetPATInfoFunc: mockGetAccountInfoFromPAT,
}
disabledLimiter := NewAPIRateLimiter(nil)
disabledLimiter.SetEnabled(false)
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
@@ -681,7 +685,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
nil,
disabledLimiter,
nil,
)

View File

@@ -4,14 +4,27 @@ import (
"context"
"net"
"net/http"
"os"
"strconv"
"sync"
"sync/atomic"
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/time/rate"
"github.com/netbirdio/netbird/shared/management/http/util"
)
const (
RateLimitingEnabledEnv = "NB_API_RATE_LIMITING_ENABLED"
RateLimitingBurstEnv = "NB_API_RATE_LIMITING_BURST"
RateLimitingRPMEnv = "NB_API_RATE_LIMITING_RPM"
defaultAPIRPM = 6
defaultAPIBurst = 500
)
// RateLimiterConfig holds configuration for the API rate limiter
type RateLimiterConfig struct {
// RequestsPerMinute defines the rate at which tokens are replenished
@@ -34,6 +47,43 @@ func DefaultRateLimiterConfig() *RateLimiterConfig {
}
}
func RateLimiterConfigFromEnv() (cfg *RateLimiterConfig, enabled bool) {
rpm := defaultAPIRPM
if v := os.Getenv(RateLimitingRPMEnv); v != "" {
value, err := strconv.Atoi(v)
if err != nil {
log.Warnf("parsing %s env var: %v, using default %d", RateLimitingRPMEnv, err, rpm)
} else {
rpm = value
}
}
if rpm <= 0 {
log.Warnf("%s=%d is non-positive, using default %d", RateLimitingRPMEnv, rpm, defaultAPIRPM)
rpm = defaultAPIRPM
}
burst := defaultAPIBurst
if v := os.Getenv(RateLimitingBurstEnv); v != "" {
value, err := strconv.Atoi(v)
if err != nil {
log.Warnf("parsing %s env var: %v, using default %d", RateLimitingBurstEnv, err, burst)
} else {
burst = value
}
}
if burst <= 0 {
log.Warnf("%s=%d is non-positive, using default %d", RateLimitingBurstEnv, burst, defaultAPIBurst)
burst = defaultAPIBurst
}
return &RateLimiterConfig{
RequestsPerMinute: float64(rpm),
Burst: burst,
CleanupInterval: 6 * time.Hour,
LimiterTTL: 24 * time.Hour,
}, os.Getenv(RateLimitingEnabledEnv) == "true"
}
// limiterEntry holds a rate limiter and its last access time
type limiterEntry struct {
limiter *rate.Limiter
@@ -46,6 +96,7 @@ type APIRateLimiter struct {
limiters map[string]*limiterEntry
mu sync.RWMutex
stopChan chan struct{}
enabled atomic.Bool
}
// NewAPIRateLimiter creates a new API rate limiter with the given configuration
@@ -59,14 +110,53 @@ func NewAPIRateLimiter(config *RateLimiterConfig) *APIRateLimiter {
limiters: make(map[string]*limiterEntry),
stopChan: make(chan struct{}),
}
rl.enabled.Store(true)
go rl.cleanupLoop()
return rl
}
func (rl *APIRateLimiter) SetEnabled(enabled bool) {
rl.enabled.Store(enabled)
}
func (rl *APIRateLimiter) Enabled() bool {
return rl.enabled.Load()
}
func (rl *APIRateLimiter) UpdateConfig(config *RateLimiterConfig) {
if config == nil {
return
}
if config.RequestsPerMinute <= 0 || config.Burst <= 0 {
log.Warnf("UpdateConfig: ignoring invalid rpm=%v burst=%d", config.RequestsPerMinute, config.Burst)
return
}
newRPS := rate.Limit(config.RequestsPerMinute / 60.0)
newBurst := config.Burst
rl.mu.Lock()
rl.config.RequestsPerMinute = config.RequestsPerMinute
rl.config.Burst = newBurst
snapshot := make([]*rate.Limiter, 0, len(rl.limiters))
for _, entry := range rl.limiters {
snapshot = append(snapshot, entry.limiter)
}
rl.mu.Unlock()
for _, l := range snapshot {
l.SetLimit(newRPS)
l.SetBurst(newBurst)
}
}
// Allow checks if a request for the given key (token) is allowed
func (rl *APIRateLimiter) Allow(key string) bool {
if !rl.enabled.Load() {
return true
}
limiter := rl.getLimiter(key)
return limiter.Allow()
}
@@ -74,6 +164,9 @@ func (rl *APIRateLimiter) Allow(key string) bool {
// Wait blocks until the rate limiter allows another request for the given key
// Returns an error if the context is canceled
func (rl *APIRateLimiter) Wait(ctx context.Context, key string) error {
if !rl.enabled.Load() {
return nil
}
limiter := rl.getLimiter(key)
return limiter.Wait(ctx)
}
@@ -153,6 +246,10 @@ func (rl *APIRateLimiter) Reset(key string) {
// Returns 429 Too Many Requests if the rate limit is exceeded.
func (rl *APIRateLimiter) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !rl.enabled.Load() {
next.ServeHTTP(w, r)
return
}
clientIP := getClientIP(r)
if !rl.Allow(clientIP) {
util.WriteErrorResponse("rate limit exceeded, please try again later", http.StatusTooManyRequests, w)

View File

@@ -1,8 +1,10 @@
package middleware
import (
"fmt"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
@@ -156,3 +158,172 @@ func TestAPIRateLimiter_Reset(t *testing.T) {
// Should be allowed again
assert.True(t, rl.Allow("test-key"))
}
func TestAPIRateLimiter_SetEnabled(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 1,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
assert.True(t, rl.Allow("key"))
assert.False(t, rl.Allow("key"), "burst exhausted while enabled")
rl.SetEnabled(false)
assert.False(t, rl.Enabled())
for i := 0; i < 5; i++ {
assert.True(t, rl.Allow("key"), "disabled limiter must always allow")
}
rl.SetEnabled(true)
assert.True(t, rl.Enabled())
assert.False(t, rl.Allow("key"), "re-enabled limiter retains prior bucket state")
}
func TestAPIRateLimiter_UpdateConfig(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 2,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
assert.True(t, rl.Allow("k1"))
assert.True(t, rl.Allow("k1"))
assert.False(t, rl.Allow("k1"), "burst=2 exhausted")
rl.UpdateConfig(&RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 10,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
// New burst applies to existing keys in place; bucket refills up to new burst over time,
// but importantly newly-added keys use the updated config immediately.
assert.True(t, rl.Allow("k2"))
for i := 0; i < 9; i++ {
assert.True(t, rl.Allow("k2"))
}
assert.False(t, rl.Allow("k2"), "new burst=10 exhausted")
}
func TestAPIRateLimiter_UpdateConfig_NilIgnored(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 1,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
rl.UpdateConfig(nil) // must not panic or zero the config
assert.True(t, rl.Allow("k"))
assert.False(t, rl.Allow("k"))
}
func TestAPIRateLimiter_UpdateConfig_NonPositiveIgnored(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 1,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
assert.True(t, rl.Allow("k"))
assert.False(t, rl.Allow("k"))
rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: 0, Burst: 0, CleanupInterval: time.Minute, LimiterTTL: time.Minute})
rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: -1, Burst: 5, CleanupInterval: time.Minute, LimiterTTL: time.Minute})
rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: 60, Burst: -1, CleanupInterval: time.Minute, LimiterTTL: time.Minute})
rl.Reset("k")
assert.True(t, rl.Allow("k"))
assert.False(t, rl.Allow("k"), "burst should still be 1 — invalid UpdateConfig calls were ignored")
}
func TestAPIRateLimiter_ConcurrentAllowAndUpdate(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 600,
Burst: 10,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
var wg sync.WaitGroup
stop := make(chan struct{})
for i := 0; i < 8; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
key := fmt.Sprintf("k%d", id)
for {
select {
case <-stop:
return
default:
rl.Allow(key)
}
}
}(i)
}
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < 200; i++ {
select {
case <-stop:
return
default:
rl.UpdateConfig(&RateLimiterConfig{
RequestsPerMinute: float64(30 + (i % 90)),
Burst: 1 + (i % 20),
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
rl.SetEnabled(i%2 == 0)
}
}
}()
time.Sleep(100 * time.Millisecond)
close(stop)
wg.Wait()
}
func TestRateLimiterConfigFromEnv(t *testing.T) {
t.Setenv(RateLimitingEnabledEnv, "true")
t.Setenv(RateLimitingRPMEnv, "42")
t.Setenv(RateLimitingBurstEnv, "7")
cfg, enabled := RateLimiterConfigFromEnv()
assert.True(t, enabled)
assert.Equal(t, float64(42), cfg.RequestsPerMinute)
assert.Equal(t, 7, cfg.Burst)
t.Setenv(RateLimitingEnabledEnv, "false")
_, enabled = RateLimiterConfigFromEnv()
assert.False(t, enabled)
t.Setenv(RateLimitingEnabledEnv, "")
t.Setenv(RateLimitingRPMEnv, "")
t.Setenv(RateLimitingBurstEnv, "")
cfg, enabled = RateLimiterConfigFromEnv()
assert.False(t, enabled)
assert.Equal(t, float64(defaultAPIRPM), cfg.RequestsPerMinute)
assert.Equal(t, defaultAPIBurst, cfg.Burst)
t.Setenv(RateLimitingRPMEnv, "0")
t.Setenv(RateLimitingBurstEnv, "-5")
cfg, _ = RateLimiterConfigFromEnv()
assert.Equal(t, float64(defaultAPIRPM), cfg.RequestsPerMinute, "non-positive rpm must fall back to default")
assert.Equal(t, defaultAPIBurst, cfg.Burst, "non-positive burst must fall back to default")
}

View File

@@ -135,7 +135,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil)
if err != nil {
t.Fatalf("Failed to create API handler: %v", err)
}
@@ -264,7 +264,7 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil)
if err != nil {
t.Fatalf("Failed to create API handler: %v", err)
}

View File

@@ -267,8 +267,8 @@ func Test_SyncProtocol(t *testing.T) {
}
// expired peers come separately.
if len(networkMap.GetOfflinePeers()) != 1 {
t.Fatal("expecting SyncResponse to have NetworkMap with 1 offline peer")
if len(networkMap.GetOfflinePeers()) != 2 {
t.Fatal("expecting SyncResponse to have NetworkMap with 2 offline peer")
}
expiredPeerPubKey := "RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4="

View File

@@ -1087,7 +1087,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -1105,7 +1105,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})

View File

@@ -1405,6 +1405,10 @@ func (am *DefaultAccountManager) getExpiredPeers(ctx context.Context, accountID
var peers []*nbpeer.Peer
for _, peer := range peersWithExpiry {
if peer.Status.LoginExpired {
continue
}
expired, _ := peer.LoginExpired(settings.PeerLoginExpiration)
if expired {
peers = append(peers, peer)

View File

@@ -1907,7 +1907,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -1929,7 +1929,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -1994,7 +1994,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -2012,7 +2012,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -2058,7 +2058,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -2076,7 +2076,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -2113,7 +2113,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -2131,7 +2131,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})

View File

@@ -1231,7 +1231,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -1263,7 +1263,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -1294,7 +1294,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -1314,7 +1314,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -1355,7 +1355,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -1373,7 +1373,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
@@ -1393,7 +1393,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})

View File

@@ -244,7 +244,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -273,7 +273,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -292,7 +292,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -395,7 +395,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -438,7 +438,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})

View File

@@ -2070,7 +2070,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
@@ -2107,7 +2107,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -2127,7 +2127,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -2145,7 +2145,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -2185,7 +2185,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -2225,7 +2225,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})

View File

@@ -3310,7 +3310,7 @@ func (s *SqlStore) GetAccountPeersWithExpiration(ctx context.Context, lockStreng
var peers []*nbpeer.Peer
result := tx.
Where("login_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true).
Where("login_expiration_enabled = ? AND peer_status_login_expired != ? AND user_id IS NOT NULL AND user_id != ''", true, true).
Find(&peers, accountIDCondition, accountID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to get peers with expiration from the store: %s", result.Error)

View File

@@ -2729,7 +2729,7 @@ func TestSqlStore_GetAccountPeers(t *testing.T) {
{
name: "should retrieve peers for an existing account ID",
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
expectedCount: 4,
expectedCount: 5,
},
{
name: "should return no peers for a non-existing account ID",
@@ -2751,7 +2751,7 @@ func TestSqlStore_GetAccountPeers(t *testing.T) {
name: "should filter peers by partial name",
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
nameFilter: "host",
expectedCount: 3,
expectedCount: 4,
},
{
name: "should filter peers by ip",
@@ -2777,14 +2777,16 @@ func TestSqlStore_GetAccountPeersWithExpiration(t *testing.T) {
require.NoError(t, err)
tests := []struct {
name string
accountID string
expectedCount int
name string
accountID string
expectedCount int
expectedPeerIDs []string
}{
{
name: "should retrieve peers with expiration for an existing account ID",
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
expectedCount: 1,
name: "should retrieve only non-expired peers with expiration enabled",
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
expectedCount: 1,
expectedPeerIDs: []string{"notexpired01"},
},
{
name: "should return no peers with expiration for a non-existing account ID",
@@ -2803,10 +2805,30 @@ func TestSqlStore_GetAccountPeersWithExpiration(t *testing.T) {
peers, err := store.GetAccountPeersWithExpiration(context.Background(), LockingStrengthNone, tt.accountID)
require.NoError(t, err)
require.Len(t, peers, tt.expectedCount)
for i, peer := range peers {
assert.Equal(t, tt.expectedPeerIDs[i], peer.ID)
}
})
}
}
func TestSqlStore_GetAccountPeersWithExpiration_ExcludesAlreadyExpired(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_with_expired_peers.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
peers, err := store.GetAccountPeersWithExpiration(context.Background(), LockingStrengthNone, accountID)
require.NoError(t, err)
// Verify the already-expired peer (cg05lnblo1hkg2j514p0) is not returned
for _, peer := range peers {
assert.NotEqual(t, "cg05lnblo1hkg2j514p0", peer.ID, "already expired peer should not be returned")
assert.False(t, peer.Status.LoginExpired, "returned peers should not have LoginExpired set")
}
}
func TestSqlStore_GetAccountPeersWithInactivity(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_with_expired_peers.sql", t.TempDir())
t.Cleanup(cleanup)
@@ -2887,7 +2909,7 @@ func TestSqlStore_GetUserPeers(t *testing.T) {
name: "should retrieve peers for another valid account ID and user ID",
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
userID: "edafee4e-63fb-11ec-90d6-0242ac120003",
expectedCount: 2,
expectedCount: 3,
},
{
name: "should return no peers for existing account ID with empty user ID",

View File

@@ -31,6 +31,7 @@ INSERT INTO peers VALUES('cfvprsrlo1hqoo49ohog','bf1c8084-ba50-4ce7-9439-3465300
INSERT INTO peers VALUES('cg05lnblo1hkg2j514p0','bf1c8084-ba50-4ce7-9439-34653001fc3b','RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4=','','"100.64.39.54"','expiredhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'expiredhost','expiredhost','2023-03-02 09:19:57.276717255+01:00',0,1,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMbK5ZXJsGOOWoBT4OmkPtgdPZe2Q7bDuS/zjn2CZxhK',0,1,0,'2023-03-02 09:14:21.791679181+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
INSERT INTO peers VALUES('cg3161rlo1hs9cq94gdg','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
INSERT INTO peers VALUES('csrnkiq7qv9d8aitqd50','bf1c8084-ba50-4ce7-9439-34653001fc3b','nVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HX=','','"100.64.117.97"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost-1','2023-03-06 18:21:27.252010027+01:00',0,0,0,'f4f6d672-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,1,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
INSERT INTO peers VALUES('notexpired01','bf1c8084-ba50-4ce7-9439-34653001fc3b','oVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HY=','','"100.64.117.98"','activehost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'activehost','activehost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,1,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,'');
INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,'');
INSERT INTO installations VALUES(1,'');

View File

@@ -1586,7 +1586,7 @@ func TestUserAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -1609,7 +1609,7 @@ func TestUserAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})

View File

@@ -433,6 +433,7 @@ func setSessionCookie(w http.ResponseWriter, token string, expiration time.Durat
http.SetCookie(w, &http.Cookie{
Name: auth.SessionCookieName,
Value: token,
Path: "/",
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteLaxMode,

View File

@@ -391,6 +391,15 @@ func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) {
assert.Equal(t, http.SameSiteLaxMode, sessionCookie.SameSite)
}
func TestSetSessionCookieHasRootPath(t *testing.T) {
w := httptest.NewRecorder()
setSessionCookie(w, "test-token", time.Hour)
cookies := w.Result().Cookies()
require.Len(t, cookies, 1)
assert.Equal(t, "/", cookies[0].Path, "session cookie must be scoped to root so it applies to all paths")
}
func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)