Compare commits

..

7 Commits

Author SHA1 Message Date
Zoltan Papp
f23aaa9ae7 [client] iOS: structured ResolvedIPs collection for domain routes (#6090)
* [client] iOS: structured ResolvedIPs collection for domain routes

Replace comma-joined ResolvedIPs string with a gomobile-friendly
ResolvedIPs collection (Add/Get/Size), mirroring the Android bridge
in client/android/network_domains.go.

This allows the iOS app to match domain-route resolved IPs against
connected peer routes without parsing CSV strings, fixing the route
status indicator for dynamic (DNS) routes.

* [client] iOS: align dynamic route exposure with Android bridge

For dynamic (DNS) routes the Swift side previously received
"invalid Prefix" as the Network value, forcing UI code to special-case
that sentinel. The Android bridge uses Domains.SafeString() instead so
peer.routes entries (which also derive from Domains.SafeString()) match
directly. Mirror that here.

Also fix the resolved IP lookup: resolvedDomains is keyed by the
resolved domain (e.g. api.ipify.org), not the configured pattern
(e.g. *.ipify.org). Group entries by ParentDomain like the daemon does
in client/server/network.go, so wildcard route patterns get their
resolved IPs populated.
2026-05-06 17:14:11 +02:00
Viktor Liu
f532976e05 [client] Add public key to debug bundle config.txt (#6092) 2026-05-06 13:42:47 +02:00
Viktor Liu
71a400f90f [client] Include MTU and SSH auth/JWT cache config in debug bundle (#6071) 2026-05-06 13:23:43 +02:00
Pascal Fischer
bfeb9b19ec [management] remove permissions from geolocations api (#6091) 2026-05-06 13:07:01 +02:00
Pascal Fischer
b19b7464ea [management] fix flaky invite token test (#6077) 2026-05-05 18:48:51 +02:00
Pascal Fischer
cfb1b3fe31 [proxy] consolidate mapping update (#6072) 2026-05-05 18:40:42 +02:00
Bethuel Mmbaga
3c28d29725 [management] Map Entra oid claim as Dex user ID (#6067) 2026-05-05 18:12:18 +03:00
20 changed files with 1020 additions and 396 deletions

View File

@@ -19,7 +19,7 @@ jobs:
- name: codespell
uses: codespell-project/actions-codespell@v2
with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA,ede,additionals
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA
skip: go.mod,go.sum,**/proxy/web/**
golangci:
strategy:

View File

@@ -21,6 +21,7 @@ import (
"time"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/protobuf/encoding/protojson"
"github.com/netbirdio/netbird/client/anonymize"
@@ -583,6 +584,9 @@ func isSensitiveEnvVar(key string) bool {
func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder) {
configContent.WriteString("NetBird Client Configuration:\n\n")
if key, err := wgtypes.ParseKey(g.internalConfig.PrivateKey); err == nil {
configContent.WriteString(fmt.Sprintf("PublicKey: %s\n", key.PublicKey().String()))
}
configContent.WriteString(fmt.Sprintf("WgIface: %s\n", g.internalConfig.WgIface))
configContent.WriteString(fmt.Sprintf("WgPort: %d\n", g.internalConfig.WgPort))
if g.internalConfig.NetworkMonitor != nil {
@@ -607,6 +611,12 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
if g.internalConfig.EnableSSHRemotePortForwarding != nil {
configContent.WriteString(fmt.Sprintf("EnableSSHRemotePortForwarding: %v\n", *g.internalConfig.EnableSSHRemotePortForwarding))
}
if g.internalConfig.DisableSSHAuth != nil {
configContent.WriteString(fmt.Sprintf("DisableSSHAuth: %v\n", *g.internalConfig.DisableSSHAuth))
}
if g.internalConfig.SSHJWTCacheTTL != nil {
configContent.WriteString(fmt.Sprintf("SSHJWTCacheTTL: %d\n", *g.internalConfig.SSHJWTCacheTTL))
}
configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes))
configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes))
@@ -633,6 +643,7 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
}
configContent.WriteString(fmt.Sprintf("LazyConnectionEnabled: %v\n", g.internalConfig.LazyConnectionEnabled))
configContent.WriteString(fmt.Sprintf("MTU: %d\n", g.internalConfig.MTU))
}
func (g *BundleGenerator) addProf() (err error) {

View File

@@ -5,16 +5,21 @@ import (
"bytes"
"encoding/json"
"net"
"net/url"
"os"
"path/filepath"
"reflect"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/configs"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/shared/management/domain"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
@@ -471,8 +476,8 @@ func TestSanitizeServiceEnvVars(t *testing.T) {
anonymize: false,
input: map[string]any{
jsonKeyServiceEnv: map[string]any{
"HOME": "/root",
"PATH": "/usr/bin",
"HOME": "/root",
"PATH": "/usr/bin",
"NB_LOG_LEVEL": "debug",
},
},
@@ -489,9 +494,9 @@ func TestSanitizeServiceEnvVars(t *testing.T) {
anonymize: false,
input: map[string]any{
jsonKeyServiceEnv: map[string]any{
"NB_SETUP_KEY": "abc123",
"NB_API_TOKEN": "tok_xyz",
"NB_LOG_LEVEL": "info",
"NB_SETUP_KEY": "abc123",
"NB_API_TOKEN": "tok_xyz",
"NB_LOG_LEVEL": "info",
},
},
check: func(t *testing.T, params map[string]any) {
@@ -766,3 +771,127 @@ Chain OUTPUT (policy ACCEPT 0 packets, 0 bytes)
assert.Contains(t, anonNftables, "chain input {")
assert.Contains(t, anonNftables, "type filter hook input priority filter; policy accept;")
}
// TestAddConfig_AllFieldsCovered uses reflection to ensure every field in
// profilemanager.Config is either rendered in the debug bundle or explicitly
// excluded. When a new field is added to Config, this test fails until the
// developer either dumps it in addConfig/addCommonConfigFields or adds it to
// the excluded set with a justification.
func TestAddConfig_AllFieldsCovered(t *testing.T) {
excluded := map[string]string{
"PrivateKey": "sensitive: WireGuard private key",
"PreSharedKey": "sensitive: WireGuard pre-shared key",
"SSHKey": "sensitive: SSH private key",
"ClientCertKeyPair": "non-config: parsed cert pair, not serialized",
}
mURL, _ := url.Parse("https://api.example.com:443")
aURL, _ := url.Parse("https://admin.example.com:443")
bTrue := true
iVal := 42
cfg := &profilemanager.Config{
PrivateKey: "priv",
PreSharedKey: "psk",
ManagementURL: mURL,
AdminURL: aURL,
WgIface: "wt0",
WgPort: 51820,
NetworkMonitor: &bTrue,
IFaceBlackList: []string{"eth0"},
DisableIPv6Discovery: true,
RosenpassEnabled: true,
RosenpassPermissive: true,
ServerSSHAllowed: &bTrue,
EnableSSHRoot: &bTrue,
EnableSSHSFTP: &bTrue,
EnableSSHLocalPortForwarding: &bTrue,
EnableSSHRemotePortForwarding: &bTrue,
DisableSSHAuth: &bTrue,
SSHJWTCacheTTL: &iVal,
DisableClientRoutes: true,
DisableServerRoutes: true,
DisableDNS: true,
DisableFirewall: true,
BlockLANAccess: true,
BlockInbound: true,
DisableNotifications: &bTrue,
DNSLabels: domain.List{},
SSHKey: "sshkey",
NATExternalIPs: []string{"1.2.3.4"},
CustomDNSAddress: "1.1.1.1:53",
DisableAutoConnect: true,
DNSRouteInterval: 5 * time.Second,
ClientCertPath: "/tmp/cert",
ClientCertKeyPath: "/tmp/key",
LazyConnectionEnabled: true,
MTU: 1280,
}
for _, anonymize := range []bool{false, true} {
t.Run("anonymize="+map[bool]string{true: "true", false: "false"}[anonymize], func(t *testing.T) {
g := &BundleGenerator{
anonymizer: newAnonymizerForTest(),
internalConfig: cfg,
anonymize: anonymize,
}
var sb strings.Builder
g.addCommonConfigFields(&sb)
rendered := sb.String() + renderAddConfigSpecific(g)
val := reflect.ValueOf(cfg).Elem()
typ := val.Type()
var missing []string
for i := 0; i < typ.NumField(); i++ {
name := typ.Field(i).Name
if _, ok := excluded[name]; ok {
continue
}
if !strings.Contains(rendered, name+":") {
missing = append(missing, name)
}
}
if len(missing) > 0 {
t.Fatalf("Config field(s) not present in debug bundle output: %v\n"+
"Either render the field in addCommonConfigFields/addConfig, "+
"or add it to the excluded map with a justification.", missing)
}
})
}
}
// renderAddConfigSpecific renders the fields handled by the anonymize/non-anonymize
// branches in addConfig (ManagementURL, AdminURL, NATExternalIPs, CustomDNSAddress).
// addCommonConfigFields covers the rest. Keeping this in the test mirrors the
// production shape without needing to write an actual zip.
func renderAddConfigSpecific(g *BundleGenerator) string {
var sb strings.Builder
if g.anonymize {
if g.internalConfig.ManagementURL != nil {
sb.WriteString("ManagementURL: " + g.anonymizer.AnonymizeURI(g.internalConfig.ManagementURL.String()) + "\n")
}
if g.internalConfig.AdminURL != nil {
sb.WriteString("AdminURL: " + g.anonymizer.AnonymizeURI(g.internalConfig.AdminURL.String()) + "\n")
}
sb.WriteString("NATExternalIPs: x\n")
if g.internalConfig.CustomDNSAddress != "" {
sb.WriteString("CustomDNSAddress: " + g.anonymizer.AnonymizeString(g.internalConfig.CustomDNSAddress) + "\n")
}
} else {
if g.internalConfig.ManagementURL != nil {
sb.WriteString("ManagementURL: " + g.internalConfig.ManagementURL.String() + "\n")
}
if g.internalConfig.AdminURL != nil {
sb.WriteString("AdminURL: " + g.internalConfig.AdminURL.String() + "\n")
}
sb.WriteString("NATExternalIPs: x\n")
if g.internalConfig.CustomDNSAddress != "" {
sb.WriteString("CustomDNSAddress: " + g.internalConfig.CustomDNSAddress + "\n")
}
}
return sb.String()
}
func newAnonymizerForTest() *anonymize.Anonymizer {
return anonymize.NewAnonymizer(anonymize.DefaultAddresses())
}

View File

@@ -29,27 +29,6 @@ import (
var currentMTU uint16 = iface.DefaultMTU
// nonRetryableEDECodes lists EDE info codes (RFC 8914) for which a SERVFAIL
// from one upstream means another upstream would return the same answer:
// DNSSEC validation outcomes and policy-based blocks. Transient errors
// (network, cached, not ready) are not included.
var nonRetryableEDECodes = map[uint16]struct{}{
dns.ExtendedErrorCodeUnsupportedDNSKEYAlgorithm: {},
dns.ExtendedErrorCodeUnsupportedDSDigestType: {},
dns.ExtendedErrorCodeDNSSECIndeterminate: {},
dns.ExtendedErrorCodeDNSBogus: {},
dns.ExtendedErrorCodeSignatureExpired: {},
dns.ExtendedErrorCodeSignatureNotYetValid: {},
dns.ExtendedErrorCodeDNSKEYMissing: {},
dns.ExtendedErrorCodeRRSIGsMissing: {},
dns.ExtendedErrorCodeNoZoneKeyBitSet: {},
dns.ExtendedErrorCodeNSECMissing: {},
dns.ExtendedErrorCodeBlocked: {},
dns.ExtendedErrorCodeCensored: {},
dns.ExtendedErrorCodeFiltered: {},
dns.ExtendedErrorCodeProhibited: {},
}
func SetCurrentMTU(mtu uint16) {
currentMTU = mtu
}
@@ -264,18 +243,6 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.Re
var t time.Duration
var err error
// Advertise EDNS0 so the upstream may include Extended DNS Errors
// (RFC 8914) in failure responses; we use those to short-circuit
// failover for definitive answers like DNSSEC validation failures.
// Operate on a copy so the inbound request is unchanged: a client that
// did not advertise EDNS0 must not see an OPT in the response.
hadEdns := r.IsEdns0() != nil
reqUp := r
if !hadEdns {
reqUp = r.Copy()
reqUp.SetEdns0(upstreamUDPSize(), false)
}
var startTime time.Time
var upstreamProto *upstreamProtocolResult
func() {
@@ -283,7 +250,7 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.Re
defer cancel()
ctx, upstreamProto = contextWithupstreamProtocolResult(ctx)
startTime = time.Now()
rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), reqUp)
rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r)
}()
if err != nil {
@@ -295,49 +262,13 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.Re
}
if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused {
if code, ok := nonRetryableEDE(rm); ok {
resutil.SetMeta(w, "ede", edeName(code))
if !hadEdns {
stripOPT(rm)
}
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger)
return nil
}
return &upstreamFailure{upstream: upstream, reason: dns.RcodeToString[rm.Rcode]}
}
if !hadEdns {
stripOPT(rm)
}
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger)
return nil
}
// upstreamUDPSize returns the EDNS0 UDP buffer size we advertise to upstreams,
// derived from the tunnel MTU and bounded against underflow.
func upstreamUDPSize() uint16 {
if currentMTU > ipUDPHeaderSize {
return currentMTU - ipUDPHeaderSize
}
return dns.MinMsgSize
}
// stripOPT removes any OPT pseudo-RRs from the response's Extra section so
// the response complies with RFC 6891 when the client did not advertise EDNS0.
func stripOPT(rm *dns.Msg) {
if len(rm.Extra) == 0 {
return
}
out := rm.Extra[:0]
for _, rr := range rm.Extra {
if _, ok := rr.(*dns.OPT); ok {
continue
}
out = append(out, rr)
}
rm.Extra = out
}
func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, startTime time.Time) *upstreamFailure {
if !errors.Is(err, context.DeadlineExceeded) && !isTimeout(err) {
return &upstreamFailure{upstream: upstream, reason: err.Error()}
@@ -399,34 +330,6 @@ func formatFailures(failures []upstreamFailure) string {
return strings.Join(parts, ", ")
}
// nonRetryableEDE returns the first non-retryable EDE code carried in the
// response, if any.
func nonRetryableEDE(rm *dns.Msg) (uint16, bool) {
opt := rm.IsEdns0()
if opt == nil {
return 0, false
}
for _, o := range opt.Option {
ede, ok := o.(*dns.EDNS0_EDE)
if !ok {
continue
}
if _, ok := nonRetryableEDECodes[ede.InfoCode]; ok {
return ede.InfoCode, true
}
}
return 0, false
}
// edeName returns a human-readable name for an EDE code, falling back to
// the numeric code when unknown.
func edeName(code uint16) string {
if name, ok := dns.ExtendedErrorCodeToString[code]; ok {
return name
}
return fmt.Sprintf("EDE %d", code)
}
// ProbeAvailability tests all upstream servers simultaneously and
// disables the resolver if none work
func (u *upstreamResolverBase) ProbeAvailability(ctx context.Context) {

View File

@@ -770,132 +770,3 @@ func TestExchangeWithFallback_TCPTruncatesToClientSize(t *testing.T) {
assert.Less(t, len(rm2.Answer), 20, "small EDNS0 client should get fewer records")
assert.True(t, rm2.Truncated, "response should be truncated for small buffer client")
}
func msgWithEDE(rcode int, codes ...uint16) *dns.Msg {
m := new(dns.Msg)
m.Response = true
m.Rcode = rcode
if len(codes) == 0 {
return m
}
opt := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}}
opt.SetUDPSize(dns.MinMsgSize)
for _, c := range codes {
opt.Option = append(opt.Option, &dns.EDNS0_EDE{InfoCode: c})
}
m.Extra = append(m.Extra, opt)
return m
}
func TestNonRetryableEDE(t *testing.T) {
tests := []struct {
name string
msg *dns.Msg
wantOK bool
wantCode uint16
}{
{name: "no edns0", msg: msgWithEDE(dns.RcodeServerFailure)},
{
name: "opt without ede",
msg: func() *dns.Msg {
m := msgWithEDE(dns.RcodeServerFailure)
opt := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}}
opt.Option = append(opt.Option, &dns.EDNS0_NSID{Code: dns.EDNS0NSID})
m.Extra = []dns.RR{opt}
return m
}(),
},
{name: "ede dnsbogus", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeDNSBogus), wantOK: true, wantCode: dns.ExtendedErrorCodeDNSBogus},
{name: "ede signature expired", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeSignatureExpired), wantOK: true, wantCode: dns.ExtendedErrorCodeSignatureExpired},
{name: "ede blocked", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeBlocked), wantOK: true, wantCode: dns.ExtendedErrorCodeBlocked},
{name: "ede prohibited", msg: msgWithEDE(dns.RcodeRefused, dns.ExtendedErrorCodeProhibited), wantOK: true, wantCode: dns.ExtendedErrorCodeProhibited},
{name: "ede cached error retryable", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeCachedError)},
{name: "ede network error retryable", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeNetworkError)},
{name: "ede not ready retryable", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeNotReady)},
{
name: "first non-retryable wins",
msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeNetworkError, dns.ExtendedErrorCodeDNSBogus),
wantOK: true,
wantCode: dns.ExtendedErrorCodeDNSBogus,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
code, ok := nonRetryableEDE(tc.msg)
assert.Equal(t, tc.wantOK, ok, "ok should match")
if tc.wantOK {
assert.Equal(t, tc.wantCode, code, "code should match")
}
})
}
}
func TestEDEName(t *testing.T) {
assert.Equal(t, "DNSSEC Bogus", edeName(dns.ExtendedErrorCodeDNSBogus))
assert.Equal(t, "Signature Expired", edeName(dns.ExtendedErrorCodeSignatureExpired))
assert.Equal(t, "EDE 9999", edeName(9999), "unknown code falls back to numeric")
}
func TestStripOPT(t *testing.T) {
rm := &dns.Msg{
Extra: []dns.RR{
&dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}},
&dns.A{Hdr: dns.RR_Header{Name: "x.", Rrtype: dns.TypeA}, A: net.IPv4(1, 2, 3, 4)},
},
}
stripOPT(rm)
assert.Len(t, rm.Extra, 1, "OPT should be removed, A kept")
_, isOPT := rm.Extra[0].(*dns.OPT)
assert.False(t, isOPT, "remaining record must not be OPT")
}
func TestUpstreamResolver_NonRetryableEDEShortCircuits(t *testing.T) {
upstream1 := netip.MustParseAddrPort("192.0.2.1:53")
upstream2 := netip.MustParseAddrPort("192.0.2.2:53")
servfailWithEDE := msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeDNSBogus)
successResp := buildMockResponse(dns.RcodeSuccess, "192.0.2.100")
var queried []string
tracking := &trackingMockClient{
inner: &mockUpstreamResolverPerServer{
responses: map[string]mockUpstreamResponse{
upstream1.String(): {msg: servfailWithEDE},
upstream2.String(): {msg: successResp},
},
rtt: time.Millisecond,
},
queriedUpstreams: &queried,
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
resolver := &upstreamResolverBase{
ctx: ctx,
upstreamClient: tracking,
upstreamServers: []netip.AddrPort{upstream1, upstream2},
upstreamTimeout: UpstreamTimeout,
}
var written *dns.Msg
w := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
written = m
return nil
},
}
// Client query without EDNS0 must not see an OPT in the response.
q := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
resolver.ServeDNS(w, q)
require.NotNil(t, written, "response must be written")
assert.Equal(t, dns.RcodeServerFailure, written.Rcode, "SERVFAIL must propagate")
assert.Len(t, queried, 1, "only first upstream should be queried")
assert.Equal(t, upstream1.String(), queried[0])
for _, rr := range written.Extra {
_, isOPT := rr.(*dns.OPT)
assert.False(t, isOPT, "synthetic OPT must not leak to a non-EDNS0 client")
}
}

View File

@@ -413,25 +413,40 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) {
func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[domain.Domain]peer.ResolvedDomainInfo) *RoutesSelectionDetails {
var routeSelection []RoutesSelectionInfo
for _, r := range routes {
domainList := make([]DomainInfo, 0)
// resolvedDomains is keyed by the resolved domain (e.g. api.ipify.org),
// not the configured pattern (e.g. *.ipify.org). Group entries whose
// ParentDomain belongs to this route, mirroring the daemon logic in
// client/server/network.go.
domainList := make([]DomainInfo, 0, len(r.Domains))
domainIndex := make(map[domain.Domain]int, len(r.Domains))
for _, d := range r.Domains {
domainResp := DomainInfo{
Domain: d.SafeString(),
}
if info, exists := resolvedDomains[d]; exists {
var ipStrings []string
for _, prefix := range info.Prefixes {
ipStrings = append(ipStrings, prefix.Addr().String())
}
domainResp.ResolvedIPs = strings.Join(ipStrings, ", ")
}
domainList = append(domainList, domainResp)
domainIndex[d] = len(domainList)
domainList = append(domainList, DomainInfo{Domain: d.SafeString()})
}
for _, info := range resolvedDomains {
idx, ok := domainIndex[info.ParentDomain]
if !ok {
continue
}
for _, prefix := range info.Prefixes {
domainList[idx].AddResolvedIP(prefix.Addr().String())
}
}
domainDetails := DomainDetails{items: domainList}
// For dynamic (DNS) routes, expose the joined domain pattern as the
// Network value so it matches the peer.routes entries on the Swift
// side (mirroring the Android bridge in client/android/client.go).
netStr := r.Network.String()
if len(r.Domains) > 0 {
netStr = r.Domains.SafeString()
}
routeSelection = append(routeSelection, RoutesSelectionInfo{
ID: r.NetID,
Network: r.Network.String(),
Network: netStr,
Domains: &domainDetails,
Selected: r.Selected,
})

View File

@@ -34,7 +34,34 @@ type DomainDetails struct {
type DomainInfo struct {
Domain string
ResolvedIPs string
resolvedIPs ResolvedIPs
}
func (d *DomainInfo) AddResolvedIP(ipAddress string) {
d.resolvedIPs.Add(ipAddress)
}
func (d *DomainInfo) GetResolvedIPs() *ResolvedIPs {
return &d.resolvedIPs
}
type ResolvedIPs struct {
items []string
}
func (r *ResolvedIPs) Add(ipAddress string) {
r.items = append(r.items, ipAddress)
}
func (r *ResolvedIPs) Get(i int) string {
if i < 0 || i >= len(r.items) {
return ""
}
return r.items[i]
}
func (r *ResolvedIPs) Size() int {
return len(r.items)
}
// Add new PeerInfo to the collection

2
go.mod
View File

@@ -72,7 +72,7 @@ require (
github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81
github.com/mdlayher/socket v0.5.1
github.com/mdp/qrterminal/v3 v3.2.1
github.com/miekg/dns v1.1.72
github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45

4
go.sum
View File

@@ -421,8 +421,8 @@ github.com/mdp/qrterminal/v3 v3.2.1 h1:6+yQjiiOsSuXT5n9/m60E54vdgFsw0zhADHhHLrFe
github.com/mdp/qrterminal/v3 v3.2.1/go.mod h1:jOTmXvnBsMy5xqLniO0R++Jmjs2sTm9dFSuQ5kpz/SU=
github.com/mholt/acmez/v2 v2.0.1 h1:3/3N0u1pLjMK4sNEAFSI+bcvzbPhRpY383sy1kLHJ6k=
github.com/mholt/acmez/v2 v2.0.1/go.mod h1:fX4c9r5jYwMyMsC+7tkYRxHibkOTgta5DIFGoe67e1U=
github.com/miekg/dns v1.1.72 h1:vhmr+TF2A3tuoGNkLDFK9zi36F2LS+hKTRW0Uf8kbzI=
github.com/miekg/dns v1.1.72/go.mod h1:+EuEPhdHOsfk6Wk5TT2CzssZdqkmFhf8r+aVyDEToIs=
github.com/miekg/dns v1.1.59 h1:C9EXc/UToRwKLhK5wKU/I4QVsBUc8kE6MkHBkeypWZs=
github.com/miekg/dns v1.1.59/go.mod h1:nZpewl5p6IvctfgrckopVx2OlSEHPRO/U4SYkRklrEk=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc=
github.com/mitchellh/copystructure v1.2.0 h1:vpKXTN4ewci03Vljg/q9QvCGUDttBOGBIa15WveJJGw=

View File

@@ -89,21 +89,33 @@ func (p *Provider) ListConnectors(ctx context.Context) ([]*ConnectorConfig, erro
}
// UpdateConnector updates an existing connector in Dex storage.
// It merges incoming updates with existing values to prevent data loss on partial updates.
// It overlays user-mutable config fields (issuer, clientID, clientSecret,
// redirectURI) onto the stored connector config, and updates the connector name
// when cfg.Name is set. Empty fields on cfg leave stored values unchanged, so
// partial updates preserve create-time defaults such as scopes, claimMapping,
// and userIDKey.
func (p *Provider) UpdateConnector(ctx context.Context, cfg *ConnectorConfig) error {
if err := p.storage.UpdateConnector(ctx, cfg.ID, func(old storage.Connector) (storage.Connector, error) {
oldCfg, err := p.parseStorageConnector(old)
if err != nil {
return storage.Connector{}, fmt.Errorf("failed to parse existing connector: %w", err)
if cfg.Type != "" && cfg.Type != inferIdentityProviderType(old.Type, cfg.ID, nil) {
return storage.Connector{}, errors.New("connector type change not allowed")
}
mergeConnectorConfig(cfg, oldCfg)
storageConn, err := p.buildStorageConnector(cfg)
configData, err := overlayConnectorConfig(old.Config, cfg)
if err != nil {
return storage.Connector{}, fmt.Errorf("failed to build connector: %w", err)
return storage.Connector{}, fmt.Errorf("failed to overlay connector config: %w", err)
}
return storageConn, nil
name := cfg.Name
if name == "" {
name = old.Name
}
return storage.Connector{
ID: cfg.ID,
Type: old.Type,
Name: name,
Config: configData,
}, nil
}); err != nil {
return fmt.Errorf("failed to update connector: %w", err)
}
@@ -112,23 +124,27 @@ func (p *Provider) UpdateConnector(ctx context.Context, cfg *ConnectorConfig) er
return nil
}
// mergeConnectorConfig preserves existing values for empty fields in the update.
func mergeConnectorConfig(cfg, oldCfg *ConnectorConfig) {
if cfg.ClientSecret == "" {
cfg.ClientSecret = oldCfg.ClientSecret
// overlayConnectorConfig writes only the user-mutable fields onto the existing
// stored config, preserving every other field (scopes, claimMapping, userIDKey,
// insecure flags, etc.). Empty fields on cfg leave the existing value alone.
func overlayConnectorConfig(oldConfig []byte, cfg *ConnectorConfig) ([]byte, error) {
var m map[string]any
if err := decodeConnectorConfig(oldConfig, &m); err != nil {
return nil, err
}
if cfg.RedirectURI == "" {
cfg.RedirectURI = oldCfg.RedirectURI
if cfg.Issuer != "" {
m["issuer"] = cfg.Issuer
}
if cfg.Issuer == "" && cfg.Type == oldCfg.Type {
cfg.Issuer = oldCfg.Issuer
if cfg.ClientID != "" {
m["clientID"] = cfg.ClientID
}
if cfg.ClientID == "" {
cfg.ClientID = oldCfg.ClientID
if cfg.ClientSecret != "" {
m["clientSecret"] = cfg.ClientSecret
}
if cfg.Name == "" {
cfg.Name = oldCfg.Name
if cfg.RedirectURI != "" {
m["redirectURI"] = cfg.RedirectURI
}
return encodeConnectorConfig(m)
}
// DeleteConnector removes a connector from Dex storage.
@@ -216,6 +232,10 @@ func buildOIDCConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte,
oidcConfig["getUserInfo"] = true
case "entra":
oidcConfig["claimMapping"] = map[string]string{"email": "preferred_username"}
// Use the Entra Object ID (oid) instead of the default OIDC sub claim.
// Entra issues sub as a per-app pairwise identifier that does not match
// the stable Object ID.
oidcConfig["userIDKey"] = "oid"
case "okta":
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
case "pocketid":

205
idp/dex/connector_test.go Normal file
View File

@@ -0,0 +1,205 @@
package dex
import (
"context"
"encoding/json"
"log/slog"
"os"
"path/filepath"
"testing"
"github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/sql"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func newTestProvider(t *testing.T) (*Provider, func()) {
t.Helper()
tmpDir, err := os.MkdirTemp("", "dex-connector-test-*")
require.NoError(t, err)
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
s, err := (&sql.SQLite3{File: filepath.Join(tmpDir, "dex.db")}).Open(logger)
require.NoError(t, err)
return &Provider{storage: s, logger: logger}, func() {
_ = s.Close()
_ = os.RemoveAll(tmpDir)
}
}
func TestBuildOIDCConnectorConfig_EntraSetsUserIDKey(t *testing.T) {
cfg := &ConnectorConfig{
ID: "entra-test",
Name: "Entra",
Type: "entra",
Issuer: "https://login.microsoftonline.com/tid/v2.0",
ClientID: "client-id",
ClientSecret: "client-secret",
}
data, err := buildOIDCConnectorConfig(cfg, "https://example.com/oauth2/callback")
require.NoError(t, err)
var m map[string]any
require.NoError(t, json.Unmarshal(data, &m))
assert.Equal(t, "oid", m["userIDKey"], "entra connectors must default userIDKey to oid")
assert.Equal(t, map[string]any{"email": "preferred_username"}, m["claimMapping"])
}
func TestBuildOIDCConnectorConfig_NonEntraDoesNotSetUserIDKey(t *testing.T) {
// ensures the Entra userIDKey override does not leak into other OIDC providers,
// which already use a stable sub claim.
for _, typ := range []string{"oidc", "zitadel", "okta", "pocketid", "authentik", "keycloak", "adfs"} {
t.Run(typ, func(t *testing.T) {
data, err := buildOIDCConnectorConfig(&ConnectorConfig{Type: typ}, "https://example.com/oauth2/callback")
require.NoError(t, err)
var m map[string]any
require.NoError(t, json.Unmarshal(data, &m))
_, ok := m["userIDKey"]
assert.False(t, ok, "%s connectors must not have userIDKey set", typ)
})
}
}
func TestUpdateConnector_PreservesCreateTimeDefaults(t *testing.T) {
ctx := context.Background()
p, cleanup := newTestProvider(t)
defer cleanup()
created, err := p.CreateConnector(ctx, &ConnectorConfig{
ID: "entra-test",
Name: "Entra",
Type: "entra",
Issuer: "https://login.microsoftonline.com/tid/v2.0",
ClientID: "client-id",
ClientSecret: "old-secret",
RedirectURI: "https://example.com/oauth2/callback",
})
require.NoError(t, err)
require.Equal(t, "entra-test", created.ID)
// Rotate only the client secret.
err = p.UpdateConnector(ctx, &ConnectorConfig{
ID: "entra-test",
Type: "entra",
ClientSecret: "new-secret",
})
require.NoError(t, err)
conn, err := p.storage.GetConnector(ctx, "entra-test")
require.NoError(t, err)
var m map[string]any
require.NoError(t, json.Unmarshal(conn.Config, &m))
assert.Equal(t, "new-secret", m["clientSecret"], "clientSecret should be rotated")
assert.Equal(t, "client-id", m["clientID"], "clientID must survive (overlay should leave it alone)")
assert.Equal(t, "https://login.microsoftonline.com/tid/v2.0", m["issuer"])
assert.Equal(t, "oid", m["userIDKey"], "userIDKey must survive update")
assert.Equal(t, map[string]any{"email": "preferred_username"}, m["claimMapping"], "claimMapping must survive update")
}
func TestUpdateConnector_DoesNotAddUserIDKeyToExistingConnector(t *testing.T) {
ctx := context.Background()
p, cleanup := newTestProvider(t)
defer cleanup()
// Seed a connector directly into storage without userIDKey
preFixConfig, err := json.Marshal(map[string]any{
"issuer": "https://login.microsoftonline.com/tid/v2.0",
"clientID": "client-id",
"clientSecret": "old-secret",
"redirectURI": "https://example.com/oauth2/callback",
"scopes": []string{"openid", "profile", "email"},
"claimMapping": map[string]string{"email": "preferred_username"},
})
require.NoError(t, err)
require.NoError(t, p.storage.CreateConnector(ctx, storage.Connector{
ID: "entra-prefix",
Type: "oidc",
Name: "Entra",
Config: preFixConfig,
}))
// Rotate client secret via UpdateConnector.
err = p.UpdateConnector(ctx, &ConnectorConfig{
ID: "entra-prefix",
Type: "entra",
ClientSecret: "new-secret",
})
require.NoError(t, err)
conn, err := p.storage.GetConnector(ctx, "entra-prefix")
require.NoError(t, err)
var m map[string]any
require.NoError(t, json.Unmarshal(conn.Config, &m))
assert.Equal(t, "new-secret", m["clientSecret"])
_, has := m["userIDKey"]
assert.False(t, has, "userIDKey must not be auto-added to a connector that did not have it before")
}
func TestUpdateConnector_RejectsTypeChange(t *testing.T) {
ctx := context.Background()
p, cleanup := newTestProvider(t)
defer cleanup()
_, err := p.CreateConnector(ctx, &ConnectorConfig{
ID: "entra-test",
Name: "Entra",
Type: "entra",
Issuer: "https://login.microsoftonline.com/tid/v2.0",
ClientID: "client-id",
ClientSecret: "secret",
RedirectURI: "https://example.com/oauth2/callback",
})
require.NoError(t, err)
// Attempt to switch the connector to okta.
err = p.UpdateConnector(ctx, &ConnectorConfig{
ID: "entra-test",
Type: "okta",
})
require.Error(t, err)
assert.Contains(t, err.Error(), "connector type change not allowed")
// stored connector type/config unchanged after the rejected update.
conn, err := p.storage.GetConnector(ctx, "entra-test")
require.NoError(t, err)
assert.Equal(t, "oidc", conn.Type)
var m map[string]any
require.NoError(t, json.Unmarshal(conn.Config, &m))
assert.Equal(t, "oid", m["userIDKey"])
}
func TestUpdateConnector_AllowsSameTypeUpdate(t *testing.T) {
ctx := context.Background()
p, cleanup := newTestProvider(t)
defer cleanup()
_, err := p.CreateConnector(ctx, &ConnectorConfig{
ID: "entra-test",
Name: "Entra",
Type: "entra",
Issuer: "https://login.microsoftonline.com/old/v2.0",
ClientID: "client-id",
ClientSecret: "secret",
RedirectURI: "https://example.com/oauth2/callback",
})
require.NoError(t, err)
err = p.UpdateConnector(ctx, &ConnectorConfig{
ID: "entra-test",
Type: "entra",
Issuer: "https://login.microsoftonline.com/new/v2.0",
})
require.NoError(t, err)
conn, err := p.storage.GetConnector(ctx, "entra-test")
require.NoError(t, err)
var m map[string]any
require.NoError(t, json.Unmarshal(conn.Config, &m))
assert.Equal(t, "https://login.microsoftonline.com/new/v2.0", m["issuer"])
}

View File

@@ -11,6 +11,8 @@ import (
"fmt"
"net/http"
"net/url"
"os"
"strconv"
"strings"
"sync"
"time"
@@ -82,11 +84,40 @@ type ProxyServiceServer struct {
// Store for PKCE verifiers
pkceVerifierStore *PKCEVerifierStore
// tokenTTL is the lifetime of one-time tokens generated for proxy
// authentication. Defaults to defaultProxyTokenTTL when zero.
tokenTTL time.Duration
// snapshotBatchSize is the number of mappings per gRPC message during
// initial snapshot delivery. Configurable via NB_PROXY_SNAPSHOT_BATCH_SIZE.
snapshotBatchSize int
cancel context.CancelFunc
}
const pkceVerifierTTL = 10 * time.Minute
const defaultProxyTokenTTL = 5 * time.Minute
const defaultSnapshotBatchSize = 500
func snapshotBatchSizeFromEnv() int {
if v := os.Getenv("NB_PROXY_SNAPSHOT_BATCH_SIZE"); v != "" {
if n, err := strconv.Atoi(v); err == nil && n > 0 {
return n
}
}
return defaultSnapshotBatchSize
}
// proxyTokenTTL returns the configured token TTL or the default when unset.
func (s *ProxyServiceServer) proxyTokenTTL() time.Duration {
if s.tokenTTL > 0 {
return s.tokenTTL
}
return defaultProxyTokenTTL
}
// proxyConnection represents a connected proxy
type proxyConnection struct {
proxyID string
@@ -110,6 +141,7 @@ func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeT
peersManager: peersManager,
usersManager: usersManager,
proxyManager: proxyMgr,
snapshotBatchSize: snapshotBatchSizeFromEnv(),
cancel: cancel,
}
go s.cleanupStaleProxies(ctx)
@@ -192,11 +224,6 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
cancel: cancel,
}
s.connectedProxies.Store(proxyID, conn)
if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil {
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err)
}
// Register proxy in database with capabilities
var caps *proxy.Capabilities
if c := req.GetCapabilities(); c != nil {
@@ -209,13 +236,31 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
proxyRecord, err := s.proxyManager.Connect(ctx, proxyID, sessionID, proxyAddress, peerInfo, caps)
if err != nil {
log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err)
s.connectedProxies.CompareAndDelete(proxyID, conn)
if unregErr := s.proxyController.UnregisterProxyFromCluster(ctx, conn.address, proxyID); unregErr != nil {
log.WithContext(ctx).Debugf("cleanup after Connect failure for proxy %s: %v", proxyID, unregErr)
}
cancel()
return status.Errorf(codes.Internal, "register proxy in database: %v", err)
}
s.connectedProxies.Store(proxyID, conn)
if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil {
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err)
}
if err := s.sendSnapshot(ctx, conn); err != nil {
if s.connectedProxies.CompareAndDelete(proxyID, conn) {
if unregErr := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); unregErr != nil {
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", proxyID, unregErr)
}
}
cancel()
if disconnErr := s.proxyManager.Disconnect(context.Background(), proxyID, sessionID); disconnErr != nil {
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", proxyID, disconnErr)
}
return fmt.Errorf("send snapshot to proxy %s: %w", proxyID, err)
}
errChan := make(chan error, 2)
go s.sender(conn, errChan)
log.WithFields(log.Fields{
"proxy_id": proxyID,
"session_id": sessionID,
@@ -241,13 +286,6 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
log.Infof("Proxy %s session %s disconnected", proxyID, sessionID)
}()
if err := s.sendSnapshot(ctx, conn); err != nil {
return fmt.Errorf("send snapshot to proxy %s: %w", proxyID, err)
}
errChan := make(chan error, 2)
go s.sender(conn, errChan)
go s.heartbeat(connCtx, proxyRecord)
select {
@@ -290,22 +328,27 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
return err
}
// Send mappings in batches to reduce per-message gRPC overhead while
// staying well within the default 4 MB message size limit.
for i := 0; i < len(mappings); i += s.snapshotBatchSize {
end := i + s.snapshotBatchSize
if end > len(mappings) {
end = len(mappings)
}
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
Mapping: mappings[i:end],
InitialSyncComplete: end == len(mappings),
}); err != nil {
return fmt.Errorf("send snapshot batch: %w", err)
}
}
if len(mappings) == 0 {
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
InitialSyncComplete: true,
}); err != nil {
return fmt.Errorf("send snapshot completion: %w", err)
}
return nil
}
for i, m := range mappings {
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
Mapping: []*proto.ProxyMapping{m},
InitialSyncComplete: i == len(mappings)-1,
}); err != nil {
return fmt.Errorf("send proxy mapping: %w", err)
}
}
return nil
@@ -323,13 +366,9 @@ func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *
continue
}
token, err := s.tokenStore.GenerateToken(service.AccountID, service.ID, 5*time.Minute)
token, err := s.tokenStore.GenerateToken(service.AccountID, service.ID, s.proxyTokenTTL())
if err != nil {
log.WithFields(log.Fields{
"service": service.Name,
"account": service.AccountID,
}).WithError(err).Error("failed to generate auth token for snapshot")
continue
return nil, fmt.Errorf("generate auth token for service %s: %w", service.ID, err)
}
m := service.ToProtoMapping(rpservice.Create, token, s.GetOIDCValidationConfig())
@@ -409,13 +448,16 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateRes
conn := value.(*proxyConnection)
resp := s.perProxyMessage(update, conn.proxyID)
if resp == nil {
log.Warnf("Token generation failed for proxy %s, disconnecting to force resync", conn.proxyID)
conn.cancel()
return true
}
select {
case conn.sendChan <- resp:
log.Debugf("Sent service update to proxy server %s", conn.proxyID)
default:
log.Warnf("Failed to send service update to proxy server %s (channel full)", conn.proxyID)
log.Warnf("Send channel full for proxy %s, disconnecting to force resync", conn.proxyID)
conn.cancel()
}
return true
})
@@ -495,13 +537,16 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd
}
msg := s.perProxyMessage(updateResponse, proxyID)
if msg == nil {
log.WithContext(ctx).Warnf("Token generation failed for proxy %s in cluster %s, disconnecting to force resync", proxyID, clusterAddr)
conn.cancel()
continue
}
select {
case conn.sendChan <- msg:
log.WithContext(ctx).Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr)
default:
log.WithContext(ctx).Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr)
log.WithContext(ctx).Warnf("Send channel full for proxy %s in cluster %s, disconnecting to force resync", proxyID, clusterAddr)
conn.cancel()
}
}
}
@@ -527,7 +572,8 @@ func proxyAcceptsMapping(conn *proxyConnection, mapping *proto.ProxyMapping) boo
// perProxyMessage returns a copy of update with a fresh one-time token for
// create/update operations. For delete operations the original mapping is
// used unchanged because proxies do not need to authenticate for removal.
// Returns nil if token generation fails (the proxy should be skipped).
// Returns nil if token generation fails; the caller must disconnect the
// proxy so it can resync via a fresh snapshot on reconnect.
func (s *ProxyServiceServer) perProxyMessage(update *proto.GetMappingUpdateResponse, proxyID string) *proto.GetMappingUpdateResponse {
resp := make([]*proto.ProxyMapping, 0, len(update.Mapping))
for _, mapping := range update.Mapping {
@@ -536,7 +582,7 @@ func (s *ProxyServiceServer) perProxyMessage(update *proto.GetMappingUpdateRespo
continue
}
token, err := s.tokenStore.GenerateToken(mapping.AccountId, mapping.Id, 5*time.Minute)
token, err := s.tokenStore.GenerateToken(mapping.AccountId, mapping.Id, s.proxyTokenTTL())
if err != nil {
log.Warnf("Failed to generate token for proxy %s: %v", proxyID, err)
return nil

View File

@@ -0,0 +1,174 @@
package grpc
import (
"context"
"fmt"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/shared/management/proto"
)
// recordingStream captures all messages sent via Send so tests can inspect
// batching behaviour without a real gRPC transport.
type recordingStream struct {
grpc.ServerStream
messages []*proto.GetMappingUpdateResponse
}
func (s *recordingStream) Send(m *proto.GetMappingUpdateResponse) error {
s.messages = append(s.messages, m)
return nil
}
func (s *recordingStream) Context() context.Context { return context.Background() }
func (s *recordingStream) SetHeader(metadata.MD) error { return nil }
func (s *recordingStream) SendHeader(metadata.MD) error { return nil }
func (s *recordingStream) SetTrailer(metadata.MD) {}
func (s *recordingStream) SendMsg(any) error { return nil }
func (s *recordingStream) RecvMsg(any) error { return nil }
// makeServices creates n enabled services assigned to the given cluster.
func makeServices(n int, cluster string) []*rpservice.Service {
services := make([]*rpservice.Service, n)
for i := range n {
services[i] = &rpservice.Service{
ID: fmt.Sprintf("svc-%d", i),
AccountID: "acct-1",
Name: fmt.Sprintf("svc-%d", i),
Domain: fmt.Sprintf("svc-%d.example.com", i),
ProxyCluster: cluster,
Enabled: true,
Targets: []*rpservice.Target{
{TargetType: rpservice.TargetTypeHost, TargetId: "host-1"},
},
}
}
return services
}
func newSnapshotTestServer(t *testing.T, batchSize int) *ProxyServiceServer {
t.Helper()
s := &ProxyServiceServer{
tokenStore: NewOneTimeTokenStore(context.Background(), testCacheStore(t)),
snapshotBatchSize: batchSize,
}
s.SetProxyController(newTestProxyController())
return s
}
func TestSendSnapshot_BatchesMappings(t *testing.T) {
const cluster = "cluster.example.com"
const batchSize = 3
const totalServices = 7 // 3 + 3 + 1
ctrl := gomock.NewController(t)
mgr := rpservice.NewMockManager(ctrl)
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
s := newSnapshotTestServer(t, batchSize)
s.serviceManager = mgr
stream := &recordingStream{}
conn := &proxyConnection{
proxyID: "proxy-a",
address: cluster,
stream: stream,
}
err := s.sendSnapshot(context.Background(), conn)
require.NoError(t, err)
// Expect ceil(7/3) = 3 messages
require.Len(t, stream.messages, 3, "should send ceil(totalServices/batchSize) messages")
assert.Len(t, stream.messages[0].Mapping, 3)
assert.False(t, stream.messages[0].InitialSyncComplete, "first batch should not be sync-complete")
assert.Len(t, stream.messages[1].Mapping, 3)
assert.False(t, stream.messages[1].InitialSyncComplete, "middle batch should not be sync-complete")
assert.Len(t, stream.messages[2].Mapping, 1)
assert.True(t, stream.messages[2].InitialSyncComplete, "last batch must be sync-complete")
// Verify all service IDs are present exactly once
seen := make(map[string]bool)
for _, msg := range stream.messages {
for _, m := range msg.Mapping {
assert.False(t, seen[m.Id], "duplicate service ID %s", m.Id)
seen[m.Id] = true
}
}
assert.Len(t, seen, totalServices)
}
func TestSendSnapshot_ExactBatchMultiple(t *testing.T) {
const cluster = "cluster.example.com"
const batchSize = 3
const totalServices = 6 // exactly 2 batches
ctrl := gomock.NewController(t)
mgr := rpservice.NewMockManager(ctrl)
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
s := newSnapshotTestServer(t, batchSize)
s.serviceManager = mgr
stream := &recordingStream{}
conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream}
require.NoError(t, s.sendSnapshot(context.Background(), conn))
require.Len(t, stream.messages, 2)
assert.Len(t, stream.messages[0].Mapping, 3)
assert.False(t, stream.messages[0].InitialSyncComplete)
assert.Len(t, stream.messages[1].Mapping, 3)
assert.True(t, stream.messages[1].InitialSyncComplete)
}
func TestSendSnapshot_SingleBatch(t *testing.T) {
const cluster = "cluster.example.com"
const batchSize = 100
const totalServices = 5
ctrl := gomock.NewController(t)
mgr := rpservice.NewMockManager(ctrl)
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
s := newSnapshotTestServer(t, batchSize)
s.serviceManager = mgr
stream := &recordingStream{}
conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream}
require.NoError(t, s.sendSnapshot(context.Background(), conn))
require.Len(t, stream.messages, 1, "all mappings should fit in one batch")
assert.Len(t, stream.messages[0].Mapping, totalServices)
assert.True(t, stream.messages[0].InitialSyncComplete)
}
func TestSendSnapshot_EmptySnapshot(t *testing.T) {
const cluster = "cluster.example.com"
ctrl := gomock.NewController(t)
mgr := rpservice.NewMockManager(ctrl)
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(nil, nil)
s := newSnapshotTestServer(t, 500)
s.serviceManager = mgr
stream := &recordingStream{}
conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream}
require.NoError(t, s.sendSnapshot(context.Background(), conn))
require.Len(t, stream.messages, 1, "empty snapshot must still send sync-complete")
assert.Empty(t, stream.messages[0].Mapping)
assert.True(t, stream.messages[0].InitialSyncComplete)
}

View File

@@ -85,11 +85,14 @@ func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan
// registerFakeProxyWithCaps adds a fake proxy connection with explicit capabilities.
func registerFakeProxyWithCaps(s *ProxyServiceServer, proxyID, clusterAddr string, caps *proto.ProxyCapabilities) chan *proto.GetMappingUpdateResponse {
ch := make(chan *proto.GetMappingUpdateResponse, 10)
ctx, cancel := context.WithCancel(context.Background())
conn := &proxyConnection{
proxyID: proxyID,
address: clusterAddr,
capabilities: caps,
sendChan: ch,
ctx: ctx,
cancel: cancel,
}
s.connectedProxies.Store(proxyID, conn)

View File

@@ -7,11 +7,8 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
@@ -45,11 +42,6 @@ func newGeolocationsHandlerHandler(accountManager account.Manager, geolocationMa
// getAllCountries retrieves a list of all countries
func (l *geolocationsHandler) getAllCountries(w http.ResponseWriter, r *http.Request) {
if err := l.authenticateUser(r); err != nil {
util.WriteError(r.Context(), err, w)
return
}
if l.geolocationManager == nil {
// TODO: update error message to include geo db self hosted doc link when ready
util.WriteError(r.Context(), status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w)
@@ -71,11 +63,6 @@ func (l *geolocationsHandler) getAllCountries(w http.ResponseWriter, r *http.Req
// getCitiesByCountry retrieves a list of cities based on the given country code
func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http.Request) {
if err := l.authenticateUser(r); err != nil {
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
countryCode := vars["country"]
if !countryCodeRegex.MatchString(countryCode) {
@@ -102,27 +89,6 @@ func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http.
util.WriteJSONObject(r.Context(), w, cities)
}
func (l *geolocationsHandler) authenticateUser(r *http.Request) error {
ctx := r.Context()
userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
if err != nil {
return err
}
accountID, userID := userAuth.AccountId, userAuth.UserId
allowed, err := l.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operations.Read)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
return nil
}
func toCountryResponse(country geolocation.Country) api.Country {
return api.Country{
CountryName: country.CountryName,

View File

@@ -144,8 +144,11 @@ func TestValidateInviteToken_ModifiedToken(t *testing.T) {
_, plainToken, err := GenerateInviteToken()
require.NoError(t, err)
// Modify one character in the secret part
modifiedToken := plainToken[:5] + "X" + plainToken[6:]
replacement := "X"
if plainToken[5] == 'X' {
replacement = "Y"
}
modifiedToken := plainToken[:5] + replacement + plainToken[6:]
err = ValidateInviteToken(modifiedToken)
require.Error(t, err)
}

View File

@@ -364,14 +364,16 @@ func TestIntegration_ProxyConnection_HappyPath(t *testing.T) {
})
require.NoError(t, err)
// Receive all mappings from the snapshot - server sends each mapping individually
mappingsByID := make(map[string]*proto.ProxyMapping)
for i := 0; i < 2; i++ {
for {
msg, err := stream.Recv()
require.NoError(t, err)
for _, m := range msg.GetMapping() {
mappingsByID[m.GetId()] = m
}
if msg.GetInitialSyncComplete() {
break
}
}
// Should receive 2 mappings total
@@ -411,12 +413,14 @@ func TestIntegration_ProxyConnection_SendsClusterAddress(t *testing.T) {
})
require.NoError(t, err)
// Receive all mappings - server sends each mapping individually
mappings := make([]*proto.ProxyMapping, 0)
for i := 0; i < 2; i++ {
for {
msg, err := stream.Recv()
require.NoError(t, err)
mappings = append(mappings, msg.GetMapping()...)
if msg.GetInitialSyncComplete() {
break
}
}
// Should receive the 2 mappings matching the cluster
@@ -440,13 +444,15 @@ func TestIntegration_ProxyConnection_Reconnect_ReceivesSameConfig(t *testing.T)
clusterAddress := "test.proxy.io"
proxyID := "test-proxy-reconnect"
// Helper to receive all mappings from a stream
receiveMappings := func(stream proto.ProxyService_GetMappingUpdateClient, count int) []*proto.ProxyMapping {
receiveMappings := func(stream proto.ProxyService_GetMappingUpdateClient) []*proto.ProxyMapping {
var mappings []*proto.ProxyMapping
for i := 0; i < count; i++ {
for {
msg, err := stream.Recv()
require.NoError(t, err)
mappings = append(mappings, msg.GetMapping()...)
if msg.GetInitialSyncComplete() {
break
}
}
return mappings
}
@@ -460,7 +466,7 @@ func TestIntegration_ProxyConnection_Reconnect_ReceivesSameConfig(t *testing.T)
})
require.NoError(t, err)
firstMappings := receiveMappings(stream1, 2)
firstMappings := receiveMappings(stream1)
cancel1()
time.Sleep(100 * time.Millisecond)
@@ -476,7 +482,7 @@ func TestIntegration_ProxyConnection_Reconnect_ReceivesSameConfig(t *testing.T)
})
require.NoError(t, err)
secondMappings := receiveMappings(stream2, 2)
secondMappings := receiveMappings(stream2)
// Should receive the same mappings
assert.Equal(t, len(firstMappings), len(secondMappings),
@@ -542,12 +548,14 @@ func TestIntegration_ProxyConnection_ReconnectDoesNotDuplicateState(t *testing.T
}
}
// Helper to receive and apply all mappings
receiveAndApply := func(stream proto.ProxyService_GetMappingUpdateClient) {
for i := 0; i < 2; i++ {
for {
msg, err := stream.Recv()
require.NoError(t, err)
applyMappings(msg.GetMapping())
if msg.GetInitialSyncComplete() {
break
}
}
}
@@ -636,12 +644,14 @@ func TestIntegration_ProxyConnection_MultipleProxiesReceiveUpdates(t *testing.T)
})
require.NoError(t, err)
// Receive all mappings - server sends each mapping individually
count := 0
for i := 0; i < 2; i++ {
for {
msg, err := stream.Recv()
require.NoError(t, err)
count += len(msg.GetMapping())
if msg.GetInitialSyncComplete() {
break
}
}
mu.Lock()
@@ -681,9 +691,12 @@ func TestIntegration_ProxyConnection_FastReconnectDoesNotLoseState(t *testing.T)
})
require.NoError(t, err)
for i := 0; i < 2; i++ {
_, err := stream1.Recv()
for {
msg, err := stream1.Recv()
require.NoError(t, err)
if msg.GetInitialSyncComplete() {
break
}
}
require.Contains(t, setup.proxyService.GetConnectedProxies(), proxyID,
@@ -699,9 +712,12 @@ func TestIntegration_ProxyConnection_FastReconnectDoesNotLoseState(t *testing.T)
})
require.NoError(t, err)
for i := 0; i < 2; i++ {
_, err := stream2.Recv()
for {
msg, err := stream2.Recv()
require.NoError(t, err)
if msg.GetInitialSyncComplete() {
break
}
}
cancel1()

View File

@@ -943,6 +943,8 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr
operation := func() error {
s.Logger.Debug("connecting to management mapping stream")
initialSyncDone = false
if s.healthChecker != nil {
s.healthChecker.SetManagementConnected(false)
}
@@ -1000,6 +1002,11 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr
return ctx.Err()
}
var snapshotIDs map[types.ServiceID]struct{}
if !*initialSyncDone {
snapshotIDs = make(map[types.ServiceID]struct{})
}
for {
// Check for context completion to gracefully shutdown.
select {
@@ -1020,17 +1027,45 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr
s.processMappings(ctx, msg.GetMapping())
s.Logger.Debug("Processing mapping update completed")
if !*initialSyncDone && msg.GetInitialSyncComplete() {
if s.healthChecker != nil {
s.healthChecker.SetInitialSyncComplete()
if !*initialSyncDone {
for _, m := range msg.GetMapping() {
snapshotIDs[types.ServiceID(m.GetId())] = struct{}{}
}
if msg.GetInitialSyncComplete() {
s.reconcileSnapshot(ctx, snapshotIDs)
snapshotIDs = nil
if s.healthChecker != nil {
s.healthChecker.SetInitialSyncComplete()
}
*initialSyncDone = true
s.Logger.Info("Initial mapping sync complete")
}
*initialSyncDone = true
s.Logger.Info("Initial mapping sync complete")
}
}
}
}
// reconcileSnapshot removes local mappings that are absent from the snapshot.
// This ensures services deleted while the proxy was disconnected get cleaned up.
func (s *Server) reconcileSnapshot(ctx context.Context, snapshotIDs map[types.ServiceID]struct{}) {
s.portMu.RLock()
var stale []*proto.ProxyMapping
for svcID, mapping := range s.lastMappings {
if _, ok := snapshotIDs[svcID]; !ok {
stale = append(stale, mapping)
}
}
s.portMu.RUnlock()
for _, mapping := range stale {
s.Logger.WithFields(log.Fields{
"service_id": mapping.GetId(),
"domain": mapping.GetDomain(),
}).Info("Removing stale mapping absent from snapshot")
s.removeMapping(ctx, mapping)
}
}
func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMapping) {
for _, mapping := range mappings {
s.Logger.WithFields(log.Fields{

View File

@@ -0,0 +1,227 @@
package proxy
import (
"context"
"io"
"testing"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/proxy/internal/health"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
// collectStaleIDs mirrors the stale-detection logic in reconcileSnapshot
// so we can verify it without triggering removeMapping (which requires full
// server wiring). This keeps the test focused on the detection algorithm.
func collectStaleIDs(lastMappings map[types.ServiceID]*proto.ProxyMapping, snapshotIDs map[types.ServiceID]struct{}) []types.ServiceID {
var stale []types.ServiceID
for svcID := range lastMappings {
if _, ok := snapshotIDs[svcID]; !ok {
stale = append(stale, svcID)
}
}
return stale
}
// TestStaleDetection_PartialOverlap verifies that only services absent from
// the snapshot are flagged as stale.
func TestStaleDetection_PartialOverlap(t *testing.T) {
local := map[types.ServiceID]*proto.ProxyMapping{
"svc-1": {Id: "svc-1"},
"svc-2": {Id: "svc-2"},
"svc-stale-a": {Id: "svc-stale-a"},
"svc-stale-b": {Id: "svc-stale-b"},
}
snapshot := map[types.ServiceID]struct{}{
"svc-1": {},
"svc-2": {},
"svc-3": {}, // new service, not in local
}
stale := collectStaleIDs(local, snapshot)
assert.Len(t, stale, 2)
staleSet := make(map[types.ServiceID]struct{})
for _, id := range stale {
staleSet[id] = struct{}{}
}
assert.Contains(t, staleSet, types.ServiceID("svc-stale-a"))
assert.Contains(t, staleSet, types.ServiceID("svc-stale-b"))
}
// TestStaleDetection_AllStale verifies an empty snapshot flags everything.
func TestStaleDetection_AllStale(t *testing.T) {
local := map[types.ServiceID]*proto.ProxyMapping{
"svc-1": {Id: "svc-1"},
"svc-2": {Id: "svc-2"},
}
stale := collectStaleIDs(local, map[types.ServiceID]struct{}{})
assert.Len(t, stale, 2)
}
// TestStaleDetection_NoneStale verifies full overlap produces no stale entries.
func TestStaleDetection_NoneStale(t *testing.T) {
local := map[types.ServiceID]*proto.ProxyMapping{
"svc-1": {Id: "svc-1"},
"svc-2": {Id: "svc-2"},
}
snapshot := map[types.ServiceID]struct{}{
"svc-1": {},
"svc-2": {},
}
stale := collectStaleIDs(local, snapshot)
assert.Empty(t, stale)
}
// TestStaleDetection_EmptyLocal verifies no stale entries when local is empty.
func TestStaleDetection_EmptyLocal(t *testing.T) {
stale := collectStaleIDs(
map[types.ServiceID]*proto.ProxyMapping{},
map[types.ServiceID]struct{}{"svc-1": {}},
)
assert.Empty(t, stale)
}
// TestReconcileSnapshot_NoStale verifies reconciliation is a no-op when all
// local mappings are present in the snapshot (removeMapping is never called).
func TestReconcileSnapshot_NoStale(t *testing.T) {
s := &Server{
Logger: log.StandardLogger(),
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
}
s.lastMappings["svc-1"] = &proto.ProxyMapping{Id: "svc-1"}
s.lastMappings["svc-2"] = &proto.ProxyMapping{Id: "svc-2"}
snapshotIDs := map[types.ServiceID]struct{}{
"svc-1": {},
"svc-2": {},
}
// This should not panic — no stale entries means removeMapping is never called.
s.reconcileSnapshot(context.Background(), snapshotIDs)
assert.Len(t, s.lastMappings, 2, "no mappings should be removed when all are in snapshot")
}
// TestReconcileSnapshot_EmptyLocal verifies reconciliation is a no-op with
// no local mappings.
func TestReconcileSnapshot_EmptyLocal(t *testing.T) {
s := &Server{
Logger: log.StandardLogger(),
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
}
s.reconcileSnapshot(context.Background(), map[types.ServiceID]struct{}{"svc-1": {}})
assert.Empty(t, s.lastMappings)
}
// --- handleMappingStream tests for batched snapshot ID accumulation ---
// TestHandleMappingStream_BatchedSnapshotSyncComplete verifies that sync is
// marked done only after the final InitialSyncComplete message, even when
// the snapshot arrives in multiple batches.
func TestHandleMappingStream_BatchedSnapshotSyncComplete(t *testing.T) {
checker := health.NewChecker(nil, nil)
s := &Server{
Logger: log.StandardLogger(),
healthChecker: checker,
routerReady: closedChan(),
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
}
stream := &mockMappingStream{
messages: []*proto.GetMappingUpdateResponse{
{}, // batch 1: no sync-complete
{}, // batch 2: no sync-complete
{InitialSyncComplete: true}, // batch 3: sync done
},
}
syncDone := false
err := s.handleMappingStream(context.Background(), stream, &syncDone)
assert.NoError(t, err)
assert.True(t, syncDone, "sync should be marked done after final batch")
}
// TestHandleMappingStream_PostSyncDoesNotReconcile verifies that messages
// arriving after InitialSyncComplete do not trigger a second reconciliation.
func TestHandleMappingStream_PostSyncDoesNotReconcile(t *testing.T) {
s := &Server{
Logger: log.StandardLogger(),
routerReady: closedChan(),
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
}
// Simulate state left over from a previous sync.
s.lastMappings["svc-1"] = &proto.ProxyMapping{Id: "svc-1", AccountId: "acct-1"}
s.lastMappings["svc-2"] = &proto.ProxyMapping{Id: "svc-2", AccountId: "acct-1"}
stream := &mockMappingStream{
messages: []*proto.GetMappingUpdateResponse{
{}, // post-sync empty message — must not reconcile
},
}
syncDone := true // sync already completed in a previous stream
err := s.handleMappingStream(context.Background(), stream, &syncDone)
require.NoError(t, err)
assert.Len(t, s.lastMappings, 2,
"post-sync messages must not trigger reconciliation — all entries should survive")
}
// TestHandleMappingStream_ImmediateEOF_NoReconciliation verifies that if the
// stream closes before sync completes, no reconciliation occurs.
func TestHandleMappingStream_ImmediateEOF_NoReconciliation(t *testing.T) {
s := &Server{
Logger: log.StandardLogger(),
routerReady: closedChan(),
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
}
s.lastMappings["svc-stale"] = &proto.ProxyMapping{Id: "svc-stale", AccountId: "acct-1"}
stream := &mockMappingStream{} // no messages → immediate EOF
syncDone := false
err := s.handleMappingStream(context.Background(), stream, &syncDone)
assert.NoError(t, err)
assert.False(t, syncDone, "sync should not be marked done on immediate EOF")
_, hasStale := s.lastMappings["svc-stale"]
assert.True(t, hasStale, "stale mapping should remain when sync never completed")
}
// mockErrRecvStream returns an error on the second Recv to verify
// handleMappingStream returns without completing sync.
type mockErrRecvStream struct {
mockMappingStream
calls int
}
func (m *mockErrRecvStream) Recv() (*proto.GetMappingUpdateResponse, error) {
m.calls++
if m.calls == 1 {
return &proto.GetMappingUpdateResponse{}, nil
}
return nil, io.ErrUnexpectedEOF
}
func TestHandleMappingStream_ErrorMidSync_NoReconciliation(t *testing.T) {
s := &Server{
Logger: log.StandardLogger(),
routerReady: closedChan(),
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
}
s.lastMappings["svc-stale"] = &proto.ProxyMapping{Id: "svc-stale", AccountId: "acct-1"}
syncDone := false
err := s.handleMappingStream(context.Background(), &mockErrRecvStream{}, &syncDone)
assert.Error(t, err)
assert.False(t, syncDone)
_, hasStale := s.lastMappings["svc-stale"]
assert.True(t, hasStale, "stale mapping should remain when sync was interrupted by error")
}

View File

@@ -10,7 +10,6 @@ import (
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/miekg/dns"
)
// TextWriter writes human-readable one-line-per-packet summaries.
@@ -593,45 +592,19 @@ func formatDNSResponse(d *layers.DNS, rd string, plen int) string {
anCount := d.ANCount
nsCount := d.NSCount
arCount := d.ARCount
ede := formatEDE(d)
if d.ResponseCode != layers.DNSResponseCodeNoErr {
return fmt.Sprintf("%04x %d/%d/%d %s%s (%d)", d.ID, anCount, nsCount, arCount, d.ResponseCode, ede, plen)
return fmt.Sprintf("%04x %d/%d/%d %s (%d)", d.ID, anCount, nsCount, arCount, d.ResponseCode, plen)
}
if anCount > 0 && len(d.Answers) > 0 {
rr := d.Answers[0]
if rdata := shortRData(&rr); rdata != "" {
return fmt.Sprintf("%04x %d/%d/%d %s %s%s (%d)", d.ID, anCount, nsCount, arCount, rr.Type, rdata, ede, plen)
return fmt.Sprintf("%04x %d/%d/%d %s %s (%d)", d.ID, anCount, nsCount, arCount, rr.Type, rdata, plen)
}
}
return fmt.Sprintf("%04x %d/%d/%d%s (%d)", d.ID, anCount, nsCount, arCount, ede, plen)
}
// dnsOPTCodeEDE is the EDNS0 option code for Extended DNS Errors (RFC 8914).
const dnsOPTCodeEDE layers.DNSOptionCode = layers.DNSOptionCode(dns.EDNS0EDE)
// formatEDE returns " EDE=Name" for the first Extended DNS Error option
// found in the response, or empty string if none is present.
func formatEDE(d *layers.DNS) string {
for _, rr := range d.Additionals {
if rr.Type != layers.DNSTypeOPT {
continue
}
for _, opt := range rr.OPT {
if opt.Code != dnsOPTCodeEDE || len(opt.Data) < 2 {
continue
}
info := binary.BigEndian.Uint16(opt.Data[:2])
name, ok := dns.ExtendedErrorCodeToString[info]
if !ok {
name = fmt.Sprintf("%d", info)
}
return " EDE=" + name
}
}
return ""
return fmt.Sprintf("%04x %d/%d/%d (%d)", d.ID, anCount, nsCount, arCount, plen)
}
func shortRData(rr *layers.DNSResourceRecord) string {