mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-23 02:36:42 +00:00
Compare commits
8 Commits
reduce-emb
...
dependabot
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5730c638db | ||
|
|
5da05ecca6 | ||
|
|
801de8c68d | ||
|
|
a822a33240 | ||
|
|
57b23c5b25 | ||
|
|
1165058fad | ||
|
|
703353d354 | ||
|
|
2fb50aef6b |
@@ -239,8 +239,12 @@ func TestICEBind_HandlesConcurrentMixedTraffic(t *testing.T) {
|
||||
ipv6Count++
|
||||
}
|
||||
|
||||
assert.Equal(t, packetsPerFamily, ipv4Count)
|
||||
assert.Equal(t, packetsPerFamily, ipv6Count)
|
||||
// Allow some UDP packet loss under load (e.g. FreeBSD/QEMU runners). The
|
||||
// routing-correctness checks above are the real assertions; the counts
|
||||
// are a sanity bound to catch a totally silent path.
|
||||
minDelivered := packetsPerFamily * 80 / 100
|
||||
assert.GreaterOrEqual(t, ipv4Count, minDelivered, "IPv4 delivery below threshold")
|
||||
assert.GreaterOrEqual(t, ipv6Count, minDelivered, "IPv6 delivery below threshold")
|
||||
}
|
||||
|
||||
func TestICEBind_DetectsAddressFamilyFromConnection(t *testing.T) {
|
||||
|
||||
@@ -3,10 +3,12 @@ package debug
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
@@ -19,8 +21,10 @@ func TestUpload(t *testing.T) {
|
||||
t.Skip("Skipping upload test on docker ci")
|
||||
}
|
||||
testDir := t.TempDir()
|
||||
testURL := "http://localhost:8080"
|
||||
addr := reserveLoopbackPort(t)
|
||||
testURL := "http://" + addr
|
||||
t.Setenv("SERVER_URL", testURL)
|
||||
t.Setenv("SERVER_ADDRESS", addr)
|
||||
t.Setenv("STORE_DIR", testDir)
|
||||
srv := server.NewServer()
|
||||
go func() {
|
||||
@@ -33,6 +37,7 @@ func TestUpload(t *testing.T) {
|
||||
t.Errorf("Failed to stop server: %v", err)
|
||||
}
|
||||
})
|
||||
waitForServer(t, addr)
|
||||
|
||||
file := filepath.Join(t.TempDir(), "tmpfile")
|
||||
fileContent := []byte("test file content")
|
||||
@@ -47,3 +52,30 @@ func TestUpload(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, fileContent, createdFileContent)
|
||||
}
|
||||
|
||||
// reserveLoopbackPort binds an ephemeral port on loopback to learn a free
|
||||
// address, then releases it so the server under test can rebind. The close/
|
||||
// rebind window is racy in theory; on loopback with a kernel-assigned port
|
||||
// it's essentially never contended in practice.
|
||||
func reserveLoopbackPort(t *testing.T) string {
|
||||
t.Helper()
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
addr := l.Addr().String()
|
||||
require.NoError(t, l.Close())
|
||||
return addr
|
||||
}
|
||||
|
||||
func waitForServer(t *testing.T, addr string) {
|
||||
t.Helper()
|
||||
deadline := time.Now().Add(5 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
c, err := net.DialTimeout("tcp", addr, 100*time.Millisecond)
|
||||
if err == nil {
|
||||
_ = c.Close()
|
||||
return
|
||||
}
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
t.Fatalf("server did not start listening on %s in time", addr)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"net"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -192,6 +195,12 @@ func (c *HandlerChain) logHandlers() {
|
||||
}
|
||||
|
||||
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
c.dispatch(w, r, math.MaxInt)
|
||||
}
|
||||
|
||||
// dispatch routes a DNS request through the chain, skipping handlers with
|
||||
// priority > maxPriority. Shared by ServeDNS and ResolveInternal.
|
||||
func (c *HandlerChain) dispatch(w dns.ResponseWriter, r *dns.Msg, maxPriority int) {
|
||||
if len(r.Question) == 0 {
|
||||
return
|
||||
}
|
||||
@@ -216,6 +225,9 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
|
||||
// Try handlers in priority order
|
||||
for _, entry := range handlers {
|
||||
if entry.Priority > maxPriority {
|
||||
continue
|
||||
}
|
||||
if !c.isHandlerMatch(qname, entry) {
|
||||
continue
|
||||
}
|
||||
@@ -273,6 +285,55 @@ func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, q
|
||||
cw.response.Len(), meta, time.Since(startTime))
|
||||
}
|
||||
|
||||
// ResolveInternal runs an in-process DNS query against the chain, skipping any
|
||||
// handler with priority > maxPriority. Used by internal callers (e.g. the mgmt
|
||||
// cache refresher) that must bypass themselves to avoid loops. Honors ctx
|
||||
// cancellation; on ctx.Done the dispatch goroutine is left to drain on its own
|
||||
// (bounded by the invoked handler's internal timeout).
|
||||
func (c *HandlerChain) ResolveInternal(ctx context.Context, r *dns.Msg, maxPriority int) (*dns.Msg, error) {
|
||||
if len(r.Question) == 0 {
|
||||
return nil, fmt.Errorf("empty question")
|
||||
}
|
||||
|
||||
base := &internalResponseWriter{}
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
c.dispatch(base, r, maxPriority)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-ctx.Done():
|
||||
// Prefer a completed response if dispatch finished concurrently with cancellation.
|
||||
select {
|
||||
case <-done:
|
||||
default:
|
||||
return nil, fmt.Errorf("resolve %s: %w", strings.ToLower(r.Question[0].Name), ctx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
if base.response == nil || base.response.Rcode == dns.RcodeRefused {
|
||||
return nil, fmt.Errorf("no handler resolved %s at priority ≤ %d",
|
||||
strings.ToLower(r.Question[0].Name), maxPriority)
|
||||
}
|
||||
return base.response, nil
|
||||
}
|
||||
|
||||
// HasRootHandlerAtOrBelow reports whether any "." handler is registered at
|
||||
// priority ≤ maxPriority.
|
||||
func (c *HandlerChain) HasRootHandlerAtOrBelow(maxPriority int) bool {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
for _, h := range c.handlers {
|
||||
if h.Pattern == "." && h.Priority <= maxPriority {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
|
||||
switch {
|
||||
case entry.Pattern == ".":
|
||||
@@ -291,3 +352,36 @@ func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// internalResponseWriter captures a dns.Msg for in-process chain queries.
|
||||
type internalResponseWriter struct {
|
||||
response *dns.Msg
|
||||
}
|
||||
|
||||
func (w *internalResponseWriter) WriteMsg(m *dns.Msg) error { w.response = m; return nil }
|
||||
func (w *internalResponseWriter) LocalAddr() net.Addr { return nil }
|
||||
func (w *internalResponseWriter) RemoteAddr() net.Addr { return nil }
|
||||
|
||||
// Write unpacks raw DNS bytes so handlers that call Write instead of WriteMsg
|
||||
// still surface their answer to ResolveInternal.
|
||||
func (w *internalResponseWriter) Write(p []byte) (int, error) {
|
||||
msg := new(dns.Msg)
|
||||
if err := msg.Unpack(p); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
w.response = msg
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (w *internalResponseWriter) Close() error { return nil }
|
||||
func (w *internalResponseWriter) TsigStatus() error { return nil }
|
||||
|
||||
// TsigTimersOnly is part of dns.ResponseWriter.
|
||||
func (w *internalResponseWriter) TsigTimersOnly(bool) {
|
||||
// no-op: in-process queries carry no TSIG state.
|
||||
}
|
||||
|
||||
// Hijack is part of dns.ResponseWriter.
|
||||
func (w *internalResponseWriter) Hijack() {
|
||||
// no-op: in-process queries have no underlying connection to hand off.
|
||||
}
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
package dns_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||
@@ -1042,3 +1046,163 @@ func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// answeringHandler writes a fixed A record to ack the query. Used to verify
|
||||
// which handler ResolveInternal dispatches to.
|
||||
type answeringHandler struct {
|
||||
name string
|
||||
ip string
|
||||
}
|
||||
|
||||
func (h *answeringHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
resp := &dns.Msg{}
|
||||
resp.SetReply(r)
|
||||
resp.Answer = []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP(h.ip).To4(),
|
||||
}}
|
||||
_ = w.WriteMsg(resp)
|
||||
}
|
||||
|
||||
func (h *answeringHandler) String() string { return h.name }
|
||||
|
||||
func TestHandlerChain_ResolveInternal_SkipsAboveMaxPriority(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
|
||||
high := &answeringHandler{name: "high", ip: "10.0.0.1"}
|
||||
low := &answeringHandler{name: "low", ip: "10.0.0.2"}
|
||||
|
||||
chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache)
|
||||
chain.AddHandler("example.com.", low, nbdns.PriorityUpstream)
|
||||
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
assert.Equal(t, 1, len(resp.Answer))
|
||||
a, ok := resp.Answer[0].(*dns.A)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "10.0.0.2", a.A.String(), "should skip mgmtCache handler and resolve via upstream")
|
||||
}
|
||||
|
||||
func TestHandlerChain_ResolveInternal_ErrorWhenNoMatch(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
high := &answeringHandler{name: "high", ip: "10.0.0.1"}
|
||||
chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache)
|
||||
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
_, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
|
||||
assert.Error(t, err, "no handler at or below maxPriority should error")
|
||||
}
|
||||
|
||||
// rawWriteHandler packs a response and calls ResponseWriter.Write directly
|
||||
// (instead of WriteMsg), exercising the internalResponseWriter.Write path.
|
||||
type rawWriteHandler struct {
|
||||
ip string
|
||||
}
|
||||
|
||||
func (h *rawWriteHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
resp := &dns.Msg{}
|
||||
resp.SetReply(r)
|
||||
resp.Answer = []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP(h.ip).To4(),
|
||||
}}
|
||||
packed, err := resp.Pack()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, _ = w.Write(packed)
|
||||
}
|
||||
|
||||
func TestHandlerChain_ResolveInternal_CapturesRawWrite(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
chain.AddHandler("example.com.", &rawWriteHandler{ip: "10.0.0.3"}, nbdns.PriorityUpstream)
|
||||
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
|
||||
assert.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Len(t, resp.Answer, 1)
|
||||
a, ok := resp.Answer[0].(*dns.A)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "10.0.0.3", a.A.String(), "handlers calling Write(packed) must still surface their answer")
|
||||
}
|
||||
|
||||
func TestHandlerChain_ResolveInternal_EmptyQuestion(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
_, err := chain.ResolveInternal(context.Background(), new(dns.Msg), nbdns.PriorityUpstream)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// hangingHandler blocks indefinitely until closed, simulating a wedged upstream.
|
||||
type hangingHandler struct {
|
||||
block chan struct{}
|
||||
}
|
||||
|
||||
func (h *hangingHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
<-h.block
|
||||
resp := &dns.Msg{}
|
||||
resp.SetReply(r)
|
||||
_ = w.WriteMsg(resp)
|
||||
}
|
||||
|
||||
func (h *hangingHandler) String() string { return "hangingHandler" }
|
||||
|
||||
func TestHandlerChain_ResolveInternal_HonorsContextTimeout(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
h := &hangingHandler{block: make(chan struct{})}
|
||||
defer close(h.block)
|
||||
|
||||
chain.AddHandler("example.com.", h, nbdns.PriorityUpstream)
|
||||
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
start := time.Now()
|
||||
_, err := chain.ResolveInternal(ctx, r, nbdns.PriorityUpstream)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
assert.Less(t, elapsed, 500*time.Millisecond, "ResolveInternal must return shortly after ctx deadline")
|
||||
}
|
||||
|
||||
func TestHandlerChain_HasRootHandlerAtOrBelow(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
h := &answeringHandler{name: "h", ip: "10.0.0.1"}
|
||||
|
||||
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "empty chain")
|
||||
|
||||
chain.AddHandler("example.com.", h, nbdns.PriorityUpstream)
|
||||
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "non-root handler does not count")
|
||||
|
||||
chain.AddHandler(".", h, nbdns.PriorityMgmtCache)
|
||||
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler above threshold excluded")
|
||||
|
||||
chain.AddHandler(".", h, nbdns.PriorityDefault)
|
||||
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler at PriorityDefault included")
|
||||
|
||||
chain.RemoveHandler(".", nbdns.PriorityDefault)
|
||||
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream))
|
||||
|
||||
// Primary nsgroup case: root handler lands at PriorityUpstream.
|
||||
chain.AddHandler(".", h, nbdns.PriorityUpstream)
|
||||
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityUpstream included")
|
||||
chain.RemoveHandler(".", nbdns.PriorityUpstream)
|
||||
|
||||
// Fallback case: original /etc/resolv.conf entries land at PriorityFallback.
|
||||
chain.AddHandler(".", h, nbdns.PriorityFallback)
|
||||
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityFallback included")
|
||||
chain.RemoveHandler(".", nbdns.PriorityFallback)
|
||||
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream))
|
||||
}
|
||||
|
||||
@@ -2,40 +2,83 @@ package mgmt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sync/singleflight"
|
||||
|
||||
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
const dnsTimeout = 5 * time.Second
|
||||
const (
|
||||
dnsTimeout = 5 * time.Second
|
||||
defaultTTL = 300 * time.Second
|
||||
refreshBackoff = 30 * time.Second
|
||||
|
||||
// Resolver caches critical NetBird infrastructure domains
|
||||
// envMgmtCacheTTL overrides defaultTTL for integration/dev testing.
|
||||
envMgmtCacheTTL = "NB_MGMT_CACHE_TTL"
|
||||
)
|
||||
|
||||
// ChainResolver lets the cache refresh stale entries through the DNS handler
|
||||
// chain instead of net.DefaultResolver, avoiding loopback when NetBird is the
|
||||
// system resolver.
|
||||
type ChainResolver interface {
|
||||
ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error)
|
||||
HasRootHandlerAtOrBelow(maxPriority int) bool
|
||||
}
|
||||
|
||||
// cachedRecord holds DNS records plus timestamps used for TTL refresh.
|
||||
// records and cachedAt are set at construction and treated as immutable;
|
||||
// lastFailedRefresh and consecFailures are mutable and must be accessed under
|
||||
// Resolver.mutex.
|
||||
type cachedRecord struct {
|
||||
records []dns.RR
|
||||
cachedAt time.Time
|
||||
lastFailedRefresh time.Time
|
||||
consecFailures int
|
||||
}
|
||||
|
||||
// Resolver caches critical NetBird infrastructure domains.
|
||||
// records, refreshing, mgmtDomain and serverDomains are all guarded by mutex.
|
||||
type Resolver struct {
|
||||
records map[dns.Question][]dns.RR
|
||||
records map[dns.Question]*cachedRecord
|
||||
mgmtDomain *domain.Domain
|
||||
serverDomains *dnsconfig.ServerDomains
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
type ipsResponse struct {
|
||||
ips []netip.Addr
|
||||
err error
|
||||
chain ChainResolver
|
||||
chainMaxPriority int
|
||||
refreshGroup singleflight.Group
|
||||
|
||||
// refreshing tracks questions whose refresh is running via the OS
|
||||
// fallback path. A ServeDNS hit for a question in this map indicates
|
||||
// the OS resolver routed the recursive query back to us (loop). Only
|
||||
// the OS path arms this so chain-path refreshes don't produce false
|
||||
// positives. The atomic bool is CAS-flipped once per refresh to
|
||||
// throttle the warning log.
|
||||
refreshing map[dns.Question]*atomic.Bool
|
||||
|
||||
cacheTTL time.Duration
|
||||
}
|
||||
|
||||
// NewResolver creates a new management domains cache resolver.
|
||||
func NewResolver() *Resolver {
|
||||
return &Resolver{
|
||||
records: make(map[dns.Question][]dns.RR),
|
||||
records: make(map[dns.Question]*cachedRecord),
|
||||
refreshing: make(map[dns.Question]*atomic.Bool),
|
||||
cacheTTL: resolveCacheTTL(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,7 +87,19 @@ func (m *Resolver) String() string {
|
||||
return "MgmtCacheResolver"
|
||||
}
|
||||
|
||||
// ServeDNS implements dns.Handler interface.
|
||||
// SetChainResolver wires the handler chain used to refresh stale cache entries.
|
||||
// maxPriority caps which handlers may answer refresh queries (typically
|
||||
// PriorityUpstream, so upstream/default/fallback handlers are consulted and
|
||||
// mgmt/route/local handlers are skipped).
|
||||
func (m *Resolver) SetChainResolver(chain ChainResolver, maxPriority int) {
|
||||
m.mutex.Lock()
|
||||
m.chain = chain
|
||||
m.chainMaxPriority = maxPriority
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
|
||||
// ServeDNS serves cached A/AAAA records. Stale entries are returned
|
||||
// immediately and refreshed asynchronously (stale-while-revalidate).
|
||||
func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
if len(r.Question) == 0 {
|
||||
m.continueToNext(w, r)
|
||||
@@ -60,7 +115,14 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
}
|
||||
|
||||
m.mutex.RLock()
|
||||
records, found := m.records[question]
|
||||
cached, found := m.records[question]
|
||||
inflight := m.refreshing[question]
|
||||
var shouldRefresh bool
|
||||
if found {
|
||||
stale := time.Since(cached.cachedAt) > m.cacheTTL
|
||||
inBackoff := !cached.lastFailedRefresh.IsZero() && time.Since(cached.lastFailedRefresh) < refreshBackoff
|
||||
shouldRefresh = stale && !inBackoff
|
||||
}
|
||||
m.mutex.RUnlock()
|
||||
|
||||
if !found {
|
||||
@@ -68,12 +130,23 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
return
|
||||
}
|
||||
|
||||
if inflight != nil && inflight.CompareAndSwap(false, true) {
|
||||
log.Warnf("mgmt cache: possible resolver loop for domain=%s: served stale while an OS-fallback refresh was inflight (if NetBird is the system resolver, the OS-path predicate is wrong)",
|
||||
question.Name)
|
||||
}
|
||||
|
||||
// Skip scheduling a refresh goroutine if one is already inflight for
|
||||
// this question; singleflight would dedup anyway but skipping avoids
|
||||
// a parked goroutine per stale hit under bursty load.
|
||||
if shouldRefresh && inflight == nil {
|
||||
m.scheduleRefresh(question, cached)
|
||||
}
|
||||
|
||||
resp := &dns.Msg{}
|
||||
resp.SetReply(r)
|
||||
resp.Authoritative = false
|
||||
resp.RecursionAvailable = true
|
||||
|
||||
resp.Answer = append(resp.Answer, records...)
|
||||
resp.Answer = cloneRecordsWithTTL(cached.records, m.responseTTL(cached.cachedAt))
|
||||
|
||||
log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name)
|
||||
|
||||
@@ -98,101 +171,260 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
|
||||
}
|
||||
}
|
||||
|
||||
// AddDomain manually adds a domain to cache by resolving it.
|
||||
// AddDomain resolves a domain and stores its A/AAAA records in the cache.
|
||||
// A family that resolves NODATA (nil err, zero records) evicts any stale
|
||||
// entry for that qtype.
|
||||
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
||||
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, dnsTimeout)
|
||||
defer cancel()
|
||||
|
||||
ips, err := lookupIPWithExtraTimeout(ctx, d)
|
||||
if err != nil {
|
||||
return err
|
||||
aRecords, aaaaRecords, errA, errAAAA := m.lookupBoth(ctx, d, dnsName)
|
||||
|
||||
if errA != nil && errAAAA != nil {
|
||||
return fmt.Errorf("resolve %s: %w", d.SafeString(), errors.Join(errA, errAAAA))
|
||||
}
|
||||
|
||||
var aRecords, aaaaRecords []dns.RR
|
||||
for _, ip := range ips {
|
||||
if ip.Is4() {
|
||||
rr := &dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: dnsName,
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 300,
|
||||
},
|
||||
A: ip.AsSlice(),
|
||||
}
|
||||
aRecords = append(aRecords, rr)
|
||||
} else if ip.Is6() {
|
||||
rr := &dns.AAAA{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: dnsName,
|
||||
Rrtype: dns.TypeAAAA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 300,
|
||||
},
|
||||
AAAA: ip.AsSlice(),
|
||||
}
|
||||
aaaaRecords = append(aaaaRecords, rr)
|
||||
if len(aRecords) == 0 && len(aaaaRecords) == 0 {
|
||||
if err := errors.Join(errA, errAAAA); err != nil {
|
||||
return fmt.Errorf("resolve %s: no A/AAAA records: %w", d.SafeString(), err)
|
||||
}
|
||||
return fmt.Errorf("resolve %s: no A/AAAA records", d.SafeString())
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if len(aRecords) > 0 {
|
||||
aQuestion := dns.Question{
|
||||
Name: dnsName,
|
||||
Qtype: dns.TypeA,
|
||||
Qclass: dns.ClassINET,
|
||||
}
|
||||
m.records[aQuestion] = aRecords
|
||||
}
|
||||
m.applyFamilyRecords(dnsName, dns.TypeA, aRecords, errA, now)
|
||||
m.applyFamilyRecords(dnsName, dns.TypeAAAA, aaaaRecords, errAAAA, now)
|
||||
|
||||
if len(aaaaRecords) > 0 {
|
||||
aaaaQuestion := dns.Question{
|
||||
Name: dnsName,
|
||||
Qtype: dns.TypeAAAA,
|
||||
Qclass: dns.ClassINET,
|
||||
}
|
||||
m.records[aaaaQuestion] = aaaaRecords
|
||||
}
|
||||
|
||||
m.mutex.Unlock()
|
||||
|
||||
log.Debugf("added domain=%s with %d A records and %d AAAA records",
|
||||
log.Debugf("added/updated domain=%s with %d A records and %d AAAA records",
|
||||
d.SafeString(), len(aRecords), len(aaaaRecords))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func lookupIPWithExtraTimeout(ctx context.Context, d domain.Domain) ([]netip.Addr, error) {
|
||||
log.Infof("looking up IP for mgmt domain=%s", d.SafeString())
|
||||
defer log.Infof("done looking up IP for mgmt domain=%s", d.SafeString())
|
||||
resultChan := make(chan *ipsResponse, 1)
|
||||
// applyFamilyRecords writes records, evicts on NODATA, leaves the cache
|
||||
// untouched on error. Caller holds m.mutex.
|
||||
func (m *Resolver) applyFamilyRecords(dnsName string, qtype uint16, records []dns.RR, err error, now time.Time) {
|
||||
q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET}
|
||||
switch {
|
||||
case len(records) > 0:
|
||||
m.records[q] = &cachedRecord{records: records, cachedAt: now}
|
||||
case err == nil:
|
||||
delete(m.records, q)
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString())
|
||||
resultChan <- &ipsResponse{
|
||||
err: err,
|
||||
ips: ips,
|
||||
// scheduleRefresh kicks off an async refresh. DoChan spawns one goroutine per
|
||||
// unique in-flight key; bursty stale hits share its channel. expected is the
|
||||
// cachedRecord pointer observed by the caller; the refresh only mutates the
|
||||
// cache if that pointer is still the one stored, so a stale in-flight refresh
|
||||
// can't clobber a newer entry written by AddDomain or a competing refresh.
|
||||
func (m *Resolver) scheduleRefresh(question dns.Question, expected *cachedRecord) {
|
||||
key := question.Name + "|" + dns.TypeToString[question.Qtype]
|
||||
_ = m.refreshGroup.DoChan(key, func() (any, error) {
|
||||
return nil, m.refreshQuestion(question, expected)
|
||||
})
|
||||
}
|
||||
|
||||
// refreshQuestion replaces the cached records on success, or marks the entry
|
||||
// failed (arming the backoff) on failure. While this runs, ServeDNS can detect
|
||||
// a resolver loop by spotting a query for this same question arriving on us.
|
||||
// expected pins the cache entry observed at schedule time; mutations only apply
|
||||
// if m.records[question] still points at it.
|
||||
func (m *Resolver) refreshQuestion(question dns.Question, expected *cachedRecord) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
|
||||
defer cancel()
|
||||
|
||||
d, err := domain.FromString(strings.TrimSuffix(question.Name, "."))
|
||||
if err != nil {
|
||||
m.markRefreshFailed(question, expected)
|
||||
return fmt.Errorf("parse domain: %w", err)
|
||||
}
|
||||
|
||||
records, err := m.lookupRecords(ctx, d, question)
|
||||
if err != nil {
|
||||
fails := m.markRefreshFailed(question, expected)
|
||||
logf := log.Warnf
|
||||
if fails == 0 || fails > 1 {
|
||||
logf = log.Debugf
|
||||
}
|
||||
}()
|
||||
|
||||
var resp *ipsResponse
|
||||
|
||||
select {
|
||||
case <-time.After(dnsTimeout + time.Millisecond*500):
|
||||
log.Warnf("timed out waiting for IP for mgmt domain=%s", d.SafeString())
|
||||
return nil, fmt.Errorf("timed out waiting for ips to be available for domain %s", d.SafeString())
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case resp = <-resultChan:
|
||||
logf("refresh mgmt cache domain=%s type=%s: %v (consecutive failures=%d)",
|
||||
d.SafeString(), dns.TypeToString[question.Qtype], err, fails)
|
||||
return err
|
||||
}
|
||||
|
||||
if resp.err != nil {
|
||||
return nil, fmt.Errorf("resolve domain %s: %w", d.SafeString(), resp.err)
|
||||
// NOERROR/NODATA: family gone upstream, evict so we stop serving stale.
|
||||
if len(records) == 0 {
|
||||
m.mutex.Lock()
|
||||
if m.records[question] == expected {
|
||||
delete(m.records, question)
|
||||
m.mutex.Unlock()
|
||||
log.Infof("removed mgmt cache domain=%s type=%s: no records returned",
|
||||
d.SafeString(), dns.TypeToString[question.Qtype])
|
||||
return nil
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
log.Debugf("skipping refresh evict for domain=%s type=%s: entry changed during refresh",
|
||||
d.SafeString(), dns.TypeToString[question.Qtype])
|
||||
return nil
|
||||
}
|
||||
return resp.ips, nil
|
||||
|
||||
now := time.Now()
|
||||
m.mutex.Lock()
|
||||
if m.records[question] != expected {
|
||||
m.mutex.Unlock()
|
||||
log.Debugf("skipping refresh write for domain=%s type=%s: entry changed during refresh",
|
||||
d.SafeString(), dns.TypeToString[question.Qtype])
|
||||
return nil
|
||||
}
|
||||
m.records[question] = &cachedRecord{records: records, cachedAt: now}
|
||||
m.mutex.Unlock()
|
||||
|
||||
log.Infof("refreshed mgmt cache domain=%s type=%s",
|
||||
d.SafeString(), dns.TypeToString[question.Qtype])
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Resolver) markRefreshing(question dns.Question) {
|
||||
m.mutex.Lock()
|
||||
m.refreshing[question] = &atomic.Bool{}
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (m *Resolver) clearRefreshing(question dns.Question) {
|
||||
m.mutex.Lock()
|
||||
delete(m.refreshing, question)
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
|
||||
// markRefreshFailed arms the backoff and returns the new consecutive-failure
|
||||
// count so callers can downgrade subsequent failure logs to debug.
|
||||
func (m *Resolver) markRefreshFailed(question dns.Question, expected *cachedRecord) int {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
c, ok := m.records[question]
|
||||
if !ok || c != expected {
|
||||
return 0
|
||||
}
|
||||
c.lastFailedRefresh = time.Now()
|
||||
c.consecFailures++
|
||||
return c.consecFailures
|
||||
}
|
||||
|
||||
// lookupBoth resolves A and AAAA via chain or OS. Per-family errors let
|
||||
// callers tell records, NODATA (nil err, no records), and failure apart.
|
||||
func (m *Resolver) lookupBoth(ctx context.Context, d domain.Domain, dnsName string) (aRecords, aaaaRecords []dns.RR, errA, errAAAA error) {
|
||||
m.mutex.RLock()
|
||||
chain := m.chain
|
||||
maxPriority := m.chainMaxPriority
|
||||
m.mutex.RUnlock()
|
||||
|
||||
if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) {
|
||||
aRecords, errA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeA)
|
||||
aaaaRecords, errAAAA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeAAAA)
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: drop once every supported OS registers a fallback resolver. Safe
|
||||
// today: no root handler at priority ≤ PriorityUpstream means NetBird is
|
||||
// not the system resolver, so net.DefaultResolver will not loop back.
|
||||
aRecords, errA = m.osLookup(ctx, d, dnsName, dns.TypeA)
|
||||
aaaaRecords, errAAAA = m.osLookup(ctx, d, dnsName, dns.TypeAAAA)
|
||||
return
|
||||
}
|
||||
|
||||
// lookupRecords resolves a single record type via chain or OS. The OS branch
|
||||
// arms the loop detector for the duration of its call so that ServeDNS can
|
||||
// spot the OS resolver routing the recursive query back to us.
|
||||
func (m *Resolver) lookupRecords(ctx context.Context, d domain.Domain, q dns.Question) ([]dns.RR, error) {
|
||||
m.mutex.RLock()
|
||||
chain := m.chain
|
||||
maxPriority := m.chainMaxPriority
|
||||
m.mutex.RUnlock()
|
||||
|
||||
if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) {
|
||||
return m.lookupViaChain(ctx, chain, maxPriority, q.Name, q.Qtype)
|
||||
}
|
||||
|
||||
// TODO: drop once every supported OS registers a fallback resolver.
|
||||
m.markRefreshing(q)
|
||||
defer m.clearRefreshing(q)
|
||||
|
||||
return m.osLookup(ctx, d, q.Name, q.Qtype)
|
||||
}
|
||||
|
||||
// lookupViaChain resolves via the handler chain and rewrites each RR to use
|
||||
// dnsName as owner and m.cacheTTL as TTL, so CNAME-backed domains don't cache
|
||||
// target-owned records or upstream TTLs. NODATA returns (nil, nil).
|
||||
func (m *Resolver) lookupViaChain(ctx context.Context, chain ChainResolver, maxPriority int, dnsName string, qtype uint16) ([]dns.RR, error) {
|
||||
msg := &dns.Msg{}
|
||||
msg.SetQuestion(dnsName, qtype)
|
||||
msg.RecursionDesired = true
|
||||
|
||||
resp, err := chain.ResolveInternal(ctx, msg, maxPriority)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("chain resolve: %w", err)
|
||||
}
|
||||
if resp == nil {
|
||||
return nil, fmt.Errorf("chain resolve returned nil response")
|
||||
}
|
||||
if resp.Rcode != dns.RcodeSuccess {
|
||||
return nil, fmt.Errorf("chain resolve rcode=%s", dns.RcodeToString[resp.Rcode])
|
||||
}
|
||||
|
||||
ttl := uint32(m.cacheTTL.Seconds())
|
||||
owners := cnameOwners(dnsName, resp.Answer)
|
||||
var filtered []dns.RR
|
||||
for _, rr := range resp.Answer {
|
||||
h := rr.Header()
|
||||
if h.Class != dns.ClassINET || h.Rrtype != qtype {
|
||||
continue
|
||||
}
|
||||
if !owners[strings.ToLower(dns.Fqdn(h.Name))] {
|
||||
continue
|
||||
}
|
||||
if cp := cloneIPRecord(rr, dnsName, ttl); cp != nil {
|
||||
filtered = append(filtered, cp)
|
||||
}
|
||||
}
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
// osLookup resolves a single family via net.DefaultResolver using resutil,
|
||||
// which disambiguates NODATA from NXDOMAIN and Unmaps v4-mapped-v6. NODATA
|
||||
// returns (nil, nil).
|
||||
func (m *Resolver) osLookup(ctx context.Context, d domain.Domain, dnsName string, qtype uint16) ([]dns.RR, error) {
|
||||
network := resutil.NetworkForQtype(qtype)
|
||||
if network == "" {
|
||||
return nil, fmt.Errorf("unsupported qtype %s", dns.TypeToString[qtype])
|
||||
}
|
||||
|
||||
log.Infof("looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype])
|
||||
defer log.Infof("done looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype])
|
||||
|
||||
result := resutil.LookupIP(ctx, net.DefaultResolver, network, d.PunycodeString(), qtype)
|
||||
if result.Rcode == dns.RcodeSuccess {
|
||||
return resutil.IPsToRRs(dnsName, result.IPs, uint32(m.cacheTTL.Seconds())), nil
|
||||
}
|
||||
|
||||
if result.Err != nil {
|
||||
return nil, fmt.Errorf("resolve %s type=%s: %w", d.SafeString(), dns.TypeToString[qtype], result.Err)
|
||||
}
|
||||
return nil, fmt.Errorf("resolve %s type=%s: rcode=%s", d.SafeString(), dns.TypeToString[qtype], dns.RcodeToString[result.Rcode])
|
||||
}
|
||||
|
||||
// responseTTL returns the remaining cache lifetime in seconds (rounded up),
|
||||
// so downstream resolvers don't cache an answer for longer than we will.
|
||||
func (m *Resolver) responseTTL(cachedAt time.Time) uint32 {
|
||||
remaining := m.cacheTTL - time.Since(cachedAt)
|
||||
if remaining <= 0 {
|
||||
return 0
|
||||
}
|
||||
return uint32((remaining + time.Second - 1) / time.Second)
|
||||
}
|
||||
|
||||
// PopulateFromConfig extracts and caches domains from the client configuration.
|
||||
@@ -224,19 +456,12 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
aQuestion := dns.Question{
|
||||
Name: dnsName,
|
||||
Qtype: dns.TypeA,
|
||||
Qclass: dns.ClassINET,
|
||||
}
|
||||
delete(m.records, aQuestion)
|
||||
|
||||
aaaaQuestion := dns.Question{
|
||||
Name: dnsName,
|
||||
Qtype: dns.TypeAAAA,
|
||||
Qclass: dns.ClassINET,
|
||||
}
|
||||
delete(m.records, aaaaQuestion)
|
||||
qA := dns.Question{Name: dnsName, Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
qAAAA := dns.Question{Name: dnsName, Qtype: dns.TypeAAAA, Qclass: dns.ClassINET}
|
||||
delete(m.records, qA)
|
||||
delete(m.records, qAAAA)
|
||||
delete(m.refreshing, qA)
|
||||
delete(m.refreshing, qAAAA)
|
||||
|
||||
log.Debugf("removed domain=%s from cache", d.SafeString())
|
||||
return nil
|
||||
@@ -394,3 +619,73 @@ func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.Serve
|
||||
|
||||
return domains
|
||||
}
|
||||
|
||||
// cloneIPRecord returns a deep copy of rr retargeted to owner with ttl. Non
|
||||
// A/AAAA records return nil.
|
||||
func cloneIPRecord(rr dns.RR, owner string, ttl uint32) dns.RR {
|
||||
switch r := rr.(type) {
|
||||
case *dns.A:
|
||||
cp := *r
|
||||
cp.Hdr.Name = owner
|
||||
cp.Hdr.Ttl = ttl
|
||||
cp.A = slices.Clone(r.A)
|
||||
return &cp
|
||||
case *dns.AAAA:
|
||||
cp := *r
|
||||
cp.Hdr.Name = owner
|
||||
cp.Hdr.Ttl = ttl
|
||||
cp.AAAA = slices.Clone(r.AAAA)
|
||||
return &cp
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// cloneRecordsWithTTL clones A/AAAA records preserving their owner and
|
||||
// stamping ttl so the response shares no memory with the cached slice.
|
||||
func cloneRecordsWithTTL(records []dns.RR, ttl uint32) []dns.RR {
|
||||
out := make([]dns.RR, 0, len(records))
|
||||
for _, rr := range records {
|
||||
if cp := cloneIPRecord(rr, rr.Header().Name, ttl); cp != nil {
|
||||
out = append(out, cp)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// cnameOwners returns dnsName plus every target reachable by following CNAMEs
|
||||
// in answer, iterating until fixed point so out-of-order chains resolve.
|
||||
func cnameOwners(dnsName string, answer []dns.RR) map[string]bool {
|
||||
owners := map[string]bool{dnsName: true}
|
||||
for {
|
||||
added := false
|
||||
for _, rr := range answer {
|
||||
cname, ok := rr.(*dns.CNAME)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
name := strings.ToLower(dns.Fqdn(cname.Hdr.Name))
|
||||
if !owners[name] {
|
||||
continue
|
||||
}
|
||||
target := strings.ToLower(dns.Fqdn(cname.Target))
|
||||
if !owners[target] {
|
||||
owners[target] = true
|
||||
added = true
|
||||
}
|
||||
}
|
||||
if !added {
|
||||
return owners
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// resolveCacheTTL reads the cache TTL override env var; invalid or empty
|
||||
// values fall back to defaultTTL. Called once per Resolver from NewResolver.
|
||||
func resolveCacheTTL() time.Duration {
|
||||
if v := os.Getenv(envMgmtCacheTTL); v != "" {
|
||||
if d, err := time.ParseDuration(v); err == nil && d > 0 {
|
||||
return d
|
||||
}
|
||||
}
|
||||
return defaultTTL
|
||||
}
|
||||
|
||||
408
client/internal/dns/mgmt/mgmt_refresh_test.go
Normal file
408
client/internal/dns/mgmt/mgmt_refresh_test.go
Normal file
@@ -0,0 +1,408 @@
|
||||
package mgmt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
type fakeChain struct {
|
||||
mu sync.Mutex
|
||||
calls map[string]int
|
||||
answers map[string][]dns.RR
|
||||
err error
|
||||
hasRoot bool
|
||||
onLookup func()
|
||||
}
|
||||
|
||||
func newFakeChain() *fakeChain {
|
||||
return &fakeChain{
|
||||
calls: map[string]int{},
|
||||
answers: map[string][]dns.RR{},
|
||||
hasRoot: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeChain) HasRootHandlerAtOrBelow(maxPriority int) bool {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
return f.hasRoot
|
||||
}
|
||||
|
||||
func (f *fakeChain) ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error) {
|
||||
f.mu.Lock()
|
||||
q := msg.Question[0]
|
||||
key := q.Name + "|" + dns.TypeToString[q.Qtype]
|
||||
f.calls[key]++
|
||||
answers := f.answers[key]
|
||||
err := f.err
|
||||
onLookup := f.onLookup
|
||||
f.mu.Unlock()
|
||||
|
||||
if onLookup != nil {
|
||||
onLookup()
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp := &dns.Msg{}
|
||||
resp.SetReply(msg)
|
||||
resp.Answer = answers
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (f *fakeChain) setAnswer(name string, qtype uint16, ip string) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
key := name + "|" + dns.TypeToString[qtype]
|
||||
hdr := dns.RR_Header{Name: name, Rrtype: qtype, Class: dns.ClassINET, Ttl: 60}
|
||||
switch qtype {
|
||||
case dns.TypeA:
|
||||
f.answers[key] = []dns.RR{&dns.A{Hdr: hdr, A: net.ParseIP(ip).To4()}}
|
||||
case dns.TypeAAAA:
|
||||
f.answers[key] = []dns.RR{&dns.AAAA{Hdr: hdr, AAAA: net.ParseIP(ip).To16()}}
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeChain) callCount(name string, qtype uint16) int {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
return f.calls[name+"|"+dns.TypeToString[qtype]]
|
||||
}
|
||||
|
||||
// waitFor polls the predicate until it returns true or the deadline passes.
|
||||
func waitFor(t *testing.T, d time.Duration, fn func() bool) {
|
||||
t.Helper()
|
||||
deadline := time.Now().Add(d)
|
||||
for time.Now().Before(deadline) {
|
||||
if fn() {
|
||||
return
|
||||
}
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
}
|
||||
t.Fatalf("condition not met within %s", d)
|
||||
}
|
||||
|
||||
func queryA(t *testing.T, r *Resolver, name string) *dns.Msg {
|
||||
t.Helper()
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion(name, dns.TypeA)
|
||||
w := &test.MockResponseWriter{}
|
||||
r.ServeDNS(w, msg)
|
||||
return w.GetLastResponse()
|
||||
}
|
||||
|
||||
func firstA(t *testing.T, resp *dns.Msg) string {
|
||||
t.Helper()
|
||||
require.NotNil(t, resp)
|
||||
require.Greater(t, len(resp.Answer), 0, "expected at least one answer")
|
||||
a, ok := resp.Answer[0].(*dns.A)
|
||||
require.True(t, ok, "expected A record")
|
||||
return a.A.String()
|
||||
}
|
||||
|
||||
func TestResolver_CacheTTLGatesRefresh(t *testing.T) {
|
||||
// Same cached entry age, different cacheTTL values: the shorter TTL must
|
||||
// trigger a background refresh, the longer one must not. Proves that the
|
||||
// per-Resolver cacheTTL field actually drives the stale decision.
|
||||
cachedAt := time.Now().Add(-100 * time.Millisecond)
|
||||
|
||||
newRec := func() *cachedRecord {
|
||||
return &cachedRecord{
|
||||
records: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: "mgmt.example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1").To4(),
|
||||
}},
|
||||
cachedAt: cachedAt,
|
||||
}
|
||||
}
|
||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
|
||||
t.Run("short TTL treats entry as stale and refreshes", func(t *testing.T) {
|
||||
r := NewResolver()
|
||||
r.cacheTTL = 10 * time.Millisecond
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
r.records[q] = newRec()
|
||||
|
||||
resp := queryA(t, r, q.Name)
|
||||
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must be served while refresh runs")
|
||||
|
||||
waitFor(t, time.Second, func() bool {
|
||||
return chain.callCount(q.Name, dns.TypeA) >= 1
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("long TTL keeps entry fresh and skips refresh", func(t *testing.T) {
|
||||
r := NewResolver()
|
||||
r.cacheTTL = time.Hour
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
r.records[q] = newRec()
|
||||
|
||||
resp := queryA(t, r, q.Name)
|
||||
assert.Equal(t, "10.0.0.1", firstA(t, resp))
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
assert.Equal(t, 0, chain.callCount(q.Name, dns.TypeA), "fresh entry must not trigger refresh")
|
||||
})
|
||||
}
|
||||
|
||||
func TestResolver_ServeFresh_NoRefresh(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
r.records[dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}] = &cachedRecord{
|
||||
records: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: "mgmt.example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1").To4(),
|
||||
}},
|
||||
cachedAt: time.Now(), // fresh
|
||||
}
|
||||
|
||||
resp := queryA(t, r, "mgmt.example.com.")
|
||||
assert.Equal(t, "10.0.0.1", firstA(t, resp))
|
||||
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
assert.Equal(t, 0, chain.callCount("mgmt.example.com.", dns.TypeA), "fresh entry must not trigger refresh")
|
||||
}
|
||||
|
||||
func TestResolver_StaleTriggersAsyncRefresh(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
r.records[q] = &cachedRecord{
|
||||
records: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1").To4(),
|
||||
}},
|
||||
cachedAt: time.Now().Add(-2 * defaultTTL), // stale
|
||||
}
|
||||
|
||||
// First query: serves stale immediately.
|
||||
resp := queryA(t, r, "mgmt.example.com.")
|
||||
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must be served while refresh runs")
|
||||
|
||||
waitFor(t, time.Second, func() bool {
|
||||
return chain.callCount("mgmt.example.com.", dns.TypeA) >= 1
|
||||
})
|
||||
|
||||
// Next query should now return the refreshed IP.
|
||||
waitFor(t, time.Second, func() bool {
|
||||
resp := queryA(t, r, "mgmt.example.com.")
|
||||
return resp != nil && len(resp.Answer) > 0 && firstA(t, resp) == "10.0.0.2"
|
||||
})
|
||||
}
|
||||
|
||||
func TestResolver_ConcurrentStaleHitsCollapseRefresh(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
||||
|
||||
var inflight atomic.Int32
|
||||
var maxInflight atomic.Int32
|
||||
chain.onLookup = func() {
|
||||
cur := inflight.Add(1)
|
||||
defer inflight.Add(-1)
|
||||
for {
|
||||
prev := maxInflight.Load()
|
||||
if cur <= prev || maxInflight.CompareAndSwap(prev, cur) {
|
||||
break
|
||||
}
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond) // hold inflight long enough to collide
|
||||
}
|
||||
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
r.records[q] = &cachedRecord{
|
||||
records: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1").To4(),
|
||||
}},
|
||||
cachedAt: time.Now().Add(-2 * defaultTTL),
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 50; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
queryA(t, r, "mgmt.example.com.")
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
waitFor(t, 2*time.Second, func() bool {
|
||||
return inflight.Load() == 0
|
||||
})
|
||||
|
||||
calls := chain.callCount("mgmt.example.com.", dns.TypeA)
|
||||
assert.LessOrEqual(t, calls, 2, "singleflight must collapse concurrent refreshes (got %d)", calls)
|
||||
assert.Equal(t, int32(1), maxInflight.Load(), "only one refresh should run concurrently")
|
||||
}
|
||||
|
||||
func TestResolver_RefreshFailureArmsBackoff(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.err = errors.New("boom")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
r.records[q] = &cachedRecord{
|
||||
records: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1").To4(),
|
||||
}},
|
||||
cachedAt: time.Now().Add(-2 * defaultTTL),
|
||||
}
|
||||
|
||||
// First stale hit triggers a refresh attempt that fails.
|
||||
resp := queryA(t, r, "mgmt.example.com.")
|
||||
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry served while refresh fails")
|
||||
|
||||
waitFor(t, time.Second, func() bool {
|
||||
return chain.callCount("mgmt.example.com.", dns.TypeA) == 1
|
||||
})
|
||||
waitFor(t, time.Second, func() bool {
|
||||
r.mutex.RLock()
|
||||
defer r.mutex.RUnlock()
|
||||
c, ok := r.records[q]
|
||||
return ok && !c.lastFailedRefresh.IsZero()
|
||||
})
|
||||
|
||||
// Subsequent stale hits within backoff window should not schedule more refreshes.
|
||||
for i := 0; i < 10; i++ {
|
||||
queryA(t, r, "mgmt.example.com.")
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeA), "backoff must suppress further refreshes")
|
||||
}
|
||||
|
||||
func TestResolver_NoRootHandler_SkipsChain(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.hasRoot = false
|
||||
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
// With hasRoot=false the chain must not be consulted. Use a short
|
||||
// deadline so the OS fallback returns quickly without waiting on a
|
||||
// real network call in CI.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
defer cancel()
|
||||
_, _, _, _ = r.lookupBoth(ctx, domain.Domain("mgmt.example.com"), "mgmt.example.com.")
|
||||
|
||||
assert.Equal(t, 0, chain.callCount("mgmt.example.com.", dns.TypeA),
|
||||
"chain must not be used when no root handler is registered at the bound priority")
|
||||
}
|
||||
|
||||
func TestResolver_ServeDuringRefreshSetsLoopFlag(t *testing.T) {
|
||||
// ServeDNS being invoked for a question while a refresh for that question
|
||||
// is inflight indicates a resolver loop (OS resolver sent the recursive
|
||||
// query back to us). The inflightRefresh.loopLoggedOnce flag must be set.
|
||||
r := NewResolver()
|
||||
|
||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
r.records[q] = &cachedRecord{
|
||||
records: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1").To4(),
|
||||
}},
|
||||
cachedAt: time.Now(),
|
||||
}
|
||||
|
||||
// Simulate an inflight refresh.
|
||||
r.markRefreshing(q)
|
||||
defer r.clearRefreshing(q)
|
||||
|
||||
resp := queryA(t, r, "mgmt.example.com.")
|
||||
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must still be served to avoid breaking external queries")
|
||||
|
||||
r.mutex.RLock()
|
||||
inflight := r.refreshing[q]
|
||||
r.mutex.RUnlock()
|
||||
require.NotNil(t, inflight)
|
||||
assert.True(t, inflight.Load(), "loop flag must be set once a ServeDNS during refresh was observed")
|
||||
}
|
||||
|
||||
func TestResolver_LoopFlagOnlyTrippedOncePerRefresh(t *testing.T) {
|
||||
r := NewResolver()
|
||||
|
||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
r.records[q] = &cachedRecord{
|
||||
records: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1").To4(),
|
||||
}},
|
||||
cachedAt: time.Now(),
|
||||
}
|
||||
|
||||
r.markRefreshing(q)
|
||||
defer r.clearRefreshing(q)
|
||||
|
||||
// Multiple ServeDNS calls during the same refresh must not re-set the flag
|
||||
// (CompareAndSwap from false -> true returns true only on the first call).
|
||||
for range 5 {
|
||||
queryA(t, r, "mgmt.example.com.")
|
||||
}
|
||||
|
||||
r.mutex.RLock()
|
||||
inflight := r.refreshing[q]
|
||||
r.mutex.RUnlock()
|
||||
assert.True(t, inflight.Load())
|
||||
}
|
||||
|
||||
func TestResolver_NoLoopFlagWhenNotRefreshing(t *testing.T) {
|
||||
r := NewResolver()
|
||||
|
||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
r.records[q] = &cachedRecord{
|
||||
records: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1").To4(),
|
||||
}},
|
||||
cachedAt: time.Now(),
|
||||
}
|
||||
|
||||
queryA(t, r, "mgmt.example.com.")
|
||||
|
||||
r.mutex.RLock()
|
||||
_, ok := r.refreshing[q]
|
||||
r.mutex.RUnlock()
|
||||
assert.False(t, ok, "no refresh inflight means no loop tracking")
|
||||
}
|
||||
|
||||
func TestResolver_AddDomain_UsesChainWhenRootRegistered(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
||||
chain.setAnswer("mgmt.example.com.", dns.TypeAAAA, "fd00::2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
require.NoError(t, r.AddDomain(context.Background(), domain.Domain("mgmt.example.com")))
|
||||
|
||||
resp := queryA(t, r, "mgmt.example.com.")
|
||||
assert.Equal(t, "10.0.0.2", firstA(t, resp))
|
||||
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeA))
|
||||
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeAAAA))
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -23,6 +24,60 @@ func TestResolver_NewResolver(t *testing.T) {
|
||||
assert.False(t, resolver.MatchSubdomains())
|
||||
}
|
||||
|
||||
func TestResolveCacheTTL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
value string
|
||||
want time.Duration
|
||||
}{
|
||||
{"unset falls back to default", "", defaultTTL},
|
||||
{"valid duration", "45s", 45 * time.Second},
|
||||
{"valid minutes", "2m", 2 * time.Minute},
|
||||
{"malformed falls back to default", "not-a-duration", defaultTTL},
|
||||
{"zero falls back to default", "0s", defaultTTL},
|
||||
{"negative falls back to default", "-5s", defaultTTL},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Setenv(envMgmtCacheTTL, tc.value)
|
||||
got := resolveCacheTTL()
|
||||
assert.Equal(t, tc.want, got, "parsed TTL should match")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewResolver_CacheTTLFromEnv(t *testing.T) {
|
||||
t.Setenv(envMgmtCacheTTL, "7s")
|
||||
r := NewResolver()
|
||||
assert.Equal(t, 7*time.Second, r.cacheTTL, "NewResolver should evaluate cacheTTL once from env")
|
||||
}
|
||||
|
||||
func TestResolver_ResponseTTL(t *testing.T) {
|
||||
now := time.Now()
|
||||
tests := []struct {
|
||||
name string
|
||||
cacheTTL time.Duration
|
||||
cachedAt time.Time
|
||||
wantMin uint32
|
||||
wantMax uint32
|
||||
}{
|
||||
{"fresh entry returns full TTL", 60 * time.Second, now, 59, 60},
|
||||
{"half-aged entry returns half TTL", 60 * time.Second, now.Add(-30 * time.Second), 29, 31},
|
||||
{"expired entry returns zero", 60 * time.Second, now.Add(-61 * time.Second), 0, 0},
|
||||
{"exactly expired returns zero", 10 * time.Second, now.Add(-10 * time.Second), 0, 0},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
r := &Resolver{cacheTTL: tc.cacheTTL}
|
||||
got := r.responseTTL(tc.cachedAt)
|
||||
assert.GreaterOrEqual(t, got, tc.wantMin, "remaining TTL should be >= wantMin")
|
||||
assert.LessOrEqual(t, got, tc.wantMax, "remaining TTL should be <= wantMax")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolver_ExtractDomainFromURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -212,6 +212,7 @@ func newDefaultServer(
|
||||
ctx, stop := context.WithCancel(ctx)
|
||||
|
||||
mgmtCacheResolver := mgmt.NewResolver()
|
||||
mgmtCacheResolver.SetChainResolver(handlerChain, PriorityUpstream)
|
||||
|
||||
defaultServer := &DefaultServer{
|
||||
ctx: ctx,
|
||||
|
||||
@@ -457,6 +457,18 @@ func TestReceive_ProtocolErrorStreamReconnect(t *testing.T) {
|
||||
|
||||
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Cleanups run LIFO: the goroutine-drain registered here runs after Close below,
|
||||
// which is when Receive has actually returned. Without this, the Receive goroutine
|
||||
// can outlive the test and call t.Logf after teardown, panicking.
|
||||
receiveDone := make(chan struct{})
|
||||
t.Cleanup(func() {
|
||||
select {
|
||||
case <-receiveDone:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Error("Receive goroutine did not exit after Close")
|
||||
}
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
err := client.Close()
|
||||
assert.NoError(t, err, "failed to close flow")
|
||||
@@ -468,6 +480,7 @@ func TestReceive_ProtocolErrorStreamReconnect(t *testing.T) {
|
||||
receivedAfterReconnect := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer close(receiveDone)
|
||||
err := client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
|
||||
if msg.IsInitiator || len(msg.EventId) == 0 {
|
||||
return nil
|
||||
|
||||
6
go.mod
6
go.mod
@@ -63,7 +63,7 @@ require (
|
||||
github.com/hashicorp/go-multierror v1.1.1
|
||||
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2
|
||||
github.com/hashicorp/go-version v1.7.0
|
||||
github.com/jackc/pgx/v5 v5.5.5
|
||||
github.com/jackc/pgx/v5 v5.9.2
|
||||
github.com/libdns/route53 v1.5.0
|
||||
github.com/libp2p/go-nat v0.2.0
|
||||
github.com/libp2p/go-netroute v0.2.1
|
||||
@@ -221,8 +221,8 @@ require (
|
||||
github.com/huin/goupnp v1.2.0 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
|
||||
github.com/jackc/puddle/v2 v2.2.1 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
||||
github.com/jackpal/go-nat-pmp v1.0.2 // indirect
|
||||
github.com/jeandeaual/go-locale v0.0.0-20250612000132-0ef82f21eade // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
|
||||
12
go.sum
12
go.sum
@@ -324,12 +324,12 @@ github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||
github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw=
|
||||
github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A=
|
||||
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
|
||||
github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||
github.com/jackc/pgx/v5 v5.9.2 h1:3ZhOzMWnR4yJ+RW1XImIPsD1aNSz4T4fyP7zlQb56hw=
|
||||
github.com/jackc/pgx/v5 v5.9.2/go.mod h1:mal1tBGAFfLHvZzaYh77YS/eC6IX9OWbRV1QIIM0Jn4=
|
||||
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||
github.com/jackpal/go-nat-pmp v1.0.2 h1:KzKSgb7qkJvOUTqYl9/Hg/me3pWgBmERKrTGD7BdWus=
|
||||
github.com/jackpal/go-nat-pmp v1.0.2/go.mod h1:QPH045xvCAeXUZOxsnwmrtiCoxIr9eob+4orBN1SBKc=
|
||||
github.com/jcmturner/aescts/v2 v2.0.0 h1:9YKLH6ey7H4eDBXW8khjYslgyqG2xZikXP0EQFKrle8=
|
||||
|
||||
@@ -472,7 +472,7 @@ start_services_and_show_instructions() {
|
||||
if [[ "$ENABLE_CROWDSEC" == "true" ]]; then
|
||||
echo "Registering CrowdSec bouncer..."
|
||||
local cs_retries=0
|
||||
while ! $DOCKER_COMPOSE_COMMAND exec -T crowdsec cscli capi status >/dev/null 2>&1; do
|
||||
while ! $DOCKER_COMPOSE_COMMAND exec -T crowdsec cscli lapi status >/dev/null 2>&1; do
|
||||
cs_retries=$((cs_retries + 1))
|
||||
if [[ $cs_retries -ge 30 ]]; then
|
||||
echo "WARNING: CrowdSec did not become ready. Skipping CrowdSec setup." > /dev/stderr
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
serverauth "github.com/netbirdio/netbird/management/server/auth"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
||||
@@ -87,17 +88,14 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
||||
|
||||
switch authType {
|
||||
case "bearer":
|
||||
request, err := m.checkJWTFromRequest(r, authHeader)
|
||||
if err != nil {
|
||||
if err := m.checkJWTFromRequest(r, authHeader); err != nil {
|
||||
log.WithContext(r.Context()).Errorf("Error when validating JWT: %s", err.Error())
|
||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
|
||||
return
|
||||
}
|
||||
|
||||
h.ServeHTTP(w, request)
|
||||
h.ServeHTTP(w, r)
|
||||
case "token":
|
||||
request, err := m.checkPATFromRequest(r, authHeader)
|
||||
if err != nil {
|
||||
if err := m.checkPATFromRequest(r, authHeader); err != nil {
|
||||
log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error())
|
||||
// Check if it's a status error, otherwise default to Unauthorized
|
||||
if _, ok := status.FromError(err); !ok {
|
||||
@@ -106,7 +104,7 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
h.ServeHTTP(w, request)
|
||||
h.ServeHTTP(w, r)
|
||||
default:
|
||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "no valid authentication provided"), w)
|
||||
return
|
||||
@@ -115,19 +113,19 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
||||
}
|
||||
|
||||
// CheckJWTFromRequest checks if the JWT is valid
|
||||
func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) {
|
||||
func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []string) error {
|
||||
token, err := getTokenFromJWTRequest(authHeaderParts)
|
||||
|
||||
// If an error occurs, call the error handler and return an error
|
||||
if err != nil {
|
||||
return r, fmt.Errorf("error extracting token: %w", err)
|
||||
return fmt.Errorf("error extracting token: %w", err)
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
userAuth, validatedToken, err := m.authManager.ValidateAndParseToken(ctx, token)
|
||||
if err != nil {
|
||||
return r, err
|
||||
return err
|
||||
}
|
||||
|
||||
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
|
||||
@@ -143,7 +141,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
|
||||
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
|
||||
accountId, _, err := m.ensureAccount(ctx, userAuth)
|
||||
if err != nil {
|
||||
return r, err
|
||||
return err
|
||||
}
|
||||
|
||||
if userAuth.AccountId != accountId {
|
||||
@@ -153,7 +151,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
|
||||
|
||||
userAuth, err = m.authManager.EnsureUserAccessByJWTGroups(ctx, userAuth, validatedToken)
|
||||
if err != nil {
|
||||
return r, err
|
||||
return err
|
||||
}
|
||||
|
||||
err = m.syncUserJWTGroups(ctx, userAuth)
|
||||
@@ -164,17 +162,19 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
|
||||
_, err = m.getUserFromUserAuth(ctx, userAuth)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("HTTP server failed to update user from user auth: %s", err)
|
||||
return r, err
|
||||
return err
|
||||
}
|
||||
|
||||
return nbcontext.SetUserAuthInRequest(r, userAuth), nil
|
||||
// propagates ctx change to upstream middleware
|
||||
*r = *nbcontext.SetUserAuthInRequest(r, userAuth)
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckPATFromRequest checks if the PAT is valid
|
||||
func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) {
|
||||
func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []string) error {
|
||||
token, err := getTokenFromPATRequest(authHeaderParts)
|
||||
if err != nil {
|
||||
return r, fmt.Errorf("error extracting token: %w", err)
|
||||
return fmt.Errorf("error extracting token: %w", err)
|
||||
}
|
||||
|
||||
if m.patUsageTracker != nil {
|
||||
@@ -183,22 +183,22 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
|
||||
|
||||
if m.rateLimiter != nil && !isTerraformRequest(r) {
|
||||
if !m.rateLimiter.Allow(token) {
|
||||
return r, status.Errorf(status.TooManyRequests, "too many requests")
|
||||
return status.Errorf(status.TooManyRequests, "too many requests")
|
||||
}
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
user, pat, accDomain, accCategory, err := m.authManager.GetPATInfo(ctx, token)
|
||||
if err != nil {
|
||||
return r, fmt.Errorf("invalid Token: %w", err)
|
||||
return fmt.Errorf("invalid Token: %w", err)
|
||||
}
|
||||
if time.Now().After(pat.GetExpirationDate()) {
|
||||
return r, fmt.Errorf("token expired")
|
||||
return fmt.Errorf("token expired")
|
||||
}
|
||||
|
||||
err = m.authManager.MarkPATUsed(ctx, pat.ID)
|
||||
if err != nil {
|
||||
return r, err
|
||||
return err
|
||||
}
|
||||
|
||||
userAuth := auth.UserAuth{
|
||||
@@ -216,7 +216,9 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
|
||||
}
|
||||
}
|
||||
|
||||
return nbcontext.SetUserAuthInRequest(r, userAuth), nil
|
||||
// propagates ctx change to upstream middleware
|
||||
*r = *nbcontext.SetUserAuthInRequest(r, userAuth)
|
||||
return nil
|
||||
}
|
||||
|
||||
func isTerraformRequest(r *http.Request) bool {
|
||||
|
||||
@@ -193,20 +193,12 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler {
|
||||
}
|
||||
})
|
||||
|
||||
h.ServeHTTP(w, r.WithContext(ctx))
|
||||
// Hold on to req so auth's in-place ctx update is visible after ServeHTTP.
|
||||
req := r.WithContext(ctx)
|
||||
h.ServeHTTP(w, req)
|
||||
close(handlerDone)
|
||||
|
||||
userAuth, err := nbContext.GetUserAuthFromContext(r.Context())
|
||||
if err == nil {
|
||||
if userAuth.AccountId != "" {
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, nbContext.AccountIDKey, userAuth.AccountId)
|
||||
}
|
||||
if userAuth.UserId != "" {
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, nbContext.UserIDKey, userAuth.UserId)
|
||||
}
|
||||
}
|
||||
ctx = req.Context()
|
||||
|
||||
if w.Status() > 399 {
|
||||
log.WithContext(ctx).Errorf("HTTP response %v: %v %v status %v", reqID, r.Method, r.URL, w.Status())
|
||||
|
||||
@@ -30,6 +30,8 @@ import (
|
||||
|
||||
const ConnectTimeout = 10 * time.Second
|
||||
|
||||
const healthCheckTimeout = 5 * time.Second
|
||||
|
||||
const (
|
||||
// EnvMaxRecvMsgSize overrides the default gRPC max receive message size (4 MB)
|
||||
// for the management client connection. Value is in bytes.
|
||||
@@ -532,7 +534,7 @@ func (c *GrpcClient) IsHealthy() bool {
|
||||
case connectivity.Ready:
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(c.ctx, 1*time.Second)
|
||||
ctx, cancel := context.WithTimeout(c.ctx, healthCheckTimeout)
|
||||
defer cancel()
|
||||
|
||||
_, err := c.realClient.GetServerKey(ctx, &proto.Empty{})
|
||||
|
||||
@@ -23,6 +23,8 @@ import (
|
||||
"github.com/netbirdio/netbird/util/wsproxy"
|
||||
)
|
||||
|
||||
const healthCheckTimeout = 5 * time.Second
|
||||
|
||||
// ConnStateNotifier is a wrapper interface of the status recorder
|
||||
type ConnStateNotifier interface {
|
||||
MarkSignalDisconnected(error)
|
||||
@@ -263,7 +265,7 @@ func (c *GrpcClient) IsHealthy() bool {
|
||||
case connectivity.Ready:
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(c.ctx, 1*time.Second)
|
||||
ctx, cancel := context.WithTimeout(c.ctx, healthCheckTimeout)
|
||||
defer cancel()
|
||||
_, err := c.realClient.Send(ctx, &proto.EncryptedMessage{
|
||||
Key: c.key.PublicKey().String(),
|
||||
|
||||
Reference in New Issue
Block a user