Compare commits

..

39 Commits

Author SHA1 Message Date
crn4
13e41e432c idp dex fix 2026-05-21 15:21:28 +02:00
crn4
efa6a3f502 missed file 2026-05-20 12:41:05 +02:00
crn4
5fbcdeceac more comments 2026-05-19 21:41:08 +02:00
crn4
3a1bbeba90 review comments 2026-05-19 20:27:50 +02:00
crn4
728057ef15 missed files for client side and shared files 2026-05-19 14:46:23 +02:00
crn4
582cd70086 client side and components on shared folder 2026-05-19 14:46:09 +02:00
crn4
9bbbafaf69 int id for networks and posture checks migration 2026-05-19 14:45:40 +02:00
crn4
672b057aa0 fix Group.Copy losing AccountSeqID 2026-05-19 14:43:59 +02:00
crn4
b9a0186200 fix routes filtering in account componnents 2026-05-19 14:43:49 +02:00
crn4
9083bdb977 capabilities conditioning 2026-05-19 14:43:38 +02:00
crn4
b194af48b8 wire size benches fix 2026-05-19 14:43:28 +02:00
crn4
4543780ef0 grpc components encoding with optimisations 2026-05-19 14:43:17 +02:00
crn4
2de0283971 init int inds migration 2026-05-19 14:42:55 +02:00
Maycon Santos
af24fd7796 [management] Add metrics for peer status updates and ephemeral cleanup (#6196)
* [management] Add metrics for peer status updates and ephemeral cleanup

The session-fenced MarkPeerConnected / MarkPeerDisconnected path and
the ephemeral peer cleanup loop both run silently today: when fencing
rejects a stale stream, when a cleanup tick deletes peers, or when a
batch delete fails, we have no operational signal beyond log lines.

Add OpenTelemetry counters and a histogram so the same SLO-style
dashboards that already exist for the network-map controller can cover
peer connect/disconnect and ephemeral cleanup too.

All new attributes are bounded enums: operation in {connect,disconnect}
and outcome in {applied,stale,error,peer_not_found}. No account, peer,
or user ID is ever written as a metric label — total cardinality is
fixed at compile time (8 counter series, 2 histogram series, 4 unlabeled
ephemeral series).

Metric methods are nil-receiver safe so test composition that doesn't
wire telemetry (the bulk of the existing tests) works unchanged. The
ephemeral manager exposes a SetMetrics setter rather than taking the
collector through its constructor, keeping the constructor signature
stable across all test call sites.

* [management] Add OpenTelemetry metrics for ephemeral peer cleanup

Introduce counters for tracking ephemeral peer cleanup, including peers pending deletion, cleanup runs, successful deletions, and failed batches. Metrics are nil-receiver safe to ensure compatibility with test setups without telemetry.
2026-05-18 22:55:19 +02:00
Maycon Santos
13d32d274f [management] Fence peer status updates with a session token (#6193)
* [management] Fence peer status updates with a session token

The connect/disconnect path used a best-effort LastSeen-after-streamStart
comparison to decide whether a status update should land. Under contention
— a re-sync arriving while the previous stream's disconnect was still in
flight, or two management replicas seeing the same peer at once — the
check was a read-then-decide-then-write window: any UPDATE in between
caused the wrong row to be written. The Go-side time.Now() that fed the
comparison also drifted under lock contention, since it was captured
seconds before the write actually committed.

Replace it with an integer-nanosecond fencing token stored alongside the
status. Every gRPC sync stream uses its open time (UnixNano) as its token.
Connects only land when the incoming token is strictly greater than the
stored one; disconnects only land when the incoming token equals the
stored one (i.e. we're the stream that owns the current session). Both
are single optimistic-locked UPDATEs — no read-then-write, no transaction
wrapper.

LastSeen is now written by the database itself (CURRENT_TIMESTAMP). The
caller never supplies it, so the value always reflects the real moment
of the UPDATE rather than the moment the caller queued the work — which
was already off by minutes under heavy lock contention.

Side effects (geo lookup, peer-login-expiration scheduling, network-map
fan-out) are explicitly documented as running after the fence UPDATE
commits, never inside it. Geo also skips the update when realIP equals
the stored ConnectionIP, dropping a redundant SavePeerLocation call on
same-IP reconnects.

Tests cover the three semantic cases (matched disconnect lands, stale
disconnect dropped, stale connect dropped) plus a 16-goroutine race test
that asserts the highest token always wins.

* [management] Add SessionStartedAt to peer status updates

Stored `SessionStartedAt` for fencing token propagation across goroutines and updated database queries/functions to handle the new field. Removed outdated geolocation handling logic and adjusted tests for concurrency safety.

* Rename `peer_status_required_approval` to `peer_status_requires_approval` in SQL store fields
2026-05-18 20:25:12 +02:00
Nicolas Frati
705f87fc20 [management] fix: device redirect uri wasn't registered (#6191)
* fix: device redirect uri wasn't registered

* fix lint
2026-05-18 12:57:59 +02:00
Viktor Liu
3f91f49277 Clean up legacy 32-bit and HKCU registry entries on Windows install (#6176) 2026-05-16 16:52:57 +02:00
Maycon Santos
347c5bf317 Avoid context cancellation in cancelPeerRoutines (#6175)
When closing go routines and handling peer disconnect, we should avoid canceling the flow due to parent gRPC context cancellation.

This change triggers disconnection handling with a context that is not bound to the parent gRPC cancellation.
2026-05-16 16:29:01 +02:00
Viktor Liu
22e2519d71 [management] Avoid peer IP reallocation when account settings update preserves the network range (#6173) 2026-05-16 15:51:48 +02:00
Vlad
e916f12cca [proxy] auth token generation on mapping (#6157)
* [management / proxy] auth token generation on mapping

* fix tests
2026-05-15 19:13:44 +02:00
Viktor Liu
9ed2e2a5b4 [client] Drop DNS probes for passive health projection (#5971) 2026-05-15 17:07:38 +02:00
Viktor Liu
2ccae7ec47 [client] Mirror v4 exit selection onto v6 pair and honour SkipAutoApply per route (#6150) 2026-05-15 16:58:47 +02:00
Viktor Liu
07e5450117 [management] Bracket IPv6 reverse-proxy target hosts when building URL Host field (#6141) 2026-05-14 16:42:40 +02:00
Viktor Liu
3f914090cb [client] Bracket IPv6 in embed listeners, expand debug bundle (#6134) 2026-05-14 16:22:53 +02:00
Viktor Liu
ea9fab4396 [management] Allocate and preserve IPv6 overlay addresses for embedded proxy peers (#6132) 2026-05-14 16:05:33 +02:00
Vlad
77b479286e [management] fix offline statuses for public proxy clusters (#6133) 2026-05-14 13:27:50 +02:00
Maycon Santos
ab2a8794e7 [client] Add short flags for status command options (#6137)
* [client] Add short flags for status command options

* uppercase filters
2026-05-14 12:30:42 +02:00
Viktor Liu
9126a192ca [client] Set 0644 perms on SSH client config after os.CreateTemp (#6126) 2026-05-12 15:05:53 +02:00
Viktor Liu
1224d6e1ee [client] Persist management URL and pre-shared key overrides on login (#6065) 2026-05-12 14:52:56 +02:00
Nicolas Frati
96672dd1f8 [management] chores: update dex version (#6124)
* chores: update dex version

* chore: update dex fork
2026-05-12 13:50:35 +02:00
Viktor Liu
946ce4c3da [client] Fix --config flag default to point at profile path (#6122) 2026-05-11 17:48:21 +02:00
Vlad
07cbfdbede [proxy] feature: bring your own proxy (#5627) 2026-05-11 14:31:38 +02:00
Viktor Liu
a4114a5e45 [client] Skip DNS upstream failover on definitive EDE (#6089) 2026-05-11 10:00:23 +02:00
Viktor Liu
6b08e89c7b [relay] Preserve non-standard port in WS dialer URL prep (#6061) 2026-05-11 09:59:33 +02:00
Viktor Liu
a852b3bd34 [client, proxy] Harden uspfilter conntrack and share TCP relay (#5936) 2026-05-11 09:59:13 +02:00
Viktor Liu
afb83b3049 [client] Use unique temp file and clean up on failure when writing ssh config (#6064) 2026-05-11 09:58:49 +02:00
Nicolas Frati
e89aad09f5 [management] Enable MFA for local users (#5804)
* wip: totp for local users

* fix providers not getting populated

* polished UI and fix post_login_redirect_uri

* fix: make sure logout is only prompted from oidc flow

Signed-off-by: jnfrati <nicofrati@gmail.com>

* update templates

Signed-off-by: jnfrati <nicofrati@gmail.com>

* deps: update dex dependency

Signed-off-by: jnfrati <nicofrati@gmail.com>

* fix qube issues

Signed-off-by: jnfrati <nicofrati@gmail.com>

* replace window with globalThis on home html

Signed-off-by: jnfrati <nicofrati@gmail.com>

* fixed coderabbit comments

Signed-off-by: jnfrati <nicofrati@gmail.com>

* debug

* remove unused config and rename totp issuer

* deps: update dex reference to latest

* add dashboard post logout redirect uri to embedded config

* implemented api for mfa configuration

* update docs and config parsing

* catch error on idp manager init mfa

* fix tests

* Add remember me  for MFA

* Add cookie encryption and session share between tabs

* fixed logout showing non actionable error and session cookie encription key

* fixed missing mfa settings on sql query for account

* fix code index for mfa activity

---------

Signed-off-by: jnfrati <nicofrati@gmail.com>
Co-authored-by: braginini <bangvalo@gmail.com>
2026-05-08 16:31:20 +02:00
Maycon Santos
7da94a4956 [misc] Update CONTRIBUTING.md (#6076) 2026-05-07 16:16:48 +02:00
Pascal Fischer
39eac377e4 [management] add update reason to buffered calls (#6103) 2026-05-07 15:55:59 +02:00
187 changed files with 18231 additions and 2818 deletions

View File

@@ -19,7 +19,7 @@ jobs:
- name: codespell
uses: codespell-project/actions-codespell@v2
with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA,ede,additionals
skip: go.mod,go.sum,**/proxy/web/**
golangci:
strategy:

View File

@@ -8,7 +8,7 @@ There are many ways that you can contribute:
- Sharing use cases in slack or Reddit
- Bug fix or feature enhancement
If you haven't already, join our slack workspace [here](https://join.slack.com/t/netbirdio/shared_invite/zt-vrahf41g-ik1v7fV8du6t0RwxSrJ96A), we would love to discuss topics that need community contribution and enhancements to existing features.
If you haven't already, join our slack workspace [here](https://docs.netbird.io/slack-url), we would love to discuss topics that need community contribution and enhancements to existing features.
## Contents

View File

@@ -143,7 +143,7 @@ func init() {
rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets WireGuard PreSharedKey property. If set, then only peers that have the same key can communicate.")
rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device")
rootCmd.PersistentFlags().BoolVarP(&anonymizeFlag, "anonymize", "A", false, "anonymize IP addresses and non-netbird.io domains in logs and status output")
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "Overrides the default profile file location")
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", profilemanager.DefaultConfigPath, "Overrides the default profile file location")
rootCmd.AddCommand(upCmd)
rootCmd.AddCommand(downCmd)

View File

@@ -43,16 +43,16 @@ func init() {
ipsFilterMap = make(map[string]struct{})
prefixNamesFilterMap = make(map[string]struct{})
statusCmd.PersistentFlags().BoolVarP(&detailFlag, "detail", "d", false, "display detailed status information in human-readable format")
statusCmd.PersistentFlags().BoolVar(&jsonFlag, "json", false, "display detailed status information in json format")
statusCmd.PersistentFlags().BoolVar(&yamlFlag, "yaml", false, "display detailed status information in yaml format")
statusCmd.PersistentFlags().BoolVar(&ipv4Flag, "ipv4", false, "display only NetBird IPv4 of this peer, e.g., --ipv4 will output 100.64.0.33")
statusCmd.PersistentFlags().BoolVar(&ipv6Flag, "ipv6", false, "display only NetBird IPv6 of this peer")
statusCmd.PersistentFlags().BoolVarP(&jsonFlag, "json", "j", false, "display detailed status information in json format")
statusCmd.PersistentFlags().BoolVarP(&yamlFlag, "yaml", "y", false, "display detailed status information in yaml format")
statusCmd.PersistentFlags().BoolVarP(&ipv4Flag, "ipv4", "4", false, "display only NetBird IPv4 of this peer, e.g., --ipv4 will output 100.64.0.33")
statusCmd.PersistentFlags().BoolVarP(&ipv6Flag, "ipv6", "6", false, "display only NetBird IPv6 of this peer")
statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4", "ipv6")
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs (v4 or v6), e.g., --filter-by-ips 100.64.0.100,fd00::1")
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
statusCmd.PersistentFlags().StringVar(&connectionTypeFilter, "filter-by-connection-type", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P")
statusCmd.PersistentFlags().StringVar(&checkFlag, "check", "", "run a health check and exit with code 0 on success, 1 on failure (live|ready|startup)")
statusCmd.PersistentFlags().StringSliceVarP(&ipsFilter, "filter-by-ips", "I", []string{}, "filters the detailed output by a list of one or more IPs (v4 or v6), e.g., --filter-by-ips 100.64.0.100,fd00::1")
statusCmd.PersistentFlags().StringSliceVarP(&prefixNamesFilter, "filter-by-names", "N", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
statusCmd.PersistentFlags().StringVarP(&statusFilter, "filter-by-status", "S", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
statusCmd.PersistentFlags().StringVarP(&connectionTypeFilter, "filter-by-connection-type", "T", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P")
statusCmd.PersistentFlags().StringVarP(&checkFlag, "check", "C", "", "run a health check and exit with code 0 on success, 1 on failure (live|ready|startup)")
}
func statusFunc(cmd *cobra.Command, args []string) error {

View File

@@ -12,7 +12,6 @@ import (
"sync"
"github.com/sirupsen/logrus"
wgdevice "golang.zx2c4.com/wireguard/device"
wgnetstack "golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface"
@@ -95,26 +94,6 @@ type Options struct {
MTU *uint16
// DNSLabels defines additional DNS labels configured in the peer.
DNSLabels []string
// Performance configures the tunnel's buffer pool cap and batch size.
Performance Performance
}
// Performance configures the embedded client's tunnel memory/throughput knobs.
//
// These settings are process-global: any non-nil field also becomes the
// default for Clients constructed by later embed.New calls in the same
// process. Nil fields are ignored.
type Performance struct {
// PreallocatedBuffersPerPool caps the per-tunnel buffer pool. Zero
// leaves the pool unbounded. Lower values trade throughput for a
// tighter memory ceiling. May also be changed on a running Client via
// Client.SetPerformance, provided this field was nonzero at construction.
PreallocatedBuffersPerPool *uint32
// MaxBatchSize overrides the number of packets the tunnel reads or
// writes per syscall, which also bounds eager buffer allocation per
// worker. Zero uses the platform default. Applied at construction
// only; ignored by Client.SetPerformance.
MaxBatchSize *uint32
}
// validateCredentials checks that exactly one credential type is provided
@@ -213,13 +192,6 @@ func New(opts Options) (*Client, error) {
config.PrivateKey = opts.PrivateKey
}
if opts.Performance.PreallocatedBuffersPerPool != nil {
wgdevice.SetPreallocatedBuffersPerPool(*opts.Performance.PreallocatedBuffersPerPool)
}
if opts.Performance.MaxBatchSize != nil {
wgdevice.SetMaxBatchSizeOverride(*opts.Performance.MaxBatchSize)
}
return &Client{
deviceName: opts.DeviceName,
setupKey: opts.SetupKey,
@@ -364,7 +336,7 @@ func (c *Client) ListenTCP(address string) (net.Listener, error) {
if err != nil {
return nil, fmt.Errorf("split host port: %w", err)
}
listenAddr := fmt.Sprintf("%s:%s", addr, port)
listenAddr := net.JoinHostPort(addr.String(), port)
tcpAddr, err := net.ResolveTCPAddr("tcp", listenAddr)
if err != nil {
@@ -385,7 +357,7 @@ func (c *Client) ListenUDP(address string) (net.PacketConn, error) {
if err != nil {
return nil, fmt.Errorf("split host port: %w", err)
}
listenAddr := fmt.Sprintf("%s:%s", addr, port)
listenAddr := net.JoinHostPort(addr.String(), port)
udpAddr, err := net.ResolveUDPAddr("udp", listenAddr)
if err != nil {
@@ -501,25 +473,6 @@ func (c *Client) VerifySSHHostKey(peerAddress string, key []byte) error {
return sshcommon.VerifyHostKey(storedKey, key, peerAddress)
}
// SetPerformance retunes a running Client. Only PreallocatedBuffersPerPool
// takes effect, and only when it was nonzero at construction;
// MaxBatchSize is construction-only and returns an error if set here.
//
// Returns ErrClientNotStarted / ErrEngineNotStarted if the Client is not
// running yet.
func (c *Client) SetPerformance(t Performance) error {
if t.MaxBatchSize != nil {
return errors.New("MaxBatchSize is construction-only and cannot be changed at runtime")
}
engine, err := c.getEngine()
if err != nil {
return err
}
return engine.SetPerformance(internal.Performance{
PreallocatedBuffersPerPool: t.PreallocatedBuffersPerPool,
})
}
// StartCapture begins capturing packets on this client's tunnel device.
// Only one capture can be active at a time; starting a new one stops the previous.
// Call StopCapture (or CaptureSession.Stop) to end it.

View File

@@ -0,0 +1,125 @@
package conntrack
import (
"net/netip"
"testing"
"github.com/google/gopacket/layers"
"github.com/stretchr/testify/require"
)
func TestTCPCapEvicts(t *testing.T) {
t.Setenv(EnvTCPMaxEntries, "4")
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
require.Equal(t, 4, tracker.maxEntries)
src := netip.MustParseAddr("100.64.0.1")
dst := netip.MustParseAddr("100.64.0.2")
for i := 0; i < 10; i++ {
tracker.TrackOutbound(src, dst, uint16(10000+i), 80, TCPSyn, 0)
}
require.LessOrEqual(t, len(tracker.connections), 4,
"TCP table must not exceed the configured cap")
require.Greater(t, len(tracker.connections), 0,
"some entries must remain after eviction")
// The most recently admitted flow must be present: eviction must make
// room for new entries, not silently drop them.
require.Contains(t, tracker.connections,
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(10009), DstPort: 80},
"newest TCP flow must be admitted after eviction")
// A pre-cap flow must have been evicted to fit the last one.
require.NotContains(t, tracker.connections,
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(10000), DstPort: 80},
"oldest TCP flow should have been evicted")
}
func TestTCPCapPrefersTombstonedForEviction(t *testing.T) {
t.Setenv(EnvTCPMaxEntries, "3")
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
src := netip.MustParseAddr("100.64.0.1")
dst := netip.MustParseAddr("100.64.0.2")
// Fill to cap with 3 live connections.
for i := 0; i < 3; i++ {
tracker.TrackOutbound(src, dst, uint16(20000+i), 80, TCPSyn, 0)
}
require.Len(t, tracker.connections, 3)
// Tombstone one by sending RST through IsValidInbound.
tombstonedKey := ConnKey{SrcIP: src, DstIP: dst, SrcPort: 20001, DstPort: 80}
require.True(t, tracker.IsValidInbound(dst, src, 80, 20001, TCPRst|TCPAck, 0))
require.True(t, tracker.connections[tombstonedKey].IsTombstone())
// Another live connection forces eviction. The tombstone must go first.
tracker.TrackOutbound(src, dst, uint16(29999), 80, TCPSyn, 0)
_, tombstonedStillPresent := tracker.connections[tombstonedKey]
require.False(t, tombstonedStillPresent,
"tombstoned entry should be evicted before live entries")
require.LessOrEqual(t, len(tracker.connections), 3)
// Both live pre-cap entries must survive: eviction must prefer the
// tombstone, not just satisfy the size bound by dropping any entry.
require.Contains(t, tracker.connections,
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(20000), DstPort: 80},
"live entries must not be evicted while a tombstone exists")
require.Contains(t, tracker.connections,
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(20002), DstPort: 80},
"live entries must not be evicted while a tombstone exists")
}
func TestUDPCapEvicts(t *testing.T) {
t.Setenv(EnvUDPMaxEntries, "5")
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
defer tracker.Close()
require.Equal(t, 5, tracker.maxEntries)
src := netip.MustParseAddr("100.64.0.1")
dst := netip.MustParseAddr("100.64.0.2")
for i := 0; i < 12; i++ {
tracker.TrackOutbound(src, dst, uint16(30000+i), 53, 0)
}
require.LessOrEqual(t, len(tracker.connections), 5)
require.Greater(t, len(tracker.connections), 0)
require.Contains(t, tracker.connections,
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(30011), DstPort: 53},
"newest UDP flow must be admitted after eviction")
require.NotContains(t, tracker.connections,
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(30000), DstPort: 53},
"oldest UDP flow should have been evicted")
}
func TestICMPCapEvicts(t *testing.T) {
t.Setenv(EnvICMPMaxEntries, "3")
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
defer tracker.Close()
require.Equal(t, 3, tracker.maxEntries)
src := netip.MustParseAddr("100.64.0.1")
dst := netip.MustParseAddr("100.64.0.2")
echoReq := layers.CreateICMPv4TypeCode(uint8(layers.ICMPv4TypeEchoRequest), 0)
for i := 0; i < 8; i++ {
tracker.TrackOutbound(src, dst, uint16(i), echoReq, nil, 64)
}
require.LessOrEqual(t, len(tracker.connections), 3)
require.Greater(t, len(tracker.connections), 0)
require.Contains(t, tracker.connections,
ICMPConnKey{SrcIP: src, DstIP: dst, ID: uint16(7)},
"newest ICMP flow must be admitted after eviction")
require.NotContains(t, tracker.connections,
ICMPConnKey{SrcIP: src, DstIP: dst, ID: uint16(0)},
"oldest ICMP flow should have been evicted")
}

View File

@@ -3,15 +3,61 @@ package conntrack
import (
"net"
"net/netip"
"os"
"strconv"
"sync/atomic"
"time"
"github.com/google/uuid"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
// evictSampleSize bounds how many map entries we scan per eviction call.
// Keeps eviction O(1) even at cap under sustained load; the sampled-LRU
// heuristic is good enough for a conntrack table that only overflows under
// abuse.
const evictSampleSize = 8
// envDuration parses an os.Getenv(name) as a time.Duration. Falls back to
// def on empty or invalid; logs a warning on invalid.
func envDuration(logger *nblog.Logger, name string, def time.Duration) time.Duration {
v := os.Getenv(name)
if v == "" {
return def
}
d, err := time.ParseDuration(v)
if err != nil {
logger.Warn3("invalid %s=%q: %v, using default", name, v, err)
return def
}
if d <= 0 {
logger.Warn2("invalid %s=%q: must be positive, using default", name, v)
return def
}
return d
}
// envInt parses an os.Getenv(name) as an int. Falls back to def on empty,
// invalid, or non-positive. Logs a warning on invalid input.
func envInt(logger *nblog.Logger, name string, def int) int {
v := os.Getenv(name)
if v == "" {
return def
}
n, err := strconv.Atoi(v)
switch {
case err != nil:
logger.Warn3("invalid %s=%q: %v, using default", name, v, err)
return def
case n <= 0:
logger.Warn2("invalid %s=%q: must be positive, using default", name, v)
return def
}
return n
}
// BaseConnTrack provides common fields and locking for all connection types
type BaseConnTrack struct {
FlowId uuid.UUID

View File

@@ -0,0 +1,11 @@
//go:build !ios && !android
package conntrack
// Default per-tracker entry caps on desktop/server platforms. These mirror
// typical Linux netfilter nf_conntrack_max territory with ample headroom.
const (
DefaultMaxTCPEntries = 65536
DefaultMaxUDPEntries = 16384
DefaultMaxICMPEntries = 2048
)

View File

@@ -0,0 +1,13 @@
//go:build ios || android
package conntrack
// Default per-tracker entry caps on mobile platforms. iOS network extensions
// are capped at ~50 MB; Android runs under aggressive memory pressure. These
// values keep conntrack footprint well under 5 MB worst case (TCPConnTrack
// is ~200 B plus map overhead).
const (
DefaultMaxTCPEntries = 4096
DefaultMaxUDPEntries = 2048
DefaultMaxICMPEntries = 512
)

View File

@@ -50,6 +50,9 @@ type ICMPConnTrack struct {
ICMPCode uint8
}
// EnvICMPMaxEntries caps the ICMP conntrack table size.
const EnvICMPMaxEntries = "NB_CONNTRACK_ICMP_MAX"
// ICMPTracker manages ICMP connection states
type ICMPTracker struct {
logger *nblog.Logger
@@ -58,6 +61,7 @@ type ICMPTracker struct {
cleanupTicker *time.Ticker
tickerCancel context.CancelFunc
mutex sync.RWMutex
maxEntries int
flowLogger nftypes.FlowLogger
}
@@ -171,6 +175,7 @@ func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nfty
timeout: timeout,
cleanupTicker: time.NewTicker(ICMPCleanupInterval),
tickerCancel: cancel,
maxEntries: envInt(logger, EnvICMPMaxEntries, DefaultMaxICMPEntries),
flowLogger: flowLogger,
}
@@ -257,7 +262,9 @@ func (t *ICMPTracker) track(
// non echo requests don't need tracking
if typ != uint8(layers.ICMPv4TypeEchoRequest) {
t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo)
if t.logger.Enabled(nblog.LevelTrace) {
t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo)
}
t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size)
return
}
@@ -276,10 +283,15 @@ func (t *ICMPTracker) track(
conn.UpdateCounters(direction, size)
t.mutex.Lock()
if t.maxEntries > 0 && len(t.connections) >= t.maxEntries {
t.evictOneLocked()
}
t.connections[key] = conn
t.mutex.Unlock()
t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo)
if t.logger.Enabled(nblog.LevelTrace) {
t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo)
}
t.sendEvent(nftypes.TypeStart, conn, ruleId)
}
@@ -323,6 +335,34 @@ func (t *ICMPTracker) cleanupRoutine(ctx context.Context) {
}
}
// evictOneLocked removes one entry to make room. Caller must hold t.mutex.
// Bounded sample scan: picks the oldest among up to evictSampleSize entries.
func (t *ICMPTracker) evictOneLocked() {
var candKey ICMPConnKey
var candSeen int64
haveCand := false
sampled := 0
for k, c := range t.connections {
seen := c.lastSeen.Load()
if !haveCand || seen < candSeen {
candKey = k
candSeen = seen
haveCand = true
}
sampled++
if sampled >= evictSampleSize {
break
}
}
if haveCand {
if evicted := t.connections[candKey]; evicted != nil {
t.sendEvent(nftypes.TypeEnd, evicted, nil)
}
delete(t.connections, candKey)
}
}
func (t *ICMPTracker) cleanup() {
t.mutex.Lock()
defer t.mutex.Unlock()
@@ -331,8 +371,10 @@ func (t *ICMPTracker) cleanup() {
if conn.timeoutExceeded(t.timeout) {
delete(t.connections, key)
t.logger.Trace5("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
if t.logger.Enabled(nblog.LevelTrace) {
t.logger.Trace5("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
}
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
}

View File

@@ -38,6 +38,27 @@ const (
TCPHandshakeTimeout = 60 * time.Second
// TCPCleanupInterval is how often we check for stale connections
TCPCleanupInterval = 5 * time.Minute
// FinWaitTimeout bounds FIN_WAIT_1 / FIN_WAIT_2 / CLOSING states.
// Matches Linux netfilter nf_conntrack_tcp_timeout_fin_wait.
FinWaitTimeout = 60 * time.Second
// CloseWaitTimeout bounds CLOSE_WAIT. Matches Linux default; apps
// holding CloseWait longer than this should bump the env var.
CloseWaitTimeout = 60 * time.Second
// LastAckTimeout bounds LAST_ACK. Matches Linux default.
LastAckTimeout = 30 * time.Second
)
// Env vars to override per-state teardown timeouts. Values parsed by
// time.ParseDuration (e.g. "120s", "2m"). Invalid values fall back to the
// defaults above with a warning.
const (
EnvTCPFinWaitTimeout = "NB_CONNTRACK_TCP_FIN_WAIT_TIMEOUT"
EnvTCPCloseWaitTimeout = "NB_CONNTRACK_TCP_CLOSE_WAIT_TIMEOUT"
EnvTCPLastAckTimeout = "NB_CONNTRACK_TCP_LAST_ACK_TIMEOUT"
// EnvTCPMaxEntries caps the TCP conntrack table size. Oldest entries
// (tombstones first) are evicted when the cap is reached.
EnvTCPMaxEntries = "NB_CONNTRACK_TCP_MAX"
)
// TCPState represents the state of a TCP connection
@@ -133,14 +154,18 @@ func (t *TCPConnTrack) SetTombstone() {
// TCPTracker manages TCP connection states
type TCPTracker struct {
logger *nblog.Logger
connections map[ConnKey]*TCPConnTrack
mutex sync.RWMutex
cleanupTicker *time.Ticker
tickerCancel context.CancelFunc
timeout time.Duration
waitTimeout time.Duration
flowLogger nftypes.FlowLogger
logger *nblog.Logger
connections map[ConnKey]*TCPConnTrack
mutex sync.RWMutex
cleanupTicker *time.Ticker
tickerCancel context.CancelFunc
timeout time.Duration
waitTimeout time.Duration
finWaitTimeout time.Duration
closeWaitTimeout time.Duration
lastAckTimeout time.Duration
maxEntries int
flowLogger nftypes.FlowLogger
}
// NewTCPTracker creates a new TCP connection tracker
@@ -155,13 +180,17 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
ctx, cancel := context.WithCancel(context.Background())
tracker := &TCPTracker{
logger: logger,
connections: make(map[ConnKey]*TCPConnTrack),
cleanupTicker: time.NewTicker(TCPCleanupInterval),
tickerCancel: cancel,
timeout: timeout,
waitTimeout: waitTimeout,
flowLogger: flowLogger,
logger: logger,
connections: make(map[ConnKey]*TCPConnTrack),
cleanupTicker: time.NewTicker(TCPCleanupInterval),
tickerCancel: cancel,
timeout: timeout,
waitTimeout: waitTimeout,
finWaitTimeout: envDuration(logger, EnvTCPFinWaitTimeout, FinWaitTimeout),
closeWaitTimeout: envDuration(logger, EnvTCPCloseWaitTimeout, CloseWaitTimeout),
lastAckTimeout: envDuration(logger, EnvTCPLastAckTimeout, LastAckTimeout),
maxEntries: envInt(logger, EnvTCPMaxEntries, DefaultMaxTCPEntries),
flowLogger: flowLogger,
}
go tracker.cleanupRoutine(ctx)
@@ -209,6 +238,12 @@ func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, fla
if exists || flags&TCPSyn == 0 {
return
}
// Reject illegal SYN combinations (SYN+FIN, SYN+RST, …) so they don't
// create spurious conntrack entries. Not mandated by RFC 9293 but a
// common hardening (Linux netfilter/nftables rejects these too).
if !isValidFlagCombination(flags) {
return
}
conn := &TCPConnTrack{
BaseConnTrack: BaseConnTrack{
@@ -225,20 +260,65 @@ func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, fla
conn.state.Store(int32(TCPStateNew))
conn.DNATOrigPort.Store(uint32(origPort))
if origPort != 0 {
t.logger.Trace4("New %s TCP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
} else {
t.logger.Trace2("New %s TCP connection: %s", direction, key)
if t.logger.Enabled(nblog.LevelTrace) {
if origPort != 0 {
t.logger.Trace4("New %s TCP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
} else {
t.logger.Trace2("New %s TCP connection: %s", direction, key)
}
}
t.updateState(key, conn, flags, direction, size)
t.mutex.Lock()
if t.maxEntries > 0 && len(t.connections) >= t.maxEntries {
t.evictOneLocked()
}
t.connections[key] = conn
t.mutex.Unlock()
t.sendEvent(nftypes.TypeStart, conn, ruleID)
}
// evictOneLocked removes one entry to make room. Caller must hold t.mutex.
// Bounded scan: samples up to evictSampleSize pseudo-random entries (Go map
// iteration order is randomized), preferring a tombstone. If no tombstone
// found in the sample, evicts the oldest among the sampled entries. O(1)
// worst case — cheap enough to run on every insert at cap during abuse.
func (t *TCPTracker) evictOneLocked() {
var candKey ConnKey
var candSeen int64
haveCand := false
sampled := 0
for k, c := range t.connections {
if c.IsTombstone() {
delete(t.connections, k)
return
}
seen := c.lastSeen.Load()
if !haveCand || seen < candSeen {
candKey = k
candSeen = seen
haveCand = true
}
sampled++
if sampled >= evictSampleSize {
break
}
}
if haveCand {
if evicted := t.connections[candKey]; evicted != nil {
// TypeEnd is already emitted at the state transition to
// TimeWait and when a connection is tombstoned. Only emit
// here when we're reaping a still-active flow.
if evicted.GetState() != TCPStateTimeWait && !evicted.IsTombstone() {
t.sendEvent(nftypes.TypeEnd, evicted, nil)
}
}
delete(t.connections, candKey)
}
}
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) bool {
key := ConnKey{
@@ -256,12 +336,19 @@ func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort ui
return false
}
// Reject illegal flag combinations regardless of state. These never belong
// to a legitimate flow and must not advance or tear down state.
if !isValidFlagCombination(flags) {
if t.logger.Enabled(nblog.LevelWarn) {
t.logger.Warn3("TCP illegal flag combination %x for connection %s (state %s)", flags, key, conn.GetState())
}
return false
}
currentState := conn.GetState()
if !t.isValidStateForFlags(currentState, flags) {
t.logger.Warn3("TCP state %s is not valid with flags %x for connection %s", currentState, flags, key)
// allow all flags for established for now
if currentState == TCPStateEstablished {
return true
if t.logger.Enabled(nblog.LevelWarn) {
t.logger.Warn3("TCP state %s is not valid with flags %x for connection %s", currentState, flags, key)
}
return false
}
@@ -270,116 +357,208 @@ func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort ui
return true
}
// updateState updates the TCP connection state based on flags
// updateState updates the TCP connection state based on flags.
func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, packetDir nftypes.Direction, size int) {
conn.UpdateLastSeen()
conn.UpdateCounters(packetDir, size)
// Malformed flag combinations must not refresh lastSeen or drive state,
// otherwise spoofed packets keep a dead flow alive past its timeout.
if !isValidFlagCombination(flags) {
return
}
conn.UpdateLastSeen()
currentState := conn.GetState()
if flags&TCPRst != 0 {
if conn.CompareAndSwapState(currentState, TCPStateClosed) {
conn.SetTombstone()
t.logger.Trace6("TCP connection reset: %s (dir: %s) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
key, packetDir, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
// Hardening beyond RFC 9293 §3.10.7.4: without sequence tracking we
// cannot apply the RFC 5961 in-window RST check, so we conservatively
// reject RSTs that the spec would accept (TIME-WAIT with in-window
// SEQ, SynSent from same direction as own SYN, etc.).
t.handleRst(key, conn, currentState, packetDir)
return
}
var newState TCPState
switch currentState {
case TCPStateNew:
if flags&TCPSyn != 0 && flags&TCPAck == 0 {
if conn.Direction == nftypes.Egress {
newState = TCPStateSynSent
} else {
newState = TCPStateSynReceived
}
}
newState := nextState(currentState, conn.Direction, packetDir, flags)
if newState == 0 || !conn.CompareAndSwapState(currentState, newState) {
return
}
t.onTransition(key, conn, currentState, newState, packetDir)
}
case TCPStateSynSent:
if flags&TCPSyn != 0 && flags&TCPAck != 0 {
if packetDir != conn.Direction {
newState = TCPStateEstablished
} else {
// Simultaneous open
newState = TCPStateSynReceived
}
}
// handleRst processes a RST segment. Late RSTs in TimeWait and spoofed RSTs
// from the SYN direction are ignored; otherwise the flow is tombstoned.
func (t *TCPTracker) handleRst(key ConnKey, conn *TCPConnTrack, currentState TCPState, packetDir nftypes.Direction) {
// TimeWait exists to absorb late segments; don't let a late RST
// tombstone the entry and break same-4-tuple reuse.
if currentState == TCPStateTimeWait {
return
}
// A RST from the same direction as the SYN cannot be a legitimate
// response and must not tear down a half-open connection.
if currentState == TCPStateSynSent && packetDir == conn.Direction {
return
}
if !conn.CompareAndSwapState(currentState, TCPStateClosed) {
return
}
conn.SetTombstone()
if t.logger.Enabled(nblog.LevelTrace) {
t.logger.Trace6("TCP connection reset: %s (dir: %s) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
key, packetDir, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
}
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
case TCPStateSynReceived:
if flags&TCPAck != 0 && flags&TCPSyn == 0 {
if packetDir == conn.Direction {
newState = TCPStateEstablished
}
}
// stateTransition describes one state's transition logic. It receives the
// packet's flags plus whether the packet direction matches the connection's
// origin direction (same=true means same side as the SYN initiator). Return 0
// for no transition.
type stateTransition func(flags uint8, connDir nftypes.Direction, same bool) TCPState
case TCPStateEstablished:
if flags&TCPFin != 0 {
if packetDir == conn.Direction {
newState = TCPStateFinWait1
} else {
newState = TCPStateCloseWait
}
}
// stateTable maps each state to its transition function. Centralized here so
// nextState stays trivial and each rule is easy to read in isolation.
var stateTable = map[TCPState]stateTransition{
TCPStateNew: transNew,
TCPStateSynSent: transSynSent,
TCPStateSynReceived: transSynReceived,
TCPStateEstablished: transEstablished,
TCPStateFinWait1: transFinWait1,
TCPStateFinWait2: transFinWait2,
TCPStateClosing: transClosing,
TCPStateCloseWait: transCloseWait,
TCPStateLastAck: transLastAck,
}
case TCPStateFinWait1:
if packetDir != conn.Direction {
switch {
case flags&TCPFin != 0 && flags&TCPAck != 0:
newState = TCPStateClosing
case flags&TCPFin != 0:
newState = TCPStateClosing
case flags&TCPAck != 0:
newState = TCPStateFinWait2
}
}
// nextState returns the target TCP state for the given current state and
// packet, or 0 if the packet does not trigger a transition.
func nextState(currentState TCPState, connDir, packetDir nftypes.Direction, flags uint8) TCPState {
fn, ok := stateTable[currentState]
if !ok {
return 0
}
return fn(flags, connDir, packetDir == connDir)
}
case TCPStateFinWait2:
if flags&TCPFin != 0 {
newState = TCPStateTimeWait
func transNew(flags uint8, connDir nftypes.Direction, _ bool) TCPState {
if flags&TCPSyn != 0 && flags&TCPAck == 0 {
if connDir == nftypes.Egress {
return TCPStateSynSent
}
return TCPStateSynReceived
}
return 0
}
case TCPStateClosing:
if flags&TCPAck != 0 {
newState = TCPStateTimeWait
func transSynSent(flags uint8, _ nftypes.Direction, same bool) TCPState {
if flags&TCPSyn != 0 && flags&TCPAck != 0 {
if same {
return TCPStateSynReceived // simultaneous open
}
return TCPStateEstablished
}
return 0
}
case TCPStateCloseWait:
if flags&TCPFin != 0 {
newState = TCPStateLastAck
}
func transSynReceived(flags uint8, _ nftypes.Direction, same bool) TCPState {
if flags&TCPAck != 0 && flags&TCPSyn == 0 && same {
return TCPStateEstablished
}
return 0
}
case TCPStateLastAck:
if flags&TCPAck != 0 {
newState = TCPStateClosed
}
func transEstablished(flags uint8, _ nftypes.Direction, same bool) TCPState {
if flags&TCPFin == 0 {
return 0
}
if same {
return TCPStateFinWait1
}
return TCPStateCloseWait
}
// transFinWait1 handles the active-close peer response. A FIN carrying our
// ACK piggybacked goes straight to TIME-WAIT (RFC 9293 §3.10.7.4, FIN-WAIT-1:
// "if our FIN has been ACKed... enter the TIME-WAIT state"); a lone FIN moves
// to CLOSING; a pure ACK of our FIN moves to FIN-WAIT-2.
func transFinWait1(flags uint8, _ nftypes.Direction, same bool) TCPState {
if same {
return 0
}
if flags&TCPFin != 0 && flags&TCPAck != 0 {
return TCPStateTimeWait
}
switch {
case flags&TCPFin != 0:
return TCPStateClosing
case flags&TCPAck != 0:
return TCPStateFinWait2
}
return 0
}
// transFinWait2 ignores own-side FIN retransmits; only the peer's FIN advances.
func transFinWait2(flags uint8, _ nftypes.Direction, same bool) TCPState {
if flags&TCPFin != 0 && !same {
return TCPStateTimeWait
}
return 0
}
// transClosing completes a simultaneous close on the peer's ACK.
func transClosing(flags uint8, _ nftypes.Direction, same bool) TCPState {
if flags&TCPAck != 0 && !same {
return TCPStateTimeWait
}
return 0
}
// transCloseWait only advances to LastAck when WE send FIN, ignoring peer retransmits.
func transCloseWait(flags uint8, _ nftypes.Direction, same bool) TCPState {
if flags&TCPFin != 0 && same {
return TCPStateLastAck
}
return 0
}
// transLastAck closes the flow only on the peer's ACK (not our own ACK retransmits).
func transLastAck(flags uint8, _ nftypes.Direction, same bool) TCPState {
if flags&TCPAck != 0 && !same {
return TCPStateClosed
}
return 0
}
// onTransition handles logging and flow-event emission after a successful
// state transition. TimeWait and Closed are terminal for flow accounting.
func (t *TCPTracker) onTransition(key ConnKey, conn *TCPConnTrack, from, to TCPState, packetDir nftypes.Direction) {
traceOn := t.logger.Enabled(nblog.LevelTrace)
if traceOn {
t.logger.Trace4("TCP connection %s transitioned from %s to %s (dir: %s)", key, from, to, packetDir)
}
if newState != 0 && conn.CompareAndSwapState(currentState, newState) {
t.logger.Trace4("TCP connection %s transitioned from %s to %s (dir: %s)", key, currentState, newState, packetDir)
switch newState {
case TCPStateTimeWait:
switch to {
case TCPStateTimeWait:
if traceOn {
t.logger.Trace5("TCP connection %s completed [in: %d Pkts/%d B, out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
case TCPStateClosed:
conn.SetTombstone()
}
t.sendEvent(nftypes.TypeEnd, conn, nil)
case TCPStateClosed:
conn.SetTombstone()
if traceOn {
t.logger.Trace5("TCP connection %s closed gracefully [in: %d Pkts/%d, B out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
}
// isValidStateForFlags checks if the TCP flags are valid for the current connection state
// isValidStateForFlags checks if the TCP flags are valid for the current
// connection state. Caller must have already verified the flag combination is
// legal via isValidFlagCombination.
func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
if !isValidFlagCombination(flags) {
return false
}
if flags&TCPRst != 0 {
if state == TCPStateSynSent {
return flags&TCPAck != 0
@@ -449,15 +628,24 @@ func (t *TCPTracker) cleanup() {
timeout = t.waitTimeout
case TCPStateEstablished:
timeout = t.timeout
case TCPStateFinWait1, TCPStateFinWait2, TCPStateClosing:
timeout = t.finWaitTimeout
case TCPStateCloseWait:
timeout = t.closeWaitTimeout
case TCPStateLastAck:
timeout = t.lastAckTimeout
default:
// SynSent / SynReceived / New
timeout = TCPHandshakeTimeout
}
if conn.timeoutExceeded(timeout) {
delete(t.connections, key)
t.logger.Trace6("Cleaned up timed-out TCP connection %s (%s) [in: %d Pkts/%d, B out: %d Pkts/%d B]",
key, conn.GetState(), conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
if t.logger.Enabled(nblog.LevelTrace) {
t.logger.Trace6("Cleaned up timed-out TCP connection %s (%s) [in: %d Pkts/%d, B out: %d Pkts/%d B]",
key, conn.GetState(), conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
}
// event already handled by state change
if currentState != TCPStateTimeWait {

View File

@@ -0,0 +1,100 @@
package conntrack
import (
"net/netip"
"testing"
"github.com/stretchr/testify/require"
)
// RST hygiene tests: the tracker currently closes the flow on any RST that
// matches the 4-tuple, regardless of direction or state. These tests cover
// the minimum checks we want (no SEQ tracking).
func TestTCPRstInSynSentWrongDirection(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateSynSent, conn.GetState())
// A RST arriving in the same direction as the SYN (i.e. TrackOutbound)
// cannot be a legitimate response. It must not close the connection.
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPRst|TCPAck, 0)
require.Equal(t, TCPStateSynSent, conn.GetState(),
"RST in same direction as SYN must not close connection")
require.False(t, conn.IsTombstone())
}
func TestTCPRstInTimeWaitIgnored(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
// Drive to TIME-WAIT via active close.
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0))
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0))
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateTimeWait, conn.GetState())
require.False(t, conn.IsTombstone(), "TIME-WAIT must not be tombstoned")
// Late RST during TIME-WAIT must not tombstone the entry (TIME-WAIT
// exists to absorb late segments).
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
require.Equal(t, TCPStateTimeWait, conn.GetState(),
"RST in TIME-WAIT must not transition state")
require.False(t, conn.IsTombstone(),
"RST in TIME-WAIT must not tombstone the entry")
}
func TestTCPIllegalFlagCombos(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
conn := tracker.connections[key]
// Illegal combos must be rejected and must not change state.
combos := []struct {
name string
flags uint8
}{
{"SYN+RST", TCPSyn | TCPRst},
{"FIN+RST", TCPFin | TCPRst},
{"SYN+FIN", TCPSyn | TCPFin},
{"SYN+FIN+RST", TCPSyn | TCPFin | TCPRst},
}
for _, c := range combos {
t.Run(c.name, func(t *testing.T) {
before := conn.GetState()
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, c.flags, 0)
require.False(t, valid, "illegal flag combo must be rejected: %s", c.name)
require.Equal(t, before, conn.GetState(),
"illegal flag combo must not change state")
require.False(t, conn.IsTombstone())
})
}
}

View File

@@ -0,0 +1,235 @@
package conntrack
import (
"net/netip"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// These tests exercise cases where the TCP state machine currently advances
// on retransmitted or wrong-direction segments and tears the flow down
// prematurely. They are expected to fail until the direction checks are added.
func TestTCPCloseWaitRetransmittedPeerFIN(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Peer sends FIN -> CloseWait (our app has not yet closed).
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid)
conn := tracker.connections[key]
require.Equal(t, TCPStateCloseWait, conn.GetState())
// Peer retransmits their FIN (ACK may have been delayed). We have NOT
// sent our FIN yet, so state must remain CloseWait.
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid, "retransmitted peer FIN must still be accepted")
require.Equal(t, TCPStateCloseWait, conn.GetState(),
"retransmitted peer FIN must not advance CloseWait to LastAck")
// Our app finally closes -> LastAck.
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
require.Equal(t, TCPStateLastAck, conn.GetState())
// Peer ACK closes.
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid)
require.Equal(t, TCPStateClosed, conn.GetState())
}
func TestTCPFinWait2RetransmittedOwnFIN(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// We initiate close.
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid)
conn := tracker.connections[key]
require.Equal(t, TCPStateFinWait2, conn.GetState())
// Stray retransmit of our own FIN (same direction as originator) must
// NOT advance FinWait2 to TimeWait; only the peer's FIN should.
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
require.Equal(t, TCPStateFinWait2, conn.GetState(),
"own FIN retransmit must not advance FinWait2 to TimeWait")
// Peer FIN -> TimeWait.
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid)
require.Equal(t, TCPStateTimeWait, conn.GetState())
}
func TestTCPLastAckDirectionCheck(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Drive to LastAck: peer FIN -> CloseWait, our FIN -> LastAck.
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0))
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateLastAck, conn.GetState())
// Our own ACK retransmit (same direction as originator) must NOT close.
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
require.Equal(t, TCPStateLastAck, conn.GetState(),
"own ACK retransmit in LastAck must not transition to Closed")
// Peer's ACK -> Closed.
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0))
require.Equal(t, TCPStateClosed, conn.GetState())
}
func TestTCPFinWait1OwnAckDoesNotAdvance(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateFinWait1, conn.GetState())
// Our own ACK retransmit (same direction as originator) must not advance.
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
require.Equal(t, TCPStateFinWait1, conn.GetState(),
"own ACK in FinWait1 must not advance to FinWait2")
}
func TestTCPPerStateTeardownTimeouts(t *testing.T) {
// Verify cleanup reaps entries in each teardown state at the configured
// per-state timeout, not at the single handshake timeout.
t.Setenv(EnvTCPFinWaitTimeout, "50ms")
t.Setenv(EnvTCPCloseWaitTimeout, "80ms")
t.Setenv(EnvTCPLastAckTimeout, "30ms")
dstIP := netip.MustParseAddr("100.64.0.2")
dstPort := uint16(80)
// Drives a connection to the target state, forces its lastSeen well
// beyond the configured timeout, runs cleanup, and asserts reaping.
cases := []struct {
name string
// drive takes a fresh tracker and returns the conn key after
// transitioning the flow into the intended teardown state.
drive func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState)
}{
{
name: "FinWait1",
drive: func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState) {
establishConnection(t, tr, srcIP, dstIP, srcPort, dstPort)
tr.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // → FinWait1
return ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}, TCPStateFinWait1
},
},
{
name: "FinWait2",
drive: func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState) {
establishConnection(t, tr, srcIP, dstIP, srcPort, dstPort)
tr.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // FinWait1
require.True(t, tr.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)) // → FinWait2
return ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}, TCPStateFinWait2
},
},
{
name: "CloseWait",
drive: func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState) {
establishConnection(t, tr, srcIP, dstIP, srcPort, dstPort)
require.True(t, tr.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)) // → CloseWait
return ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}, TCPStateCloseWait
},
},
{
name: "LastAck",
drive: func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState) {
establishConnection(t, tr, srcIP, dstIP, srcPort, dstPort)
require.True(t, tr.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)) // CloseWait
tr.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // → LastAck
return ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}, TCPStateLastAck
},
},
}
// Use a unique source port per subtest so nothing aliases.
port := uint16(12345)
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
require.Equal(t, 50*time.Millisecond, tracker.finWaitTimeout)
require.Equal(t, 80*time.Millisecond, tracker.closeWaitTimeout)
require.Equal(t, 30*time.Millisecond, tracker.lastAckTimeout)
srcIP := netip.MustParseAddr("100.64.0.1")
port++
key, wantState := c.drive(t, tracker, srcIP, port)
conn := tracker.connections[key]
require.NotNil(t, conn)
require.Equal(t, wantState, conn.GetState())
// Age the entry past the largest per-state timeout.
conn.lastSeen.Store(time.Now().Add(-500 * time.Millisecond).UnixNano())
tracker.cleanup()
_, exists := tracker.connections[key]
require.False(t, exists, "%s entry should be reaped", c.name)
})
}
}
func TestTCPEstablishedPSHACKInFinStates(t *testing.T) {
// Verifies FIN|PSH|ACK and bare ACK keepalives are not dropped in FIN
// teardown states, which some stacks emit during close.
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Peer FIN -> CloseWait.
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0))
// Peer pushes trailing data + FIN|PSH|ACK (legal).
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPPush|TCPAck, 100),
"FIN|PSH|ACK in CloseWait must be accepted")
// Bare ACK keepalive from peer in CloseWait must be accepted.
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0),
"bare ACK in CloseWait must be accepted")
}

View File

@@ -17,6 +17,9 @@ const (
DefaultUDPTimeout = 30 * time.Second
// UDPCleanupInterval is how often we check for stale connections
UDPCleanupInterval = 15 * time.Second
// EnvUDPMaxEntries caps the UDP conntrack table size.
EnvUDPMaxEntries = "NB_CONNTRACK_UDP_MAX"
)
// UDPConnTrack represents a UDP connection state
@@ -34,6 +37,7 @@ type UDPTracker struct {
cleanupTicker *time.Ticker
tickerCancel context.CancelFunc
mutex sync.RWMutex
maxEntries int
flowLogger nftypes.FlowLogger
}
@@ -51,6 +55,7 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
timeout: timeout,
cleanupTicker: time.NewTicker(UDPCleanupInterval),
tickerCancel: cancel,
maxEntries: envInt(logger, EnvUDPMaxEntries, DefaultMaxUDPEntries),
flowLogger: flowLogger,
}
@@ -117,13 +122,18 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
conn.UpdateCounters(direction, size)
t.mutex.Lock()
if t.maxEntries > 0 && len(t.connections) >= t.maxEntries {
t.evictOneLocked()
}
t.connections[key] = conn
t.mutex.Unlock()
if origPort != 0 {
t.logger.Trace4("New %s UDP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
} else {
t.logger.Trace2("New %s UDP connection: %s", direction, key)
if t.logger.Enabled(nblog.LevelTrace) {
if origPort != 0 {
t.logger.Trace4("New %s UDP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
} else {
t.logger.Trace2("New %s UDP connection: %s", direction, key)
}
}
t.sendEvent(nftypes.TypeStart, conn, ruleID)
}
@@ -151,6 +161,34 @@ func (t *UDPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort
return true
}
// evictOneLocked removes one entry to make room. Caller must hold t.mutex.
// Bounded sample: picks the oldest among up to evictSampleSize entries.
func (t *UDPTracker) evictOneLocked() {
var candKey ConnKey
var candSeen int64
haveCand := false
sampled := 0
for k, c := range t.connections {
seen := c.lastSeen.Load()
if !haveCand || seen < candSeen {
candKey = k
candSeen = seen
haveCand = true
}
sampled++
if sampled >= evictSampleSize {
break
}
}
if haveCand {
if evicted := t.connections[candKey]; evicted != nil {
t.sendEvent(nftypes.TypeEnd, evicted, nil)
}
delete(t.connections, candKey)
}
}
// cleanupRoutine periodically removes stale connections
func (t *UDPTracker) cleanupRoutine(ctx context.Context) {
defer t.cleanupTicker.Stop()
@@ -173,8 +211,10 @@ func (t *UDPTracker) cleanup() {
if conn.timeoutExceeded(t.timeout) {
delete(t.connections, key)
t.logger.Trace5("Removed UDP connection %s (timeout) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
if t.logger.Enabled(nblog.LevelTrace) {
t.logger.Trace5("Removed UDP connection %s (timeout) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
}
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
}

View File

@@ -787,7 +787,9 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
srcIP, dstIP := m.extractIPs(d)
if !srcIP.IsValid() {
m.logger.Error1("Unknown network layer: %v", d.decoded[0])
if m.logger.Enabled(nblog.LevelError) {
m.logger.Error1("Unknown network layer: %v", d.decoded[0])
}
return false
}
@@ -901,7 +903,9 @@ func (m *Manager) clampTCPMSS(packetData []byte, d *decoder) bool {
return false
}
m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, mssClampValue)
if m.logger.Enabled(nblog.LevelTrace) {
m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, mssClampValue)
}
return true
}
@@ -1044,11 +1048,13 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
// TODO: pass fragments of routed packets to forwarder
if fragment {
if d.decoded[0] == layers.LayerTypeIPv4 {
m.logger.Trace4("packet is a fragment: src=%v dst=%v id=%v flags=%v",
srcIP, dstIP, d.ip4.Id, d.ip4.Flags)
} else {
m.logger.Trace2("packet is an IPv6 fragment: src=%v dst=%v", srcIP, dstIP)
if m.logger.Enabled(nblog.LevelTrace) {
if d.decoded[0] == layers.LayerTypeIPv4 {
m.logger.Trace4("packet is a fragment: src=%v dst=%v id=%v flags=%v",
srcIP, dstIP, d.ip4.Id, d.ip4.Flags)
} else {
m.logger.Trace2("packet is an IPv6 fragment: src=%v dst=%v", srcIP, dstIP)
}
}
return false
}
@@ -1091,8 +1097,10 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
pnum := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d)
m.logger.Trace6("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
if m.logger.Enabled(nblog.LevelTrace) {
m.logger.Trace6("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
}
m.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: uuid.New(),
@@ -1142,8 +1150,10 @@ func (m *Manager) handleForwardedLocalTraffic(packetData []byte) bool {
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool {
// Drop if routing is disabled
if !m.routingEnabled.Load() {
m.logger.Trace2("Dropping routed packet (routing disabled): src=%s dst=%s",
srcIP, dstIP)
if m.logger.Enabled(nblog.LevelTrace) {
m.logger.Trace2("Dropping routed packet (routing disabled): src=%s dst=%s",
srcIP, dstIP)
}
return true
}
@@ -1160,8 +1170,10 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
if !pass {
proto := getProtocolFromPacket(d)
m.logger.Trace6("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
ruleID, proto, srcIP, srcPort, dstIP, dstPort)
if m.logger.Enabled(nblog.LevelTrace) {
m.logger.Trace6("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
ruleID, proto, srcIP, srcPort, dstIP, dstPort)
}
m.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: uuid.New(),
@@ -1287,7 +1299,9 @@ func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) {
// It returns true, true if the packet is a fragment and valid.
func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) {
if err := d.decodePacket(packetData); err != nil {
m.logger.Trace1("couldn't decode packet, err: %s", err)
if m.logger.Enabled(nblog.LevelTrace) {
m.logger.Trace1("couldn't decode packet, err: %s", err)
}
return false, false
}

View File

@@ -13,6 +13,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
@@ -97,8 +98,10 @@ func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []by
return nil, fmt.Errorf("write ICMP packet: %w", err)
}
f.logger.Trace3("forwarder: Forwarded ICMP packet %v type %v code %v",
epID(id), icmpType, icmpCode)
if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace3("forwarder: Forwarded ICMP packet %v type %v code %v",
epID(id), icmpType, icmpCode)
}
return conn, nil
}
@@ -121,12 +124,14 @@ func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndp
txBytes := f.handleEchoResponse(conn, id, v6)
rtt := time.Since(sendTime).Round(10 * time.Microsecond)
proto := "ICMP"
if v6 {
proto = "ICMPv6"
if f.logger.Enabled(nblog.LevelTrace) {
proto := "ICMP"
if v6 {
proto = "ICMPv6"
}
f.logger.Trace5("forwarder: Forwarded %s echo reply %v type %v code %v (rtt=%v, raw socket)",
proto, epID(id), icmpType, icmpCode, rtt)
}
f.logger.Trace5("forwarder: Forwarded %s echo reply %v type %v code %v (rtt=%v, raw socket)",
proto, epID(id), icmpType, icmpCode, rtt)
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
}
@@ -224,13 +229,17 @@ func (f *Forwarder) handleICMPViaPing(flowID uuid.UUID, id stack.TransportEndpoi
}
rtt := time.Since(pingStart).Round(10 * time.Microsecond)
f.logger.Trace3("forwarder: Forwarded ICMP echo request %v type %v code %v",
epID(id), icmpType, icmpCode)
if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace3("forwarder: Forwarded ICMP echo request %v type %v code %v",
epID(id), icmpType, icmpCode)
}
txBytes := f.synthesizeEchoReply(id, icmpData)
f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, ping binary)",
epID(id), icmpType, icmpCode, rtt)
if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, ping binary)",
epID(id), icmpType, icmpCode, rtt)
}
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
}

View File

@@ -1,11 +1,8 @@
package forwarder
import (
"context"
"io"
"net"
"strconv"
"sync"
"github.com/google/uuid"
@@ -15,7 +12,9 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/waiter"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/util/netrelay"
)
// handleTCP is called by the TCP forwarder for new connections.
@@ -37,7 +36,9 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
if err != nil {
r.Complete(true)
f.logger.Trace2("forwarder: dial error for %v: %v", epID(id), err)
if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace2("forwarder: dial error for %v: %v", epID(id), err)
}
return
}
@@ -60,64 +61,22 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
inConn := gonet.NewTCPConn(&wq, ep)
success = true
f.logger.Trace1("forwarder: established TCP connection %v", epID(id))
if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace1("forwarder: established TCP connection %v", epID(id))
}
go f.proxyTCP(id, inConn, outConn, ep, flowID)
}
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) {
// netrelay.Relay copies bidirectionally with proper half-close propagation
// and fully closes both conns before returning.
bytesFromInToOut, bytesFromOutToIn := netrelay.Relay(f.ctx, inConn, outConn, netrelay.Options{
Logger: f.logger,
})
ctx, cancel := context.WithCancel(f.ctx)
defer cancel()
go func() {
<-ctx.Done()
// Close connections and endpoint.
if err := inConn.Close(); err != nil && !isClosedError(err) {
f.logger.Debug1("forwarder: inConn close error: %v", err)
}
if err := outConn.Close(); err != nil && !isClosedError(err) {
f.logger.Debug1("forwarder: outConn close error: %v", err)
}
ep.Close()
}()
var wg sync.WaitGroup
wg.Add(2)
var (
bytesFromInToOut int64 // bytes from client to server (tx for client)
bytesFromOutToIn int64 // bytes from server to client (rx for client)
errInToOut error
errOutToIn error
)
go func() {
bytesFromInToOut, errInToOut = io.Copy(outConn, inConn)
cancel()
wg.Done()
}()
go func() {
bytesFromOutToIn, errOutToIn = io.Copy(inConn, outConn)
cancel()
wg.Done()
}()
wg.Wait()
if errInToOut != nil {
if !isClosedError(errInToOut) {
f.logger.Error2("proxyTCP: copy error (in → out) for %s: %v", epID(id), errInToOut)
}
}
if errOutToIn != nil {
if !isClosedError(errOutToIn) {
f.logger.Error2("proxyTCP: copy error (out → in) for %s: %v", epID(id), errOutToIn)
}
}
// Close the netstack endpoint after both conns are drained.
ep.Close()
var rxPackets, txPackets uint64
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
@@ -126,7 +85,9 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
txPackets = tcpStats.SegmentsReceived.Value()
}
f.logger.Trace5("forwarder: Removed TCP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, bytesFromOutToIn, txPackets, bytesFromInToOut)
if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace5("forwarder: Removed TCP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, bytesFromOutToIn, txPackets, bytesFromInToOut)
}
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, uint64(bytesFromOutToIn), uint64(bytesFromInToOut), rxPackets, txPackets)
}

View File

@@ -125,7 +125,9 @@ func (f *udpForwarder) cleanup() {
delete(f.conns, idle.id)
f.Unlock()
f.logger.Trace1("forwarder: cleaned up idle UDP connection %v", epID(idle.id))
if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace1("forwarder: cleaned up idle UDP connection %v", epID(idle.id))
}
}
}
}
@@ -144,7 +146,9 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
_, exists := f.udpForwarder.conns[id]
f.udpForwarder.RUnlock()
if exists {
f.logger.Trace1("forwarder: existing UDP connection for %v", epID(id))
if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace1("forwarder: existing UDP connection for %v", epID(id))
}
return true
}
@@ -206,7 +210,9 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
f.udpForwarder.Unlock()
success = true
f.logger.Trace1("forwarder: established UDP connection %v", epID(id))
if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace1("forwarder: established UDP connection %v", epID(id))
}
go f.proxyUDP(connCtx, pConn, id, ep)
return true
@@ -265,7 +271,9 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
txPackets = udpStats.PacketsReceived.Value()
}
f.logger.Trace5("forwarder: Removed UDP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, rxBytes, txPackets, txBytes)
if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace5("forwarder: Removed UDP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, rxBytes, txPackets, txBytes)
}
f.udpForwarder.Lock()
delete(f.udpForwarder.conns, id)

View File

@@ -53,16 +53,17 @@ var levelStrings = map[Level]string{
}
type logMessage struct {
level Level
format string
arg1 any
arg2 any
arg3 any
arg4 any
arg5 any
arg6 any
arg7 any
arg8 any
level Level
argCount uint8
format string
arg1 any
arg2 any
arg3 any
arg4 any
arg5 any
arg6 any
arg7 any
arg8 any
}
// Logger is a high-performance, non-blocking logger
@@ -107,6 +108,13 @@ func (l *Logger) SetLevel(level Level) {
log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level])
}
// Enabled reports whether the given level is currently logged. Callers on the
// hot path should guard log sites with this to avoid boxing arguments into
// any when the level is off.
func (l *Logger) Enabled(level Level) bool {
return l.level.Load() >= uint32(level)
}
func (l *Logger) Error(format string) {
if l.level.Load() >= uint32(LevelError) {
select {
@@ -155,7 +163,7 @@ func (l *Logger) Trace(format string) {
func (l *Logger) Error1(format string, arg1 any) {
if l.level.Load() >= uint32(LevelError) {
select {
case l.msgChannel <- logMessage{level: LevelError, format: format, arg1: arg1}:
case l.msgChannel <- logMessage{level: LevelError, argCount: 1, format: format, arg1: arg1}:
default:
}
}
@@ -164,7 +172,16 @@ func (l *Logger) Error1(format string, arg1 any) {
func (l *Logger) Error2(format string, arg1, arg2 any) {
if l.level.Load() >= uint32(LevelError) {
select {
case l.msgChannel <- logMessage{level: LevelError, format: format, arg1: arg1, arg2: arg2}:
case l.msgChannel <- logMessage{level: LevelError, argCount: 2, format: format, arg1: arg1, arg2: arg2}:
default:
}
}
}
func (l *Logger) Warn2(format string, arg1, arg2 any) {
if l.level.Load() >= uint32(LevelWarn) {
select {
case l.msgChannel <- logMessage{level: LevelWarn, argCount: 2, format: format, arg1: arg1, arg2: arg2}:
default:
}
}
@@ -173,7 +190,7 @@ func (l *Logger) Error2(format string, arg1, arg2 any) {
func (l *Logger) Warn3(format string, arg1, arg2, arg3 any) {
if l.level.Load() >= uint32(LevelWarn) {
select {
case l.msgChannel <- logMessage{level: LevelWarn, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
case l.msgChannel <- logMessage{level: LevelWarn, argCount: 3, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
default:
}
}
@@ -182,7 +199,7 @@ func (l *Logger) Warn3(format string, arg1, arg2, arg3 any) {
func (l *Logger) Warn4(format string, arg1, arg2, arg3, arg4 any) {
if l.level.Load() >= uint32(LevelWarn) {
select {
case l.msgChannel <- logMessage{level: LevelWarn, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}:
case l.msgChannel <- logMessage{level: LevelWarn, argCount: 4, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}:
default:
}
}
@@ -191,7 +208,7 @@ func (l *Logger) Warn4(format string, arg1, arg2, arg3, arg4 any) {
func (l *Logger) Debug1(format string, arg1 any) {
if l.level.Load() >= uint32(LevelDebug) {
select {
case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1}:
case l.msgChannel <- logMessage{level: LevelDebug, argCount: 1, format: format, arg1: arg1}:
default:
}
}
@@ -200,7 +217,7 @@ func (l *Logger) Debug1(format string, arg1 any) {
func (l *Logger) Debug2(format string, arg1, arg2 any) {
if l.level.Load() >= uint32(LevelDebug) {
select {
case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1, arg2: arg2}:
case l.msgChannel <- logMessage{level: LevelDebug, argCount: 2, format: format, arg1: arg1, arg2: arg2}:
default:
}
}
@@ -209,16 +226,59 @@ func (l *Logger) Debug2(format string, arg1, arg2 any) {
func (l *Logger) Debug3(format string, arg1, arg2, arg3 any) {
if l.level.Load() >= uint32(LevelDebug) {
select {
case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
case l.msgChannel <- logMessage{level: LevelDebug, argCount: 3, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
default:
}
}
}
// Debugf is the variadic shape. Dispatches to Debug/Debug1/Debug2/Debug3
// to avoid allocating an args slice on the fast path when the arg count is
// known (0-3). Args beyond 3 land on the general variadic path; callers on
// the hot path should prefer DebugN for known counts.
func (l *Logger) Debugf(format string, args ...any) {
if l.level.Load() < uint32(LevelDebug) {
return
}
switch len(args) {
case 0:
l.Debug(format)
case 1:
l.Debug1(format, args[0])
case 2:
l.Debug2(format, args[0], args[1])
case 3:
l.Debug3(format, args[0], args[1], args[2])
default:
l.sendVariadic(LevelDebug, format, args)
}
}
// sendVariadic packs a slice of arguments into a logMessage and non-blocking
// enqueues it. Used for arg counts beyond the fixed-arity fast paths. Args
// beyond the 8-arg slot limit are dropped so callers don't produce silently
// empty log lines via uint8 wraparound in argCount.
func (l *Logger) sendVariadic(level Level, format string, args []any) {
const maxArgs = 8
n := len(args)
if n > maxArgs {
n = maxArgs
}
msg := logMessage{level: level, argCount: uint8(n), format: format}
slots := [maxArgs]*any{&msg.arg1, &msg.arg2, &msg.arg3, &msg.arg4, &msg.arg5, &msg.arg6, &msg.arg7, &msg.arg8}
for i := 0; i < n; i++ {
*slots[i] = args[i]
}
select {
case l.msgChannel <- msg:
default:
}
}
func (l *Logger) Trace1(format string, arg1 any) {
if l.level.Load() >= uint32(LevelTrace) {
select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1}:
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 1, format: format, arg1: arg1}:
default:
}
}
@@ -227,7 +287,7 @@ func (l *Logger) Trace1(format string, arg1 any) {
func (l *Logger) Trace2(format string, arg1, arg2 any) {
if l.level.Load() >= uint32(LevelTrace) {
select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2}:
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 2, format: format, arg1: arg1, arg2: arg2}:
default:
}
}
@@ -236,7 +296,7 @@ func (l *Logger) Trace2(format string, arg1, arg2 any) {
func (l *Logger) Trace3(format string, arg1, arg2, arg3 any) {
if l.level.Load() >= uint32(LevelTrace) {
select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 3, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
default:
}
}
@@ -245,7 +305,7 @@ func (l *Logger) Trace3(format string, arg1, arg2, arg3 any) {
func (l *Logger) Trace4(format string, arg1, arg2, arg3, arg4 any) {
if l.level.Load() >= uint32(LevelTrace) {
select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}:
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 4, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}:
default:
}
}
@@ -254,7 +314,7 @@ func (l *Logger) Trace4(format string, arg1, arg2, arg3, arg4 any) {
func (l *Logger) Trace5(format string, arg1, arg2, arg3, arg4, arg5 any) {
if l.level.Load() >= uint32(LevelTrace) {
select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5}:
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 5, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5}:
default:
}
}
@@ -263,7 +323,7 @@ func (l *Logger) Trace5(format string, arg1, arg2, arg3, arg4, arg5 any) {
func (l *Logger) Trace6(format string, arg1, arg2, arg3, arg4, arg5, arg6 any) {
if l.level.Load() >= uint32(LevelTrace) {
select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6}:
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 6, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6}:
default:
}
}
@@ -273,7 +333,7 @@ func (l *Logger) Trace6(format string, arg1, arg2, arg3, arg4, arg5, arg6 any) {
func (l *Logger) Trace8(format string, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8 any) {
if l.level.Load() >= uint32(LevelTrace) {
select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6, arg7: arg7, arg8: arg8}:
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 8, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6, arg7: arg7, arg8: arg8}:
default:
}
}
@@ -286,35 +346,8 @@ func (l *Logger) formatMessage(buf *[]byte, msg logMessage) {
*buf = append(*buf, levelStrings[msg.level]...)
*buf = append(*buf, ' ')
// Count non-nil arguments for switch
argCount := 0
if msg.arg1 != nil {
argCount++
if msg.arg2 != nil {
argCount++
if msg.arg3 != nil {
argCount++
if msg.arg4 != nil {
argCount++
if msg.arg5 != nil {
argCount++
if msg.arg6 != nil {
argCount++
if msg.arg7 != nil {
argCount++
if msg.arg8 != nil {
argCount++
}
}
}
}
}
}
}
}
var formatted string
switch argCount {
switch msg.argCount {
case 0:
formatted = msg.format
case 1:

View File

@@ -11,6 +11,7 @@ import (
"github.com/google/gopacket/layers"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
)
var (
@@ -262,11 +263,15 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
}
if err := m.rewritePacketIP(packetData, d, translatedIP, false); err != nil {
m.logger.Error1("failed to rewrite packet destination: %v", err)
if m.logger.Enabled(nblog.LevelError) {
m.logger.Error1("failed to rewrite packet destination: %v", err)
}
return false
}
m.logger.Trace2("DNAT: %s -> %s", dstIP, translatedIP)
if m.logger.Enabled(nblog.LevelTrace) {
m.logger.Trace2("DNAT: %s -> %s", dstIP, translatedIP)
}
return true
}
@@ -283,11 +288,15 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
}
if err := m.rewritePacketIP(packetData, d, originalIP, true); err != nil {
m.logger.Error1("failed to rewrite packet source: %v", err)
if m.logger.Enabled(nblog.LevelError) {
m.logger.Error1("failed to rewrite packet source: %v", err)
}
return false
}
m.logger.Trace2("Reverse DNAT: %s -> %s", srcIP, originalIP)
if m.logger.Enabled(nblog.LevelTrace) {
m.logger.Trace2("Reverse DNAT: %s -> %s", srcIP, originalIP)
}
return true
}
@@ -612,7 +621,9 @@ func (m *Manager) applyPortRule(packetData []byte, d *decoder, srcIP, dstIP neti
}
if err := rewriteFn(packetData, d, rule.targetPort, destinationPortOffset); err != nil {
m.logger.Error1("failed to rewrite port: %v", err)
if m.logger.Enabled(nblog.LevelError) {
m.logger.Error1("failed to rewrite port: %v", err)
}
return false
}
d.dnatOrigPort = rule.origPort

View File

@@ -260,15 +260,23 @@ WriteRegStr ${REG_ROOT} "${UNINSTALL_PATH}" "Publisher" "${COMP_NAME}"
WriteRegStr ${REG_ROOT} "${UI_REG_APP_PATH}" "" "$INSTDIR\${UI_APP_EXE}"
; Create autostart registry entry based on checkbox
; Drop Run, App Paths and Uninstall entries left in the 32-bit registry view
; or HKCU by legacy installers.
DetailPrint "Cleaning legacy 32-bit / HKCU entries..."
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
SetRegView 32
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
DeleteRegKey HKLM "${REG_APP_PATH}"
DeleteRegKey HKLM "${UI_REG_APP_PATH}"
DeleteRegKey HKLM "${UNINSTALL_PATH}"
SetRegView 64
DetailPrint "Autostart enabled: $AutostartEnabled"
${If} $AutostartEnabled == "1"
WriteRegStr HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}" '"$INSTDIR\${UI_APP_EXE}.exe"'
DetailPrint "Added autostart registry entry: $INSTDIR\${UI_APP_EXE}.exe"
${Else}
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
; Legacy: pre-HKLM installs wrote to HKCU; clean that up too.
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
DetailPrint "Autostart not enabled by user"
${EndIf}
@@ -299,11 +307,16 @@ ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall'
DetailPrint "Terminating Netbird UI process..."
ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
; Remove autostart registry entry
; Remove autostart entries from every view a previous installer may have used.
DetailPrint "Removing autostart registry entry if exists..."
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
; Legacy: pre-HKLM installs wrote to HKCU; clean that up too.
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
SetRegView 32
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
DeleteRegKey HKLM "${REG_APP_PATH}"
DeleteRegKey HKLM "${UI_REG_APP_PATH}"
DeleteRegKey HKLM "${UNINSTALL_PATH}"
SetRegView 64
; Handle data deletion based on checkbox
DetailPrint "Checking if user requested data deletion..."

View File

@@ -116,7 +116,6 @@ func (c *ConnectClient) RunOniOS(
fileDescriptor int32,
networkChangeListener listener.NetworkChangeListener,
dnsManager dns.IosDnsManager,
dnsAddresses []netip.AddrPort,
stateFilePath string,
) error {
// Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension.
@@ -126,7 +125,6 @@ func (c *ConnectClient) RunOniOS(
FileDescriptor: fileDescriptor,
NetworkChangeListener: networkChangeListener,
DnsManager: dnsManager,
HostDNSAddresses: dnsAddresses,
StateFilePath: stateFilePath,
}
return c.run(mobileDependency, nil, "")

View File

@@ -45,8 +45,11 @@ netbird.out: Most recent, anonymized stdout log file of the NetBird client.
routes.txt: Detailed system routing table in tabular format including destination, gateway, interface, metrics, and protocol information, if --system-info flag was provided.
interfaces.txt: Anonymized network interface information, if --system-info flag was provided.
ip_rules.txt: Detailed IP routing rules in tabular format including priority, source, destination, interfaces, table, and action information (Linux only), if --system-info flag was provided.
iptables.txt: Anonymized iptables rules with packet counters, if --system-info flag was provided.
nftables.txt: Anonymized nftables rules with packet counters, if --system-info flag was provided.
iptables.txt: Anonymized iptables (IPv4) rules with packet counters, if --system-info flag was provided.
ip6tables.txt: Anonymized ip6tables (IPv6) rules with packet counters, if --system-info flag was provided.
ipset.txt: Anonymized ipset list output, if --system-info flag was provided.
nftables.txt: Anonymized nftables rules with packet counters across all families (ip, ip6, inet, etc.), if --system-info flag was provided.
sysctls.txt: Forwarding, reverse-path filter, source-validation, and conntrack accounting sysctl values that the NetBird client may read or modify, if --system-info flag was provided (Linux only).
resolv.conf: DNS resolver configuration from /etc/resolv.conf (Unix systems only), if --system-info flag was provided.
scutil_dns.txt: DNS configuration from scutil --dns (macOS only), if --system-info flag was provided.
resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder.
@@ -165,22 +168,33 @@ The config.txt file contains anonymized configuration information of the NetBird
Other non-sensitive configuration options are included without anonymization.
Firewall Rules (Linux only)
The bundle includes two separate firewall rule files:
The bundle includes the following firewall-related files:
iptables.txt:
- Complete iptables ruleset with packet counters using 'iptables -v -n -L'
- IPv4 iptables ruleset with packet counters using 'iptables-save' and 'iptables -v -n -L'
- Includes all tables (filter, nat, mangle, raw, security)
- Shows packet and byte counters for each rule
- All IP addresses are anonymized
- Chain names, table names, and other non-sensitive information remain unchanged
ip6tables.txt:
- IPv6 ip6tables ruleset with packet counters using 'ip6tables-save' and 'ip6tables -v -n -L'
- Same table coverage and anonymization as iptables.txt
- Omitted when ip6tables is not installed or no IPv6 rules are present
ipset.txt:
- Output of 'ipset list' (family-agnostic)
- IP addresses are anonymized; set names and types remain unchanged
nftables.txt:
- Complete nftables ruleset obtained via 'nft -a list ruleset'
- Complete nftables ruleset across all families (ip, ip6, inet, arp, bridge, netdev) via 'nft -a list ruleset'
- Includes rule handle numbers and packet counters
- All tables, chains, and rules are included
- Shows packet and byte counters for each rule
- All IP addresses are anonymized
- Chain names, table names, and other non-sensitive information remain unchanged
- All IP addresses are anonymized; chain/table names remain unchanged
sysctls.txt:
- Forwarding (IPv4 + IPv6, global and per-interface), reverse-path filter, source-validation, conntrack accounting, and TCP-related sysctls that netbird may read or modify
- Per-interface keys are enumerated from /proc/sys/net/ipv{4,6}/conf
- Interface names anonymized when --anonymize is set
IP Rules (Linux only)
The ip_rules.txt file contains detailed IP routing rule information:
@@ -412,6 +426,10 @@ func (g *BundleGenerator) addSystemInfo() {
log.Errorf("failed to add firewall rules to debug bundle: %v", err)
}
if err := g.addSysctls(); err != nil {
log.Errorf("failed to add sysctls to debug bundle: %v", err)
}
if err := g.addDNSInfo(); err != nil {
log.Errorf("failed to add DNS info to debug bundle: %v", err)
}

View File

@@ -124,15 +124,18 @@ func getSystemdLogs(serviceName string) (string, error) {
// addFirewallRules collects and adds firewall rules to the archive
func (g *BundleGenerator) addFirewallRules() error {
log.Info("Collecting firewall rules")
iptablesRules, err := collectIPTablesRules()
g.addIPTablesRulesToBundle("iptables-save", "iptables", "iptables.txt")
g.addIPTablesRulesToBundle("ip6tables-save", "ip6tables", "ip6tables.txt")
ipsetOutput, err := collectIPSets()
if err != nil {
log.Warnf("Failed to collect iptables rules: %v", err)
log.Warnf("Failed to collect ipset information: %v", err)
} else {
if g.anonymize {
iptablesRules = g.anonymizer.AnonymizeString(iptablesRules)
ipsetOutput = g.anonymizer.AnonymizeString(ipsetOutput)
}
if err := g.addFileToZip(strings.NewReader(iptablesRules), "iptables.txt"); err != nil {
log.Warnf("Failed to add iptables rules to bundle: %v", err)
if err := g.addFileToZip(strings.NewReader(ipsetOutput), "ipset.txt"); err != nil {
log.Warnf("Failed to add ipset output to bundle: %v", err)
}
}
@@ -151,44 +154,65 @@ func (g *BundleGenerator) addFirewallRules() error {
return nil
}
// collectIPTablesRules collects rules using both iptables-save and verbose listing
func collectIPTablesRules() (string, error) {
var builder strings.Builder
saveOutput, err := collectIPTablesSave()
// addIPTablesRulesToBundle collects iptables/ip6tables rules and writes them to the bundle.
func (g *BundleGenerator) addIPTablesRulesToBundle(saveBin, listBin, filename string) {
rules, err := collectIPTablesRules(saveBin, listBin)
if err != nil {
log.Warnf("Failed to collect iptables rules using iptables-save: %v", err)
} else {
builder.WriteString("=== iptables-save output ===\n")
log.Warnf("Failed to collect %s rules: %v", listBin, err)
return
}
if g.anonymize {
rules = g.anonymizer.AnonymizeString(rules)
}
if err := g.addFileToZip(strings.NewReader(rules), filename); err != nil {
log.Warnf("Failed to add %s rules to bundle: %v", listBin, err)
}
}
// collectIPTablesRules collects rules using both <saveBin> and verbose listing via <listBin>.
// Returns an error when neither command produced any output (e.g. the binary is missing),
// so the caller can skip writing an empty file.
func collectIPTablesRules(saveBin, listBin string) (string, error) {
var builder strings.Builder
var collected bool
var firstErr error
saveOutput, err := runCommand(saveBin)
switch {
case err != nil:
firstErr = err
log.Warnf("Failed to collect %s output: %v", saveBin, err)
case strings.TrimSpace(saveOutput) == "":
log.Debugf("%s produced no output, skipping", saveBin)
default:
builder.WriteString(fmt.Sprintf("=== %s output ===\n", saveBin))
builder.WriteString(saveOutput)
builder.WriteString("\n")
collected = true
}
ipsetOutput, err := collectIPSets()
if err != nil {
log.Warnf("Failed to collect ipset information: %v", err)
} else {
builder.WriteString("=== ipset list output ===\n")
builder.WriteString(ipsetOutput)
builder.WriteString("\n")
}
builder.WriteString("=== iptables -v -n -L output ===\n")
listHeader := fmt.Sprintf("=== %s -v -n -L output ===\n", listBin)
builder.WriteString(listHeader)
tables := []string{"filter", "nat", "mangle", "raw", "security"}
for _, table := range tables {
builder.WriteString(fmt.Sprintf("*%s\n", table))
stats, err := getTableStatistics(table)
stats, err := runCommand(listBin, "-v", "-n", "-L", "-t", table)
if err != nil {
log.Warnf("Failed to get statistics for table %s: %v", table, err)
if firstErr == nil {
firstErr = err
}
log.Warnf("Failed to get %s statistics for table %s: %v", listBin, table, err)
continue
}
builder.WriteString(fmt.Sprintf("*%s\n", table))
builder.WriteString(stats)
builder.WriteString("\n")
collected = true
}
if !collected {
return "", fmt.Errorf("collect %s rules: %w", listBin, firstErr)
}
return builder.String(), nil
}
@@ -214,34 +238,15 @@ func collectIPSets() (string, error) {
return ipsets, nil
}
// collectIPTablesSave uses iptables-save to get rule definitions
func collectIPTablesSave() (string, error) {
cmd := exec.Command("iptables-save")
// runCommand executes a command and returns its stdout, wrapping stderr in the error on failure.
func runCommand(name string, args ...string) (string, error) {
cmd := exec.Command(name, args...)
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
return "", fmt.Errorf("execute iptables-save: %w (stderr: %s)", err, stderr.String())
}
rules := stdout.String()
if strings.TrimSpace(rules) == "" {
return "", fmt.Errorf("no iptables rules found")
}
return rules, nil
}
// getTableStatistics gets verbose statistics for an entire table using iptables command
func getTableStatistics(table string) (string, error) {
cmd := exec.Command("iptables", "-v", "-n", "-L", "-t", table)
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
return "", fmt.Errorf("execute iptables -v -n -L: %w (stderr: %s)", err, stderr.String())
return "", fmt.Errorf("execute %s: %w (stderr: %s)", name, err, stderr.String())
}
return stdout.String(), nil
@@ -804,3 +809,91 @@ func formatSetKeyType(keyType nftables.SetDatatype) string {
return fmt.Sprintf("type-%v", keyType)
}
}
// addSysctls collects forwarding and netbird-managed sysctl values and writes them to the bundle.
func (g *BundleGenerator) addSysctls() error {
log.Info("Collecting sysctls")
content := collectSysctls()
if g.anonymize {
content = g.anonymizer.AnonymizeString(content)
}
if err := g.addFileToZip(strings.NewReader(content), "sysctls.txt"); err != nil {
return fmt.Errorf("add sysctls to bundle: %w", err)
}
return nil
}
// collectSysctls reads every sysctl that the netbird client may modify, plus
// global IPv4/IPv6 forwarding, and returns a formatted dump grouped by topic.
// Per-interface values are enumerated by listing /proc/sys/net/ipv{4,6}/conf.
func collectSysctls() string {
var builder strings.Builder
writeSysctlGroup(&builder, "forwarding", []string{
"net.ipv4.ip_forward",
"net.ipv6.conf.all.forwarding",
"net.ipv6.conf.default.forwarding",
})
writeSysctlGroup(&builder, "ipv4 per-interface forwarding", listInterfaceSysctls("ipv4", "forwarding"))
writeSysctlGroup(&builder, "ipv6 per-interface forwarding", listInterfaceSysctls("ipv6", "forwarding"))
writeSysctlGroup(&builder, "rp_filter", append(
[]string{"net.ipv4.conf.all.rp_filter", "net.ipv4.conf.default.rp_filter"},
listInterfaceSysctls("ipv4", "rp_filter")...,
))
writeSysctlGroup(&builder, "src_valid_mark", append(
[]string{"net.ipv4.conf.all.src_valid_mark", "net.ipv4.conf.default.src_valid_mark"},
listInterfaceSysctls("ipv4", "src_valid_mark")...,
))
writeSysctlGroup(&builder, "conntrack", []string{
"net.netfilter.nf_conntrack_acct",
"net.netfilter.nf_conntrack_tcp_loose",
})
writeSysctlGroup(&builder, "tcp", []string{
"net.ipv4.tcp_tw_reuse",
})
return builder.String()
}
func writeSysctlGroup(builder *strings.Builder, title string, keys []string) {
builder.WriteString(fmt.Sprintf("=== %s ===\n", title))
for _, key := range keys {
value, err := readSysctl(key)
if err != nil {
builder.WriteString(fmt.Sprintf("%s = <error: %v>\n", key, err))
continue
}
builder.WriteString(fmt.Sprintf("%s = %s\n", key, value))
}
builder.WriteString("\n")
}
// listInterfaceSysctls returns net.ipvX.conf.<iface>.<leaf> keys for every
// interface present in /proc/sys/net/ipvX/conf, skipping "all" and "default"
// (callers add those explicitly so they appear first).
func listInterfaceSysctls(family, leaf string) []string {
dir := fmt.Sprintf("/proc/sys/net/%s/conf", family)
entries, err := os.ReadDir(dir)
if err != nil {
return nil
}
var keys []string
for _, e := range entries {
name := e.Name()
if name == "all" || name == "default" {
continue
}
keys = append(keys, fmt.Sprintf("net.%s.conf.%s.%s", family, name, leaf))
}
sort.Strings(keys)
return keys
}
func readSysctl(key string) (string, error) {
path := fmt.Sprintf("/proc/sys/%s", strings.ReplaceAll(key, ".", "/"))
value, err := os.ReadFile(path)
if err != nil {
return "", err
}
return strings.TrimSpace(string(value)), nil
}

View File

@@ -17,3 +17,8 @@ func (g *BundleGenerator) addIPRules() error {
// IP rules are only supported on Linux
return nil
}
func (g *BundleGenerator) addSysctls() error {
// Sysctl collection is only supported on Linux
return nil
}

View File

@@ -16,6 +16,10 @@ type hostManager interface {
restoreHostDNS() error
supportCustomPort() bool
string() string
// getOriginalNameservers returns the OS-side resolvers used as PriorityFallback
// upstreams: pre-takeover snapshots on desktop, the OS-pushed list on Android,
// hardcoded Quad9 on iOS, nil for noop / mock.
getOriginalNameservers() []netip.Addr
}
type SystemDNSSettings struct {
@@ -131,3 +135,11 @@ func (n noopHostConfigurator) supportCustomPort() bool {
func (n noopHostConfigurator) string() string {
return "noop"
}
func (n noopHostConfigurator) getOriginalNameservers() []netip.Addr {
return nil
}
func (m *mockHostConfigurator) getOriginalNameservers() []netip.Addr {
return nil
}

View File

@@ -1,14 +1,20 @@
package dns
import (
"net/netip"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
// androidHostManager is a noop on the OS side (Android's VPN service handles
// DNS for us) but tracks the OS-reported resolver list pushed via
// OnUpdatedHostDNSServer so it can serve as the fallback nameserver source.
type androidHostManager struct {
holder *hostsDNSHolder
}
func newHostManager() (*androidHostManager, error) {
return &androidHostManager{}, nil
func newHostManager(holder *hostsDNSHolder) (*androidHostManager, error) {
return &androidHostManager{holder: holder}, nil
}
func (a androidHostManager) applyDNSConfig(HostDNSConfig, *statemanager.Manager) error {
@@ -26,3 +32,12 @@ func (a androidHostManager) supportCustomPort() bool {
func (a androidHostManager) string() string {
return "none"
}
func (a androidHostManager) getOriginalNameservers() []netip.Addr {
hosts := a.holder.get()
out := make([]netip.Addr, 0, len(hosts))
for ap := range hosts {
out = append(out, ap.Addr())
}
return out
}

View File

@@ -3,6 +3,7 @@ package dns
import (
"encoding/json"
"fmt"
"net/netip"
log "github.com/sirupsen/logrus"
@@ -20,6 +21,14 @@ func newHostManager(dnsManager IosDnsManager) (*iosHostManager, error) {
}, nil
}
func (a iosHostManager) getOriginalNameservers() []netip.Addr {
// Quad9 v4+v6: 9.9.9.9, 2620:fe::fe.
return []netip.Addr{
netip.AddrFrom4([4]byte{9, 9, 9, 9}),
netip.AddrFrom16([16]byte{0x26, 0x20, 0x00, 0xfe, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xfe}),
}
}
func (a iosHostManager) applyDNSConfig(config HostDNSConfig, _ *statemanager.Manager) error {
jsonData, err := json.Marshal(config)
if err != nil {

View File

@@ -7,6 +7,7 @@ import (
"io"
"net/netip"
"os/exec"
"slices"
"strings"
"syscall"
"time"
@@ -44,9 +45,11 @@ const (
nrptMaxDomainsPerRule = 50
interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces`
interfaceConfigNameServerKey = "NameServer"
interfaceConfigSearchListKey = "SearchList"
interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces`
interfaceConfigPathV6 = `SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces`
interfaceConfigNameServerKey = "NameServer"
interfaceConfigDhcpNameSrvKey = "DhcpNameServer"
interfaceConfigSearchListKey = "SearchList"
// Network interface DNS registration settings
disableDynamicUpdateKey = "DisableDynamicUpdate"
@@ -67,10 +70,11 @@ const (
)
type registryConfigurator struct {
guid string
routingAll bool
gpo bool
nrptEntryCount int
guid string
routingAll bool
gpo bool
nrptEntryCount int
origNameservers []netip.Addr
}
func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
@@ -94,6 +98,17 @@ func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
gpo: useGPO,
}
origNameservers, err := configurator.captureOriginalNameservers()
switch {
case err != nil:
log.Warnf("capture original nameservers from non-WG adapters: %v", err)
case len(origNameservers) == 0:
log.Warnf("no original nameservers captured from non-WG adapters; DNS fallback will be empty")
default:
log.Debugf("captured %d original nameservers from non-WG adapters: %v", len(origNameservers), origNameservers)
}
configurator.origNameservers = origNameservers
if err := configurator.configureInterface(); err != nil {
log.Errorf("failed to configure interface settings: %v", err)
}
@@ -101,6 +116,98 @@ func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
return configurator, nil
}
// captureOriginalNameservers reads DNS addresses from every Tcpip(6) interface
// registry key except the WG adapter. v4 and v6 servers live in separate
// hives (Tcpip vs Tcpip6) keyed by the same interface GUID.
func (r *registryConfigurator) captureOriginalNameservers() ([]netip.Addr, error) {
seen := make(map[netip.Addr]struct{})
var out []netip.Addr
var merr *multierror.Error
for _, root := range []string{interfaceConfigPath, interfaceConfigPathV6} {
addrs, err := r.captureFromTcpipRoot(root)
if err != nil {
merr = multierror.Append(merr, fmt.Errorf("%s: %w", root, err))
continue
}
for _, addr := range addrs {
if _, dup := seen[addr]; dup {
continue
}
seen[addr] = struct{}{}
out = append(out, addr)
}
}
return out, nberrors.FormatErrorOrNil(merr)
}
func (r *registryConfigurator) captureFromTcpipRoot(rootPath string) ([]netip.Addr, error) {
root, err := registry.OpenKey(registry.LOCAL_MACHINE, rootPath, registry.READ)
if err != nil {
return nil, fmt.Errorf("open key: %w", err)
}
defer closer(root)
guids, err := root.ReadSubKeyNames(-1)
if err != nil {
return nil, fmt.Errorf("read subkeys: %w", err)
}
var out []netip.Addr
for _, guid := range guids {
if strings.EqualFold(guid, r.guid) {
continue
}
out = append(out, readInterfaceNameservers(rootPath, guid)...)
}
return out, nil
}
func readInterfaceNameservers(rootPath, guid string) []netip.Addr {
keyPath := rootPath + "\\" + guid
k, err := registry.OpenKey(registry.LOCAL_MACHINE, keyPath, registry.QUERY_VALUE)
if err != nil {
return nil
}
defer closer(k)
// Static NameServer wins over DhcpNameServer for actual resolution.
for _, name := range []string{interfaceConfigNameServerKey, interfaceConfigDhcpNameSrvKey} {
raw, _, err := k.GetStringValue(name)
if err != nil || raw == "" {
continue
}
if out := parseRegistryNameservers(raw); len(out) > 0 {
return out
}
}
return nil
}
func parseRegistryNameservers(raw string) []netip.Addr {
var out []netip.Addr
for _, field := range strings.FieldsFunc(raw, func(r rune) bool { return r == ',' || r == ' ' || r == '\t' }) {
addr, err := netip.ParseAddr(strings.TrimSpace(field))
if err != nil {
continue
}
addr = addr.Unmap()
if !addr.IsValid() || addr.IsUnspecified() {
continue
}
// Drop unzoned link-local: not routable without a scope id. If
// the user wrote "fe80::1%eth0" ParseAddr preserves the zone.
if addr.IsLinkLocalUnicast() && addr.Zone() == "" {
continue
}
out = append(out, addr)
}
return out
}
func (r *registryConfigurator) getOriginalNameservers() []netip.Addr {
return slices.Clone(r.origNameservers)
}
func (r *registryConfigurator) supportCustomPort() bool {
return false
}

View File

@@ -25,6 +25,7 @@ func (h *hostsDNSHolder) set(list []netip.AddrPort) {
h.mutex.Unlock()
}
//nolint:unused
func (h *hostsDNSHolder) get() map[netip.AddrPort]struct{} {
h.mutex.RLock()
l := h.unprotectedDNSList

View File

@@ -76,8 +76,6 @@ func (d *Resolver) ID() types.HandlerID {
return "local-resolver"
}
func (d *Resolver) ProbeAvailability(context.Context) {}
// ServeDNS handles a DNS request
func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
logger := log.WithFields(log.Fields{

View File

@@ -9,6 +9,7 @@ import (
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
)
@@ -70,10 +71,6 @@ func (m *MockServer) SearchDomains() []string {
return make([]string, 0)
}
// ProbeAvailability mocks implementation of ProbeAvailability from the Server interface
func (m *MockServer) ProbeAvailability() {
}
func (m *MockServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {
if m.UpdateServerConfigFunc != nil {
return m.UpdateServerConfigFunc(domains)
@@ -85,8 +82,8 @@ func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error {
return nil
}
// SetRouteChecker mock implementation of SetRouteChecker from Server interface
func (m *MockServer) SetRouteChecker(func(netip.Addr) bool) {
// SetRouteSources mock implementation of SetRouteSources from Server interface
func (m *MockServer) SetRouteSources(selected, active func() route.HAMap) {
// Mock implementation - no-op
}

View File

@@ -8,6 +8,7 @@ import (
"errors"
"fmt"
"net/netip"
"slices"
"strings"
"time"
@@ -32,6 +33,15 @@ const (
networkManagerDbusDeviceGetAppliedConnectionMethod = networkManagerDbusDeviceInterface + ".GetAppliedConnection"
networkManagerDbusDeviceReapplyMethod = networkManagerDbusDeviceInterface + ".Reapply"
networkManagerDbusDeviceDeleteMethod = networkManagerDbusDeviceInterface + ".Delete"
networkManagerDbusDeviceIp4ConfigProperty = networkManagerDbusDeviceInterface + ".Ip4Config"
networkManagerDbusDeviceIp6ConfigProperty = networkManagerDbusDeviceInterface + ".Ip6Config"
networkManagerDbusDeviceIfaceProperty = networkManagerDbusDeviceInterface + ".Interface"
networkManagerDbusGetDevicesMethod = networkManagerDest + ".GetDevices"
networkManagerDbusIp4ConfigInterface = "org.freedesktop.NetworkManager.IP4Config"
networkManagerDbusIp6ConfigInterface = "org.freedesktop.NetworkManager.IP6Config"
networkManagerDbusIp4ConfigNameserverDataProperty = networkManagerDbusIp4ConfigInterface + ".NameserverData"
networkManagerDbusIp4ConfigNameserversProperty = networkManagerDbusIp4ConfigInterface + ".Nameservers"
networkManagerDbusIp6ConfigNameserversProperty = networkManagerDbusIp6ConfigInterface + ".Nameservers"
networkManagerDbusDefaultBehaviorFlag networkManagerConfigBehavior = 0
networkManagerDbusIPv4Key = "ipv4"
networkManagerDbusIPv6Key = "ipv6"
@@ -51,9 +61,10 @@ var supportedNetworkManagerVersionConstraints = []string{
}
type networkManagerDbusConfigurator struct {
dbusLinkObject dbus.ObjectPath
routingAll bool
ifaceName string
dbusLinkObject dbus.ObjectPath
routingAll bool
ifaceName string
origNameservers []netip.Addr
}
// the types below are based on dbus specification, each field is mapped to a dbus type
@@ -92,10 +103,200 @@ func newNetworkManagerDbusConfigurator(wgInterface string) (*networkManagerDbusC
log.Debugf("got network manager dbus Link Object: %s from net interface %s", s, wgInterface)
return &networkManagerDbusConfigurator{
c := &networkManagerDbusConfigurator{
dbusLinkObject: dbus.ObjectPath(s),
ifaceName: wgInterface,
}, nil
}
origNameservers, err := c.captureOriginalNameservers()
switch {
case err != nil:
log.Warnf("capture original nameservers from NetworkManager: %v", err)
case len(origNameservers) == 0:
log.Warnf("no original nameservers captured from non-WG NetworkManager devices; DNS fallback will be empty")
default:
log.Debugf("captured %d original nameservers from non-WG NetworkManager devices: %v", len(origNameservers), origNameservers)
}
c.origNameservers = origNameservers
return c, nil
}
// captureOriginalNameservers reads DNS servers from every NM device's
// IP4Config / IP6Config except our WG device.
func (n *networkManagerDbusConfigurator) captureOriginalNameservers() ([]netip.Addr, error) {
devices, err := networkManagerListDevices()
if err != nil {
return nil, fmt.Errorf("list devices: %w", err)
}
seen := make(map[netip.Addr]struct{})
var out []netip.Addr
for _, dev := range devices {
if dev == n.dbusLinkObject {
continue
}
ifaceName := readNetworkManagerDeviceInterface(dev)
for _, addr := range readNetworkManagerDeviceDNS(dev) {
addr = addr.Unmap()
if !addr.IsValid() || addr.IsUnspecified() {
continue
}
// IP6Config.Nameservers is a byte slice without zone info;
// reattach the device's interface name so a captured fe80::…
// stays routable.
if addr.IsLinkLocalUnicast() && ifaceName != "" {
addr = addr.WithZone(ifaceName)
}
if _, dup := seen[addr]; dup {
continue
}
seen[addr] = struct{}{}
out = append(out, addr)
}
}
return out, nil
}
func readNetworkManagerDeviceInterface(devicePath dbus.ObjectPath) string {
obj, closeConn, err := getDbusObject(networkManagerDest, devicePath)
if err != nil {
return ""
}
defer closeConn()
v, err := obj.GetProperty(networkManagerDbusDeviceIfaceProperty)
if err != nil {
return ""
}
s, _ := v.Value().(string)
return s
}
func networkManagerListDevices() ([]dbus.ObjectPath, error) {
obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode)
if err != nil {
return nil, fmt.Errorf("dbus NetworkManager: %w", err)
}
defer closeConn()
var devs []dbus.ObjectPath
if err := obj.Call(networkManagerDbusGetDevicesMethod, dbusDefaultFlag).Store(&devs); err != nil {
return nil, err
}
return devs, nil
}
func readNetworkManagerDeviceDNS(devicePath dbus.ObjectPath) []netip.Addr {
obj, closeConn, err := getDbusObject(networkManagerDest, devicePath)
if err != nil {
return nil
}
defer closeConn()
var out []netip.Addr
if path := readNetworkManagerConfigPath(obj, networkManagerDbusDeviceIp4ConfigProperty); path != "" {
out = append(out, readIPv4ConfigDNS(path)...)
}
if path := readNetworkManagerConfigPath(obj, networkManagerDbusDeviceIp6ConfigProperty); path != "" {
out = append(out, readIPv6ConfigDNS(path)...)
}
return out
}
func readNetworkManagerConfigPath(obj dbus.BusObject, property string) dbus.ObjectPath {
v, err := obj.GetProperty(property)
if err != nil {
return ""
}
path, ok := v.Value().(dbus.ObjectPath)
if !ok || path == "/" {
return ""
}
return path
}
func readIPv4ConfigDNS(path dbus.ObjectPath) []netip.Addr {
obj, closeConn, err := getDbusObject(networkManagerDest, path)
if err != nil {
return nil
}
defer closeConn()
// NameserverData (NM 1.13+) carries strings; older NMs only expose the
// legacy uint32 Nameservers property.
if out := readIPv4NameserverData(obj); len(out) > 0 {
return out
}
return readIPv4LegacyNameservers(obj)
}
func readIPv4NameserverData(obj dbus.BusObject) []netip.Addr {
v, err := obj.GetProperty(networkManagerDbusIp4ConfigNameserverDataProperty)
if err != nil {
return nil
}
entries, ok := v.Value().([]map[string]dbus.Variant)
if !ok {
return nil
}
var out []netip.Addr
for _, entry := range entries {
addrVar, ok := entry["address"]
if !ok {
continue
}
s, ok := addrVar.Value().(string)
if !ok {
continue
}
if a, err := netip.ParseAddr(s); err == nil {
out = append(out, a)
}
}
return out
}
func readIPv4LegacyNameservers(obj dbus.BusObject) []netip.Addr {
v, err := obj.GetProperty(networkManagerDbusIp4ConfigNameserversProperty)
if err != nil {
return nil
}
raw, ok := v.Value().([]uint32)
if !ok {
return nil
}
out := make([]netip.Addr, 0, len(raw))
for _, n := range raw {
var b [4]byte
binary.LittleEndian.PutUint32(b[:], n)
out = append(out, netip.AddrFrom4(b))
}
return out
}
func readIPv6ConfigDNS(path dbus.ObjectPath) []netip.Addr {
obj, closeConn, err := getDbusObject(networkManagerDest, path)
if err != nil {
return nil
}
defer closeConn()
v, err := obj.GetProperty(networkManagerDbusIp6ConfigNameserversProperty)
if err != nil {
return nil
}
raw, ok := v.Value().([][]byte)
if !ok {
return nil
}
out := make([]netip.Addr, 0, len(raw))
for _, b := range raw {
if a, ok := netip.AddrFromSlice(b); ok {
out = append(out, a)
}
}
return out
}
func (n *networkManagerDbusConfigurator) getOriginalNameservers() []netip.Addr {
return slices.Clone(n.origNameservers)
}
func (n *networkManagerDbusConfigurator) supportCustomPort() bool {

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,5 @@
package dns
func (s *DefaultServer) initialize() (manager hostManager, err error) {
return newHostManager()
return newHostManager(s.hostsDNSHolder)
}

View File

@@ -6,7 +6,7 @@ import (
"net"
"net/netip"
"os"
"strings"
"runtime"
"testing"
"time"
@@ -15,6 +15,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/tun/netstack"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@@ -31,8 +32,10 @@ import (
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/proto"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
)
@@ -101,16 +104,17 @@ func init() {
formatter.SetTextFormatter(log.StandardLogger())
}
func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamResolverBase {
func generateDummyHandler(d string, servers []nbdns.NameServer) *upstreamResolverBase {
var srvs []netip.AddrPort
for _, srv := range servers {
srvs = append(srvs, srv.AddrPort())
}
return &upstreamResolverBase{
domain: domain,
upstreamServers: srvs,
cancel: func() {},
u := &upstreamResolverBase{
domain: domain.Domain(d),
cancel: func() {},
}
u.addRace(srvs)
return u
}
func TestUpdateDNSServer(t *testing.T) {
@@ -653,74 +657,8 @@ func TestDNSServerStartStop(t *testing.T) {
}
}
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
hostManager := &mockHostConfigurator{}
server := DefaultServer{
ctx: context.Background(),
service: NewServiceViaMemory(&mocWGIface{}),
localResolver: local.NewResolver(),
handlerChain: NewHandlerChain(),
hostManager: hostManager,
currentConfig: HostDNSConfig{
Domains: []DomainConfig{
{false, "domain0", false},
{false, "domain1", false},
{false, "domain2", false},
},
},
statusRecorder: peer.NewRecorder("mgm"),
}
var domainsUpdate string
hostManager.applyDNSConfigFunc = func(config HostDNSConfig, statemanager *statemanager.Manager) error {
domains := []string{}
for _, item := range config.Domains {
if item.Disabled {
continue
}
domains = append(domains, item.Domain)
}
domainsUpdate = strings.Join(domains, ",")
return nil
}
deactivate, reactivate := server.upstreamCallbacks(&nbdns.NameServerGroup{
Domains: []string{"domain1"},
NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("8.8.0.0"), NSType: nbdns.UDPNameServerType, Port: 53},
},
}, nil, 0)
deactivate(nil)
expected := "domain0,domain2"
domains := []string{}
for _, item := range server.currentConfig.Domains {
if item.Disabled {
continue
}
domains = append(domains, item.Domain)
}
got := strings.Join(domains, ",")
if expected != got {
t.Errorf("expected domains list: %q, got %q", expected, got)
}
reactivate()
expected = "domain0,domain1,domain2"
domains = []string{}
for _, item := range server.currentConfig.Domains {
if item.Disabled {
continue
}
domains = append(domains, item.Domain)
}
got = strings.Join(domains, ",")
if expected != got {
t.Errorf("expected domains list: %q, got %q", expected, domainsUpdate)
}
}
func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
skipUnlessAndroid(t)
wgIFace, err := createWgInterfaceWithBind(t)
if err != nil {
t.Fatal("failed to initialize wg interface")
@@ -748,6 +686,7 @@ func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
}
func TestDNSPermanent_updateUpstream(t *testing.T) {
skipUnlessAndroid(t)
wgIFace, err := createWgInterfaceWithBind(t)
if err != nil {
t.Fatal("failed to initialize wg interface")
@@ -841,6 +780,7 @@ func TestDNSPermanent_updateUpstream(t *testing.T) {
}
func TestDNSPermanent_matchOnly(t *testing.T) {
skipUnlessAndroid(t)
wgIFace, err := createWgInterfaceWithBind(t)
if err != nil {
t.Fatal("failed to initialize wg interface")
@@ -913,6 +853,18 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
}
}
// skipUnlessAndroid marks tests that exercise the mobile-permanent DNS path,
// which only matches a real production setup on android (NewDefaultServerPermanentUpstream
// + androidHostManager). On non-android the desktop host manager replaces it
// during Initialize and the assertion stops making sense. Skipped here until we
// have an android CI runner.
func skipUnlessAndroid(t *testing.T) {
t.Helper()
if runtime.GOOS != "android" {
t.Skip("requires android runner; mobile-permanent path doesn't match production on this OS")
}
}
func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
t.Helper()
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
@@ -1065,7 +1017,6 @@ type mockHandler struct {
func (m *mockHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {}
func (m *mockHandler) Stop() {}
func (m *mockHandler) ProbeAvailability(context.Context) {}
func (m *mockHandler) ID() types.HandlerID { return types.HandlerID(m.Id) }
type mockService struct{}
@@ -2085,6 +2036,598 @@ func TestLocalResolverPriorityConstants(t *testing.T) {
assert.Equal(t, "local.example.com", localMuxUpdates[0].domain)
}
// TestBuildUpstreamHandler_MergesGroupsPerDomain verifies that multiple
// admin-defined nameserver groups targeting the same domain collapse into a
// single handler with each group preserved as a sequential inner list.
func TestBuildUpstreamHandler_MergesGroupsPerDomain(t *testing.T) {
wgInterface := &mocWGIface{}
service := NewServiceViaMemory(wgInterface)
server := &DefaultServer{
ctx: context.Background(),
wgInterface: wgInterface,
service: service,
localResolver: local.NewResolver(),
handlerChain: NewHandlerChain(),
hostManager: &noopHostConfigurator{},
dnsMuxMap: make(registeredHandlerMap),
}
groups := []*nbdns.NameServerGroup{
{
NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("192.0.2.1"), NSType: nbdns.UDPNameServerType, Port: 53},
},
Domains: []string{"example.com"},
},
{
NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("192.0.2.2"), NSType: nbdns.UDPNameServerType, Port: 53},
{IP: netip.MustParseAddr("192.0.2.3"), NSType: nbdns.UDPNameServerType, Port: 53},
},
Domains: []string{"example.com"},
},
}
muxUpdates, err := server.buildUpstreamHandlerUpdate(groups)
require.NoError(t, err)
require.Len(t, muxUpdates, 1, "same-domain groups should merge into one handler")
assert.Equal(t, "example.com", muxUpdates[0].domain)
assert.Equal(t, PriorityUpstream, muxUpdates[0].priority)
handler := muxUpdates[0].handler.(*upstreamResolver)
require.Len(t, handler.upstreamServers, 2, "handler should have two groups")
assert.Equal(t, upstreamRace{netip.MustParseAddrPort("192.0.2.1:53")}, handler.upstreamServers[0])
assert.Equal(t, upstreamRace{
netip.MustParseAddrPort("192.0.2.2:53"),
netip.MustParseAddrPort("192.0.2.3:53"),
}, handler.upstreamServers[1])
}
// TestEvaluateNSGroupHealth covers the records-only verdict. The gate
// (overlay route selected-but-no-active-peer) is intentionally NOT an
// input to the evaluator anymore: the verdict drives the Enabled flag,
// which must always reflect what we actually observed. Gate-aware event
// suppression is tested separately in the projection test.
//
// Matrix per upstream: {no record, fresh Ok, fresh Fail, stale Fail,
// stale Ok, Ok newer than Fail, Fail newer than Ok}.
// Group verdict: any fresh-working → Healthy; any fresh-broken with no
// fresh-working → Unhealthy; otherwise Undecided.
func TestEvaluateNSGroupHealth(t *testing.T) {
now := time.Now()
a := netip.MustParseAddrPort("192.0.2.1:53")
b := netip.MustParseAddrPort("192.0.2.2:53")
recentOk := UpstreamHealth{LastOk: now.Add(-2 * time.Second)}
recentFail := UpstreamHealth{LastFail: now.Add(-1 * time.Second), LastErr: "timeout"}
staleOk := UpstreamHealth{LastOk: now.Add(-10 * time.Minute)}
staleFail := UpstreamHealth{LastFail: now.Add(-10 * time.Minute), LastErr: "timeout"}
okThenFail := UpstreamHealth{
LastOk: now.Add(-10 * time.Second),
LastFail: now.Add(-1 * time.Second),
LastErr: "timeout",
}
failThenOk := UpstreamHealth{
LastOk: now.Add(-1 * time.Second),
LastFail: now.Add(-10 * time.Second),
LastErr: "timeout",
}
tests := []struct {
name string
health map[netip.AddrPort]UpstreamHealth
servers []netip.AddrPort
wantVerdict nsGroupVerdict
wantErrSubst string
}{
{
name: "no record, undecided",
servers: []netip.AddrPort{a},
wantVerdict: nsVerdictUndecided,
},
{
name: "fresh success, healthy",
health: map[netip.AddrPort]UpstreamHealth{a: recentOk},
servers: []netip.AddrPort{a},
wantVerdict: nsVerdictHealthy,
},
{
name: "fresh failure, unhealthy",
health: map[netip.AddrPort]UpstreamHealth{a: recentFail},
servers: []netip.AddrPort{a},
wantVerdict: nsVerdictUnhealthy,
wantErrSubst: "timeout",
},
{
name: "only stale success, undecided",
health: map[netip.AddrPort]UpstreamHealth{a: staleOk},
servers: []netip.AddrPort{a},
wantVerdict: nsVerdictUndecided,
},
{
name: "only stale failure, undecided",
health: map[netip.AddrPort]UpstreamHealth{a: staleFail},
servers: []netip.AddrPort{a},
wantVerdict: nsVerdictUndecided,
},
{
name: "both fresh, fail newer, unhealthy",
health: map[netip.AddrPort]UpstreamHealth{a: okThenFail},
servers: []netip.AddrPort{a},
wantVerdict: nsVerdictUnhealthy,
wantErrSubst: "timeout",
},
{
name: "both fresh, ok newer, healthy",
health: map[netip.AddrPort]UpstreamHealth{a: failThenOk},
servers: []netip.AddrPort{a},
wantVerdict: nsVerdictHealthy,
},
{
name: "two upstreams, one success wins",
health: map[netip.AddrPort]UpstreamHealth{
a: recentFail,
b: recentOk,
},
servers: []netip.AddrPort{a, b},
wantVerdict: nsVerdictHealthy,
},
{
name: "two upstreams, one fail one unseen, unhealthy",
health: map[netip.AddrPort]UpstreamHealth{
a: recentFail,
},
servers: []netip.AddrPort{a, b},
wantVerdict: nsVerdictUnhealthy,
wantErrSubst: "timeout",
},
{
name: "two upstreams, all recent failures, unhealthy",
health: map[netip.AddrPort]UpstreamHealth{
a: {LastFail: now.Add(-5 * time.Second), LastErr: "timeout"},
b: {LastFail: now.Add(-1 * time.Second), LastErr: "SERVFAIL"},
},
servers: []netip.AddrPort{a, b},
wantVerdict: nsVerdictUnhealthy,
wantErrSubst: "SERVFAIL",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
verdict, err := evaluateNSGroupHealth(tc.health, tc.servers, now)
assert.Equal(t, tc.wantVerdict, verdict, "verdict mismatch")
if tc.wantErrSubst != "" {
require.Error(t, err)
assert.Contains(t, err.Error(), tc.wantErrSubst)
} else {
assert.NoError(t, err)
}
})
}
}
// healthStubHandler is a minimal dnsMuxMap entry that exposes a fixed
// UpstreamHealth snapshot, letting tests drive recomputeNSGroupStates
// without spinning up real handlers.
type healthStubHandler struct {
health map[netip.AddrPort]UpstreamHealth
}
func (h *healthStubHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {}
func (h *healthStubHandler) Stop() {}
func (h *healthStubHandler) ID() types.HandlerID { return "health-stub" }
func (h *healthStubHandler) UpstreamHealth() map[netip.AddrPort]UpstreamHealth {
return h.health
}
// TestProjection_SteadyStateIsSilent guards against duplicate events:
// while a group stays Unhealthy tick after tick, only the first
// Unhealthy transition may emit. Same for staying Healthy.
func TestProjection_SteadyStateIsSilent(t *testing.T) {
fx := newProjTestFixture(t)
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
fx.tick()
fx.expectEvent("unreachable", "first fail emits warning")
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
fx.tick()
fx.tick()
fx.expectNoEvent("staying unhealthy must not re-emit")
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
fx.tick()
fx.expectEvent("recovered", "recovery on transition")
fx.tick()
fx.tick()
fx.expectNoEvent("staying healthy must not re-emit")
}
// projTestFixture is the common setup for the projection tests: a
// single-upstream group whose route classification the test can flip by
// assigning to selected/active. Callers drive failures/successes by
// mutating stub.health and calling refreshHealth.
type projTestFixture struct {
t *testing.T
recorder *peer.Status
events <-chan *proto.SystemEvent
server *DefaultServer
stub *healthStubHandler
group *nbdns.NameServerGroup
srv netip.AddrPort
selected route.HAMap
active route.HAMap
}
func newProjTestFixture(t *testing.T) *projTestFixture {
t.Helper()
recorder := peer.NewRecorder("mgm")
sub := recorder.SubscribeToEvents()
t.Cleanup(func() { recorder.UnsubscribeFromEvents(sub) })
srv := netip.MustParseAddrPort("100.64.0.1:53")
fx := &projTestFixture{
t: t,
recorder: recorder,
events: sub.Events(),
stub: &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{}},
srv: srv,
group: &nbdns.NameServerGroup{
Domains: []string{"example.com"},
NameServers: []nbdns.NameServer{{IP: srv.Addr(), NSType: nbdns.UDPNameServerType, Port: int(srv.Port())}},
},
}
fx.server = &DefaultServer{
ctx: context.Background(),
wgInterface: &mocWGIface{},
statusRecorder: recorder,
dnsMuxMap: make(registeredHandlerMap),
selectedRoutes: func() route.HAMap { return fx.selected },
activeRoutes: func() route.HAMap { return fx.active },
warningDelayBase: defaultWarningDelayBase,
}
fx.server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: fx.stub, priority: PriorityUpstream}
fx.server.mux.Lock()
fx.server.updateNSGroupStates([]*nbdns.NameServerGroup{fx.group})
fx.server.mux.Unlock()
return fx
}
func (f *projTestFixture) setHealth(h UpstreamHealth) {
f.stub.health = map[netip.AddrPort]UpstreamHealth{f.srv: h}
}
func (f *projTestFixture) tick() []peer.NSGroupState {
f.server.refreshHealth()
return f.recorder.GetDNSStates()
}
func (f *projTestFixture) expectNoEvent(why string) {
f.t.Helper()
select {
case evt := <-f.events:
f.t.Fatalf("unexpected event (%s): %+v", why, evt)
case <-time.After(100 * time.Millisecond):
}
}
func (f *projTestFixture) expectEvent(substr, why string) *proto.SystemEvent {
f.t.Helper()
select {
case evt := <-f.events:
assert.Contains(f.t, evt.Message, substr, why)
return evt
case <-time.After(time.Second):
f.t.Fatalf("expected event (%s) with %q", why, substr)
return nil
}
}
var overlayNetForTest = netip.MustParsePrefix("100.64.0.0/16")
var overlayMapForTest = route.HAMap{"overlay": {{Network: overlayNetForTest}}}
// TestProjection_PublicFailEmitsImmediately covers rule 1: an upstream
// that is not inside any selected route (public DNS) fires the warning
// on the first Unhealthy tick, no grace period.
func TestProjection_PublicFailEmitsImmediately(t *testing.T) {
fx := newProjTestFixture(t)
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
states := fx.tick()
require.Len(t, states, 1)
assert.False(t, states[0].Enabled)
fx.expectEvent("unreachable", "public DNS failure")
}
// TestProjection_OverlayConnectedFailEmitsImmediately covers rule 2:
// the upstream is inside a selected route AND the route has a Connected
// peer. Tunnel is up, failure is real, emit immediately.
func TestProjection_OverlayConnectedFailEmitsImmediately(t *testing.T) {
fx := newProjTestFixture(t)
fx.selected = overlayMapForTest
fx.active = overlayMapForTest
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
states := fx.tick()
require.Len(t, states, 1)
assert.False(t, states[0].Enabled)
fx.expectEvent("unreachable", "overlay + connected failure")
}
// TestProjection_OverlayNotConnectedDelaysWarning covers rule 3: the
// upstream is routed but no peer is Connected (Connecting/Idle/missing).
// First tick: Unhealthy display, no warning. After the grace window
// elapses with no recovery, the warning fires.
func TestProjection_OverlayNotConnectedDelaysWarning(t *testing.T) {
grace := 50 * time.Millisecond
fx := newProjTestFixture(t)
fx.server.warningDelayBase = grace
fx.selected = overlayMapForTest
// active stays nil: routed but not connected.
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
states := fx.tick()
require.Len(t, states, 1)
assert.False(t, states[0].Enabled, "display must reflect failure even during grace window")
fx.expectNoEvent("first fail tick within grace window")
time.Sleep(grace + 10*time.Millisecond)
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
fx.tick()
fx.expectEvent("unreachable", "warning after grace window")
}
// TestProjection_OverlayAddrNoRouteDelaysWarning covers an upstream
// whose address is inside the WireGuard overlay range but is not
// covered by any selected route (peer-to-peer DNS without an explicit
// route). Until a peer reports Connected for that address, startup
// failures must be held just like the routed case.
func TestProjection_OverlayAddrNoRouteDelaysWarning(t *testing.T) {
recorder := peer.NewRecorder("mgm")
sub := recorder.SubscribeToEvents()
t.Cleanup(func() { recorder.UnsubscribeFromEvents(sub) })
overlayPeer := netip.MustParseAddrPort("100.66.100.5:53")
server := &DefaultServer{
ctx: context.Background(),
wgInterface: &mocWGIface{},
statusRecorder: recorder,
dnsMuxMap: make(registeredHandlerMap),
selectedRoutes: func() route.HAMap { return nil },
activeRoutes: func() route.HAMap { return nil },
warningDelayBase: 50 * time.Millisecond,
}
group := &nbdns.NameServerGroup{
Domains: []string{"example.com"},
NameServers: []nbdns.NameServer{{IP: overlayPeer.Addr(), NSType: nbdns.UDPNameServerType, Port: int(overlayPeer.Port())}},
}
stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{
overlayPeer: {LastFail: time.Now(), LastErr: "timeout"},
}}
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
server.mux.Lock()
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
server.mux.Unlock()
server.refreshHealth()
select {
case evt := <-sub.Events():
t.Fatalf("unexpected event during grace window: %+v", evt)
case <-time.After(100 * time.Millisecond):
}
time.Sleep(60 * time.Millisecond)
stub.health = map[netip.AddrPort]UpstreamHealth{overlayPeer: {LastFail: time.Now(), LastErr: "timeout"}}
server.refreshHealth()
select {
case evt := <-sub.Events():
assert.Contains(t, evt.Message, "unreachable")
case <-time.After(time.Second):
t.Fatal("expected warning after grace window")
}
}
// TestProjection_StopClearsHealthState verifies that Stop wipes the
// per-group projection state so a subsequent Start doesn't inherit
// sticky flags (notably everHealthy) that would bypass the grace
// window during the next peer handshake.
func TestProjection_StopClearsHealthState(t *testing.T) {
wgIface := &mocWGIface{}
server := &DefaultServer{
ctx: context.Background(),
wgInterface: wgIface,
service: NewServiceViaMemory(wgIface),
hostManager: &noopHostConfigurator{},
extraDomains: map[domain.Domain]int{},
dnsMuxMap: make(registeredHandlerMap),
statusRecorder: peer.NewRecorder("mgm"),
selectedRoutes: func() route.HAMap { return nil },
activeRoutes: func() route.HAMap { return nil },
warningDelayBase: defaultWarningDelayBase,
currentConfigHash: ^uint64(0),
}
server.ctx, server.ctxCancel = context.WithCancel(context.Background())
srv := netip.MustParseAddrPort("8.8.8.8:53")
group := &nbdns.NameServerGroup{
Domains: []string{"example.com"},
NameServers: []nbdns.NameServer{{IP: srv.Addr(), NSType: nbdns.UDPNameServerType, Port: int(srv.Port())}},
}
stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{srv: {LastOk: time.Now()}}}
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
server.mux.Lock()
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
server.mux.Unlock()
server.refreshHealth()
server.healthProjectMu.Lock()
p, ok := server.nsGroupProj[generateGroupKey(group)]
server.healthProjectMu.Unlock()
require.True(t, ok, "projection state should exist after tick")
require.True(t, p.everHealthy, "tick with success must set everHealthy")
server.Stop()
server.healthProjectMu.Lock()
cleared := server.nsGroupProj == nil
server.healthProjectMu.Unlock()
assert.True(t, cleared, "Stop must clear nsGroupProj")
}
// TestProjection_OverlayRecoversDuringGrace covers the happy path of
// rule 3: startup failures while the peer is handshaking, then the peer
// comes up and a query succeeds before the grace window elapses. No
// warning should ever have fired, and no recovery either.
func TestProjection_OverlayRecoversDuringGrace(t *testing.T) {
fx := newProjTestFixture(t)
fx.server.warningDelayBase = 200 * time.Millisecond
fx.selected = overlayMapForTest
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
fx.tick()
fx.expectNoEvent("fail within grace, warning suppressed")
fx.active = overlayMapForTest
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
states := fx.tick()
require.Len(t, states, 1)
assert.True(t, states[0].Enabled)
fx.expectNoEvent("recovery without prior warning must not emit")
}
// TestProjection_RecoveryOnlyAfterWarning enforces the invariant the
// whole design leans on: recovery events only appear when a warning
// event was actually emitted for the current streak. A Healthy verdict
// without a prior warning is silent, so the user never sees "recovered"
// out of thin air.
func TestProjection_RecoveryOnlyAfterWarning(t *testing.T) {
fx := newProjTestFixture(t)
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
states := fx.tick()
require.Len(t, states, 1)
assert.True(t, states[0].Enabled)
fx.expectNoEvent("first healthy tick should not recover anything")
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
fx.tick()
fx.expectEvent("unreachable", "public fail emits immediately")
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
fx.tick()
fx.expectEvent("recovered", "recovery follows real warning")
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
fx.tick()
fx.expectEvent("unreachable", "second cycle warning")
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
fx.tick()
fx.expectEvent("recovered", "second cycle recovery")
}
// TestProjection_EverHealthyOverridesDelay covers rule 4: once a group
// has ever been Healthy, subsequent failures skip the grace window even
// if classification says "routed + not connected". The system has
// proved it can work, so any new failure is real.
func TestProjection_EverHealthyOverridesDelay(t *testing.T) {
fx := newProjTestFixture(t)
// Large base so any emission must come from the everHealthy bypass, not elapsed time.
fx.server.warningDelayBase = time.Hour
fx.selected = overlayMapForTest
fx.active = overlayMapForTest
// Establish "ever healthy".
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
fx.tick()
fx.expectNoEvent("first healthy tick")
// Peer drops. Query fails. Routed + not connected → normally grace,
// but everHealthy flag bypasses it.
fx.active = nil
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
fx.tick()
fx.expectEvent("unreachable", "failure after ever-healthy must be immediate")
}
// TestProjection_ReconnectBlipEmitsPair covers the explicit tradeoff
// from the design discussion: once a group has been healthy, a brief
// reconnect that produces a failing tick will fire warning + recovery.
// This is by design: user-visible blips are accurate signal, not noise.
func TestProjection_ReconnectBlipEmitsPair(t *testing.T) {
fx := newProjTestFixture(t)
fx.selected = overlayMapForTest
fx.active = overlayMapForTest
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
fx.tick()
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
fx.tick()
fx.expectEvent("unreachable", "blip warning")
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
fx.tick()
fx.expectEvent("recovered", "blip recovery")
}
// TestProjection_MixedGroupEmitsImmediately covers the multi-upstream
// rule: a group with at least one public upstream is in the "immediate"
// category regardless of the other upstreams' routing, because the
// public one has no peer-startup excuse. Prevents public-DNS failures
// from being hidden behind a routed sibling.
func TestProjection_MixedGroupEmitsImmediately(t *testing.T) {
recorder := peer.NewRecorder("mgm")
sub := recorder.SubscribeToEvents()
t.Cleanup(func() { recorder.UnsubscribeFromEvents(sub) })
events := sub.Events()
public := netip.MustParseAddrPort("8.8.8.8:53")
overlay := netip.MustParseAddrPort("100.64.0.1:53")
overlayMap := route.HAMap{"overlay": {{Network: netip.MustParsePrefix("100.64.0.0/16")}}}
server := &DefaultServer{
ctx: context.Background(),
statusRecorder: recorder,
dnsMuxMap: make(registeredHandlerMap),
selectedRoutes: func() route.HAMap { return overlayMap },
activeRoutes: func() route.HAMap { return nil },
warningDelayBase: time.Hour,
}
group := &nbdns.NameServerGroup{
Domains: []string{"example.com"},
NameServers: []nbdns.NameServer{
{IP: public.Addr(), NSType: nbdns.UDPNameServerType, Port: int(public.Port())},
{IP: overlay.Addr(), NSType: nbdns.UDPNameServerType, Port: int(overlay.Port())},
},
}
stub := &healthStubHandler{
health: map[netip.AddrPort]UpstreamHealth{
public: {LastFail: time.Now(), LastErr: "servfail"},
overlay: {LastFail: time.Now(), LastErr: "timeout"},
},
}
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
server.mux.Lock()
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
server.mux.Unlock()
server.refreshHealth()
select {
case evt := <-events:
assert.Contains(t, evt.Message, "unreachable")
case <-time.After(time.Second):
t.Fatal("expected immediate warning because group contains a public upstream")
}
}
func TestDNSLoopPrevention(t *testing.T) {
wgInterface := &mocWGIface{}
service := NewServiceViaMemory(wgInterface)
@@ -2183,17 +2726,18 @@ func TestDNSLoopPrevention(t *testing.T) {
if tt.expectedHandlers > 0 {
handler := muxUpdates[0].handler.(*upstreamResolver)
assert.Len(t, handler.upstreamServers, len(tt.expectedServers))
flat := handler.flatUpstreams()
assert.Len(t, flat, len(tt.expectedServers))
if tt.shouldFilterOwnIP {
for _, upstream := range handler.upstreamServers {
for _, upstream := range flat {
assert.NotEqual(t, dnsServerIP, upstream.Addr())
}
}
for _, expected := range tt.expectedServers {
found := false
for _, upstream := range handler.upstreamServers {
for _, upstream := range flat {
if upstream.Addr() == expected {
found = true
break

View File

@@ -8,6 +8,7 @@ import (
"fmt"
"net"
"net/netip"
"slices"
"time"
"github.com/godbus/dbus/v5"
@@ -40,10 +41,17 @@ const (
)
type systemdDbusConfigurator struct {
dbusLinkObject dbus.ObjectPath
ifaceName string
dbusLinkObject dbus.ObjectPath
ifaceName string
wgIndex int
origNameservers []netip.Addr
}
const (
systemdDbusLinkDNSProperty = systemdDbusLinkInterface + ".DNS"
systemdDbusLinkDefaultRouteProperty = systemdDbusLinkInterface + ".DefaultRoute"
)
// the types below are based on dbus specification, each field is mapped to a dbus type
// see https://dbus.freedesktop.org/doc/dbus-specification.html#basic-types for more details on dbus types
// see https://www.freedesktop.org/software/systemd/man/org.freedesktop.resolve1.html on resolve1 input types
@@ -79,10 +87,145 @@ func newSystemdDbusConfigurator(wgInterface string) (*systemdDbusConfigurator, e
log.Debugf("got dbus Link interface: %s from net interface %s and index %d", s, iface.Name, iface.Index)
return &systemdDbusConfigurator{
c := &systemdDbusConfigurator{
dbusLinkObject: dbus.ObjectPath(s),
ifaceName: wgInterface,
}, nil
wgIndex: iface.Index,
}
origNameservers, err := c.captureOriginalNameservers()
switch {
case err != nil:
log.Warnf("capture original nameservers from systemd-resolved: %v", err)
case len(origNameservers) == 0:
log.Warnf("no original nameservers captured from systemd-resolved default-route links; DNS fallback will be empty")
default:
log.Debugf("captured %d original nameservers from systemd-resolved default-route links: %v", len(origNameservers), origNameservers)
}
c.origNameservers = origNameservers
return c, nil
}
// captureOriginalNameservers reads per-link DNS from systemd-resolved for
// every default-route link except our own WG link. Non-default-route links
// (VPNs, docker bridges) are skipped because their upstreams wouldn't
// actually serve host queries.
func (s *systemdDbusConfigurator) captureOriginalNameservers() ([]netip.Addr, error) {
ifaces, err := net.Interfaces()
if err != nil {
return nil, fmt.Errorf("list interfaces: %w", err)
}
seen := make(map[netip.Addr]struct{})
var out []netip.Addr
for _, iface := range ifaces {
if !s.isCandidateLink(iface) {
continue
}
linkPath, err := getSystemdLinkPath(iface.Index)
if err != nil || !isSystemdLinkDefaultRoute(linkPath) {
continue
}
for _, addr := range readSystemdLinkDNS(linkPath) {
addr = normalizeSystemdAddr(addr, iface.Name)
if !addr.IsValid() {
continue
}
if _, dup := seen[addr]; dup {
continue
}
seen[addr] = struct{}{}
out = append(out, addr)
}
}
return out, nil
}
func (s *systemdDbusConfigurator) isCandidateLink(iface net.Interface) bool {
if iface.Index == s.wgIndex {
return false
}
if iface.Flags&net.FlagLoopback != 0 || iface.Flags&net.FlagUp == 0 {
return false
}
return true
}
// normalizeSystemdAddr unmaps v4-mapped-v6, drops unspecified, and reattaches
// the link's iface name as zone for link-local v6 (Link.DNS strips it).
// Returns the zero Addr to signal "skip this entry".
func normalizeSystemdAddr(addr netip.Addr, ifaceName string) netip.Addr {
addr = addr.Unmap()
if !addr.IsValid() || addr.IsUnspecified() {
return netip.Addr{}
}
if addr.IsLinkLocalUnicast() {
return addr.WithZone(ifaceName)
}
return addr
}
func getSystemdLinkPath(ifIndex int) (dbus.ObjectPath, error) {
obj, closeConn, err := getDbusObject(systemdResolvedDest, systemdDbusObjectNode)
if err != nil {
return "", fmt.Errorf("dbus resolve1: %w", err)
}
defer closeConn()
var p string
if err := obj.Call(systemdDbusGetLinkMethod, dbusDefaultFlag, int32(ifIndex)).Store(&p); err != nil {
return "", err
}
return dbus.ObjectPath(p), nil
}
func isSystemdLinkDefaultRoute(linkPath dbus.ObjectPath) bool {
obj, closeConn, err := getDbusObject(systemdResolvedDest, linkPath)
if err != nil {
return false
}
defer closeConn()
v, err := obj.GetProperty(systemdDbusLinkDefaultRouteProperty)
if err != nil {
return false
}
b, ok := v.Value().(bool)
return ok && b
}
func readSystemdLinkDNS(linkPath dbus.ObjectPath) []netip.Addr {
obj, closeConn, err := getDbusObject(systemdResolvedDest, linkPath)
if err != nil {
return nil
}
defer closeConn()
v, err := obj.GetProperty(systemdDbusLinkDNSProperty)
if err != nil {
return nil
}
entries, ok := v.Value().([][]any)
if !ok {
return nil
}
var out []netip.Addr
for _, entry := range entries {
if len(entry) < 2 {
continue
}
raw, ok := entry[1].([]byte)
if !ok {
continue
}
addr, ok := netip.AddrFromSlice(raw)
if !ok {
continue
}
out = append(out, addr)
}
return out
}
func (s *systemdDbusConfigurator) getOriginalNameservers() []netip.Addr {
return slices.Clone(s.origNameservers)
}
func (s *systemdDbusConfigurator) supportCustomPort() bool {

View File

@@ -1,3 +1,32 @@
// Package dns implements the client-side DNS stack: listener/service on the
// peer's tunnel address, handler chain that routes questions by domain and
// priority, and upstream resolvers that forward what remains to configured
// nameservers.
//
// # Upstream resolution and the race model
//
// When two or more nameserver groups target the same domain, DefaultServer
// merges them into one upstream handler whose state is:
//
// upstreamResolverBase
// └── upstreamServers []upstreamRace // one entry per source NS group
// └── []netip.AddrPort // primary, fallback, ...
//
// Each source nameserver group contributes one upstreamRace. Within a race
// upstreams are tried in order: the next is used only on failure (timeout,
// SERVFAIL, REFUSED, no response). NXDOMAIN is a valid answer and stops
// the walk. When more than one race exists, ServeDNS fans out one
// goroutine per race and returns the first valid answer, cancelling the
// rest. A handler with a single race skips the fan-out.
//
// # Health projection
//
// Query outcomes are recorded per-upstream in UpstreamHealth. The server
// periodically merges these snapshots across handlers and projects them
// into peer.NSGroupState. There is no active probing: a group is marked
// unhealthy only when every seen upstream has a recent failure and none
// has a recent success. Healthy→unhealthy fires a single
// SystemEvent_WARNING; steady-state refreshes do not duplicate it.
package dns
import (
@@ -11,11 +40,8 @@ import (
"slices"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/hashicorp/go-multierror"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/tun/netstack"
@@ -25,11 +51,33 @@ import (
"github.com/netbirdio/netbird/client/internal/dns/resutil"
"github.com/netbirdio/netbird/client/internal/dns/types"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
)
var currentMTU uint16 = iface.DefaultMTU
// nonRetryableEDECodes lists EDE info codes (RFC 8914) for which a SERVFAIL
// from one upstream means another upstream would return the same answer:
// DNSSEC validation outcomes and policy-based blocks. Transient errors
// (network, cached, not ready) are not included.
var nonRetryableEDECodes = map[uint16]struct{}{
dns.ExtendedErrorCodeUnsupportedDNSKEYAlgorithm: {},
dns.ExtendedErrorCodeUnsupportedDSDigestType: {},
dns.ExtendedErrorCodeDNSSECIndeterminate: {},
dns.ExtendedErrorCodeDNSBogus: {},
dns.ExtendedErrorCodeSignatureExpired: {},
dns.ExtendedErrorCodeSignatureNotYetValid: {},
dns.ExtendedErrorCodeDNSKEYMissing: {},
dns.ExtendedErrorCodeRRSIGsMissing: {},
dns.ExtendedErrorCodeNoZoneKeyBitSet: {},
dns.ExtendedErrorCodeNSECMissing: {},
dns.ExtendedErrorCodeBlocked: {},
dns.ExtendedErrorCodeCensored: {},
dns.ExtendedErrorCodeFiltered: {},
dns.ExtendedErrorCodeProhibited: {},
}
// privateClientIface is the subset of the WireGuard interface needed by GetClientPrivate.
type privateClientIface interface {
Name() string
@@ -46,15 +94,17 @@ const (
// Set longer than UpstreamTimeout to ensure context timeout takes precedence
ClientTimeout = 5 * time.Second
reactivatePeriod = 30 * time.Second
probeTimeout = 2 * time.Second
// ipv6HeaderSize + udpHeaderSize, used to derive the maximum DNS UDP
// payload from the tunnel MTU.
ipUDPHeaderSize = 60 + 8
)
const testRecord = "com."
// raceMaxTotalTimeout caps the combined time spent walking all upstreams
// within one race, so a slow primary can't eat the whole race budget.
raceMaxTotalTimeout = 5 * time.Second
// raceMinPerUpstreamTimeout is the floor applied when dividing
// raceMaxTotalTimeout across upstreams within a race.
raceMinPerUpstreamTimeout = 2 * time.Second
)
const (
protoUDP = "udp"
@@ -63,6 +113,69 @@ const (
type dnsProtocolKey struct{}
type upstreamProtocolKey struct{}
// upstreamProtocolResult holds the protocol used for the upstream exchange.
// Stored as a pointer in context so the exchange function can set it.
type upstreamProtocolResult struct {
protocol string
}
type upstreamClient interface {
exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
}
type UpstreamResolver interface {
serveDNS(r *dns.Msg) (*dns.Msg, time.Duration, error)
upstreamExchange(upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
}
// upstreamRace is an ordered list of upstreams derived from one configured
// nameserver group. Order matters: the first upstream is tried first, the
// second only on failure, and so on. Multiple upstreamRace values coexist
// inside one resolver when overlapping nameserver groups target the same
// domain; those races run in parallel and the first valid answer wins.
type upstreamRace []netip.AddrPort
// UpstreamHealth is the last query-path outcome for a single upstream,
// consumed by nameserver-group status projection.
type UpstreamHealth struct {
LastOk time.Time
LastFail time.Time
LastErr string
}
type upstreamResolverBase struct {
ctx context.Context
cancel context.CancelFunc
upstreamClient upstreamClient
upstreamServers []upstreamRace
domain domain.Domain
upstreamTimeout time.Duration
healthMu sync.RWMutex
health map[netip.AddrPort]*UpstreamHealth
statusRecorder *peer.Status
// selectedRoutes returns the current set of client routes the admin
// has enabled. Called lazily from the query hot path when an upstream
// might need a tunnel-bound client (iOS) and from health projection.
selectedRoutes func() route.HAMap
}
type upstreamFailure struct {
upstream netip.AddrPort
reason string
}
type raceResult struct {
msg *dns.Msg
upstream netip.AddrPort
protocol string
ede string
failures []upstreamFailure
}
// contextWithDNSProtocol stores the inbound DNS protocol ("udp" or "tcp") in context.
func contextWithDNSProtocol(ctx context.Context, network string) context.Context {
return context.WithValue(ctx, dnsProtocolKey{}, network)
@@ -79,16 +192,8 @@ func dnsProtocolFromContext(ctx context.Context) string {
return ""
}
type upstreamProtocolKey struct{}
// upstreamProtocolResult holds the protocol used for the upstream exchange.
// Stored as a pointer in context so the exchange function can set it.
type upstreamProtocolResult struct {
protocol string
}
// contextWithupstreamProtocolResult stores a mutable result holder in the context.
func contextWithupstreamProtocolResult(ctx context.Context) (context.Context, *upstreamProtocolResult) {
// contextWithUpstreamProtocolResult stores a mutable result holder in the context.
func contextWithUpstreamProtocolResult(ctx context.Context) (context.Context, *upstreamProtocolResult) {
r := &upstreamProtocolResult{}
return context.WithValue(ctx, upstreamProtocolKey{}, r), r
}
@@ -103,67 +208,37 @@ func setUpstreamProtocol(ctx context.Context, protocol string) {
}
}
type upstreamClient interface {
exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
}
type UpstreamResolver interface {
serveDNS(r *dns.Msg) (*dns.Msg, time.Duration, error)
upstreamExchange(upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
}
type upstreamResolverBase struct {
ctx context.Context
cancel context.CancelFunc
upstreamClient upstreamClient
upstreamServers []netip.AddrPort
domain string
disabled bool
successCount atomic.Int32
mutex sync.Mutex
reactivatePeriod time.Duration
upstreamTimeout time.Duration
wg sync.WaitGroup
deactivate func(error)
reactivate func()
statusRecorder *peer.Status
routeMatch func(netip.Addr) bool
}
type upstreamFailure struct {
upstream netip.AddrPort
reason string
}
func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, domain string) *upstreamResolverBase {
func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, d domain.Domain) *upstreamResolverBase {
ctx, cancel := context.WithCancel(ctx)
return &upstreamResolverBase{
ctx: ctx,
cancel: cancel,
domain: domain,
upstreamTimeout: UpstreamTimeout,
reactivatePeriod: reactivatePeriod,
statusRecorder: statusRecorder,
ctx: ctx,
cancel: cancel,
domain: d,
upstreamTimeout: UpstreamTimeout,
statusRecorder: statusRecorder,
}
}
// String returns a string representation of the upstream resolver
func (u *upstreamResolverBase) String() string {
return fmt.Sprintf("Upstream %s", u.upstreamServers)
return fmt.Sprintf("Upstream %s", u.flatUpstreams())
}
// ID returns the unique handler ID
// ID returns the unique handler ID. Race groupings and within-race
// ordering are both part of the identity: [[A,B]] and [[A],[B]] query
// the same servers but with different semantics (serial fallback vs
// parallel race), so their handlers must not collide.
func (u *upstreamResolverBase) ID() types.HandlerID {
servers := slices.Clone(u.upstreamServers)
slices.SortFunc(servers, func(a, b netip.AddrPort) int { return a.Compare(b) })
hash := sha256.New()
hash.Write([]byte(u.domain + ":"))
for _, s := range servers {
hash.Write([]byte(s.String()))
hash.Write([]byte("|"))
hash.Write([]byte(u.domain.PunycodeString() + ":"))
for _, race := range u.upstreamServers {
hash.Write([]byte("["))
for _, s := range race {
hash.Write([]byte(s.String()))
hash.Write([]byte("|"))
}
hash.Write([]byte("]"))
}
return types.HandlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8]))
}
@@ -173,13 +248,31 @@ func (u *upstreamResolverBase) MatchSubdomains() bool {
}
func (u *upstreamResolverBase) Stop() {
log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers)
log.Debugf("stopping serving DNS for upstreams %s", u.flatUpstreams())
u.cancel()
}
u.mutex.Lock()
u.wg.Wait()
u.mutex.Unlock()
// flatUpstreams is for logging and ID hashing only, not for dispatch.
func (u *upstreamResolverBase) flatUpstreams() []netip.AddrPort {
var out []netip.AddrPort
for _, g := range u.upstreamServers {
out = append(out, g...)
}
return out
}
// setSelectedRoutes swaps the accessor used to classify overlay-routed
// upstreams. Called when route sources are wired after the handler was
// built (permanent / iOS constructors).
func (u *upstreamResolverBase) setSelectedRoutes(selected func() route.HAMap) {
u.selectedRoutes = selected
}
func (u *upstreamResolverBase) addRace(servers []netip.AddrPort) {
if len(servers) == 0 {
return
}
u.upstreamServers = append(u.upstreamServers, slices.Clone(servers))
}
// ServeDNS handles a DNS request
@@ -221,59 +314,226 @@ func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) {
}
func (u *upstreamResolverBase) tryUpstreamServers(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) {
timeout := u.upstreamTimeout
if len(u.upstreamServers) > 1 {
maxTotal := 5 * time.Second
minPerUpstream := 2 * time.Second
scaledTimeout := maxTotal / time.Duration(len(u.upstreamServers))
if scaledTimeout > minPerUpstream {
timeout = scaledTimeout
} else {
timeout = minPerUpstream
}
groups := u.upstreamServers
switch len(groups) {
case 0:
return false, nil
case 1:
return u.tryOnlyRace(ctx, w, r, groups[0], logger)
default:
return u.raceAll(ctx, w, r, groups, logger)
}
}
func (u *upstreamResolverBase) tryOnlyRace(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, group upstreamRace, logger *log.Entry) (bool, []upstreamFailure) {
res := u.tryRace(ctx, r, group)
if res.msg == nil {
return false, res.failures
}
if res.ede != "" {
resutil.SetMeta(w, "ede", res.ede)
}
u.writeSuccessResponse(w, res.msg, res.upstream, r.Question[0].Name, res.protocol, logger)
return true, res.failures
}
// raceAll runs one worker per group in parallel, taking the first valid
// answer and cancelling the rest.
func (u *upstreamResolverBase) raceAll(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, groups []upstreamRace, logger *log.Entry) (bool, []upstreamFailure) {
raceCtx, cancel := context.WithCancel(ctx)
defer cancel()
// Buffer sized to len(groups) so workers never block on send, even
// after the coordinator has returned.
results := make(chan raceResult, len(groups))
for _, g := range groups {
// tryRace clones the request per attempt, so workers never share
// a *dns.Msg and concurrent EDNS0 mutations can't race.
go func(g upstreamRace) {
results <- u.tryRace(raceCtx, r, g)
}(g)
}
var failures []upstreamFailure
for _, upstream := range u.upstreamServers {
if failure := u.queryUpstream(ctx, w, r, upstream, timeout, logger); failure != nil {
failures = append(failures, *failure)
} else {
return true, failures
for range groups {
select {
case res := <-results:
failures = append(failures, res.failures...)
if res.msg != nil {
if res.ede != "" {
resutil.SetMeta(w, "ede", res.ede)
}
u.writeSuccessResponse(w, res.msg, res.upstream, r.Question[0].Name, res.protocol, logger)
return true, failures
}
case <-ctx.Done():
return false, failures
}
}
return false, failures
}
// queryUpstream queries a single upstream server. Returns nil on success, or failure info to try next upstream.
func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure {
var rm *dns.Msg
var t time.Duration
var err error
var startTime time.Time
var upstreamProto *upstreamProtocolResult
func() {
ctx, cancel := context.WithTimeout(parentCtx, timeout)
func (u *upstreamResolverBase) tryRace(ctx context.Context, r *dns.Msg, group upstreamRace) raceResult {
timeout := u.upstreamTimeout
if len(group) > 1 {
// Cap the whole walk at raceMaxTotalTimeout: per-upstream timeouts
// still honor raceMinPerUpstreamTimeout as a floor for correctness
// on slow links, but the outer context ensures the combined walk
// cannot exceed the cap regardless of group size.
timeout = max(raceMaxTotalTimeout/time.Duration(len(group)), raceMinPerUpstreamTimeout)
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, raceMaxTotalTimeout)
defer cancel()
ctx, upstreamProto = contextWithupstreamProtocolResult(ctx)
startTime = time.Now()
rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r)
}()
}
var failures []upstreamFailure
for _, upstream := range group {
if ctx.Err() != nil {
return raceResult{failures: failures}
}
// Clone the request per attempt: the exchange path mutates EDNS0
// options in-place, so reusing the same *dns.Msg across sequential
// upstreams would carry those mutations (e.g. a reduced UDP size)
// into the next attempt.
res, failure := u.queryUpstream(ctx, r.Copy(), upstream, timeout)
if failure != nil {
failures = append(failures, *failure)
continue
}
res.failures = failures
return res
}
return raceResult{failures: failures}
}
func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration) (raceResult, *upstreamFailure) {
ctx, cancel := context.WithTimeout(parentCtx, timeout)
defer cancel()
ctx, upstreamProto := contextWithUpstreamProtocolResult(ctx)
// Advertise EDNS0 so the upstream may include Extended DNS Errors
// (RFC 8914) in failure responses; we use those to short-circuit
// failover for definitive answers like DNSSEC validation failures.
// The caller already passed a per-attempt copy, so we can mutate r
// directly; hadEdns reflects the original client request's state and
// controls whether we strip the OPT from the response.
hadEdns := r.IsEdns0() != nil
if !hadEdns {
r.SetEdns0(upstreamUDPSize(), false)
}
startTime := time.Now()
rm, _, err := u.upstreamClient.exchange(ctx, upstream.String(), r)
if err != nil {
return u.handleUpstreamError(err, upstream, startTime)
// A parent cancellation (e.g., another race won and the coordinator
// cancelled the losers) is not an upstream failure. Check both the
// error chain and the parent context: a transport may surface the
// cancellation as a read/deadline error rather than context.Canceled.
if errors.Is(err, context.Canceled) || errors.Is(parentCtx.Err(), context.Canceled) {
return raceResult{}, &upstreamFailure{upstream: upstream, reason: "canceled"}
}
failure := u.handleUpstreamError(err, upstream, startTime)
u.markUpstreamFail(upstream, failure.reason)
return raceResult{}, failure
}
if rm == nil || !rm.Response {
return &upstreamFailure{upstream: upstream, reason: "no response"}
u.markUpstreamFail(upstream, "no response")
return raceResult{}, &upstreamFailure{upstream: upstream, reason: "no response"}
}
proto := ""
if upstreamProto != nil {
proto = upstreamProto.protocol
}
if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused {
return &upstreamFailure{upstream: upstream, reason: dns.RcodeToString[rm.Rcode]}
if code, ok := nonRetryableEDE(rm); ok {
if !hadEdns {
stripOPT(rm)
}
u.markUpstreamOk(upstream)
return raceResult{msg: rm, upstream: upstream, protocol: proto, ede: edeName(code)}, nil
}
reason := dns.RcodeToString[rm.Rcode]
u.markUpstreamFail(upstream, reason)
return raceResult{}, &upstreamFailure{upstream: upstream, reason: reason}
}
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger)
return nil
if !hadEdns {
stripOPT(rm)
}
u.markUpstreamOk(upstream)
return raceResult{msg: rm, upstream: upstream, protocol: proto}, nil
}
// healthEntry returns the mutable health record for addr, lazily creating
// the map and the entry. Caller must hold u.healthMu.
func (u *upstreamResolverBase) healthEntry(addr netip.AddrPort) *UpstreamHealth {
if u.health == nil {
u.health = make(map[netip.AddrPort]*UpstreamHealth)
}
h := u.health[addr]
if h == nil {
h = &UpstreamHealth{}
u.health[addr] = h
}
return h
}
func (u *upstreamResolverBase) markUpstreamOk(addr netip.AddrPort) {
u.healthMu.Lock()
defer u.healthMu.Unlock()
h := u.healthEntry(addr)
h.LastOk = time.Now()
h.LastFail = time.Time{}
h.LastErr = ""
}
func (u *upstreamResolverBase) markUpstreamFail(addr netip.AddrPort, reason string) {
u.healthMu.Lock()
defer u.healthMu.Unlock()
h := u.healthEntry(addr)
h.LastFail = time.Now()
h.LastErr = reason
}
// UpstreamHealth returns a snapshot of per-upstream query outcomes.
func (u *upstreamResolverBase) UpstreamHealth() map[netip.AddrPort]UpstreamHealth {
u.healthMu.RLock()
defer u.healthMu.RUnlock()
out := make(map[netip.AddrPort]UpstreamHealth, len(u.health))
for k, v := range u.health {
out[k] = *v
}
return out
}
// upstreamUDPSize returns the EDNS0 UDP buffer size we advertise to upstreams,
// derived from the tunnel MTU and bounded against underflow.
func upstreamUDPSize() uint16 {
if currentMTU > ipUDPHeaderSize {
return currentMTU - ipUDPHeaderSize
}
return dns.MinMsgSize
}
// stripOPT removes any OPT pseudo-RRs from the response's Extra section so
// the response complies with RFC 6891 when the client did not advertise EDNS0.
func stripOPT(rm *dns.Msg) {
if len(rm.Extra) == 0 {
return
}
out := rm.Extra[:0]
for _, rr := range rm.Extra {
if _, ok := rr.(*dns.OPT); ok {
continue
}
out = append(out, rr)
}
rm.Extra = out
}
func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, startTime time.Time) *upstreamFailure {
@@ -289,12 +549,23 @@ func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.Add
return &upstreamFailure{upstream: upstream, reason: reason}
}
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, upstreamProto *upstreamProtocolResult, logger *log.Entry) bool {
u.successCount.Add(1)
func (u *upstreamResolverBase) debugUpstreamTimeout(upstream netip.AddrPort) string {
if u.statusRecorder == nil {
return ""
}
peerInfo := findPeerForIP(upstream.Addr(), u.statusRecorder)
if peerInfo == nil {
return ""
}
return fmt.Sprintf("(routes through NetBird peer %s)", FormatPeerStatus(peerInfo))
}
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, proto string, logger *log.Entry) {
resutil.SetMeta(w, "upstream", upstream.String())
if upstreamProto != nil && upstreamProto.protocol != "" {
resutil.SetMeta(w, "upstream_protocol", upstreamProto.protocol)
if proto != "" {
resutil.SetMeta(w, "upstream_protocol", proto)
}
// Clear Zero bit from external responses to prevent upstream servers from
@@ -303,14 +574,11 @@ func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dn
if err := w.WriteMsg(rm); err != nil {
logger.Errorf("failed to write DNS response for question domain=%s: %s", domain, err)
return true
}
return true
}
func (u *upstreamResolverBase) logUpstreamFailures(domain string, failures []upstreamFailure, succeeded bool, logger *log.Entry) {
totalUpstreams := len(u.upstreamServers)
totalUpstreams := len(u.flatUpstreams())
failedCount := len(failures)
failureSummary := formatFailures(failures)
@@ -337,117 +605,32 @@ func formatFailures(failures []upstreamFailure) string {
return strings.Join(parts, ", ")
}
// ProbeAvailability tests all upstream servers simultaneously and
// disables the resolver if none work
func (u *upstreamResolverBase) ProbeAvailability(ctx context.Context) {
u.mutex.Lock()
defer u.mutex.Unlock()
// avoid probe if upstreams could resolve at least one query
if u.successCount.Load() > 0 {
return
// nonRetryableEDE returns the first non-retryable EDE code carried in the
// response, if any.
func nonRetryableEDE(rm *dns.Msg) (uint16, bool) {
opt := rm.IsEdns0()
if opt == nil {
return 0, false
}
var success bool
var mu sync.Mutex
var wg sync.WaitGroup
var errs *multierror.Error
for _, upstream := range u.upstreamServers {
wg.Add(1)
go func(upstream netip.AddrPort) {
defer wg.Done()
err := u.testNameserver(u.ctx, ctx, upstream, 500*time.Millisecond)
if err != nil {
mu.Lock()
errs = multierror.Append(errs, err)
mu.Unlock()
log.Warnf("probing upstream nameserver %s: %s", upstream, err)
return
}
mu.Lock()
success = true
mu.Unlock()
}(upstream)
}
wg.Wait()
select {
case <-ctx.Done():
return
case <-u.ctx.Done():
return
default:
}
// didn't find a working upstream server, let's disable and try later
if !success {
u.disable(errs.ErrorOrNil())
if u.statusRecorder == nil {
return
for _, o := range opt.Option {
ede, ok := o.(*dns.EDNS0_EDE)
if !ok {
continue
}
if _, ok := nonRetryableEDECodes[ede.InfoCode]; ok {
return ede.InfoCode, true
}
u.statusRecorder.PublishEvent(
proto.SystemEvent_WARNING,
proto.SystemEvent_DNS,
"All upstream servers failed (probe failed)",
"Unable to reach one or more DNS servers. This might affect your ability to connect to some services.",
map[string]string{"upstreams": u.upstreamServersString()},
)
}
return 0, false
}
// waitUntilResponse retries, in an exponential interval, querying the upstream servers until it gets a positive response
func (u *upstreamResolverBase) waitUntilResponse() {
exponentialBackOff := &backoff.ExponentialBackOff{
InitialInterval: 500 * time.Millisecond,
RandomizationFactor: 0.5,
Multiplier: 1.1,
MaxInterval: u.reactivatePeriod,
MaxElapsedTime: 0,
Stop: backoff.Stop,
Clock: backoff.SystemClock,
// edeName returns a human-readable name for an EDE code, falling back to
// the numeric code when unknown.
func edeName(code uint16) string {
if name, ok := dns.ExtendedErrorCodeToString[code]; ok {
return name
}
operation := func() error {
select {
case <-u.ctx.Done():
return backoff.Permanent(fmt.Errorf("exiting upstream retry loop for upstreams %s: parent context has been canceled", u.upstreamServersString()))
default:
}
for _, upstream := range u.upstreamServers {
if err := u.testNameserver(u.ctx, nil, upstream, probeTimeout); err != nil {
log.Tracef("upstream check for %s: %s", upstream, err)
} else {
// at least one upstream server is available, stop probing
return nil
}
}
log.Tracef("checking connectivity with upstreams %s failed. Retrying in %s", u.upstreamServersString(), exponentialBackOff.NextBackOff())
return fmt.Errorf("upstream check call error")
}
err := backoff.Retry(operation, backoff.WithContext(exponentialBackOff, u.ctx))
if err != nil {
if errors.Is(err, context.Canceled) {
log.Debugf("upstream retry loop exited for upstreams %s", u.upstreamServersString())
} else {
log.Warnf("upstream retry loop exited for upstreams %s: %v", u.upstreamServersString(), err)
}
return
}
log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServersString())
u.successCount.Add(1)
u.reactivate()
u.mutex.Lock()
u.disabled = false
u.mutex.Unlock()
return fmt.Sprintf("EDE %d", code)
}
// isTimeout returns true if the given error is a network timeout error.
@@ -461,45 +644,6 @@ func isTimeout(err error) bool {
return false
}
func (u *upstreamResolverBase) disable(err error) {
if u.disabled {
return
}
log.Warnf("Upstream resolving is Disabled for %v", reactivatePeriod)
u.successCount.Store(0)
u.deactivate(err)
u.disabled = true
u.wg.Add(1)
go func() {
defer u.wg.Done()
u.waitUntilResponse()
}()
}
func (u *upstreamResolverBase) upstreamServersString() string {
var servers []string
for _, server := range u.upstreamServers {
servers = append(servers, server.String())
}
return strings.Join(servers, ", ")
}
func (u *upstreamResolverBase) testNameserver(baseCtx context.Context, externalCtx context.Context, server netip.AddrPort, timeout time.Duration) error {
mergedCtx, cancel := context.WithTimeout(baseCtx, timeout)
defer cancel()
if externalCtx != nil {
stop2 := context.AfterFunc(externalCtx, cancel)
defer stop2()
}
r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA)
_, _, err := u.upstreamClient.exchange(mergedCtx, server.String(), r)
return err
}
// clientUDPMaxSize returns the maximum UDP response size the client accepts.
func clientUDPMaxSize(r *dns.Msg) int {
if opt := r.IsEdns0(); opt != nil {
@@ -511,13 +655,10 @@ func clientUDPMaxSize(r *dns.Msg) int {
// ExchangeWithFallback exchanges a DNS message with the upstream server.
// It first tries to use UDP, and if it is truncated, it falls back to TCP.
// If the inbound request came over TCP (via context), it skips the UDP attempt.
// If the passed context is nil, this will use Exchange instead of ExchangeContext.
func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, upstream string) (*dns.Msg, time.Duration, error) {
// If the request came in over TCP, go straight to TCP upstream.
if dnsProtocolFromContext(ctx) == protoTCP {
tcpClient := *client
tcpClient.Net = protoTCP
rm, t, err := tcpClient.ExchangeContext(ctx, r, upstream)
rm, t, err := toTCPClient(client).ExchangeContext(ctx, r, upstream)
if err != nil {
return nil, t, fmt.Errorf("with tcp: %w", err)
}
@@ -537,18 +678,7 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
opt.SetUDPSize(maxUDPPayload)
}
var (
rm *dns.Msg
t time.Duration
err error
)
if ctx == nil {
rm, t, err = client.Exchange(r, upstream)
} else {
rm, t, err = client.ExchangeContext(ctx, r, upstream)
}
rm, t, err := client.ExchangeContext(ctx, r, upstream)
if err != nil {
return nil, t, fmt.Errorf("with udp: %w", err)
}
@@ -562,15 +692,7 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
// data than the client's buffer, we could truncate locally and skip
// the TCP retry.
tcpClient := *client
tcpClient.Net = protoTCP
if ctx == nil {
rm, t, err = tcpClient.Exchange(r, upstream)
} else {
rm, t, err = tcpClient.ExchangeContext(ctx, r, upstream)
}
rm, t, err = toTCPClient(client).ExchangeContext(ctx, r, upstream)
if err != nil {
return nil, t, fmt.Errorf("with tcp: %w", err)
}
@@ -584,6 +706,25 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
return rm, t, nil
}
// toTCPClient returns a copy of c configured for TCP. If c's Dialer has a
// *net.UDPAddr bound as LocalAddr (iOS does this to keep the source IP on
// the tunnel interface), it is converted to the equivalent *net.TCPAddr
// so net.Dialer doesn't reject the TCP dial with "mismatched local
// address type".
func toTCPClient(c *dns.Client) *dns.Client {
tcp := *c
tcp.Net = protoTCP
if tcp.Dialer == nil {
return &tcp
}
d := *tcp.Dialer
if ua, ok := d.LocalAddr.(*net.UDPAddr); ok {
d.LocalAddr = &net.TCPAddr{IP: ua.IP, Port: ua.Port, Zone: ua.Zone}
}
tcp.Dialer = &d
return &tcp
}
// ExchangeWithNetstack performs a DNS exchange using netstack for dialing.
// This is needed when netstack is enabled to reach peer IPs through the tunnel.
func ExchangeWithNetstack(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream string) (*dns.Msg, error) {
@@ -725,15 +866,36 @@ func findPeerForIP(ip netip.Addr, statusRecorder *peer.Status) *peer.State {
return bestMatch
}
func (u *upstreamResolverBase) debugUpstreamTimeout(upstream netip.AddrPort) string {
if u.statusRecorder == nil {
return ""
// haMapRouteCount returns the total number of routes across all HA
// groups in the map. route.HAMap is keyed by HAUniqueID with slices of
// routes per key, so len(hm) is the number of HA groups, not routes.
func haMapRouteCount(hm route.HAMap) int {
total := 0
for _, routes := range hm {
total += len(routes)
}
peerInfo := findPeerForIP(upstream.Addr(), u.statusRecorder)
if peerInfo == nil {
return ""
}
return fmt.Sprintf("(routes through NetBird peer %s)", FormatPeerStatus(peerInfo))
return total
}
// haMapContains checks whether ip is covered by any concrete prefix in
// the HA map. haveDynamic is reported separately: dynamic (domain-based)
// routes carry a placeholder Network that can't be prefix-checked, so we
// can't know at this point whether ip is reached through one. Callers
// decide how to interpret the unknown: health projection treats it as
// "possibly routed" to avoid emitting false-positive warnings during
// startup, while iOS dial selection requires a concrete match before
// binding to the tunnel.
func haMapContains(hm route.HAMap, ip netip.Addr) (matched, haveDynamic bool) {
for _, routes := range hm {
for _, r := range routes {
if r.IsDynamic() {
haveDynamic = true
continue
}
if r.Network.Contains(ip) {
return true, haveDynamic
}
}
}
return false, haveDynamic
}

View File

@@ -11,6 +11,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/shared/management/domain"
)
type upstreamResolver struct {
@@ -26,9 +27,9 @@ func newUpstreamResolver(
_ WGIface,
statusRecorder *peer.Status,
hostsDNSHolder *hostsDNSHolder,
domain string,
d domain.Domain,
) (*upstreamResolver, error) {
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, d)
c := &upstreamResolver{
upstreamResolverBase: upstreamResolverBase,
hostsDNSHolder: hostsDNSHolder,

View File

@@ -12,6 +12,7 @@ import (
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/shared/management/domain"
)
type upstreamResolver struct {
@@ -24,9 +25,9 @@ func newUpstreamResolver(
wgIface WGIface,
statusRecorder *peer.Status,
_ *hostsDNSHolder,
domain string,
d domain.Domain,
) (*upstreamResolver, error) {
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, d)
nonIOS := &upstreamResolver{
upstreamResolverBase: upstreamResolverBase,
nsNet: wgIface.GetNet(),

View File

@@ -15,6 +15,7 @@ import (
"golang.org/x/sys/unix"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/shared/management/domain"
)
type upstreamResolverIOS struct {
@@ -27,9 +28,9 @@ func newUpstreamResolver(
wgIface WGIface,
statusRecorder *peer.Status,
_ *hostsDNSHolder,
domain string,
d domain.Domain,
) (*upstreamResolverIOS, error) {
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, d)
ios := &upstreamResolverIOS{
upstreamResolverBase: upstreamResolverBase,
@@ -62,9 +63,16 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
upstreamIP = upstreamIP.Unmap()
}
addr := u.wgIface.Address()
var routed bool
if u.selectedRoutes != nil {
// Only a concrete prefix match binds to the tunnel: dialing
// through a private client for an upstream we can't prove is
// routed would break public resolvers.
routed, _ = haMapContains(u.selectedRoutes(), upstreamIP)
}
needsPrivate := addr.Network.Contains(upstreamIP) ||
addr.IPv6Net.Contains(upstreamIP) ||
(u.routeMatch != nil && u.routeMatch(upstreamIP))
routed
if needsPrivate {
log.Debugf("using private client to query %s via upstream %s", r.Question[0].Name, upstream)
client, err = GetClientPrivate(u.wgIface, upstreamIP, timeout)
@@ -73,8 +81,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
}
}
// Cannot use client.ExchangeContext because it overwrites our Dialer
return ExchangeWithFallback(nil, client, r, upstream)
return ExchangeWithFallback(ctx, client, r, upstream)
}
// GetClientPrivate returns a new DNS client bound to the local IP of the Netbird interface.

View File

@@ -6,6 +6,7 @@ import (
"net"
"net/netip"
"strings"
"sync/atomic"
"testing"
"time"
@@ -73,7 +74,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
servers = append(servers, netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port()))
}
}
resolver.upstreamServers = servers
resolver.addRace(servers)
resolver.upstreamTimeout = testCase.timeout
if testCase.cancelCTX {
cancel()
@@ -132,20 +133,10 @@ func (m *mockNetstackProvider) GetInterfaceGUIDString() (string, error) {
return "", nil
}
type mockUpstreamResolver struct {
r *dns.Msg
rtt time.Duration
err error
}
// exchange mock implementation of exchange from upstreamResolver
func (c mockUpstreamResolver) exchange(_ context.Context, _ string, _ *dns.Msg) (*dns.Msg, time.Duration, error) {
return c.r, c.rtt, c.err
}
type mockUpstreamResponse struct {
msg *dns.Msg
err error
msg *dns.Msg
err error
delay time.Duration
}
type mockUpstreamResolverPerServer struct {
@@ -153,63 +144,19 @@ type mockUpstreamResolverPerServer struct {
rtt time.Duration
}
func (c mockUpstreamResolverPerServer) exchange(_ context.Context, upstream string, _ *dns.Msg) (*dns.Msg, time.Duration, error) {
if r, ok := c.responses[upstream]; ok {
return r.msg, c.rtt, r.err
func (c mockUpstreamResolverPerServer) exchange(ctx context.Context, upstream string, _ *dns.Msg) (*dns.Msg, time.Duration, error) {
r, ok := c.responses[upstream]
if !ok {
return nil, c.rtt, fmt.Errorf("no mock response for %s", upstream)
}
return nil, c.rtt, fmt.Errorf("no mock response for %s", upstream)
}
func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
mockClient := &mockUpstreamResolver{
err: dns.ErrTime,
r: new(dns.Msg),
rtt: time.Millisecond,
}
resolver := &upstreamResolverBase{
ctx: context.TODO(),
upstreamClient: mockClient,
upstreamTimeout: UpstreamTimeout,
reactivatePeriod: time.Microsecond * 100,
}
addrPort, _ := netip.ParseAddrPort("0.0.0.0:1") // Use valid port for parsing, test will still fail on connection
resolver.upstreamServers = []netip.AddrPort{netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port())}
failed := false
resolver.deactivate = func(error) {
failed = true
// After deactivation, make the mock client work again
mockClient.err = nil
}
reactivated := false
resolver.reactivate = func() {
reactivated = true
}
resolver.ProbeAvailability(context.TODO())
if !failed {
t.Errorf("expected that resolving was deactivated")
return
}
if !resolver.disabled {
t.Errorf("resolver should be Disabled")
return
}
time.Sleep(time.Millisecond * 200)
if !reactivated {
t.Errorf("expected that resolving was reactivated")
return
}
if resolver.disabled {
t.Errorf("should be enabled")
if r.delay > 0 {
select {
case <-time.After(r.delay):
case <-ctx.Done():
return nil, c.rtt, ctx.Err()
}
}
return r.msg, c.rtt, r.err
}
func TestUpstreamResolver_Failover(t *testing.T) {
@@ -339,9 +286,9 @@ func TestUpstreamResolver_Failover(t *testing.T) {
resolver := &upstreamResolverBase{
ctx: ctx,
upstreamClient: trackingClient,
upstreamServers: []netip.AddrPort{upstream1, upstream2},
upstreamTimeout: UpstreamTimeout,
}
resolver.addRace([]netip.AddrPort{upstream1, upstream2})
var responseMSG *dns.Msg
responseWriter := &test.MockResponseWriter{
@@ -421,9 +368,9 @@ func TestUpstreamResolver_SingleUpstreamFailure(t *testing.T) {
resolver := &upstreamResolverBase{
ctx: ctx,
upstreamClient: mockClient,
upstreamServers: []netip.AddrPort{upstream},
upstreamTimeout: UpstreamTimeout,
}
resolver.addRace([]netip.AddrPort{upstream})
var responseMSG *dns.Msg
responseWriter := &test.MockResponseWriter{
@@ -440,6 +387,136 @@ func TestUpstreamResolver_SingleUpstreamFailure(t *testing.T) {
assert.Equal(t, dns.RcodeServerFailure, responseMSG.Rcode, "single upstream SERVFAIL should return SERVFAIL")
}
// TestUpstreamResolver_RaceAcrossGroups covers two nameserver groups
// configured for the same domain, with one broken group. The merge+race
// path should answer as fast as the working group and not pay the timeout
// of the broken one on every query.
func TestUpstreamResolver_RaceAcrossGroups(t *testing.T) {
broken := netip.MustParseAddrPort("192.0.2.1:53")
working := netip.MustParseAddrPort("192.0.2.2:53")
successAnswer := "192.0.2.100"
timeoutErr := &net.OpError{Op: "read", Err: fmt.Errorf("i/o timeout")}
mockClient := &mockUpstreamResolverPerServer{
responses: map[string]mockUpstreamResponse{
// Force the broken upstream to only unblock via timeout /
// cancellation so the assertion below can't pass if races
// were run serially.
broken.String(): {err: timeoutErr, delay: 500 * time.Millisecond},
working.String(): {msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
},
rtt: time.Millisecond,
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
resolver := &upstreamResolverBase{
ctx: ctx,
upstreamClient: mockClient,
upstreamTimeout: 250 * time.Millisecond,
}
resolver.addRace([]netip.AddrPort{broken})
resolver.addRace([]netip.AddrPort{working})
var responseMSG *dns.Msg
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
inputMSG := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
start := time.Now()
resolver.ServeDNS(responseWriter, inputMSG)
elapsed := time.Since(start)
require.NotNil(t, responseMSG, "should write a response")
assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode)
require.NotEmpty(t, responseMSG.Answer)
assert.Contains(t, responseMSG.Answer[0].String(), successAnswer)
// Working group answers in a single RTT; the broken group's
// timeout (100ms) must not block the response.
assert.Less(t, elapsed, 100*time.Millisecond, "race must not wait for broken group's timeout")
}
// TestUpstreamResolver_AllGroupsFail checks that when every group fails the
// resolver returns SERVFAIL rather than leaking a partial response.
func TestUpstreamResolver_AllGroupsFail(t *testing.T) {
a := netip.MustParseAddrPort("192.0.2.1:53")
b := netip.MustParseAddrPort("192.0.2.2:53")
mockClient := &mockUpstreamResolverPerServer{
responses: map[string]mockUpstreamResponse{
a.String(): {msg: buildMockResponse(dns.RcodeServerFailure, "")},
b.String(): {msg: buildMockResponse(dns.RcodeServerFailure, "")},
},
rtt: time.Millisecond,
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
resolver := &upstreamResolverBase{
ctx: ctx,
upstreamClient: mockClient,
upstreamTimeout: UpstreamTimeout,
}
resolver.addRace([]netip.AddrPort{a})
resolver.addRace([]netip.AddrPort{b})
var responseMSG *dns.Msg
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("example.com.", dns.TypeA))
require.NotNil(t, responseMSG)
assert.Equal(t, dns.RcodeServerFailure, responseMSG.Rcode)
}
// TestUpstreamResolver_HealthTracking verifies that query-path results are
// recorded into per-upstream health, which is what projects back to
// NSGroupState for status reporting.
func TestUpstreamResolver_HealthTracking(t *testing.T) {
ok := netip.MustParseAddrPort("192.0.2.10:53")
bad := netip.MustParseAddrPort("192.0.2.11:53")
mockClient := &mockUpstreamResolverPerServer{
responses: map[string]mockUpstreamResponse{
ok.String(): {msg: buildMockResponse(dns.RcodeSuccess, "192.0.2.100")},
bad.String(): {msg: buildMockResponse(dns.RcodeServerFailure, "")},
},
rtt: time.Millisecond,
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
resolver := &upstreamResolverBase{
ctx: ctx,
upstreamClient: mockClient,
upstreamTimeout: UpstreamTimeout,
}
resolver.addRace([]netip.AddrPort{ok, bad})
responseWriter := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { return nil }}
resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("example.com.", dns.TypeA))
health := resolver.UpstreamHealth()
require.Contains(t, health, ok)
assert.False(t, health[ok].LastOk.IsZero(), "ok upstream should have LastOk set")
assert.Empty(t, health[ok].LastErr)
// bad upstream was never tried because ok answered first; its health
// should remain unset.
assert.NotContains(t, health, bad, "sibling upstream should not be queried when primary answers")
}
func TestFormatFailures(t *testing.T) {
testCases := []struct {
name string
@@ -665,10 +742,10 @@ func TestExchangeWithFallback_EDNS0Capped(t *testing.T) {
// Verify that a client EDNS0 larger than our MTU-derived limit gets
// capped in the outgoing request so the upstream doesn't send a
// response larger than our read buffer.
var receivedUDPSize uint16
var receivedUDPSize atomic.Uint32
udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
if opt := r.IsEdns0(); opt != nil {
receivedUDPSize = opt.UDPSize()
receivedUDPSize.Store(uint32(opt.UDPSize()))
}
m := new(dns.Msg)
m.SetReply(r)
@@ -699,7 +776,7 @@ func TestExchangeWithFallback_EDNS0Capped(t *testing.T) {
require.NotNil(t, rm)
expectedMax := uint16(currentMTU - ipUDPHeaderSize)
assert.Equal(t, expectedMax, receivedUDPSize,
assert.Equal(t, expectedMax, uint16(receivedUDPSize.Load()),
"upstream should see capped EDNS0, not the client's 4096")
}
@@ -770,3 +847,132 @@ func TestExchangeWithFallback_TCPTruncatesToClientSize(t *testing.T) {
assert.Less(t, len(rm2.Answer), 20, "small EDNS0 client should get fewer records")
assert.True(t, rm2.Truncated, "response should be truncated for small buffer client")
}
func msgWithEDE(rcode int, codes ...uint16) *dns.Msg {
m := new(dns.Msg)
m.Response = true
m.Rcode = rcode
if len(codes) == 0 {
return m
}
opt := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}}
opt.SetUDPSize(dns.MinMsgSize)
for _, c := range codes {
opt.Option = append(opt.Option, &dns.EDNS0_EDE{InfoCode: c})
}
m.Extra = append(m.Extra, opt)
return m
}
func TestNonRetryableEDE(t *testing.T) {
tests := []struct {
name string
msg *dns.Msg
wantOK bool
wantCode uint16
}{
{name: "no edns0", msg: msgWithEDE(dns.RcodeServerFailure)},
{
name: "opt without ede",
msg: func() *dns.Msg {
m := msgWithEDE(dns.RcodeServerFailure)
opt := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}}
opt.Option = append(opt.Option, &dns.EDNS0_NSID{Code: dns.EDNS0NSID})
m.Extra = []dns.RR{opt}
return m
}(),
},
{name: "ede dnsbogus", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeDNSBogus), wantOK: true, wantCode: dns.ExtendedErrorCodeDNSBogus},
{name: "ede signature expired", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeSignatureExpired), wantOK: true, wantCode: dns.ExtendedErrorCodeSignatureExpired},
{name: "ede blocked", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeBlocked), wantOK: true, wantCode: dns.ExtendedErrorCodeBlocked},
{name: "ede prohibited", msg: msgWithEDE(dns.RcodeRefused, dns.ExtendedErrorCodeProhibited), wantOK: true, wantCode: dns.ExtendedErrorCodeProhibited},
{name: "ede cached error retryable", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeCachedError)},
{name: "ede network error retryable", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeNetworkError)},
{name: "ede not ready retryable", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeNotReady)},
{
name: "first non-retryable wins",
msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeNetworkError, dns.ExtendedErrorCodeDNSBogus),
wantOK: true,
wantCode: dns.ExtendedErrorCodeDNSBogus,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
code, ok := nonRetryableEDE(tc.msg)
assert.Equal(t, tc.wantOK, ok, "ok should match")
if tc.wantOK {
assert.Equal(t, tc.wantCode, code, "code should match")
}
})
}
}
func TestEDEName(t *testing.T) {
assert.Equal(t, "DNSSEC Bogus", edeName(dns.ExtendedErrorCodeDNSBogus))
assert.Equal(t, "Signature Expired", edeName(dns.ExtendedErrorCodeSignatureExpired))
assert.Equal(t, "EDE 9999", edeName(9999), "unknown code falls back to numeric")
}
func TestStripOPT(t *testing.T) {
rm := &dns.Msg{
Extra: []dns.RR{
&dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}},
&dns.A{Hdr: dns.RR_Header{Name: "x.", Rrtype: dns.TypeA}, A: net.IPv4(1, 2, 3, 4)},
},
}
stripOPT(rm)
assert.Len(t, rm.Extra, 1, "OPT should be removed, A kept")
_, isOPT := rm.Extra[0].(*dns.OPT)
assert.False(t, isOPT, "remaining record must not be OPT")
}
func TestUpstreamResolver_NonRetryableEDEShortCircuits(t *testing.T) {
upstream1 := netip.MustParseAddrPort("192.0.2.1:53")
upstream2 := netip.MustParseAddrPort("192.0.2.2:53")
servfailWithEDE := msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeDNSBogus)
successResp := buildMockResponse(dns.RcodeSuccess, "192.0.2.100")
var queried []string
tracking := &trackingMockClient{
inner: &mockUpstreamResolverPerServer{
responses: map[string]mockUpstreamResponse{
upstream1.String(): {msg: servfailWithEDE},
upstream2.String(): {msg: successResp},
},
rtt: time.Millisecond,
},
queriedUpstreams: &queried,
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
resolver := &upstreamResolverBase{
ctx: ctx,
upstreamClient: tracking,
upstreamServers: []upstreamRace{{upstream1, upstream2}},
upstreamTimeout: UpstreamTimeout,
}
var written *dns.Msg
w := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
written = m
return nil
},
}
// Client query without EDNS0 must not see an OPT in the response.
q := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
resolver.ServeDNS(w, q)
require.NotNil(t, written, "response must be written")
assert.Equal(t, dns.RcodeServerFailure, written.Rcode, "SERVFAIL must propagate")
assert.Len(t, queried, 1, "only first upstream should be queried")
assert.Equal(t, upstream1.String(), queried[0])
for _, rr := range written.Extra {
_, isOPT := rr.(*dns.OPT)
assert.False(t, isOPT, "synthetic OPT must not leak to a non-EDNS0 client")
}
}

View File

@@ -61,9 +61,11 @@ import (
cProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
mgm "github.com/netbirdio/netbird/shared/management/client"
"github.com/netbirdio/netbird/shared/management/domain"
nbnetworkmap "github.com/netbirdio/netbird/shared/management/networkmap"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/netiputil"
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
@@ -202,6 +204,13 @@ type Engine struct {
// networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service
networkSerial uint64
// latestComponents is the most-recent NetworkMapComponents decoded from
// a NetworkMapEnvelope (capability=3 peers only). Held alongside the
// NetworkMap that Calculate() produced from it so Step 3 incremental
// updates have a base to apply changes against. nil for legacy-format
// peers. Guarded by syncMsgMux.
latestComponents *types.NetworkMapComponents
networkMonitor *networkmonitor.NetworkMonitor
sshServer sshServer
@@ -512,16 +521,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
e.dnsServer.SetRouteChecker(func(ip netip.Addr) bool {
for _, routes := range e.routeManager.GetSelectedClientRoutes() {
for _, r := range routes {
if r.Network.Contains(ip) {
return true
}
}
}
return false
})
e.dnsServer.SetRouteSources(e.routeManager.GetSelectedClientRoutes, e.routeManager.GetActiveClientRoutes)
if err = e.wgInterfaceCreate(); err != nil {
log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error())
@@ -874,8 +874,12 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return e.ctx.Err()
}
if update.NetworkMap != nil && update.NetworkMap.PeerConfig != nil {
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate)
// Envelope sync responses carry PeerConfig at the top level; legacy
// NetworkMap syncs carry it under NetworkMap.PeerConfig.
if pc := update.GetPeerConfig(); pc != nil {
e.handleAutoUpdateVersion(pc.GetAutoUpdate())
} else if nm := update.GetNetworkMap(); nm != nil && nm.GetPeerConfig() != nil {
e.handleAutoUpdateVersion(nm.GetPeerConfig().GetAutoUpdate())
}
if update.GetNetbirdConfig() != nil {
@@ -916,11 +920,45 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return err
}
nm := update.GetNetworkMap()
var (
nm *mgmProto.NetworkMap
components *types.NetworkMapComponents
)
if envelope := update.GetNetworkMapEnvelope(); envelope != nil {
// Components-format peer: decode the envelope back to typed
// components, run Calculate() locally, and convert to the wire
// NetworkMap shape the rest of the engine consumes. Components are
// retained so future incremental updates (Step 3) can apply deltas
// instead of doing a full reconstruction.
localKey := e.config.WgPrivateKey.PublicKey().String()
dnsName := ""
if pc := update.GetPeerConfig(); pc != nil {
// PeerConfig.Fqdn = "<dns_label>.<dns_domain>" — extract the
// shared domain by stripping the peer's own label prefix. Falls
// back to empty if the FQDN doesn't have the expected shape.
dnsName = extractDNSDomainFromFQDN(pc.GetFqdn())
}
result, err := nbnetworkmap.EnvelopeToNetworkMap(e.ctx, envelope, localKey, dnsName)
if err != nil {
return fmt.Errorf("decode network map envelope: %w", err)
}
nm = result.NetworkMap
components = result.Components
} else {
nm = update.GetNetworkMap()
}
if nm == nil {
return nil
}
// Only retain the components view when the server sent the envelope
// path. A legacy proto.NetworkMap means components == nil; writing it
// here would clobber a previously-cached snapshot, breaking the Step 3
// incremental-delta base on a future envelope sync.
if components != nil {
e.latestComponents = components
}
// Persist sync response under the dedicated lock (syncRespMux), not under syncMsgMux.
// Read the storage-enabled flag under the syncRespMux too.
e.syncRespMux.RLock()
@@ -946,6 +984,19 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return nil
}
// extractDNSDomainFromFQDN returns the trailing dotted domain part of the
// receiving peer's FQDN — the same value the management server fills as
// dnsName when it builds the legacy NetworkMap. "peer42.netbird.cloud" →
// "netbird.cloud". An empty string is returned for unrecognized formats.
func extractDNSDomainFromFQDN(fqdn string) string {
for i := 0; i < len(fqdn); i++ {
if fqdn[i] == '.' && i+1 < len(fqdn) {
return fqdn[i+1:]
}
}
return ""
}
func (e *Engine) handleRelayUpdate(update *mgmProto.RelayConfig) error {
if update != nil {
// when we receive token we expect valid address list too
@@ -1386,9 +1437,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
e.networkSerial = serial
// Test received (upstream) servers for availability right away instead of upon usage.
// If no server of a server group responds this will disable the respective handler and retry later.
go e.dnsServer.ProbeAvailability()
return nil
}
@@ -1932,7 +1980,7 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) {
return dnsServer, nil
case "ios":
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.mobileDep.HostDNSAddresses, e.statusRecorder, e.config.DisableDNS)
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder, e.config.DisableDNS)
return dnsServer, nil
default:
@@ -1979,29 +2027,6 @@ func (e *Engine) GetClientMetrics() *metrics.ClientMetrics {
return e.clientMetrics
}
// Performance bundles runtime-adjustable tunnel pool knobs.
// See Engine.SetPerformance. Nil fields are ignored.
type Performance struct {
PreallocatedBuffersPerPool *uint32
}
// SetPerformance applies the given tuning to this engine's live Device.
func (e *Engine) SetPerformance(t Performance) error {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
if e.wgInterface == nil {
return fmt.Errorf("wg interface not initialized")
}
dev := e.wgInterface.GetWGDevice()
if dev == nil {
return fmt.Errorf("wg device not initialized")
}
if t.PreallocatedBuffersPerPool != nil {
dev.SetPreallocatedBuffersPerPool(*t.PreallocatedBuffersPerPool)
}
return nil
}
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
iface, err := net.InterfaceByName(ifaceName)
if err != nil {

View File

@@ -53,6 +53,7 @@ type Manager interface {
GetRouteSelector() *routeselector.RouteSelector
GetClientRoutes() route.HAMap
GetSelectedClientRoutes() route.HAMap
GetActiveClientRoutes() route.HAMap
GetClientRoutesWithNetID() map[route.NetID][]*route.Route
SetRouteChangeListener(listener listener.NetworkChangeListener)
InitialRouteRange() []string
@@ -485,6 +486,39 @@ func (m *DefaultManager) GetSelectedClientRoutes() route.HAMap {
return m.routeSelector.FilterSelectedExitNodes(maps.Clone(m.clientRoutes))
}
// GetActiveClientRoutes returns the subset of selected client routes
// that are currently reachable: the route's peer is Connected and is
// the one actively carrying the route (not just an HA sibling).
func (m *DefaultManager) GetActiveClientRoutes() route.HAMap {
m.mux.Lock()
selected := m.routeSelector.FilterSelectedExitNodes(maps.Clone(m.clientRoutes))
recorder := m.statusRecorder
m.mux.Unlock()
if recorder == nil {
return selected
}
out := make(route.HAMap, len(selected))
for id, routes := range selected {
for _, r := range routes {
st, err := recorder.GetPeer(r.Peer)
if err != nil {
continue
}
if st.ConnStatus != peer.StatusConnected {
continue
}
if _, hasRoute := st.GetRoutes()[r.Network.String()]; !hasRoute {
continue
}
out[id] = routes
break
}
}
return out
}
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
func (m *DefaultManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
m.mux.Lock()
@@ -704,7 +738,10 @@ func (m *DefaultManager) collectExitNodeInfo(clientRoutes route.HAMap) exitNodeI
}
func (m *DefaultManager) isExitNodeRoute(routes []*route.Route) bool {
return len(routes) > 0 && routes[0].Network.String() == vars.ExitNodeCIDR
if len(routes) == 0 {
return false
}
return route.IsV4DefaultRoute(routes[0].Network) || route.IsV6DefaultRoute(routes[0].Network)
}
func (m *DefaultManager) categorizeUserSelection(netID route.NetID, info *exitNodeInfo) {

View File

@@ -19,6 +19,7 @@ type MockManager struct {
GetRouteSelectorFunc func() *routeselector.RouteSelector
GetClientRoutesFunc func() route.HAMap
GetSelectedClientRoutesFunc func() route.HAMap
GetActiveClientRoutesFunc func() route.HAMap
GetClientRoutesWithNetIDFunc func() map[route.NetID][]*route.Route
StopFunc func(manager *statemanager.Manager)
}
@@ -78,6 +79,14 @@ func (m *MockManager) GetSelectedClientRoutes() route.HAMap {
return nil
}
// GetActiveClientRoutes mock implementation of GetActiveClientRoutes from the Manager interface
func (m *MockManager) GetActiveClientRoutes() route.HAMap {
if m.GetActiveClientRoutesFunc != nil {
return m.GetActiveClientRoutesFunc()
}
return nil
}
// GetClientRoutesWithNetID mock implementation of GetClientRoutesWithNetID from Manager interface
func (m *MockManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
if m.GetClientRoutesWithNetIDFunc != nil {

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"slices"
"strings"
"sync"
"github.com/hashicorp/go-multierror"
@@ -12,10 +13,6 @@ import (
"github.com/netbirdio/netbird/route"
)
const (
exitNodeCIDR = "0.0.0.0/0"
)
type RouteSelector struct {
mu sync.RWMutex
deselectedRoutes map[route.NetID]struct{}
@@ -124,13 +121,7 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
rs.mu.RLock()
defer rs.mu.RUnlock()
if rs.deselectAll {
return false
}
_, deselected := rs.deselectedRoutes[routeID]
isSelected := !deselected
return isSelected
return rs.isSelectedLocked(routeID)
}
// FilterSelected removes unselected routes from the provided map.
@@ -144,23 +135,22 @@ func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
filtered := route.HAMap{}
for id, rt := range routes {
netID := id.NetID()
_, deselected := rs.deselectedRoutes[netID]
if !deselected {
if !rs.isDeselectedLocked(id.NetID()) {
filtered[id] = rt
}
}
return filtered
}
// HasUserSelectionForRoute returns true if the user has explicitly selected or deselected this specific route
// HasUserSelectionForRoute returns true if the user has explicitly selected or deselected this route.
// Intended for exit-node code paths: a v6 exit-node pair (e.g. "MyExit-v6") with no explicit state of
// its own inherits its v4 base's state, so legacy persisted selections that predate v6 pairing
// transparently apply to the synthesized v6 entry.
func (rs *RouteSelector) HasUserSelectionForRoute(routeID route.NetID) bool {
rs.mu.RLock()
defer rs.mu.RUnlock()
_, selected := rs.selectedRoutes[routeID]
_, deselected := rs.deselectedRoutes[routeID]
return selected || deselected
return rs.hasUserSelectionForRouteLocked(rs.effectiveNetID(routeID))
}
func (rs *RouteSelector) FilterSelectedExitNodes(routes route.HAMap) route.HAMap {
@@ -174,7 +164,7 @@ func (rs *RouteSelector) FilterSelectedExitNodes(routes route.HAMap) route.HAMap
filtered := make(route.HAMap, len(routes))
for id, rt := range routes {
netID := id.NetID()
if rs.isDeselected(netID) {
if rs.isDeselectedLocked(netID) {
continue
}
@@ -189,13 +179,48 @@ func (rs *RouteSelector) FilterSelectedExitNodes(routes route.HAMap) route.HAMap
return filtered
}
func (rs *RouteSelector) isDeselected(netID route.NetID) bool {
// effectiveNetID returns the v4 base for a "-v6" exit pair entry that has no explicit
// state of its own, so selections made on the v4 entry govern the v6 entry automatically.
// Only call this from exit-node-specific code paths: applying it to a non-exit "-v6" route
// would make it inherit unrelated v4 state. Must be called with rs.mu held.
func (rs *RouteSelector) effectiveNetID(id route.NetID) route.NetID {
name := string(id)
if !strings.HasSuffix(name, route.V6ExitSuffix) {
return id
}
if _, ok := rs.selectedRoutes[id]; ok {
return id
}
if _, ok := rs.deselectedRoutes[id]; ok {
return id
}
return route.NetID(strings.TrimSuffix(name, route.V6ExitSuffix))
}
func (rs *RouteSelector) isSelectedLocked(routeID route.NetID) bool {
if rs.deselectAll {
return false
}
_, deselected := rs.deselectedRoutes[routeID]
return !deselected
}
func (rs *RouteSelector) isDeselectedLocked(netID route.NetID) bool {
if rs.deselectAll {
return true
}
_, deselected := rs.deselectedRoutes[netID]
return deselected || rs.deselectAll
return deselected
}
func (rs *RouteSelector) hasUserSelectionForRouteLocked(routeID route.NetID) bool {
_, selected := rs.selectedRoutes[routeID]
_, deselected := rs.deselectedRoutes[routeID]
return selected || deselected
}
func isExitNode(rt []*route.Route) bool {
return len(rt) > 0 && rt[0].Network.String() == exitNodeCIDR
return len(rt) > 0 && (route.IsV4DefaultRoute(rt[0].Network) || route.IsV6DefaultRoute(rt[0].Network))
}
func (rs *RouteSelector) applyExitNodeFilter(
@@ -204,26 +229,23 @@ func (rs *RouteSelector) applyExitNodeFilter(
rt []*route.Route,
out route.HAMap,
) {
if rs.hasUserSelections() {
// user made explicit selects/deselects
if rs.IsSelected(netID) {
// Exit-node path: apply the v4/v6 pair mirror so a deselect on the v4 base also
// drops the synthesized v6 entry that lacks its own explicit state.
effective := rs.effectiveNetID(netID)
if rs.hasUserSelectionForRouteLocked(effective) {
if rs.isSelectedLocked(effective) {
out[id] = rt
}
return
}
// no explicit selections: only include routes marked !SkipAutoApply (=AutoApply)
// no explicit selection for this route: defer to management's SkipAutoApply flag
sel := collectSelected(rt)
if len(sel) > 0 {
out[id] = sel
}
}
func (rs *RouteSelector) hasUserSelections() bool {
return len(rs.selectedRoutes) > 0 || len(rs.deselectedRoutes) > 0
}
func collectSelected(rt []*route.Route) []*route.Route {
var sel []*route.Route
for _, r := range rt {

View File

@@ -330,6 +330,137 @@ func TestRouteSelector_FilterSelectedExitNodes(t *testing.T) {
assert.Len(t, filtered, 0) // No routes should be selected
}
// TestRouteSelector_V6ExitPairInherits covers the v4/v6 exit-node pair selection
// mirror. The mirror is scoped to exit-node code paths: HasUserSelectionForRoute
// and FilterSelectedExitNodes resolve a "-v6" entry without explicit state to its
// v4 base, so legacy persisted selections that predate v6 pairing transparently
// apply to the synthesized v6 entry. General lookups (IsSelected, FilterSelected)
// stay literal so unrelated routes named "*-v6" don't inherit unrelated state.
func TestRouteSelector_V6ExitPairInherits(t *testing.T) {
all := []route.NetID{"exit1", "exit1-v6", "exit2", "exit2-v6", "corp", "corp-v6"}
t.Run("HasUserSelectionForRoute mirrors deselected v4 base", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
assert.True(t, rs.HasUserSelectionForRoute("exit1-v6"), "v6 pair sees v4 base's user selection")
// unrelated v6 with no v4 base touched is unaffected
assert.False(t, rs.HasUserSelectionForRoute("exit2-v6"))
})
t.Run("IsSelected stays literal for non-exit lookups", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"corp"}, all))
// A non-exit route literally named "corp-v6" must not inherit "corp"'s state
// via the mirror; the mirror only applies in exit-node code paths.
assert.False(t, rs.IsSelected("corp"))
assert.True(t, rs.IsSelected("corp-v6"), "non-exit *-v6 routes must not inherit unrelated v4 state")
})
t.Run("explicit v6 state overrides v4 base in filter", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
require.NoError(t, rs.SelectRoutes([]route.NetID{"exit1-v6"}, true, all))
v4Route := &route.Route{NetID: "exit1", Network: netip.MustParsePrefix("0.0.0.0/0")}
v6Route := &route.Route{NetID: "exit1-v6", Network: netip.MustParsePrefix("::/0")}
routes := route.HAMap{
"exit1|0.0.0.0/0": {v4Route},
"exit1-v6|::/0": {v6Route},
}
filtered := rs.FilterSelectedExitNodes(routes)
assert.NotContains(t, filtered, route.HAUniqueID("exit1|0.0.0.0/0"))
assert.Contains(t, filtered, route.HAUniqueID("exit1-v6|::/0"), "explicit v6 select wins over v4 base")
})
t.Run("non-v6-suffix routes unaffected", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
// A route literally named "exit1-something" must not pair-resolve.
assert.False(t, rs.HasUserSelectionForRoute("exit1-something"))
})
t.Run("filter v6 paired with deselected v4 base", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
v4Route := &route.Route{NetID: "exit1", Network: netip.MustParsePrefix("0.0.0.0/0")}
v6Route := &route.Route{NetID: "exit1-v6", Network: netip.MustParsePrefix("::/0")}
routes := route.HAMap{
"exit1|0.0.0.0/0": {v4Route},
"exit1-v6|::/0": {v6Route},
}
filtered := rs.FilterSelectedExitNodes(routes)
assert.Empty(t, filtered, "deselecting v4 base must also drop the v6 pair")
})
t.Run("non-exit *-v6 routes pass through FilterSelectedExitNodes", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"corp"}, all))
// A non-default-route entry named "corp-v6" is not an exit node and
// must not be skipped because its v4 base "corp" is deselected.
corpV6 := &route.Route{NetID: "corp-v6", Network: netip.MustParsePrefix("10.0.0.0/8")}
routes := route.HAMap{
"corp-v6|10.0.0.0/8": {corpV6},
}
filtered := rs.FilterSelectedExitNodes(routes)
assert.Contains(t, filtered, route.HAUniqueID("corp-v6|10.0.0.0/8"),
"non-exit *-v6 routes must not inherit unrelated v4 state in FilterSelectedExitNodes")
})
}
// TestRouteSelector_SkipAutoApplyPerRoute verifies that management's
// SkipAutoApply flag governs each untouched route independently, even when
// the user has explicit selections on other routes.
func TestRouteSelector_SkipAutoApplyPerRoute(t *testing.T) {
autoApplied := &route.Route{
NetID: "Auto",
Network: netip.MustParsePrefix("0.0.0.0/0"),
SkipAutoApply: false,
}
skipApply := &route.Route{
NetID: "Skip",
Network: netip.MustParsePrefix("0.0.0.0/0"),
SkipAutoApply: true,
}
routes := route.HAMap{
"Auto|0.0.0.0/0": {autoApplied},
"Skip|0.0.0.0/0": {skipApply},
}
rs := routeselector.NewRouteSelector()
// User makes an unrelated explicit selection elsewhere.
require.NoError(t, rs.DeselectRoutes([]route.NetID{"Unrelated"}, []route.NetID{"Auto", "Skip", "Unrelated"}))
filtered := rs.FilterSelectedExitNodes(routes)
assert.Contains(t, filtered, route.HAUniqueID("Auto|0.0.0.0/0"), "AutoApply route should be included")
assert.NotContains(t, filtered, route.HAUniqueID("Skip|0.0.0.0/0"), "SkipAutoApply route should be excluded without explicit user selection")
}
// TestRouteSelector_V6ExitIsExitNode verifies that ::/0 routes are recognized
// as exit nodes by the selector's filter path.
func TestRouteSelector_V6ExitIsExitNode(t *testing.T) {
v6Exit := &route.Route{
NetID: "V6Only",
Network: netip.MustParsePrefix("::/0"),
SkipAutoApply: true,
}
routes := route.HAMap{
"V6Only|::/0": {v6Exit},
}
rs := routeselector.NewRouteSelector()
filtered := rs.FilterSelectedExitNodes(routes)
assert.Empty(t, filtered, "::/0 should be treated as an exit node and respect SkipAutoApply")
}
func TestRouteSelector_NewRoutesBehavior(t *testing.T) {
initialRoutes := []route.NetID{"route1", "route2", "route3"}
newRoutes := []route.NetID{"route1", "route2", "route3", "route4", "route5"}

View File

@@ -162,11 +162,7 @@ func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
cfg.WgIface = interfaceName
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
hostDNS := []netip.AddrPort{
netip.MustParseAddrPort("9.9.9.9:53"),
netip.MustParseAddrPort("149.112.112.112:53"),
}
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, hostDNS, c.stateFile)
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile)
}
// Stop the internal client and free the resources

View File

@@ -64,6 +64,13 @@
<RegistryValue Name="InstalledByMSI" Type="integer" Value="1" KeyPath="yes" />
</RegistryKey>
</Component>
<!-- Drop the HKCU Run\Netbird value written by legacy NSIS installers. -->
<Component Id="NetbirdLegacyHKCUCleanup" Guid="*">
<RegistryValue Root="HKCU" Key="Software\NetBird GmbH\Installer"
Name="LegacyHKCUCleanup" Type="integer" Value="1" KeyPath="yes" />
<RemoveRegistryValue Root="HKCU"
Key="Software\Microsoft\Windows\CurrentVersion\Run" Name="Netbird" />
</Component>
</StandardDirectory>
<StandardDirectory Id="CommonAppDataFolder">
@@ -76,10 +83,28 @@
</Directory>
</StandardDirectory>
<!-- Drop Run, App Paths and Uninstall entries written by legacy NSIS
installers into the 32-bit registry view (HKLM\Software\Wow6432Node). -->
<Component Id="NetbirdLegacyWow6432Cleanup" Directory="NetbirdInstallDir"
Guid="bda5d628-16bd-4086-b2c1-5099d8d51763" Bitness="always32">
<RegistryValue Root="HKLM" Key="Software\NetBird GmbH\Installer"
Name="LegacyWow6432Cleanup" Type="integer" Value="1" KeyPath="yes" />
<RemoveRegistryValue Root="HKLM"
Key="Software\Microsoft\Windows\CurrentVersion\Run" Name="Netbird" />
<RemoveRegistryKey Action="removeOnInstall" Root="HKLM"
Key="Software\Microsoft\Windows\CurrentVersion\App Paths\Netbird" />
<RemoveRegistryKey Action="removeOnInstall" Root="HKLM"
Key="Software\Microsoft\Windows\CurrentVersion\App Paths\Netbird-ui" />
<RemoveRegistryKey Action="removeOnInstall" Root="HKLM"
Key="Software\Microsoft\Windows\CurrentVersion\Uninstall\Netbird" />
</Component>
<ComponentGroup Id="NetbirdFilesComponent">
<ComponentRef Id="NetbirdFiles" />
<ComponentRef Id="NetbirdAumidRegistry" />
<ComponentRef Id="NetbirdAutoStart" />
<ComponentRef Id="NetbirdLegacyHKCUCleanup" />
<ComponentRef Id="NetbirdLegacyWow6432Cleanup" />
</ComponentGroup>
<util:CloseApplication Id="CloseNetBird" CloseMessage="no" Target="netbird.exe" RebootPrompt="no" />

View File

@@ -0,0 +1,93 @@
package server
import (
"path/filepath"
"testing"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/profilemanager"
)
func TestPersistLoginOverrides(t *testing.T) {
strPtr := func(s string) *string { return &s }
tests := []struct {
name string
initialMgmtURL string
initialPSK string
newMgmtURL string
newPSK *string
wantMgmtURL string
wantPSK string
}{
{
name: "persist new management URL",
initialMgmtURL: "https://old.example.com:33073",
newMgmtURL: "https://new.example.com:33073",
wantMgmtURL: "https://new.example.com:33073",
},
{
name: "persist new pre-shared key",
initialMgmtURL: "https://existing.example.com:33073",
initialPSK: "old-key",
newPSK: strPtr("new-key"),
wantMgmtURL: "https://existing.example.com:33073",
wantPSK: "new-key",
},
{
name: "persist both",
initialMgmtURL: "https://old.example.com:33073",
initialPSK: "old-key",
newMgmtURL: "https://new.example.com:33073",
newPSK: strPtr("new-key"),
wantMgmtURL: "https://new.example.com:33073",
wantPSK: "new-key",
},
{
name: "no inputs preserves existing",
initialMgmtURL: "https://existing.example.com:33073",
initialPSK: "existing-key",
wantMgmtURL: "https://existing.example.com:33073",
wantPSK: "existing-key",
},
{
name: "empty PSK pointer is ignored",
initialMgmtURL: "https://existing.example.com:33073",
initialPSK: "existing-key",
newPSK: strPtr(""),
wantMgmtURL: "https://existing.example.com:33073",
wantPSK: "existing-key",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
origDefault := profilemanager.DefaultConfigPath
t.Cleanup(func() { profilemanager.DefaultConfigPath = origDefault })
dir := t.TempDir()
profilemanager.DefaultConfigPath = filepath.Join(dir, "default.json")
seed := profilemanager.ConfigInput{
ConfigPath: profilemanager.DefaultConfigPath,
ManagementURL: tt.initialMgmtURL,
}
if tt.initialPSK != "" {
seed.PreSharedKey = strPtr(tt.initialPSK)
}
_, err := profilemanager.UpdateOrCreateConfig(seed)
require.NoError(t, err, "seed config")
activeProf := &profilemanager.ActiveProfileState{Name: "default"}
err = persistLoginOverrides(activeProf, tt.newMgmtURL, tt.newPSK)
require.NoError(t, err, "persistLoginOverrides")
cfg, err := profilemanager.ReadConfig(profilemanager.DefaultConfigPath)
require.NoError(t, err, "read back config")
require.Equal(t, tt.wantMgmtURL, cfg.ManagementURL.String(), "management URL")
require.Equal(t, tt.wantPSK, cfg.PreSharedKey, "pre-shared key")
})
}
}

View File

@@ -490,6 +490,11 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
s.mutex.Unlock()
if err := persistLoginOverrides(activeProf, msg.ManagementUrl, msg.OptionalPreSharedKey); err != nil {
log.Errorf("failed to persist login overrides: %v", err)
return nil, fmt.Errorf("persist login overrides: %w", err)
}
config, _, err := s.getConfig(activeProf)
if err != nil {
log.Errorf("failed to get active profile config: %v", err)
@@ -964,7 +969,7 @@ func (s *Server) handleActiveProfileLogout(ctx context.Context) (*proto.LogoutRe
return &proto.LogoutResponse{}, nil
}
// GetConfig reads config file and returns Config and whether the config file already existed. Errors out if it does not exist
// getConfig reads config file and returns Config and whether the config file already existed. Errors out if it does not exist
func (s *Server) getConfig(activeProf *profilemanager.ActiveProfileState) (*profilemanager.Config, bool, error) {
cfgPath, err := activeProf.FilePath()
if err != nil {
@@ -1766,3 +1771,29 @@ func sendTerminalNotification() error {
return wallCmd.Wait()
}
// persistLoginOverrides writes management URL and pre-shared key from a LoginRequest to the
// active profile config so that subsequent reads pick them up. Empty/nil values are ignored.
func persistLoginOverrides(activeProf *profilemanager.ActiveProfileState, managementURL string, preSharedKey *string) error {
if preSharedKey != nil && *preSharedKey == "" {
preSharedKey = nil
}
if managementURL == "" && preSharedKey == nil {
return nil
}
cfgPath, err := activeProf.FilePath()
if err != nil {
return fmt.Errorf("active profile file path: %w", err)
}
input := profilemanager.ConfigInput{
ConfigPath: cfgPath,
ManagementURL: managementURL,
PreSharedKey: preSharedKey,
}
if _, err := profilemanager.UpdateOrCreateConfig(input); err != nil {
return fmt.Errorf("update config: %w", err)
}
return nil
}

View File

@@ -25,6 +25,7 @@ import (
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/util/netrelay"
)
const (
@@ -536,7 +537,7 @@ func (c *Client) LocalPortForward(ctx context.Context, localAddr, remoteAddr str
continue
}
go c.handleLocalForward(localConn, remoteAddr)
go c.handleLocalForward(ctx, localConn, remoteAddr)
}
}()
@@ -548,7 +549,7 @@ func (c *Client) LocalPortForward(ctx context.Context, localAddr, remoteAddr str
}
// handleLocalForward handles a single local port forwarding connection
func (c *Client) handleLocalForward(localConn net.Conn, remoteAddr string) {
func (c *Client) handleLocalForward(ctx context.Context, localConn net.Conn, remoteAddr string) {
defer func() {
if err := localConn.Close(); err != nil {
log.Debugf("local port forwarding: close local connection: %v", err)
@@ -571,7 +572,7 @@ func (c *Client) handleLocalForward(localConn net.Conn, remoteAddr string) {
}
}()
nbssh.BidirectionalCopy(log.NewEntry(log.StandardLogger()), localConn, channel)
netrelay.Relay(ctx, localConn, channel, netrelay.Options{Logger: log.NewEntry(log.StandardLogger())})
}
// RemotePortForward sets up remote port forwarding, binding on remote and forwarding to localAddr
@@ -653,16 +654,19 @@ func (c *Client) handleRemoteForwardChannels(ctx context.Context, localAddr stri
select {
case <-ctx.Done():
return
case newChan := <-channelRequests:
case newChan, ok := <-channelRequests:
if !ok {
return
}
if newChan != nil {
go c.handleRemoteForwardChannel(newChan, localAddr)
go c.handleRemoteForwardChannel(ctx, newChan, localAddr)
}
}
}
}
// handleRemoteForwardChannel handles a single forwarded-tcpip channel
func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr string) {
func (c *Client) handleRemoteForwardChannel(ctx context.Context, newChan ssh.NewChannel, localAddr string) {
channel, reqs, err := newChan.Accept()
if err != nil {
return
@@ -675,8 +679,14 @@ func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr st
go ssh.DiscardRequests(reqs)
localConn, err := net.Dial("tcp", localAddr)
// Bound the dial so a black-holed localAddr can't pin the accepted SSH
// channel open indefinitely; the relay itself runs under the outer ctx.
dialCtx, cancelDial := context.WithTimeout(ctx, 10*time.Second)
var dialer net.Dialer
localConn, err := dialer.DialContext(dialCtx, "tcp", localAddr)
cancelDial()
if err != nil {
log.Debugf("remote port forwarding: dial %s: %v", localAddr, err)
return
}
defer func() {
@@ -685,7 +695,7 @@ func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr st
}
}()
nbssh.BidirectionalCopy(log.NewEntry(log.StandardLogger()), localConn, channel)
netrelay.Relay(ctx, localConn, channel, netrelay.Options{Logger: log.NewEntry(log.StandardLogger())})
}
// tcpipForwardMsg represents the structure for tcpip-forward requests

View File

@@ -194,63 +194,3 @@ func buildAddressList(hostname string, remote net.Addr) []string {
return addresses
}
// BidirectionalCopy copies data bidirectionally between two io.ReadWriter connections.
// It waits for both directions to complete before returning.
// The caller is responsible for closing the connections.
func BidirectionalCopy(logger *log.Entry, rw1, rw2 io.ReadWriter) {
done := make(chan struct{}, 2)
go func() {
if _, err := io.Copy(rw2, rw1); err != nil && !isExpectedCopyError(err) {
logger.Debugf("copy error (1->2): %v", err)
}
done <- struct{}{}
}()
go func() {
if _, err := io.Copy(rw1, rw2); err != nil && !isExpectedCopyError(err) {
logger.Debugf("copy error (2->1): %v", err)
}
done <- struct{}{}
}()
<-done
<-done
}
func isExpectedCopyError(err error) bool {
return errors.Is(err, io.EOF) || errors.Is(err, context.Canceled)
}
// BidirectionalCopyWithContext copies data bidirectionally between two io.ReadWriteCloser connections.
// It waits for both directions to complete or for context cancellation before returning.
// Both connections are closed when the function returns.
func BidirectionalCopyWithContext(logger *log.Entry, ctx context.Context, conn1, conn2 io.ReadWriteCloser) {
done := make(chan struct{}, 2)
go func() {
if _, err := io.Copy(conn2, conn1); err != nil && !isExpectedCopyError(err) {
logger.Debugf("copy error (1->2): %v", err)
}
done <- struct{}{}
}()
go func() {
if _, err := io.Copy(conn1, conn2); err != nil && !isExpectedCopyError(err) {
logger.Debugf("copy error (2->1): %v", err)
}
done <- struct{}{}
}()
select {
case <-ctx.Done():
case <-done:
select {
case <-ctx.Done():
case <-done:
}
}
_ = conn1.Close()
_ = conn2.Close()
}

View File

@@ -229,18 +229,35 @@ func (m *Manager) buildHostPatterns(peer PeerSSHInfo) []string {
func (m *Manager) writeSSHConfig(sshConfig string) error {
sshConfigPath := filepath.Join(m.sshConfigDir, m.sshConfigFile)
sshConfigPathTmp := sshConfigPath + ".tmp"
if err := os.MkdirAll(m.sshConfigDir, 0755); err != nil {
return fmt.Errorf("create SSH config directory %s: %w", m.sshConfigDir, err)
}
if err := writeFileWithTimeout(sshConfigPathTmp, []byte(sshConfig), 0644); err != nil {
return fmt.Errorf("write SSH config file %s: %w", sshConfigPath, err)
tmp, err := os.CreateTemp(m.sshConfigDir, m.sshConfigFile+".*.tmp")
if err != nil {
return fmt.Errorf("create temp SSH config: %w", err)
}
tmpPath := tmp.Name()
defer func() {
if err := os.Remove(tmpPath); err != nil && !os.IsNotExist(err) {
log.Debugf("remove temp SSH config %s: %v", tmpPath, err)
}
}()
if err := tmp.Close(); err != nil {
return fmt.Errorf("close temp SSH config %s: %w", tmpPath, err)
}
if err := os.Rename(sshConfigPathTmp, sshConfigPath); err != nil {
return fmt.Errorf("rename ssh config %s -> %s: %w", sshConfigPathTmp, sshConfigPath, err)
if err := writeFileWithTimeout(tmpPath, []byte(sshConfig), 0644); err != nil {
return fmt.Errorf("write SSH config file %s: %w", tmpPath, err)
}
if err := os.Chmod(tmpPath, 0644); err != nil {
return fmt.Errorf("chmod SSH config file %s: %w", tmpPath, err)
}
if err := os.Rename(tmpPath, sshConfigPath); err != nil {
return fmt.Errorf("rename SSH config %s -> %s: %w", tmpPath, sshConfigPath, err)
}
log.Infof("Created NetBird SSH client config: %s", sshConfigPath)

View File

@@ -23,6 +23,7 @@ import (
"github.com/netbirdio/netbird/client/proto"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/util/netrelay"
"github.com/netbirdio/netbird/version"
)
@@ -352,7 +353,7 @@ func (p *SSHProxy) directTCPIPHandler(_ *ssh.Server, _ *cryptossh.ServerConn, ne
}
go cryptossh.DiscardRequests(clientReqs)
nbssh.BidirectionalCopyWithContext(log.NewEntry(log.StandardLogger()), sshCtx, clientChan, backendChan)
netrelay.Relay(sshCtx, clientChan, backendChan, netrelay.Options{Logger: log.NewEntry(log.StandardLogger())})
}
func (p *SSHProxy) sftpSubsystemHandler(s ssh.Session, jwtToken string) {
@@ -591,7 +592,7 @@ func (p *SSHProxy) handleForwardedChannel(sshCtx ssh.Context, sshConn *cryptossh
}
go cryptossh.DiscardRequests(clientReqs)
nbssh.BidirectionalCopyWithContext(log.NewEntry(log.StandardLogger()), sshCtx, clientChan, backendChan)
netrelay.Relay(sshCtx, clientChan, backendChan, netrelay.Options{Logger: log.NewEntry(log.StandardLogger())})
}
func (p *SSHProxy) dialBackend(ctx context.Context, addr, user, jwtToken string) (*cryptossh.Client, error) {

View File

@@ -17,7 +17,7 @@ import (
log "github.com/sirupsen/logrus"
cryptossh "golang.org/x/crypto/ssh"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/util/netrelay"
)
const privilegedPortThreshold = 1024
@@ -357,7 +357,7 @@ func (s *Server) handleRemoteForwardConnection(ctx ssh.Context, conn net.Conn, h
return
}
nbssh.BidirectionalCopyWithContext(logger, ctx, conn, channel)
netrelay.Relay(ctx, conn, channel, netrelay.Options{Logger: logger})
}
// openForwardChannel creates an SSH forwarded-tcpip channel

View File

@@ -8,9 +8,9 @@ import (
"fmt"
"io"
"net"
"strconv"
"net/netip"
"slices"
"strconv"
"strings"
"sync"
"time"
@@ -27,6 +27,7 @@ import (
"github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/auth/jwt"
"github.com/netbirdio/netbird/util/netrelay"
"github.com/netbirdio/netbird/version"
)
@@ -53,6 +54,10 @@ const (
DefaultJWTMaxTokenAge = 10 * 60
)
// directTCPIPDialTimeout bounds how long relayDirectTCPIP waits on a dial to
// the forwarded destination before rejecting the SSH channel.
const directTCPIPDialTimeout = 30 * time.Second
var (
ErrPrivilegedUserDisabled = errors.New(msgPrivilegedUserDisabled)
ErrUserNotFound = errors.New("user not found")
@@ -933,5 +938,29 @@ func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn,
s.addConnectionPortForward(ctx.User(), ctx.RemoteAddr(), forwardAddr)
logger.Infof("local port forwarding: %s", hostPort)
ssh.DirectTCPIPHandler(srv, conn, newChan, ctx)
s.relayDirectTCPIP(ctx, newChan, payload.Host, int(payload.Port), logger)
}
// relayDirectTCPIP is a netrelay-based replacement for gliderlabs'
// DirectTCPIPHandler. The upstream handler closes both sides on the first
// EOF; netrelay.Relay propagates CloseWrite so each direction drains on its
// own terms.
func (s *Server) relayDirectTCPIP(ctx ssh.Context, newChan cryptossh.NewChannel, host string, port int, logger *log.Entry) {
dest := net.JoinHostPort(host, strconv.Itoa(port))
dialer := net.Dialer{Timeout: directTCPIPDialTimeout}
dconn, err := dialer.DialContext(ctx, "tcp", dest)
if err != nil {
_ = newChan.Reject(cryptossh.ConnectionFailed, err.Error())
return
}
ch, reqs, err := newChan.Accept()
if err != nil {
_ = dconn.Close()
return
}
go cryptossh.DiscardRequests(reqs)
netrelay.Relay(ctx, dconn, ch, netrelay.Options{Logger: logger})
}

View File

@@ -193,7 +193,15 @@ func getOverlappingNetworks(routes []*proto.Network) []*proto.Network {
}
func isDefaultRoute(routeRange string) bool {
return routeRange == "0.0.0.0/0" || routeRange == "::/0"
// routeRange is the merged display string from the daemon, e.g. "0.0.0.0/0",
// "::/0", or "0.0.0.0/0, ::/0" when a v4 exit node has a paired v6 entry.
for _, part := range strings.Split(routeRange, ",") {
switch strings.TrimSpace(part) {
case "0.0.0.0/0", "::/0":
return true
}
}
return false
}
func getExitNodeNetworks(routes []*proto.Network) []*proto.Network {

View File

@@ -133,13 +133,18 @@ type ManagementConfig struct {
// AuthConfig contains authentication/identity provider settings
type AuthConfig struct {
Issuer string `yaml:"issuer"`
LocalAuthDisabled bool `yaml:"localAuthDisabled"`
SignKeyRefreshEnabled bool `yaml:"signKeyRefreshEnabled"`
Storage AuthStorageConfig `yaml:"storage"`
DashboardRedirectURIs []string `yaml:"dashboardRedirectURIs"`
CLIRedirectURIs []string `yaml:"cliRedirectURIs"`
Owner *AuthOwnerConfig `yaml:"owner,omitempty"`
Issuer string `yaml:"issuer"`
LocalAuthDisabled bool `yaml:"localAuthDisabled"`
SignKeyRefreshEnabled bool `yaml:"signKeyRefreshEnabled"`
MfaSessionMaxLifetime string `yaml:"mfaSessionMaxLifetime"`
MfaSessionIdleTimeout string `yaml:"mfaSessionIdleTimeout"`
MfaSessionRememberMe bool `yaml:"mfaSessionRememberMe"`
SessionCookieEncryptionKey string `yaml:"sessionCookieEncryptionKey"`
Storage AuthStorageConfig `yaml:"storage"`
DashboardRedirectURIs []string `yaml:"dashboardRedirectURIs"`
CLIRedirectURIs []string `yaml:"cliRedirectURIs"`
Owner *AuthOwnerConfig `yaml:"owner,omitempty"`
DashboardPostLogoutRedirectURIs []string `yaml:"dashboardPostLogoutRedirectURIs"`
}
// AuthStorageConfig contains auth storage settings
@@ -581,10 +586,14 @@ func (c *CombinedConfig) buildEmbeddedIdPConfig(mgmt ManagementConfig) (*idp.Emb
}
cfg := &idp.EmbeddedIdPConfig{
Enabled: true,
Issuer: mgmt.Auth.Issuer,
LocalAuthDisabled: mgmt.Auth.LocalAuthDisabled,
SignKeyRefreshEnabled: mgmt.Auth.SignKeyRefreshEnabled,
Enabled: true,
Issuer: mgmt.Auth.Issuer,
LocalAuthDisabled: mgmt.Auth.LocalAuthDisabled,
SignKeyRefreshEnabled: mgmt.Auth.SignKeyRefreshEnabled,
MfaSessionMaxLifetime: mgmt.Auth.MfaSessionMaxLifetime,
MfaSessionIdleTimeout: mgmt.Auth.MfaSessionIdleTimeout,
MfaSessionRememberMe: mgmt.Auth.MfaSessionRememberMe,
SessionCookieEncryptionKey: mgmt.Auth.SessionCookieEncryptionKey,
Storage: idp.EmbeddedStorageConfig{
Type: authStorageType,
Config: idp.EmbeddedStorageTypeConfig{
@@ -592,8 +601,9 @@ func (c *CombinedConfig) buildEmbeddedIdPConfig(mgmt ManagementConfig) (*idp.Emb
DSN: authStorageDSN,
},
},
DashboardRedirectURIs: mgmt.Auth.DashboardRedirectURIs,
CLIRedirectURIs: mgmt.Auth.CLIRedirectURIs,
DashboardRedirectURIs: mgmt.Auth.DashboardRedirectURIs,
CLIRedirectURIs: mgmt.Auth.CLIRedirectURIs,
DashboardPostLogoutRedirectURIs: mgmt.Auth.DashboardPostLogoutRedirectURIs,
}
if mgmt.Auth.Owner != nil && mgmt.Auth.Owner.Email != "" {

View File

@@ -86,6 +86,13 @@ server:
issuer: "https://example.com/oauth2"
localAuthDisabled: false
signKeyRefreshEnabled: false
# MFA session settings (applies when TOTP is enabled for an account)
# mfaSessionMaxLifetime: "24h" # Max duration for an MFA session from creation
# mfaSessionIdleTimeout: "1h" # MFA session expires after this idle period
# mfaSessionRememberMe: false # Pre-check "remember me" on login so the MFA session persists across tabs/restarts
# Optional AES key for encrypting embedded IdP session cookies. Can also be set via NB_IDP_SESSION_COOKIE_ENCRYPTION_KEY.
# Must be 16/24/32 raw bytes or base64-encoded to one of those lengths (for example: openssl rand -hex 16).
# sessionCookieEncryptionKey: ""
# OAuth2 redirect URIs for dashboard
dashboardRedirectURIs:
- "https://app.example.com/nb-auth"
@@ -93,6 +100,9 @@ server:
# OAuth2 redirect URIs for CLI
cliRedirectURIs:
- "http://localhost:53000/"
# OAuth2 post-logout redirect URIs for dashboard (RP-initiated logout)
# dashboardPostLogoutRedirectURIs:
# - "https://app.example.com/"
# Optional initial admin user
# owner:
# email: "admin@example.com"

View File

@@ -53,6 +53,9 @@ type NameServerGroup struct {
ID string `gorm:"primaryKey"`
// AccountID is a reference to Account that this object belongs
AccountID string `gorm:"index"`
// AccountSeqID is a per-account monotonically increasing identifier used as the
// compact wire id when sending NetworkMap components to capable peers.
AccountSeqID uint32 `json:"-" gorm:"index:idx_nameserver_groups_account_seq_id;not null;default:0"`
// Name group name
Name string
// Description group description

48
go.mod
View File

@@ -14,7 +14,7 @@ require (
github.com/onsi/gomega v1.27.6
github.com/rs/cors v1.8.0
github.com/sirupsen/logrus v1.9.4
github.com/spf13/cobra v1.10.1
github.com/spf13/cobra v1.10.2
github.com/spf13/pflag v1.0.9
github.com/vishvananda/netlink v1.3.1
golang.org/x/crypto v0.50.0
@@ -41,11 +41,11 @@ require (
github.com/cilium/ebpf v0.15.0
github.com/coder/websocket v1.8.14
github.com/coreos/go-iptables v0.7.0
github.com/coreos/go-oidc/v3 v3.14.1
github.com/coreos/go-oidc/v3 v3.18.0
github.com/creack/pty v1.1.24
github.com/crowdsecurity/crowdsec v1.7.7
github.com/crowdsecurity/go-cs-bouncer v0.0.21
github.com/dexidp/dex v0.0.0-00010101000000-000000000000
github.com/dexidp/dex v2.13.0+incompatible
github.com/dexidp/dex/api/v2 v2.4.0
github.com/ebitengine/purego v0.8.4
github.com/eko/gocache/lib/v4 v4.2.0
@@ -53,9 +53,9 @@ require (
github.com/eko/gocache/store/redis/v4 v4.2.2
github.com/fsnotify/fsnotify v1.9.0
github.com/gliderlabs/ssh v0.3.8
github.com/go-jose/go-jose/v4 v4.1.3
github.com/go-jose/go-jose/v4 v4.1.4
github.com/godbus/dbus/v5 v5.1.0
github.com/golang-jwt/jwt/v5 v5.3.0
github.com/golang-jwt/jwt/v5 v5.3.1
github.com/golang/mock v1.6.0
github.com/google/go-cmp v0.7.0
github.com/google/gopacket v1.1.19
@@ -72,7 +72,7 @@ require (
github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81
github.com/mdlayher/socket v0.5.1
github.com/mdp/qrterminal/v3 v3.2.1
github.com/miekg/dns v1.1.59
github.com/miekg/dns v1.1.72
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
@@ -113,7 +113,7 @@ require (
go.opentelemetry.io/otel/exporters/prometheus v0.64.0
go.opentelemetry.io/otel/metric v1.43.0
go.opentelemetry.io/otel/sdk/metric v1.43.0
go.uber.org/mock v0.5.2
go.uber.org/mock v0.6.0
go.uber.org/zap v1.27.0
goauthentik.io/api/v3 v3.2023051.3
golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b
@@ -141,7 +141,7 @@ require (
filippo.io/edwards25519 v1.1.1 // indirect
github.com/AppsFlyer/go-sundheit v0.6.0 // indirect
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect
github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 // indirect
github.com/Azure/go-ntlmssp v0.1.0 // indirect
github.com/BurntSushi/toml v1.5.0 // indirect
github.com/Masterminds/goutils v1.1.1 // indirect
github.com/Masterminds/semver/v3 v3.3.0 // indirect
@@ -168,6 +168,7 @@ require (
github.com/aws/smithy-go v1.23.0 // indirect
github.com/beevik/etree v1.6.0 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect
github.com/caddyserver/zerossl v0.1.3 // indirect
github.com/cenkalti/backoff/v5 v5.0.3 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
@@ -183,6 +184,7 @@ require (
github.com/docker/go-units v0.5.0 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/fredbi/uri v1.1.1 // indirect
github.com/fxamacker/cbor/v2 v2.9.1 // indirect
github.com/fyne-io/gl-js v0.2.0 // indirect
github.com/fyne-io/glfw-js v0.3.0 // indirect
github.com/fyne-io/image v0.1.1 // indirect
@@ -190,7 +192,7 @@ require (
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 // indirect
github.com/go-gl/gl v0.0.0-20231021071112-07e5d0ea2e71 // indirect
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a // indirect
github.com/go-ldap/ldap/v3 v3.4.12 // indirect
github.com/go-ldap/ldap/v3 v3.4.13 // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-ole/go-ole v1.3.0 // indirect
@@ -206,11 +208,15 @@ require (
github.com/go-sql-driver/mysql v1.9.3 // indirect
github.com/go-text/render v0.2.0 // indirect
github.com/go-text/typesetting v0.2.1 // indirect
github.com/go-viper/mapstructure/v2 v2.5.0 // indirect
github.com/go-webauthn/webauthn v0.16.4 // indirect
github.com/go-webauthn/x v0.2.3 // indirect
github.com/goccy/go-yaml v1.18.0 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang-jwt/jwt/v4 v4.5.2 // indirect
github.com/google/btree v1.1.2 // indirect
github.com/google/go-querystring v1.1.0 // indirect
github.com/google/go-tpm v0.9.8 // indirect
github.com/google/s2a-go v0.1.9 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.14 // indirect
github.com/googleapis/gax-go/v2 v2.21.0 // indirect
@@ -218,7 +224,13 @@ require (
github.com/hack-pad/go-indexeddb v0.3.2 // indirect
github.com/hack-pad/safejs v0.1.0 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect
github.com/hashicorp/go-cleanhttp v0.5.2 // indirect
github.com/hashicorp/go-retryablehttp v0.7.8 // indirect
github.com/hashicorp/go-secure-stdlib/parseutil v0.2.0 // indirect
github.com/hashicorp/go-secure-stdlib/strutil v0.1.2 // indirect
github.com/hashicorp/go-sockaddr v1.0.7 // indirect
github.com/hashicorp/go-uuid v1.0.3 // indirect
github.com/hashicorp/hcl v1.0.1-vault-7 // indirect
github.com/huandu/xstrings v1.5.0 // indirect
github.com/huin/goupnp v1.2.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
@@ -238,13 +250,13 @@ require (
github.com/klauspost/cpuid/v2 v2.2.10 // indirect
github.com/koron/go-ssdp v0.0.4 // indirect
github.com/kr/fs v0.1.0 // indirect
github.com/lib/pq v1.10.9 // indirect
github.com/lib/pq v1.12.3 // indirect
github.com/libdns/libdns v0.2.2 // indirect
github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae // indirect
github.com/magiconair/properties v1.8.10 // indirect
github.com/mailru/easyjson v0.9.0 // indirect
github.com/mattermost/xml-roundtrip-validator v0.1.0 // indirect
github.com/mattn/go-sqlite3 v1.14.32 // indirect
github.com/mattn/go-sqlite3 v1.14.42 // indirect
github.com/mdelapenya/tlscert v0.2.0 // indirect
github.com/mdlayher/genetlink v1.3.2 // indirect
github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 // indirect
@@ -265,8 +277,10 @@ require (
github.com/nxadm/tail v1.4.11 // indirect
github.com/oklog/ulid v1.3.1 // indirect
github.com/onsi/ginkgo/v2 v2.9.5 // indirect
github.com/openbao/openbao/api/v2 v2.5.1 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.1 // indirect
github.com/philhofer/fwd v1.2.0 // indirect
github.com/pion/dtls/v2 v2.2.10 // indirect
github.com/pion/dtls/v3 v3.0.9 // indirect
github.com/pion/mdns/v2 v2.0.7 // indirect
@@ -275,11 +289,13 @@ require (
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect
github.com/pquerna/otp v1.5.0 // indirect
github.com/prometheus/client_model v0.6.2 // indirect
github.com/prometheus/common v0.67.5 // indirect
github.com/prometheus/otlptranslator v1.0.0 // indirect
github.com/prometheus/procfs v0.19.2 // indirect
github.com/russellhaering/goxmldsig v1.6.0 // indirect
github.com/ryanuber/go-glob v1.0.0 // indirect
github.com/rymdport/portal v0.4.2 // indirect
github.com/shirou/gopsutil/v4 v4.25.8 // indirect
github.com/shoenig/go-m1cpu v0.2.1 // indirect
@@ -288,11 +304,13 @@ require (
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c // indirect
github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef // indirect
github.com/stretchr/objx v0.5.2 // indirect
github.com/tinylib/msgp v1.6.3 // indirect
github.com/tklauser/go-sysconf v0.3.15 // indirect
github.com/tklauser/numcpus v0.10.0 // indirect
github.com/vishvananda/netns v0.0.5 // indirect
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
github.com/wlynxg/anet v0.0.5 // indirect
github.com/x448/float16 v0.8.4 // indirect
github.com/yuin/goldmark v1.7.8 // indirect
github.com/zeebo/blake3 v0.2.3 // indirect
go.mongodb.org/mongo-driver v1.17.9 // indirect
@@ -317,12 +335,14 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260422100739-63c67f59bf58
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0
replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6
replace github.com/cloudflare/circl => codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6
replace github.com/pion/ice/v4 => github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51
replace github.com/dexidp/dex => github.com/netbirdio/dex v0.244.0
replace github.com/dexidp/dex => github.com/netbirdio/dex v0.244.1-0.20260512110716-8d70ad8647c1
replace github.com/dexidp/dex/api/v2 => github.com/netbirdio/dex/api/v2 v2.0.0-20260512110716-8d70ad8647c1
replace github.com/mailru/easyjson => github.com/netbirdio/easyjson v0.9.0

105
go.sum
View File

@@ -5,6 +5,8 @@ cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3R
cloud.google.com/go/compute/metadata v0.2.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs=
cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10=
codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 h1:b8xUw3004wk+3ipBhu0VU4RtUJsegMIiqjxSK4++lzA=
codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6/go.mod h1:+CauBF6R70Jqcyl8N2hC8pAXYbWkGIezuSbuGLtRhnw=
cunicu.li/go-rosenpass v0.4.0 h1:LtPtBgFWY/9emfgC4glKLEqS0MJTylzV6+ChRhiZERw=
cunicu.li/go-rosenpass v0.4.0/go.mod h1:MPbjH9nxV4l3vEagKVdFNwHOketqgS5/To1VYJplf/M=
dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s=
@@ -23,8 +25,8 @@ github.com/AppsFlyer/go-sundheit v0.6.0 h1:d2hBvCjBSb2lUsEWGfPigr4MCOt04sxB+Rppl
github.com/AppsFlyer/go-sundheit v0.6.0/go.mod h1:LDdBHD6tQBtmHsdW+i1GwdTt6Wqc0qazf5ZEJVTbTME=
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0=
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 h1:mFRzDkZVAjdal+s7s0MwaRv9igoPqLRdzOLzw/8Xvq8=
github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358/go.mod h1:chxPXzSsl7ZWRAuOIE23GDNzjWuZquvFlgA8xmpunjU=
github.com/Azure/go-ntlmssp v0.1.0 h1:DjFo6YtWzNqNvQdrwEyr/e4nhU3vRiwenz5QX7sFz+A=
github.com/Azure/go-ntlmssp v0.1.0/go.mod h1:NYqdhxd/8aAct/s4qSYZEerdPuH1liG2/X9DiVTbhpk=
github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg=
github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
github.com/Masterminds/goutils v1.1.1 h1:5nUrii3FMTL5diU80unEVvNevw1nH4+ZV4DSLVJLSYI=
@@ -91,6 +93,8 @@ github.com/beevik/etree v1.6.0/go.mod h1:bh4zJxiIr62SOf9pRzN7UUYaEDa9HEKafK25+sL
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/bmatcuk/doublestar v1.1.1/go.mod h1:UD6OnuiIn0yFxxA2le/rnRU1G4RaI4UvFv1sNto9p6w=
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI=
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
@@ -117,8 +121,8 @@ github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpS
github.com/containerd/platforms v0.2.1/go.mod h1:XHCb+2/hzowdiut9rkudds9bE5yJ7npe7dG/wG+uFPw=
github.com/coreos/go-iptables v0.7.0 h1:XWM3V+MPRr5/q51NuWSgU0fqMad64Zyxs8ZUoMsamr8=
github.com/coreos/go-iptables v0.7.0/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q=
github.com/coreos/go-oidc/v3 v3.14.1 h1:9ePWwfdwC4QKRlCXsJGou56adA/owXczOzwKdOumLqk=
github.com/coreos/go-oidc/v3 v3.14.1/go.mod h1:HaZ3szPaZ0e4r6ebqvsLWlk2Tn+aejfmrfah6hnSYEU=
github.com/coreos/go-oidc/v3 v3.18.0 h1:V9orjXynvu5wiC9SemFTWnG4F45v403aIcjWo0d41+A=
github.com/coreos/go-oidc/v3 v3.18.0/go.mod h1:DYCf24+ncYi+XkIH97GY1+dqoRlbaSI26KVTCI9SrY4=
github.com/cpuguy83/dockercfg v0.3.2 h1:DlJTyZGBDlXqUZ2Dk2Q3xHs/FtnooJJVaad2S9GKorA=
github.com/cpuguy83/dockercfg v0.3.2/go.mod h1:sugsbF4//dDlL/i+S+rtpIWp+5h0BHJHfjj5/jFyUJc=
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
@@ -130,14 +134,10 @@ github.com/crowdsecurity/go-cs-bouncer v0.0.21 h1:arPz0VtdVSaz+auOSfHythzkZVLyy1
github.com/crowdsecurity/go-cs-bouncer v0.0.21/go.mod h1:4JiH0XXA4KKnnWThItUpe5+heJHWzsLOSA2IWJqUDBA=
github.com/crowdsecurity/go-cs-lib v0.0.25 h1:Ov6VPW9yV+OPsbAIQk1iTkEWhwkpaG0v3lrBzeqjzj4=
github.com/crowdsecurity/go-cs-lib v0.0.25/go.mod h1:X0GMJY2CxdA1S09SpuqIKaWQsvRGxXmecUp9cP599dE=
github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 h1:/DS5cDX3FJdl+XaN2D7XAwFpuanTxnp52DBLZAaJKx0=
github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6/go.mod h1:+CauBF6R70Jqcyl8N2hC8pAXYbWkGIezuSbuGLtRhnw=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dexidp/dex/api/v2 v2.4.0 h1:gNba7n6BKVp8X4Jp24cxYn5rIIGhM6kDOXcZoL6tr9A=
github.com/dexidp/dex/api/v2 v2.4.0/go.mod h1:/p550ADvFFh7K95VmhUD+jgm15VdaNnab9td8DHOpyI=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
@@ -156,6 +156,8 @@ github.com/eko/gocache/store/go_cache/v4 v4.2.2 h1:tAI9nl6TLoJyKG1ujF0CS0n/IgTEM
github.com/eko/gocache/store/go_cache/v4 v4.2.2/go.mod h1:T9zkHokzr8K9EiC7RfMbDg6HSwaV6rv3UdcNu13SGcA=
github.com/eko/gocache/store/redis/v4 v4.2.2 h1:Thw31fzGuH3WzJywsdbMivOmP550D6JS7GDHhvCJPA0=
github.com/eko/gocache/store/redis/v4 v4.2.2/go.mod h1:LaTxLKx9TG/YUEybQvPMij++D7PBTIJ4+pzvk0ykz0w=
github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM=
github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU=
github.com/felixge/fgprof v0.9.3 h1:VvyZxILNuCiUCSXtPtYmmtGvb65nqXh2QFWc0Wpf2/g=
github.com/felixge/fgprof v0.9.3/go.mod h1:RdbpDgzqYVh/T9fPELJyV7EYJuHB55UTEULNun8eiPw=
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
@@ -171,6 +173,8 @@ github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4
github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw=
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/fxamacker/cbor/v2 v2.9.1 h1:2rWm8B193Ll4VdjsJY28jxs70IdDsHRWgQYAI80+rMQ=
github.com/fxamacker/cbor/v2 v2.9.1/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ=
github.com/fyne-io/gl-js v0.2.0 h1:+EXMLVEa18EfkXBVKhifYB6OGs3HwKO3lUElA0LlAjs=
github.com/fyne-io/gl-js v0.2.0/go.mod h1:ZcepK8vmOYLu96JoxbCKJy2ybr+g1pTnaBDdl7c3ajI=
github.com/fyne-io/glfw-js v0.3.0 h1:d8k2+Y7l+zy2pc7wlGRyPfTgZoqDf3AI4G+2zOWhWUk=
@@ -189,10 +193,10 @@ github.com/go-gl/gl v0.0.0-20231021071112-07e5d0ea2e71 h1:5BVwOaUSBTlVZowGO6VZGw
github.com/go-gl/gl v0.0.0-20231021071112-07e5d0ea2e71/go.mod h1:9YTyiznxEY1fVinfM7RvRcjRHbw2xLBJ3AAGIT0I4Nw=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a h1:vxnBhFDDT+xzxf1jTJKMKZw3H0swfWk9RpWbBbDK5+0=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs=
github.com/go-jose/go-jose/v4 v4.1.3/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08=
github.com/go-ldap/ldap/v3 v3.4.12 h1:1b81mv7MagXZ7+1r7cLTWmyuTqVqdwbtJSjC0DAp9s4=
github.com/go-ldap/ldap/v3 v3.4.12/go.mod h1:+SPAGcTtOfmGsCb3h1RFiq4xpp4N636G75OEace8lNo=
github.com/go-jose/go-jose/v4 v4.1.4 h1:moDMcTHmvE6Groj34emNPLs/qtYXRVcd6S7NHbHz3kA=
github.com/go-jose/go-jose/v4 v4.1.4/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08=
github.com/go-ldap/ldap/v3 v3.4.13 h1:+x1nG9h+MZN7h/lUi5Q3UZ0fJ1GyDQYbPvbuH38baDQ=
github.com/go-ldap/ldap/v3 v3.4.13/go.mod h1:LxsGZV6vbaK0sIvYfsv47rfh4ca0JXokCoKjZxsszv0=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
@@ -229,12 +233,20 @@ github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI6
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls=
github.com/go-test/deep v1.1.1 h1:0r/53hagsehfO4bzD2Pgr/+RgHqhmf+k1Bpse2cTu1U=
github.com/go-test/deep v1.1.1/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE=
github.com/go-text/render v0.2.0 h1:LBYoTmp5jYiJ4NPqDc2pz17MLmA3wHw1dZSVGcOdeAc=
github.com/go-text/render v0.2.0/go.mod h1:CkiqfukRGKJA5vZZISkjSYrcdtgKQWRa2HIzvwNN5SU=
github.com/go-text/typesetting v0.2.1 h1:x0jMOGyO3d1qFAPI0j4GSsh7M0Q3Ypjzr4+CEVg82V8=
github.com/go-text/typesetting v0.2.1/go.mod h1:mTOxEwasOFpAMBjEQDhdWRckoLLeI/+qrQeBCTGEt6M=
github.com/go-text/typesetting-utils v0.0.0-20241103174707-87a29e9e6066 h1:qCuYC+94v2xrb1PoS4NIDe7DGYtLnU2wWiQe9a1B1c0=
github.com/go-text/typesetting-utils v0.0.0-20241103174707-87a29e9e6066/go.mod h1:DDxDdQEnB70R8owOx3LVpEFvpMK9eeH1o2r0yZhFI9o=
github.com/go-viper/mapstructure/v2 v2.5.0 h1:vM5IJoUAy3d7zRSVtIwQgBj7BiWtMPfmPEgAXnvj1Ro=
github.com/go-viper/mapstructure/v2 v2.5.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
github.com/go-webauthn/webauthn v0.16.4 h1:R9jqR/cYZa7hRquFF7Za/8qoH/K/TIs1/Q/4CyGN+1Q=
github.com/go-webauthn/webauthn v0.16.4/go.mod h1:SU2ljAgToTV/YLPI0C05QS4qn+e04WpB5g1RMfcZfS4=
github.com/go-webauthn/x v0.2.3 h1:8oArS+Rc1SWFLXhE17KZNx258Z4kUSyaDgsSncCO5RA=
github.com/go-webauthn/x v0.2.3/go.mod h1:tM04GF3V6VYq79AZMl7vbj4q6pz9r7L2criWRzbWhPk=
github.com/goccy/go-yaml v1.18.0 h1:8W7wMFS12Pcas7KU+VVkaiCng+kG8QiFeFwzFb+rwuw=
github.com/goccy/go-yaml v1.18.0/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk=
@@ -243,8 +255,8 @@ github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI=
github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY=
github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
@@ -276,6 +288,10 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8=
github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU=
github.com/google/go-tpm v0.9.8 h1:slArAR9Ft+1ybZu0lBwpSmpwhRXaa85hWtMinMyRAWo=
github.com/google/go-tpm v0.9.8/go.mod h1:h9jEsEECg7gtLis0upRBQU+GhYVH6jMjrFxI8u6bVUY=
github.com/google/go-tpm-tools v0.3.13-0.20230620182252-4639ecce2aba h1:qJEJcuLzH5KDR0gKc0zcktin6KSAwL7+jWKBYceddTc=
github.com/google/go-tpm-tools v0.3.13-0.20230620182252-4639ecce2aba/go.mod h1:EFYHy8/1y2KfgTAsx7Luu7NGhoxtuVHnNo8jE7FikKc=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
@@ -308,15 +324,29 @@ github.com/hack-pad/safejs v0.1.0/go.mod h1:HdS+bKF1NrE72VoXZeWzxFOVQVUSqZJAG0xN
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I=
github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ=
github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48=
github.com/hashicorp/go-hclog v1.6.3 h1:Qr2kF+eVWjTiYmU7Y31tYlP1h0q/X3Nl3tPGdaB11/k=
github.com/hashicorp/go-hclog v1.6.3/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M=
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
github.com/hashicorp/go-retryablehttp v0.7.8 h1:ylXZWnqa7Lhqpk0L1P1LzDtGcCR0rPVUrx/c8Unxc48=
github.com/hashicorp/go-retryablehttp v0.7.8/go.mod h1:rjiScheydd+CxvumBsIrFKlx3iS0jrZ7LvzFGFmuKbw=
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 h1:ET4pqyjiGmY09R5y+rSd70J2w45CtbWDNvGqWp/R3Ng=
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2/go.mod h1:EdWO6czbmthiwZ3/PUsDV+UD1D5IRU4ActiaWGwt0Yw=
github.com/hashicorp/go-secure-stdlib/parseutil v0.2.0 h1:U+kC2dOhMFQctRfhK0gRctKAPTloZdMU5ZJxaesJ/VM=
github.com/hashicorp/go-secure-stdlib/parseutil v0.2.0/go.mod h1:Ll013mhdmsVDuoIXVfBtvgGJsXDYkTw1kooNcoCXuE0=
github.com/hashicorp/go-secure-stdlib/strutil v0.1.2 h1:kes8mmyCpxJsI7FTwtzRqEy9CdjCtrXrXGuOpxEA7Ts=
github.com/hashicorp/go-secure-stdlib/strutil v0.1.2/go.mod h1:Gou2R9+il93BqX25LAKCLuM+y9U2T4hlwvT1yprcna4=
github.com/hashicorp/go-sockaddr v1.0.7 h1:G+pTkSO01HpR5qCxg7lxfsFEZaG+C0VssTy/9dbT+Fw=
github.com/hashicorp/go-sockaddr v1.0.7/go.mod h1:FZQbEYa1pxkQ7WLpyXJ6cbjpT8q0YgQaK/JakXqGyWw=
github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8=
github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
github.com/hashicorp/go-version v1.7.0 h1:5tqGy27NaOTB8yJKUZELlFAS/LTKJkrmONwQKeRZfjY=
github.com/hashicorp/go-version v1.7.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA=
github.com/hashicorp/hcl v1.0.1-vault-7 h1:ag5OxFVy3QYTFTJODRzTKVZ6xvdfLLCA1cy/Y6xGI0I=
github.com/hashicorp/hcl v1.0.1-vault-7/go.mod h1:XYhtn6ijBSAj6n4YqAaf7RBPS4I06AItNorpy+MoQNM=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI=
github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE=
@@ -387,8 +417,8 @@ github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/leodido/go-urn v1.1.0/go.mod h1:+cyI34gQWZcE1eQU7NVgKkkzdXDQHr1dBMtdAPozLkw=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/lib/pq v1.12.3 h1:tTWxr2YLKwIvK90ZXEw8GP7UFHtcbTtty8zsI+YjrfQ=
github.com/lib/pq v1.12.3/go.mod h1:/p+8NSbOcwzAEI7wiMXFlgydTwcgTr3OSKMsD2BitpA=
github.com/libdns/libdns v0.2.2 h1:O6ws7bAfRPaBsgAYt8MDe2HcNBGC29hkZ9MX2eUSX3s=
github.com/libdns/libdns v0.2.2/go.mod h1:4Bj9+5CQiNMVGf87wjX4CY3HQJypUHRuLvlsfsZqLWQ=
github.com/libdns/route53 v1.5.0 h1:2SKdpPFl/qgWsXQvsLNJJAoX7rSxlk7zgoL4jnWdXVA=
@@ -406,9 +436,13 @@ github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8S
github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
github.com/mattermost/xml-roundtrip-validator v0.1.0 h1:RXbVD2UAl7A7nOTR4u7E3ILa4IbtvKBHw64LDsmu9hU=
github.com/mattermost/xml-roundtrip-validator v0.1.0/go.mod h1:qccnGMcpgwcNaBnxqpJpWWUiPNr5H3O8eDgGV9gT5To=
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ=
github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs=
github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.42 h1:MigqEP4ZmHw3aIdIT7T+9TLa90Z6smwcthx+Azv4Cgo=
github.com/mattn/go-sqlite3 v1.14.42/go.mod h1:pjEuOr8IwzLJP2MfGeTb0A35jauH+C2kbHKBr7yXKVQ=
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
github.com/mdelapenya/tlscert v0.2.0/go.mod h1:O4njj3ELLnJjGdkN7M/vIVCpZ+Cf0L6muqOG4tLSl8o=
github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw=
@@ -421,8 +455,8 @@ github.com/mdp/qrterminal/v3 v3.2.1 h1:6+yQjiiOsSuXT5n9/m60E54vdgFsw0zhADHhHLrFe
github.com/mdp/qrterminal/v3 v3.2.1/go.mod h1:jOTmXvnBsMy5xqLniO0R++Jmjs2sTm9dFSuQ5kpz/SU=
github.com/mholt/acmez/v2 v2.0.1 h1:3/3N0u1pLjMK4sNEAFSI+bcvzbPhRpY383sy1kLHJ6k=
github.com/mholt/acmez/v2 v2.0.1/go.mod h1:fX4c9r5jYwMyMsC+7tkYRxHibkOTgta5DIFGoe67e1U=
github.com/miekg/dns v1.1.59 h1:C9EXc/UToRwKLhK5wKU/I4QVsBUc8kE6MkHBkeypWZs=
github.com/miekg/dns v1.1.59/go.mod h1:nZpewl5p6IvctfgrckopVx2OlSEHPRO/U4SYkRklrEk=
github.com/miekg/dns v1.1.72 h1:vhmr+TF2A3tuoGNkLDFK9zi36F2LS+hKTRW0Uf8kbzI=
github.com/miekg/dns v1.1.72/go.mod h1:+EuEPhdHOsfk6Wk5TT2CzssZdqkmFhf8r+aVyDEToIs=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc=
github.com/mitchellh/copystructure v1.2.0 h1:vpKXTN4ewci03Vljg/q9QvCGUDttBOGBIa15WveJJGw=
@@ -451,8 +485,10 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/netbirdio/dex v0.244.0 h1:1GOvi8wnXYassnKGildzNqRHq0RbcfEUw7LKYpKIN7U=
github.com/netbirdio/dex v0.244.0/go.mod h1:STGInJhPcAflrHmDO7vyit2kSq03PdL+8zQPoGALtcU=
github.com/netbirdio/dex v0.244.1-0.20260512110716-8d70ad8647c1 h1:4TaYr9O4xX0D2kszeOLclTiCbA3eHq3xWV+9ILJbIYs=
github.com/netbirdio/dex v0.244.1-0.20260512110716-8d70ad8647c1/go.mod h1:IHH+H8vK2GfqtIt5u/5OdPh18yk0oDHuj2vz5+Goetg=
github.com/netbirdio/dex/api/v2 v2.0.0-20260512110716-8d70ad8647c1 h1:neE7z+FPUkldl3faK/Jt+hJK2L+1XfQ1W33TQhU9m88=
github.com/netbirdio/dex/api/v2 v2.0.0-20260512110716-8d70ad8647c1/go.mod h1:awuTyT29CYALpEyET0S307EgNlPWrc7fFKRAyhsO45M=
github.com/netbirdio/easyjson v0.9.0 h1:6Nw2lghSVuy8RSkAYDhDv1thBVEmfVbKZnV7T7Z6Aus=
github.com/netbirdio/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
@@ -463,8 +499,8 @@ github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9ax
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
github.com/netbirdio/wireguard-go v0.0.0-20260422100739-63c67f59bf58 h1:6REpBYpJBLTTgqCcLGpTqvRDoEoLbA5r2nAXqMd2La0=
github.com/netbirdio/wireguard-go v0.0.0-20260422100739-63c67f59bf58/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0 h1:h/QnNzm7xzHPm+gajcblYUOclrW2FeNeDlUNj6tTWKQ=
github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
github.com/nicksnyder/go-i18n/v2 v2.5.1 h1:IxtPxYsR9Gp60cGXjfuR/llTqV8aYMsC472zD0D1vHk=
@@ -489,6 +525,8 @@ github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7J
github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo=
github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE=
github.com/onsi/gomega v1.27.6/go.mod h1:PIQNjfQwkP3aQAH7lf7j87O/5FiNr+ZR8+ipb+qQlhg=
github.com/openbao/openbao/api/v2 v2.5.1 h1:Br79D6L20SbAa5P7xqENxmvv8LyI4HoKosPy7klhn4o=
github.com/openbao/openbao/api/v2 v2.5.1/go.mod h1:Dh5un77tqGgMbmlVEqjqN+8/dMyUohnkaQVg/wXW0Ig=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
@@ -501,6 +539,8 @@ github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203 h1:E7Kmf11E4K7B5hDti2K2NqPb1nlYlGYsu02S1JNd/Bs=
github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4=
github.com/philhofer/fwd v1.2.0 h1:e6DnBTl7vGY+Gz322/ASL4Gyp1FspeMvx1RNDoToZuM=
github.com/philhofer/fwd v1.2.0/go.mod h1:RqIHx9QI14HlwKwm98g9Re5prTQ6LdeRQn+gXJFxsJM=
github.com/pion/dtls/v2 v2.2.7/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ/F9s=
github.com/pion/dtls/v2 v2.2.10 h1:u2Axk+FyIR1VFTPurktB+1zoEPGIW3bmyj3LEFrXjAA=
github.com/pion/dtls/v2 v2.2.10/go.mod h1:d9SYc9fch0CqK90mRk1dC7AkzzpwJj6u2GU3u+9pqFE=
@@ -542,6 +582,8 @@ github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 h1:o4JXh1EVt9k/+g42oCprj/FisM4qX9L3sZB3upGN2ZU=
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
github.com/pquerna/otp v1.5.0 h1:NMMR+WrmaqXU4EzdGJEE1aUUI0AMRzsp96fFFWNPwxs=
github.com/pquerna/otp v1.5.0/go.mod h1:dkJfzwRKNiegxyNb54X/3fLwhCynbMspSyWKnvi1AEg=
github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o=
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
@@ -565,6 +607,8 @@ github.com/rs/xid v1.3.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/russellhaering/goxmldsig v1.6.0 h1:8fdWXEPh2k/NZNQBPFNoVfS3JmzS4ZprY/sAOpKQLks=
github.com/russellhaering/goxmldsig v1.6.0/go.mod h1:TrnaquDcYxWXfJrOjeMBTX4mLBeYAqaHEyUeWPxZlBM=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/ryanuber/go-glob v1.0.0 h1:iQh3xXAumdQ+4Ufa5b25cRpC5TYKlno6hsv6Cb3pkBk=
github.com/ryanuber/go-glob v1.0.0/go.mod h1:807d1WSdnB0XRJzKNil9Om6lcp/3a0v4qIHxIXzX/Yc=
github.com/rymdport/portal v0.4.2 h1:7jKRSemwlTyVHHrTGgQg7gmNPJs88xkbKcIL3NlcmSU=
github.com/rymdport/portal v0.4.2/go.mod h1:kFF4jslnJ8pD5uCi17brj/ODlfIidOxlgUDTO5ncnC4=
github.com/shirou/gopsutil/v3 v3.24.4 h1:dEHgzZXt4LMNm+oYELpzl9YCqV65Yr/6SfrvgRBtXeU=
@@ -587,8 +631,8 @@ github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 h1:TG/diQgUe0pntT/2D
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E=
github.com/spf13/cast v1.7.0 h1:ntdiHjuueXFgm5nzDRdOS4yfT43P5Fnud6DH50rz/7w=
github.com/spf13/cast v1.7.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
github.com/spf13/cobra v1.10.1 h1:lJeBwCfmrnXthfAupyUTzJ/J4Nc1RsHC/mSRU2dll/s=
github.com/spf13/cobra v1.10.1/go.mod h1:7SmJGaTHFVBY0jW4NXGluQoLvhqFQM+6XSKD+P4XaB0=
github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU=
github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4=
github.com/spf13/pflag v1.0.9 h1:9exaQaMOCwffKiiiYk6/BndUBv+iRViNW+4lEMi0PvY=
github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spkg/bom v0.0.0-20160624110644-59b7046e48ad/go.mod h1:qLr4V1qq6nMqFKkMo8ZTx3f+BZEkzsRUY10Xsm2mwU0=
@@ -628,6 +672,8 @@ github.com/ti-mo/conntrack v0.5.1 h1:opEwkFICnDbQc0BUXl73PHBK0h23jEIFVjXsqvF4GY0
github.com/ti-mo/conntrack v0.5.1/go.mod h1:T6NCbkMdVU4qEIgwL0njA6lw/iCAbzchlnwm1Sa314o=
github.com/ti-mo/netfilter v0.5.2 h1:CTjOwFuNNeZ9QPdRXt1MZFLFUf84cKtiQutNauHWd40=
github.com/ti-mo/netfilter v0.5.2/go.mod h1:Btx3AtFiOVdHReTDmP9AE+hlkOcvIy403u7BXXbWZKo=
github.com/tinylib/msgp v1.6.3 h1:bCSxiTz386UTgyT1i0MSCvdbWjVW+8sG3PjkGsZQt4s=
github.com/tinylib/msgp v1.6.3/go.mod h1:RSp0LW9oSxFut3KzESt5Voq4GVWyS+PSulT77roAqEA=
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
github.com/tklauser/go-sysconf v0.3.15 h1:VE89k0criAymJ/Os65CSn1IXaol+1wrsFHEB8Ol49K4=
github.com/tklauser/go-sysconf v0.3.15/go.mod h1:Dmjwr6tYFIseJw7a3dRLJfsHAMXZ3nEnL/aZY+0IuI4=
@@ -646,6 +692,8 @@ github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAh
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
github.com/wlynxg/anet v0.0.5 h1:J3VJGi1gvo0JwZ/P1/Yc/8p63SoW98B5dHkYDmpgvvU=
github.com/wlynxg/anet v0.0.5/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA=
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
@@ -690,14 +738,15 @@ go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lI
go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.uber.org/mock v0.5.2 h1:LbtPTcP8A5k9WPXj54PPPbjcI4Y6lhyOZXn+VS7wNko=
go.uber.org/mock v0.5.2/go.mod h1:wLlUxC2vVTPTaE3UD51E0BGOAElKrILxhVSDYQLld5o=
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0=
go.yaml.in/yaml/v2 v2.4.3/go.mod h1:zSxWcmIDjOzPXpjlTTbAsKokqkDNAVtZO0WOMiT90s8=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
goauthentik.io/api/v3 v3.2023051.3 h1:NebAhD/TeTWNo/9X3/Uj+rM5fG1HaiLOlKTNLQv9Qq4=
goauthentik.io/api/v3 v3.2023051.3/go.mod h1:nYECml4jGbp/541hj8GcylKQG1gVBsKppHy4+7G8u4U=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=

View File

@@ -51,6 +51,70 @@ type YAMLConfig struct {
// StaticPasswords cause the server use this list of passwords rather than
// querying the storage.
StaticPasswords []Password `yaml:"staticPasswords" json:"staticPasswords"`
// Sessions holds authentication session configuration.
// Requires DEX_SESSIONS_ENABLED=true feature flag.
Sessions *Sessions `yaml:"sessions" json:"sessions"`
// MFA holds multi-factor authentication configuration.
MFA MFAConfig `yaml:"mfa" json:"mfa"`
}
type Sessions struct {
// CookieName is the name of the session cookie. Defaults to "dex_session".
CookieName string `yaml:"cookieName" json:"cookieName"`
// AbsoluteLifetime is the maximum session lifetime from creation. Defaults to "24h".
AbsoluteLifetime string `yaml:"absoluteLifetime" json:"absoluteLifetime"`
// ValidIfNotUsedFor is the idle timeout. Defaults to "1h".
ValidIfNotUsedFor string `yaml:"validIfNotUsedFor" json:"validIfNotUsedFor"`
// RememberMeCheckedByDefault controls the default state of the "remember me" checkbox.
RememberMeCheckedByDefault *bool `yaml:"rememberMeCheckedByDefault" json:"rememberMeCheckedByDefault"`
// CookieEncryptionKey is the AES key for encrypting session cookies.
// Must be 16, 24, or 32 bytes for AES-128, AES-192, or AES-256.
// If empty, cookies are not encrypted.
CookieEncryptionKey string `yaml:"cookieEncryptionKey" json:"cookieEncryptionKey"`
// SSOSharedWithDefault is the default SSO sharing policy for clients without explicit ssoSharedWith.
// "all" = share with all clients, "none" = share with no one (default: "none").
SSOSharedWithDefault string `yaml:"ssoSharedWithDefault" json:"ssoSharedWithDefault"`
}
type MFAConfig struct {
Authenticators []MFAAuthenticator `yaml:"authenticators" json:"authenticators"`
}
type MFAAuthenticator struct {
ID string `yaml:"id" json:"id"`
Type string `yaml:"type" json:"type"`
Config map[string]interface{} `yaml:"config" json:"config"`
ConnectorTypes []string `yaml:"connectorTypes" json:"connectorTypes"`
}
type TOTPConfig struct {
Issuer string `yaml:"issuer" json:"issuer"`
}
// WebAuthnConfig holds configuration for a WebAuthn authenticator.
type WebAuthnConfig struct {
// RPDisplayName is the human-readable relying party name shown in the browser
// dialog during key registration and authentication (e.g., "My Company SSO").
RPDisplayName string `yaml:"rpDisplayName" json:"rpDisplayName"`
// RPID is the relying party identifier — must match the domain in the browser
// address bar. If empty, derived from the issuer URL hostname.
// Example: "auth.example.com"
RPID string `yaml:"rpID" json:"rpID"`
// RPOrigins is the list of allowed origins for WebAuthn ceremonies.
// If empty, derived from the issuer URL (scheme + host).
// Example: ["https://auth.example.com"]
RPOrigins []string `yaml:"rpOrigins" json:"rpOrigins"`
// AttestationPreference controls what attestation data the authenticator should provide:
// "none" — don't request attestation (simpler, more private)
// "indirect" — authenticator may anonymize attestation (default)
// "direct" — request full attestation (for enterprise key model verification)
AttestationPreference string `yaml:"attestationPreference" json:"attestationPreference"`
// Timeout is the duration allowed for the browser WebAuthn ceremony
// (registration or login). Defaults to "60s".
Timeout string `yaml:"timeout" json:"timeout"`
}
// Web is the config format for the HTTP server.
@@ -116,7 +180,6 @@ type Storage struct {
Config map[string]interface{} `yaml:"config" json:"config"`
}
// Password represents a static user configuration
type Password storage.Password
func (p *Password) UnmarshalYAML(node *yaml.Node) error {
@@ -245,7 +308,7 @@ func (s *Storage) OpenStorage(logger *slog.Logger) (storage.Storage, error) {
if file == "" {
return nil, fmt.Errorf("sqlite3 storage requires 'file' config")
}
return (&sql.SQLite3{File: file}).Open(logger)
return newSQLite3(file).Open(logger)
case "postgres":
dsn, _ := s.Config["dsn"].(string)
if dsn == "" {
@@ -429,9 +492,98 @@ func (c *YAMLConfig) Validate() error {
if !c.EnablePasswordDB && len(c.StaticPasswords) != 0 {
return fmt.Errorf("cannot specify static passwords without enabling password db")
}
return nil
}
func buildTotpConfig(auth MFAAuthenticator) (*server.TOTPProvider, error) {
data, err := json.Marshal(auth.Config)
if err != nil {
return nil, fmt.Errorf("failed to marshal TOTP config id: %s - %w", auth.ID, err)
}
var cfg TOTPConfig
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("failed to parse TOTP config id: %s - %w", auth.ID, err)
}
return server.NewTOTPProvider(cfg.Issuer, auth.ConnectorTypes), nil
}
func buildWebAuthnConfig(auth MFAAuthenticator, issuerURL string) (*server.WebAuthnProvider, error) {
data, err := json.Marshal(auth.Config)
if err != nil {
return nil, fmt.Errorf("failed to marshal WebAuthn config id: %s - %w", auth.ID, err)
}
var cfg WebAuthnConfig
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("failed to parse WebAuthn config id: %s - %w", auth.ID, err)
}
provider, err := server.NewWebAuthnProvider(cfg.RPDisplayName, cfg.RPID, cfg.RPOrigins,
cfg.AttestationPreference, cfg.Timeout, issuerURL, auth.ConnectorTypes)
if err != nil {
return nil, fmt.Errorf("failed to create WebAuthn provider id: %s - err: %w", auth.ID, err)
}
return provider, nil
}
func buildMFAProviders(authenticators []MFAAuthenticator, issuerURL string, logger *slog.Logger) map[string]server.MFAProvider {
if len(authenticators) == 0 {
return nil
}
providers := make(map[string]server.MFAProvider, len(authenticators))
for _, auth := range authenticators {
switch auth.Type {
case "TOTP":
provider, err := buildTotpConfig(auth)
if err != nil {
logger.Error("failed to parse TOTP config", "id", auth.ID, "err", err)
continue
}
providers[auth.ID] = provider
logger.Info("MFA authenticator configured", "id", auth.ID, "type", auth.Type)
case "WebAuthn":
provider, err := buildWebAuthnConfig(auth, issuerURL)
if err != nil {
logger.Error("failed to parse WebAuthn config", "id", auth.ID, "err", err)
continue
}
providers[auth.ID] = provider
logger.Info("MFA authenticator configured", "id", auth.ID, "type", auth.Type)
default:
logger.Error("unknown MFA authenticator type, skipping", "id", auth.ID, "type", auth.Type)
}
}
return providers
}
func buildSessionsConfig(sessions *Sessions) *server.SessionConfig {
if sessions == nil {
return nil
}
if sessions.RememberMeCheckedByDefault == nil {
defaultRememberMeCheckedByDefault := false
sessions.RememberMeCheckedByDefault = &defaultRememberMeCheckedByDefault
}
absoluteLifetime, _ := parseDuration(sessions.AbsoluteLifetime)
validIfNotUsedFor, _ := parseDuration(sessions.ValidIfNotUsedFor)
return &server.SessionConfig{
CookieEncryptionKey: []byte(sessions.CookieEncryptionKey),
CookieName: sessions.CookieName,
AbsoluteLifetime: absoluteLifetime,
ValidIfNotUsedFor: validIfNotUsedFor,
RememberMeCheckedByDefault: *sessions.RememberMeCheckedByDefault,
SSOSharedWithDefault: sessions.SSOSharedWithDefault,
}
}
// ToServerConfig converts YAMLConfig to dex server.Config
func (c *YAMLConfig) ToServerConfig(stor storage.Storage, logger *slog.Logger) server.Config {
cfg := server.Config{
@@ -448,6 +600,8 @@ func (c *YAMLConfig) ToServerConfig(stor storage.Storage, logger *slog.Logger) s
Dir: c.Frontend.Dir,
Extra: c.Frontend.Extra,
},
SessionConfig: buildSessionsConfig(c.Sessions),
MFAProviders: buildMFAProviders(c.MFA.Authenticators, c.Issuer, logger),
}
// Use embedded NetBird-styled templates if no custom dir specified
@@ -460,11 +614,6 @@ func (c *YAMLConfig) ToServerConfig(stor storage.Storage, logger *slog.Logger) s
}
// Apply expiry settings
if c.Expiry.SigningKeys != "" {
if d, err := parseDuration(c.Expiry.SigningKeys); err == nil {
cfg.RotateKeysAfter = d
}
}
if c.Expiry.IDTokens != "" {
if d, err := parseDuration(c.Expiry.IDTokens); err == nil {
cfg.IDTokensValidFor = d

View File

@@ -18,8 +18,8 @@ import (
dexapi "github.com/dexidp/dex/api/v2"
"github.com/dexidp/dex/server"
"github.com/dexidp/dex/server/signer"
"github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/sql"
jose "github.com/go-jose/go-jose/v4"
"github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus"
@@ -70,13 +70,13 @@ func NewProvider(ctx context.Context, config *Config) (*Provider, error) {
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
// Ensure data directory exists
if err := os.MkdirAll(config.DataDir, 0700); err != nil {
if err := os.MkdirAll(config.DataDir, 0o700); err != nil {
return nil, fmt.Errorf("failed to create data directory: %w", err)
}
// Initialize SQLite storage
dbPath := filepath.Join(config.DataDir, "oidc.db")
sqliteConfig := &sql.SQLite3{File: dbPath}
sqliteConfig := newSQLite3(dbPath)
stor, err := sqliteConfig.Open(logger)
if err != nil {
return nil, fmt.Errorf("failed to open storage: %w", err)
@@ -101,6 +101,15 @@ func NewProvider(ctx context.Context, config *Config) (*Provider, error) {
return nil, fmt.Errorf("failed to create refresh token policy: %w", err)
}
localSignerConfig := signer.LocalConfig{
KeysRotationPeriod: "6h",
}
localSigner, err := localSignerConfig.Open(ctx, stor, 24*time.Hour, time.Now, logger)
if err != nil {
return nil, fmt.Errorf("failed to create local signer: %w", err)
}
// Build Dex server config - use Dex's types directly
dexConfig := server.Config{
Issuer: issuer,
@@ -110,12 +119,12 @@ func NewProvider(ctx context.Context, config *Config) (*Provider, error) {
ContinueOnConnectorFailure: true,
Logger: logger,
PrometheusRegistry: prometheus.NewRegistry(),
RotateKeysAfter: 6 * time.Hour,
IDTokensValidFor: 24 * time.Hour,
RefreshTokenPolicy: refreshPolicy,
Web: server.WebConfig{
Issuer: "NetBird",
},
Signer: localSigner,
}
dexSrv, err := server.NewServer(ctx, dexConfig)
@@ -167,6 +176,14 @@ func NewProviderFromYAML(ctx context.Context, yamlConfig *YAMLConfig) (*Provider
return nil, fmt.Errorf("failed to create refresh token policy: %w", err)
}
localSigner, err := getSigner(ctx, stor, yamlConfig, logger)
if err != nil {
stor.Close()
return nil, fmt.Errorf("failed to create local signer: %w", err)
}
dexConfig.Signer = localSigner
dexSrv, err := server.NewServer(ctx, dexConfig)
if err != nil {
stor.Close()
@@ -182,6 +199,32 @@ func NewProviderFromYAML(ctx context.Context, yamlConfig *YAMLConfig) (*Provider
}, nil
}
func getSigner(ctx context.Context, stor storage.Storage, yamlConfig *YAMLConfig, logger *slog.Logger) (signer.Signer, error) {
// Parse expiry durations
idTokensValidFor := 24 * time.Hour // default
if yamlConfig.Expiry.IDTokens != "" {
var err error
idTokensValidFor, err = parseDuration(yamlConfig.Expiry.IDTokens)
if err != nil {
return nil, fmt.Errorf("invalid config value %q for id token expiry: %v", yamlConfig.Expiry.IDTokens, err)
}
}
localSignerConfig := &signer.LocalConfig{
KeysRotationPeriod: "720h", // 30 Days
}
if yamlConfig.Expiry.SigningKeys != "" {
if _, err := parseDuration(yamlConfig.Expiry.SigningKeys); err != nil {
return nil, fmt.Errorf("invalid config value %q for signing key expiry: %v", yamlConfig.Expiry.SigningKeys, err)
}
localSignerConfig.KeysRotationPeriod = yamlConfig.Expiry.SigningKeys
}
return localSignerConfig.Open(ctx, stor, idTokensValidFor, time.Now, logger)
}
// initializeStorage sets up connectors, passwords, and clients in storage
func initializeStorage(ctx context.Context, stor storage.Storage, cfg *YAMLConfig) error {
if cfg.EnablePasswordDB {
@@ -241,6 +284,8 @@ func ensureStaticClients(ctx context.Context, stor storage.Storage, clients []st
old.RedirectURIs = client.RedirectURIs
old.Name = client.Name
old.Public = client.Public
old.PostLogoutRedirectURIs = client.PostLogoutRedirectURIs
old.MFAChain = client.MFAChain
return old, nil
}); err != nil {
return fmt.Errorf("failed to update client %s: %w", client.ID, err)
@@ -253,9 +298,6 @@ func ensureStaticClients(ctx context.Context, stor storage.Storage, clients []st
func buildDexConfig(yamlConfig *YAMLConfig, stor storage.Storage, logger *slog.Logger) server.Config {
cfg := yamlConfig.ToServerConfig(stor, logger)
cfg.PrometheusRegistry = prometheus.NewRegistry()
if cfg.RotateKeysAfter == 0 {
cfg.RotateKeysAfter = 24 * 30 * time.Hour
}
if cfg.IDTokensValidFor == 0 {
cfg.IDTokensValidFor = 24 * time.Hour
}
@@ -450,10 +492,34 @@ func (p *Provider) Storage() storage.Storage {
return p.storage
}
// SetClientsMFAChain updates the MFAChain field on the dashboard and CLI OAuth2 clients.
// Pass a non-empty slice (e.g. []string{"default-totp"}) to enable MFA, or nil to disable it.
func (p *Provider) SetClientsMFAChain(ctx context.Context, clientIDs []string, mfaChain []string) error {
for _, clientID := range clientIDs {
if err := p.storage.UpdateClient(ctx, clientID, func(old storage.Client) (storage.Client, error) {
old.MFAChain = mfaChain
return old, nil
}); err != nil {
return fmt.Errorf("failed to update MFA chain on client %s: %w", clientID, err)
}
}
return nil
}
// Handler returns the Dex server as an http.Handler for embedding in another server.
// The handler expects requests with path prefix "/oauth2/".
func (p *Provider) Handler() http.Handler {
return p.dexServer
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Dex's /logout endpoint requires id_token_hint for RP-initiated logout with
// post_logout_redirect_uri. If the dashboard calls logout without one, avoid
// rendering Dex's non-actionable Bad Request page and send the user home.
if strings.HasSuffix(r.URL.Path, "/logout") && r.FormValue("id_token_hint") == "" {
http.Redirect(w, r, "/", http.StatusSeeOther)
return
}
p.dexServer.ServeHTTP(w, r)
})
}
// CreateUser creates a new user with the given email, username, and password.

View File

@@ -4,6 +4,8 @@ import (
"context"
"encoding/json"
"log/slog"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
@@ -144,6 +146,30 @@ func TestEncodeDexUserID_MatchesDexFormat(t *testing.T) {
assert.Equal(t, knownEncodedID, reEncoded)
}
func TestHandlerRedirectsLogoutWithoutIDTokenHint(t *testing.T) {
ctx := context.Background()
tmpDir, err := os.MkdirTemp("", "dex-logout-handler-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
provider, err := NewProvider(ctx, &Config{
Issuer: "http://localhost:5556/oauth2",
Port: 5556,
DataDir: tmpDir,
})
require.NoError(t, err)
defer func() { _ = provider.Stop(ctx) }()
req := httptest.NewRequest(http.MethodGet, "/oauth2/logout?post_logout_redirect_uri=https://example.com", nil)
rec := httptest.NewRecorder()
provider.Handler().ServeHTTP(rec, req)
require.Equal(t, http.StatusSeeOther, rec.Code)
require.Equal(t, "/", rec.Header().Get("Location"))
}
func TestCreateUserInTempDB(t *testing.T) {
ctx := context.Background()

View File

@@ -0,0 +1,12 @@
{{ template "header.html" . }}
<script>globalThis.location.replace("/");</script>
<noscript>
<div class="nb-card">
<h1 class="nb-heading">Redirecting…</h1>
<p class="nb-subheading">You are being redirected to the NetBird dashboard.</p>
<a href="/" class="nb-btn" style="display:block;text-align:center;text-decoration:none">Go to Dashboard</a>
</div>
</noscript>
{{ template "footer.html" . }}

View File

@@ -0,0 +1,14 @@
{{ template "header.html" . }}
<div class="nb-card">
<h1 class="nb-heading">Logged Out</h1>
<p class="nb-subheading">You have been successfully logged out.</p>
{{ if .BackURL }}
<div class="nb-back-link">
<a href="{{ .BackURL }}" class="nb-link">&larr; Back to Application</a>
</div>
{{ end }}
</div>
{{ template "footer.html" . }}

View File

@@ -18,6 +18,7 @@
id="login"
name="login"
class="nb-input"
autocomplete="username"
placeholder="Enter your {{ .UsernamePrompt | lower }}"
{{ if .Username }}value="{{ .Username }}"{{ else }}autofocus{{ end }}
required
@@ -31,6 +32,7 @@
id="password"
name="password"
class="nb-input"
autocomplete="current-password"
placeholder="Enter your password"
{{ if .Invalid }}autofocus{{ end }}
required

View File

@@ -0,0 +1,44 @@
{{ template "header.html" . }}
<div class="nb-card">
<h1 class="nb-heading">Two-factor authentication</h1>
{{ if not (eq .QRCode "") }}
<p class="nb-subheading">Scan the QR code below using your authenticator app, then enter the code.</p>
<div style="text-align: center; margin: 1em 0;">
<img src="data:image/png;base64,{{ .QRCode }}" alt="QR code" width="200" height="200"/>
</div>
{{ else }}
<p class="nb-subheading">Enter the code from your authenticator app.</p>
{{ end }}
<form method="post" action="{{ .PostURL }}">
{{ if .Invalid }}
<div class="nb-error">
Invalid code. Please try again.
</div>
{{ end }}
<div class="nb-form-group">
<label class="nb-label" for="totp">One-time code</label>
<input
type="text"
id="totp"
name="totp"
class="nb-input"
inputmode="numeric"
pattern="[0-9]*"
maxlength="6"
autocomplete="one-time-code"
placeholder="000000"
autofocus
required
>
</div>
<button type="submit" id="submit-login" class="nb-btn">
Verify
</button>
</form>
</div>
{{ template "footer.html" . }}

View File

@@ -0,0 +1,12 @@
{{ template "header.html" . }}
<script>globalThis.location.replace("/");</script>
<noscript>
<div class="nb-card">
<h1 class="nb-heading">Redirecting…</h1>
<p class="nb-subheading">You are being redirected to the NetBird dashboard.</p>
<a href="/" class="nb-btn" style="display:block;text-align:center;text-decoration:none">Go to Dashboard</a>
</div>
</noscript>
{{ template "footer.html" . }}

View File

@@ -55,6 +55,15 @@ type Controller struct {
proxyController port_forwarding.Controller
integratedPeerValidator integrated_validator.IntegratedValidator
// componentsDisabled is the kill switch for the component-based wire
// format. When true the controller emits legacy proto.NetworkMap to every
// peer regardless of capability — used to roll back instantly via a
// management restart from a bad components encoder.
//
// Set once in NewController from NB_NETWORK_MAP_COMPONENTS_DISABLE and
// never written after — readers race-free without a mutex.
componentsDisabled bool
}
type bufferUpdate struct {
@@ -81,12 +90,30 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
settingsManager: settingsManager,
dnsDomain: dnsDomain,
config: config,
componentsDisabled: parseBoolEnv("NB_NETWORK_MAP_COMPONENTS_DISABLE"),
proxyController: proxyController,
EphemeralPeersManager: ephemeralPeersManager,
}
}
// PeerNeedsComponents reports whether the gRPC layer should emit the
// component-based wire format for this peer. Combines the peer's advertised
// capability with the controller-level kill switch — callers ask exactly
// this question, so encapsulating it removes accidental double-checks.
func (c *Controller) PeerNeedsComponents(p *nbpeer.Peer) bool {
return p != nil && p.SupportsComponentNetworkMap() && !c.componentsDisabled
}
// parseBoolEnv reads an env var via strconv.ParseBool so callers accept the
// usual "1/t/T/TRUE/true/True" set instead of being strict about a single
// literal — matches the convention used elsewhere in the codebase
// (e.g. event.go's NB_TRAFFIC_EVENT_*) and reduces operator surprises.
func parseBoolEnv(key string) bool {
v, _ := strconv.ParseBool(os.Getenv(key))
return v
}
func (c *Controller) OnPeerConnected(ctx context.Context, accountID string, peerID string) (chan *network_map.UpdateMessage, error) {
peer, err := c.repo.GetPeerByID(ctx, accountID, peerID)
if err != nil {
@@ -192,18 +219,26 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
c.metrics.CountCalcPostureChecksDuration(time.Since(start))
start = time.Now()
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
result := account.GetPeerNetworkMapResult(ctx, p.ID, c.componentsDisabled, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok {
remotePeerNetworkMap.Merge(proxyNetworkMap)
proxyNetworkMap := proxyNetworkMaps[p.ID]
if result.NetworkMap != nil && proxyNetworkMap != nil {
result.NetworkMap.Merge(proxyNetworkMap)
}
peerGroups := account.GetPeerGroups(p.ID)
start = time.Now()
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
var update *proto.SyncResponse
if result.IsComponents() {
// proxyNetworkMap rides the envelope as a ProxyPatch sidecar;
// the client merges it into Calculate()'s output the same
// way the legacy server did via NetworkMap.Merge.
update = grpc.ToComponentSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, result.Components, proxyNetworkMap, dnsDomain, postureChecks, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
} else {
update = grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, result.NetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
}
c.metrics.CountToSyncResponseDuration(time.Since(start))
c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{
@@ -221,9 +256,13 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
return nil
}
func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID string) error {
func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error {
log.WithContext(ctx).Tracef("buffer sending update peers for account %s from %s", accountID, util.GetCallerName())
if c.accountManagerMetrics != nil {
c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation))
}
bufUpd, _ := c.sendAccountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{})
b := bufUpd.(*bufferUpdate)
@@ -310,11 +349,11 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
return err
}
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
result := account.GetPeerNetworkMapResult(ctx, peerId, c.componentsDisabled, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok {
remotePeerNetworkMap.Merge(proxyNetworkMap)
proxyNetworkMap := proxyNetworkMaps[peer.ID]
if result.NetworkMap != nil && proxyNetworkMap != nil {
result.NetworkMap.Merge(proxyNetworkMap)
}
extraSettings, err := c.settingsManager.GetExtraSettings(ctx, peer.AccountID)
@@ -325,7 +364,12 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
peerGroups := account.GetPeerGroups(peerId)
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
var update *proto.SyncResponse
if result.IsComponents() {
update = grpc.ToComponentSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, result.Components, proxyNetworkMap, dnsDomain, postureChecks, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
} else {
update = grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, result.NetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
}
c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{
Update: update,
MessageType: network_map.MessageTypeNetworkMap,
@@ -372,6 +416,67 @@ func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID str
return nil
}
// GetValidatedPeerWithComponents is the components-format counterpart of
// GetValidatedPeerWithMap. It returns raw NetworkMapComponents for capable
// peers along with the proxy NetworkMap fragment (BYOP / port-forwarding
// data the legacy server folds in via NetworkMap.Merge). The gRPC layer
// encodes both into the wire envelope. The caller is responsible for
// checking peer capability + componentsDisabled before dispatching here —
// this method does NOT branch on capability itself.
func (c *Controller) GetValidatedPeerWithComponents(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMapComponents, *types.NetworkMap, []*posture.Checks, int64, error) {
if isRequiresApproval {
network, err := c.repo.GetAccountNetwork(ctx, accountID)
if err != nil {
return nil, nil, nil, nil, 0, err
}
return peer, &types.NetworkMapComponents{Network: network.Copy()}, nil, nil, 0, nil
}
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, nil, nil, nil, 0, err
}
account.InjectProxyPolicies(ctx)
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
return nil, nil, nil, nil, 0, err
}
postureChecks, err := c.getPeerPostureChecks(account, peer.ID)
if err != nil {
return nil, nil, nil, nil, 0, err
}
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
if err != nil {
return nil, nil, nil, nil, 0, err
}
// Fetch the proxy network map fragment for this peer alongside the
// components — same single-account-load path the streaming controller
// uses, so initial-sync delivers BYOP/forwarding patches synchronously
// instead of waiting for the next streaming push.
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers)
if err != nil {
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
return nil, nil, nil, nil, 0, err
}
dnsDomain := c.GetDNSDomain(account.Settings)
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
components := account.GetPeerNetworkMapComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, groupIDToUserIDs)
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
return peer, components, proxyNetworkMaps[peer.ID], postureChecks, dnsFwdPort, nil
}
func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
if isRequiresApproval {
network, err := c.repo.GetAccountNetwork(ctx, accountID)
@@ -570,7 +675,7 @@ func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *t
}
func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string) error {
err := c.bufferSendUpdateAccountPeers(ctx, accountID)
err := c.bufferSendUpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationUpdate})
if err != nil {
log.WithContext(ctx).Errorf("failed to buffer update account peers for peer update in account %s: %v", accountID, err)
}
@@ -580,7 +685,7 @@ func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerI
func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error {
log.WithContext(ctx).Debugf("OnPeersAdded call to add peers: %v", peerIDs)
return c.bufferSendUpdateAccountPeers(ctx, accountID)
return c.bufferSendUpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationCreate})
}
func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error {
@@ -616,7 +721,7 @@ func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerI
c.peersUpdateManager.CloseChannel(ctx, peerID)
}
return c.bufferSendUpdateAccountPeers(ctx, accountID)
return c.bufferSendUpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationDelete})
}
// GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result)

View File

@@ -22,6 +22,10 @@ type Controller interface {
UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error
BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error
GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
GetValidatedPeerWithComponents(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMapComponents, *types.NetworkMap, []*posture.Checks, int64, error)
// PeerNeedsComponents combines the peer's advertised capability with the
// kill-switch flag — the only public predicate gRPC layers should ask.
PeerNeedsComponents(p *nbpeer.Peer) bool
GetDNSDomain(settings *types.Settings) string
StartWarmup(context.Context)
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)

View File

@@ -130,6 +130,39 @@ func (mr *MockControllerMockRecorder) GetValidatedPeerWithMap(ctx, isRequiresApp
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, p)
}
// GetValidatedPeerWithComponents mocks base method.
func (m *MockController) GetValidatedPeerWithComponents(ctx context.Context, isRequiresApproval bool, accountID string, p *peer.Peer) (*peer.Peer, *types.NetworkMapComponents, *types.NetworkMap, []*posture.Checks, int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetValidatedPeerWithComponents", ctx, isRequiresApproval, accountID, p)
ret0, _ := ret[0].(*peer.Peer)
ret1, _ := ret[1].(*types.NetworkMapComponents)
ret2, _ := ret[2].(*types.NetworkMap)
ret3, _ := ret[3].([]*posture.Checks)
ret4, _ := ret[4].(int64)
ret5, _ := ret[5].(error)
return ret0, ret1, ret2, ret3, ret4, ret5
}
// GetValidatedPeerWithComponents indicates an expected call of GetValidatedPeerWithComponents.
func (mr *MockControllerMockRecorder) GetValidatedPeerWithComponents(ctx, isRequiresApproval, accountID, p any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithComponents", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithComponents), ctx, isRequiresApproval, accountID, p)
}
// PeerNeedsComponents mocks base method.
func (m *MockController) PeerNeedsComponents(p *peer.Peer) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "PeerNeedsComponents", p)
ret0, _ := ret[0].(bool)
return ret0
}
// PeerNeedsComponents indicates an expected call of PeerNeedsComponents.
func (mr *MockControllerMockRecorder) PeerNeedsComponents(p any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PeerNeedsComponents", reflect.TypeOf((*MockController)(nil).PeerNeedsComponents), p)
}
// OnPeerConnected mocks base method.
func (m *MockController) OnPeerConnected(ctx context.Context, accountID, peerID string) (chan *UpdateMessage, error) {
m.ctrl.T.Helper()

View File

@@ -11,6 +11,7 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
"github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/store"
)
@@ -47,6 +48,11 @@ type EphemeralManager struct {
lifeTime time.Duration
cleanupWindow time.Duration
// metrics is nil-safe; methods on telemetry.EphemeralPeersMetrics
// no-op when the receiver is nil so deployments without an app
// metrics provider work unchanged.
metrics *telemetry.EphemeralPeersMetrics
}
// NewEphemeralManager instantiate new EphemeralManager
@@ -60,6 +66,15 @@ func NewEphemeralManager(store store.Store, peersManager peers.Manager) *Ephemer
}
}
// SetMetrics attaches a metrics collector. Safe to call once before
// LoadInitialPeers; later attachment is fine but earlier loads won't be
// reflected in the gauge. Pass nil to detach.
func (e *EphemeralManager) SetMetrics(m *telemetry.EphemeralPeersMetrics) {
e.peersLock.Lock()
e.metrics = m
e.peersLock.Unlock()
}
// LoadInitialPeers load from the database the ephemeral type of peers and schedule a cleanup procedure to the head
// of the linked list (to the most deprecated peer). At the end of cleanup it schedules the next cleanup to the new
// head.
@@ -97,7 +112,9 @@ func (e *EphemeralManager) OnPeerConnected(ctx context.Context, peer *nbpeer.Pee
e.peersLock.Lock()
defer e.peersLock.Unlock()
e.removePeer(peer.ID)
if e.removePeer(peer.ID) {
e.metrics.DecPending(1)
}
// stop the unnecessary timer
if e.headPeer == nil && e.timer != nil {
@@ -123,6 +140,7 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.
}
e.addPeer(peer.AccountID, peer.ID, e.newDeadLine())
e.metrics.IncPending()
if e.timer == nil {
delay := e.headPeer.deadline.Sub(timeNow()) + e.cleanupWindow
if delay < 0 {
@@ -145,6 +163,7 @@ func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) {
for _, p := range peers {
e.addPeer(p.AccountID, p.ID, t)
}
e.metrics.AddPending(int64(len(peers)))
log.WithContext(ctx).Debugf("loaded ephemeral peer(s): %d", len(peers))
}
@@ -181,6 +200,15 @@ func (e *EphemeralManager) cleanup(ctx context.Context) {
e.peersLock.Unlock()
// Drop the gauge by the number of entries we just took off the list,
// regardless of whether the subsequent DeletePeers call succeeds. The
// list invariant is what the gauge tracks; failed delete batches are
// counted separately via CountCleanupError so we can still see them.
if len(deletePeers) > 0 {
e.metrics.CountCleanupRun()
e.metrics.DecPending(int64(len(deletePeers)))
}
peerIDsPerAccount := make(map[string][]string)
for id, p := range deletePeers {
peerIDsPerAccount[p.accountID] = append(peerIDsPerAccount[p.accountID], id)
@@ -191,7 +219,10 @@ func (e *EphemeralManager) cleanup(ctx context.Context) {
err := e.peersManager.DeletePeers(ctx, accountID, peerIDs, activity.SystemInitiator, true)
if err != nil {
log.WithContext(ctx).Errorf("failed to delete ephemeral peers: %s", err)
e.metrics.CountCleanupError()
continue
}
e.metrics.CountPeersCleaned(int64(len(peerIDs)))
}
}
@@ -211,9 +242,12 @@ func (e *EphemeralManager) addPeer(accountID string, peerID string, deadline tim
e.tailPeer = ep
}
func (e *EphemeralManager) removePeer(id string) {
// removePeer drops the entry from the linked list. Returns true if a
// matching entry was found and removed so callers can keep the pending
// metric gauge in sync.
func (e *EphemeralManager) removePeer(id string) bool {
if e.headPeer == nil {
return
return false
}
if e.headPeer.id == id {
@@ -221,7 +255,7 @@ func (e *EphemeralManager) removePeer(id string) {
if e.tailPeer.id == id {
e.tailPeer = nil
}
return
return true
}
for p := e.headPeer; p.next != nil; p = p.next {
@@ -231,9 +265,10 @@ func (e *EphemeralManager) removePeer(id string) {
e.tailPeer = p
}
p.next = p.next.next
return
return true
}
}
return false
}
func (e *EphemeralManager) isPeerOnList(id string) bool {

View File

@@ -31,6 +31,7 @@ type store interface {
type proxyManager interface {
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
@@ -71,8 +72,8 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
var ret []*domain.Domain
// Add connected proxy clusters as free domains.
// The cluster address itself is the free domain base (e.g., "eu.proxy.netbird.io").
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
// For BYOP accounts, only their own cluster is returned; otherwise shared clusters.
allowList, err := m.getClusterAllowList(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err)
return nil, err
@@ -126,8 +127,8 @@ func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName
return nil, status.NewPermissionDeniedError()
}
// Verify the target cluster is in the available clusters
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
// Verify the target cluster is in the available clusters for this account
allowList, err := m.getClusterAllowList(ctx, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get active proxy cluster addresses: %w", err)
}
@@ -273,7 +274,7 @@ func (m Manager) GetClusterDomains() []string {
// For free domains (those ending with a known cluster suffix), the cluster is extracted from the domain.
// For custom domains, the cluster is determined by checking the registered custom domain's target cluster.
func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error) {
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
allowList, err := m.getClusterAllowList(ctx, accountID)
if err != nil {
return "", fmt.Errorf("failed to get active proxy cluster addresses: %w", err)
}
@@ -298,6 +299,34 @@ func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain
return "", fmt.Errorf("domain %s does not match any available proxy cluster", domain)
}
func (m Manager) getClusterAllowList(ctx context.Context, accountID string) ([]string, error) {
byopAddresses, err := m.proxyManager.GetActiveClusterAddressesForAccount(ctx, accountID)
if err != nil {
return nil, fmt.Errorf("get BYOP cluster addresses: %w", err)
}
publicAddresses, err := m.proxyManager.GetActiveClusterAddresses(ctx)
if err != nil {
return nil, fmt.Errorf("get public cluster addresses: %w", err)
}
seen := make(map[string]struct{}, len(byopAddresses)+len(publicAddresses))
merged := make([]string, 0, len(byopAddresses)+len(publicAddresses))
for _, addr := range byopAddresses {
if _, ok := seen[addr]; ok {
continue
}
seen[addr] = struct{}{}
merged = append(merged, addr)
}
for _, addr := range publicAddresses {
if _, ok := seen[addr]; ok {
continue
}
seen[addr] = struct{}{}
merged = append(merged, addr)
}
return merged, nil
}
func extractClusterFromCustomDomains(serviceDomain string, customDomains []*domain.Domain) (string, bool) {
bestCluster := ""
bestLen := -1

View File

@@ -0,0 +1,154 @@
package manager
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type mockProxyManager struct {
getActiveClusterAddressesFunc func(ctx context.Context) ([]string, error)
getActiveClusterAddressesForAccountFunc func(ctx context.Context, accountID string) ([]string, error)
}
func (m *mockProxyManager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
if m.getActiveClusterAddressesFunc != nil {
return m.getActiveClusterAddressesFunc(ctx)
}
return nil, nil
}
func (m *mockProxyManager) GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
if m.getActiveClusterAddressesForAccountFunc != nil {
return m.getActiveClusterAddressesForAccountFunc(ctx, accountID)
}
return nil, nil
}
func (m *mockProxyManager) ClusterSupportsCustomPorts(_ context.Context, _ string) *bool {
return nil
}
func (m *mockProxyManager) ClusterRequireSubdomain(_ context.Context, _ string) *bool {
return nil
}
func (m *mockProxyManager) ClusterSupportsCrowdSec(_ context.Context, _ string) *bool {
return nil
}
func TestGetClusterAllowList_BYOPMergedWithPublic(t *testing.T) {
pm := &mockProxyManager{
getActiveClusterAddressesForAccountFunc: func(_ context.Context, accID string) ([]string, error) {
assert.Equal(t, "acc-123", accID)
return []string{"byop.example.com"}, nil
},
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
return []string{"eu.proxy.netbird.io"}, nil
},
}
mgr := Manager{proxyManager: pm}
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
require.NoError(t, err)
assert.Equal(t, []string{"byop.example.com", "eu.proxy.netbird.io"}, result)
}
func TestGetClusterAllowList_DeduplicatesBYOPAndPublic(t *testing.T) {
pm := &mockProxyManager{
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
return []string{"shared.example.com", "byop.example.com"}, nil
},
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
return []string{"shared.example.com", "eu.proxy.netbird.io"}, nil
},
}
mgr := Manager{proxyManager: pm}
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
require.NoError(t, err)
assert.Equal(t, []string{"shared.example.com", "byop.example.com", "eu.proxy.netbird.io"}, result)
}
func TestGetClusterAllowList_NoBYOP_FallbackToShared(t *testing.T) {
pm := &mockProxyManager{
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
return nil, nil
},
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
return []string{"eu.proxy.netbird.io", "us.proxy.netbird.io"}, nil
},
}
mgr := Manager{proxyManager: pm}
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
require.NoError(t, err)
assert.Equal(t, []string{"eu.proxy.netbird.io", "us.proxy.netbird.io"}, result)
}
func TestGetClusterAllowList_BYOPError_ReturnsError(t *testing.T) {
pm := &mockProxyManager{
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
return nil, errors.New("db error")
},
}
mgr := Manager{proxyManager: pm}
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
require.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "BYOP cluster addresses")
}
func TestGetClusterAllowList_PublicError_ReturnsError(t *testing.T) {
pm := &mockProxyManager{
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
return []string{"byop.example.com"}, nil
},
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
return nil, errors.New("db error")
},
}
mgr := Manager{proxyManager: pm}
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
require.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "public cluster addresses")
}
func TestGetClusterAllowList_BYOPEmptySlice_FallbackToShared(t *testing.T) {
pm := &mockProxyManager{
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
return []string{}, nil
},
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
return []string{"eu.proxy.netbird.io"}, nil
},
}
mgr := Manager{proxyManager: pm}
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
require.NoError(t, err)
assert.Equal(t, []string{"eu.proxy.netbird.io"}, result)
}
func TestGetClusterAllowList_PublicEmpty_BYOPOnly(t *testing.T) {
pm := &mockProxyManager{
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
return []string{"byop.example.com"}, nil
},
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
return nil, nil
},
}
mgr := Manager{proxyManager: pm}
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
require.NoError(t, err)
assert.Equal(t, []string{"byop.example.com"}, result)
}

View File

@@ -11,15 +11,19 @@ import (
// Manager defines the interface for proxy operations
type Manager interface {
Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *Capabilities) (*Proxy, error)
Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, accountID *string, capabilities *Capabilities) (*Proxy, error)
Disconnect(ctx context.Context, proxyID, sessionID string) error
Heartbeat(ctx context.Context, p *Proxy) error
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
GetActiveClusters(ctx context.Context) ([]Cluster, error)
GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
CleanupStale(ctx context.Context, inactivityDuration time.Duration) error
GetAccountProxy(ctx context.Context, accountID string) (*Proxy, error)
CountAccountProxies(ctx context.Context, accountID string) (int64, error)
IsClusterAddressAvailable(ctx context.Context, clusterAddress, accountID string) (bool, error)
DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error
}
// OIDCValidationConfig contains the OIDC configuration needed for token validation.

View File

@@ -16,11 +16,16 @@ type store interface {
DisconnectProxy(ctx context.Context, proxyID, sessionID string) error
UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error)
GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error)
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error)
CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error)
IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error)
DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error
}
// Manager handles all proxy operations
@@ -44,7 +49,7 @@ func NewManager(store store, meter metric.Meter) (*Manager, error) {
// Connect registers a new proxy connection in the database.
// capabilities may be nil for old proxies that do not report them.
func (m Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *proxy.Capabilities) (*proxy.Proxy, error) {
func (m *Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, accountID *string, capabilities *proxy.Capabilities) (*proxy.Proxy, error) {
now := time.Now()
var caps proxy.Capabilities
if capabilities != nil {
@@ -55,9 +60,10 @@ func (m Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress
SessionID: sessionID,
ClusterAddress: clusterAddress,
IPAddress: ipAddress,
AccountID: accountID,
LastSeen: now,
ConnectedAt: &now,
Status: "connected",
Status: proxy.StatusConnected,
Capabilities: caps,
}
@@ -77,7 +83,7 @@ func (m Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress
}
// Disconnect marks a proxy as disconnected in the database.
func (m Manager) Disconnect(ctx context.Context, proxyID, sessionID string) error {
func (m *Manager) Disconnect(ctx context.Context, proxyID, sessionID string) error {
if err := m.store.DisconnectProxy(ctx, proxyID, sessionID); err != nil {
log.WithContext(ctx).Errorf("failed to disconnect proxy %s session %s: %v", proxyID, sessionID, err)
return err
@@ -92,7 +98,7 @@ func (m Manager) Disconnect(ctx context.Context, proxyID, sessionID string) erro
}
// Heartbeat updates the proxy's last seen timestamp.
func (m Manager) Heartbeat(ctx context.Context, p *proxy.Proxy) error {
func (m *Manager) Heartbeat(ctx context.Context, p *proxy.Proxy) error {
if err := m.store.UpdateProxyHeartbeat(ctx, p); err != nil {
log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", p.ID, err)
return err
@@ -104,7 +110,7 @@ func (m Manager) Heartbeat(ctx context.Context, p *proxy.Proxy) error {
}
// GetActiveClusterAddresses returns all unique cluster addresses for active proxies
func (m Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
func (m *Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
addresses, err := m.store.GetActiveProxyClusterAddresses(ctx)
if err != nil {
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err)
@@ -113,16 +119,6 @@ func (m Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error
return addresses, nil
}
// GetActiveClusters returns all active proxy clusters with their connected proxy count.
func (m Manager) GetActiveClusters(ctx context.Context) ([]proxy.Cluster, error) {
clusters, err := m.store.GetActiveProxyClusters(ctx)
if err != nil {
log.WithContext(ctx).Errorf("failed to get active proxy clusters: %v", err)
return nil, err
}
return clusters, nil
}
// ClusterSupportsCustomPorts returns whether any active proxy in the cluster
// supports custom ports. Returns nil when no proxy has reported capabilities.
func (m Manager) ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool {
@@ -142,10 +138,44 @@ func (m Manager) ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string
}
// CleanupStale removes proxies that haven't sent heartbeat in the specified duration
func (m Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
func (m *Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil {
log.WithContext(ctx).Errorf("failed to cleanup stale proxies: %v", err)
return err
}
return nil
}
func (m *Manager) GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
addresses, err := m.store.GetActiveProxyClusterAddressesForAccount(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses for account %s: %v", accountID, err)
return nil, err
}
return addresses, nil
}
func (m *Manager) GetAccountProxy(ctx context.Context, accountID string) (*proxy.Proxy, error) {
return m.store.GetProxyByAccountID(ctx, accountID)
}
func (m *Manager) CountAccountProxies(ctx context.Context, accountID string) (int64, error) {
return m.store.CountProxiesByAccountID(ctx, accountID)
}
func (m *Manager) IsClusterAddressAvailable(ctx context.Context, clusterAddress, accountID string) (bool, error) {
conflicting, err := m.store.IsClusterAddressConflicting(ctx, clusterAddress, accountID)
if err != nil {
return false, err
}
return !conflicting, nil
}
func (m *Manager) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error {
if err := m.store.DeleteAccountCluster(ctx, clusterAddress, accountID); err != nil {
log.WithContext(ctx).Errorf("failed to delete cluster %s for account %s: %v", clusterAddress, accountID, err)
return err
}
return nil
}

View File

@@ -0,0 +1,337 @@
package manager
import (
"context"
"errors"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel/metric/noop"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
)
type mockStore struct {
saveProxyFunc func(ctx context.Context, p *proxy.Proxy) error
disconnectProxyFunc func(ctx context.Context, proxyID, sessionID string) error
updateProxyHeartbeatFunc func(ctx context.Context, p *proxy.Proxy) error
getActiveProxyClusterAddressesFunc func(ctx context.Context) ([]string, error)
getActiveProxyClusterAddressesForAccFunc func(ctx context.Context, accountID string) ([]string, error)
cleanupStaleProxiesFunc func(ctx context.Context, d time.Duration) error
getProxyByAccountIDFunc func(ctx context.Context, accountID string) (*proxy.Proxy, error)
countProxiesByAccountIDFunc func(ctx context.Context, accountID string) (int64, error)
isClusterAddressConflictingFunc func(ctx context.Context, clusterAddress, accountID string) (bool, error)
deleteAccountClusterFunc func(ctx context.Context, clusterAddress, accountID string) error
}
func (m *mockStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
if m.saveProxyFunc != nil {
return m.saveProxyFunc(ctx, p)
}
return nil
}
func (m *mockStore) DisconnectProxy(ctx context.Context, proxyID, sessionID string) error {
if m.disconnectProxyFunc != nil {
return m.disconnectProxyFunc(ctx, proxyID, sessionID)
}
return nil
}
func (m *mockStore) UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error {
if m.updateProxyHeartbeatFunc != nil {
return m.updateProxyHeartbeatFunc(ctx, p)
}
return nil
}
func (m *mockStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) {
if m.getActiveProxyClusterAddressesFunc != nil {
return m.getActiveProxyClusterAddressesFunc(ctx)
}
return nil, nil
}
func (m *mockStore) GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
if m.getActiveProxyClusterAddressesForAccFunc != nil {
return m.getActiveProxyClusterAddressesForAccFunc(ctx, accountID)
}
return nil, nil
}
func (m *mockStore) GetActiveProxyClusters(_ context.Context, _ string) ([]proxy.Cluster, error) {
return nil, nil
}
func (m *mockStore) CleanupStaleProxies(ctx context.Context, d time.Duration) error {
if m.cleanupStaleProxiesFunc != nil {
return m.cleanupStaleProxiesFunc(ctx, d)
}
return nil
}
func (m *mockStore) GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error) {
if m.getProxyByAccountIDFunc != nil {
return m.getProxyByAccountIDFunc(ctx, accountID)
}
return nil, fmt.Errorf("proxy not found for account %s", accountID)
}
func (m *mockStore) CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error) {
if m.countProxiesByAccountIDFunc != nil {
return m.countProxiesByAccountIDFunc(ctx, accountID)
}
return 0, nil
}
func (m *mockStore) IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error) {
if m.isClusterAddressConflictingFunc != nil {
return m.isClusterAddressConflictingFunc(ctx, clusterAddress, accountID)
}
return false, nil
}
func (m *mockStore) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error {
if m.deleteAccountClusterFunc != nil {
return m.deleteAccountClusterFunc(ctx, clusterAddress, accountID)
}
return nil
}
func (m *mockStore) GetClusterSupportsCustomPorts(_ context.Context, _ string) *bool {
return nil
}
func (m *mockStore) GetClusterRequireSubdomain(_ context.Context, _ string) *bool {
return nil
}
func (m *mockStore) GetClusterSupportsCrowdSec(_ context.Context, _ string) *bool {
return nil
}
func newTestManager(s store) *Manager {
meter := noop.NewMeterProvider().Meter("test")
m, err := NewManager(s, meter)
if err != nil {
panic(err)
}
return m
}
func TestConnect_WithAccountID(t *testing.T) {
accountID := "acc-123"
var savedProxy *proxy.Proxy
s := &mockStore{
saveProxyFunc: func(_ context.Context, p *proxy.Proxy) error {
savedProxy = p
return nil
},
}
mgr := newTestManager(s)
_, err := mgr.Connect(context.Background(), "proxy-1", "session-1", "cluster.example.com", "10.0.0.1", &accountID, nil)
require.NoError(t, err)
require.NotNil(t, savedProxy)
assert.Equal(t, "proxy-1", savedProxy.ID)
assert.Equal(t, "session-1", savedProxy.SessionID)
assert.Equal(t, "cluster.example.com", savedProxy.ClusterAddress)
assert.Equal(t, "10.0.0.1", savedProxy.IPAddress)
assert.Equal(t, &accountID, savedProxy.AccountID)
assert.Equal(t, proxy.StatusConnected, savedProxy.Status)
assert.NotNil(t, savedProxy.ConnectedAt)
}
func TestConnect_WithoutAccountID(t *testing.T) {
var savedProxy *proxy.Proxy
s := &mockStore{
saveProxyFunc: func(_ context.Context, p *proxy.Proxy) error {
savedProxy = p
return nil
},
}
mgr := newTestManager(s)
_, err := mgr.Connect(context.Background(), "proxy-1", "session-1", "eu.proxy.netbird.io", "10.0.0.1", nil, nil)
require.NoError(t, err)
require.NotNil(t, savedProxy)
assert.Nil(t, savedProxy.AccountID)
assert.Equal(t, proxy.StatusConnected, savedProxy.Status)
}
func TestConnect_StoreError(t *testing.T) {
s := &mockStore{
saveProxyFunc: func(_ context.Context, _ *proxy.Proxy) error {
return errors.New("db error")
},
}
mgr := newTestManager(s)
_, err := mgr.Connect(context.Background(), "proxy-1", "session-1", "cluster.example.com", "10.0.0.1", nil, nil)
assert.Error(t, err)
}
func TestIsClusterAddressAvailable(t *testing.T) {
tests := []struct {
name string
conflicting bool
storeErr error
wantResult bool
wantErr bool
}{
{
name: "available - no conflict",
conflicting: false,
wantResult: true,
},
{
name: "not available - conflict exists",
conflicting: true,
wantResult: false,
},
{
name: "store error",
storeErr: errors.New("db error"),
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &mockStore{
isClusterAddressConflictingFunc: func(_ context.Context, _, _ string) (bool, error) {
return tt.conflicting, tt.storeErr
},
}
mgr := newTestManager(s)
result, err := mgr.IsClusterAddressAvailable(context.Background(), "cluster.example.com", "acc-123")
if tt.wantErr {
assert.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.wantResult, result)
})
}
}
func TestCountAccountProxies(t *testing.T) {
tests := []struct {
name string
count int64
storeErr error
wantCount int64
wantErr bool
}{
{
name: "no proxies",
count: 0,
wantCount: 0,
},
{
name: "one proxy",
count: 1,
wantCount: 1,
},
{
name: "store error",
storeErr: errors.New("db error"),
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &mockStore{
countProxiesByAccountIDFunc: func(_ context.Context, _ string) (int64, error) {
return tt.count, tt.storeErr
},
}
mgr := newTestManager(s)
count, err := mgr.CountAccountProxies(context.Background(), "acc-123")
if tt.wantErr {
assert.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.wantCount, count)
})
}
}
func TestGetAccountProxy(t *testing.T) {
accountID := "acc-123"
t.Run("found", func(t *testing.T) {
expected := &proxy.Proxy{
ID: "proxy-1",
ClusterAddress: "byop.example.com",
AccountID: &accountID,
Status: proxy.StatusConnected,
}
s := &mockStore{
getProxyByAccountIDFunc: func(_ context.Context, accID string) (*proxy.Proxy, error) {
assert.Equal(t, accountID, accID)
return expected, nil
},
}
mgr := newTestManager(s)
p, err := mgr.GetAccountProxy(context.Background(), accountID)
require.NoError(t, err)
assert.Equal(t, expected, p)
})
t.Run("not found", func(t *testing.T) {
s := &mockStore{
getProxyByAccountIDFunc: func(_ context.Context, _ string) (*proxy.Proxy, error) {
return nil, errors.New("not found")
},
}
mgr := newTestManager(s)
_, err := mgr.GetAccountProxy(context.Background(), accountID)
assert.Error(t, err)
})
}
func TestDeleteAccountCluster(t *testing.T) {
t.Run("success", func(t *testing.T) {
var deletedCluster, deletedAccount string
s := &mockStore{
deleteAccountClusterFunc: func(_ context.Context, clusterAddress, accountID string) error {
deletedCluster = clusterAddress
deletedAccount = accountID
return nil
},
}
mgr := newTestManager(s)
err := mgr.DeleteAccountCluster(context.Background(), "cluster.example.com", "acc-123")
require.NoError(t, err)
assert.Equal(t, "cluster.example.com", deletedCluster)
assert.Equal(t, "acc-123", deletedAccount)
})
t.Run("store error", func(t *testing.T) {
s := &mockStore{
deleteAccountClusterFunc: func(_ context.Context, _, _ string) error {
return errors.New("db error")
},
}
mgr := newTestManager(s)
err := mgr.DeleteAccountCluster(context.Background(), "cluster.example.com", "acc-123")
assert.Error(t, err)
})
}
func TestGetActiveClusterAddressesForAccount(t *testing.T) {
expected := []string{"byop.example.com"}
s := &mockStore{
getActiveProxyClusterAddressesForAccFunc: func(_ context.Context, accID string) ([]string, error) {
assert.Equal(t, "acc-123", accID)
return expected, nil
},
}
mgr := newTestManager(s)
result, err := mgr.GetActiveClusterAddressesForAccount(context.Background(), "acc-123")
require.NoError(t, err)
assert.Equal(t, expected, result)
}

View File

@@ -93,18 +93,18 @@ func (mr *MockManagerMockRecorder) ClusterSupportsCrowdSec(ctx, clusterAddr inte
}
// Connect mocks base method.
func (m *MockManager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *Capabilities) (*Proxy, error) {
func (m *MockManager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, accountID *string, capabilities *Capabilities) (*Proxy, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities)
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, sessionID, clusterAddress, ipAddress, accountID, capabilities)
ret0, _ := ret[0].(*Proxy)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Connect indicates an expected call of Connect.
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities interface{}) *gomock.Call {
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, sessionID, clusterAddress, ipAddress, accountID, capabilities interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, sessionID, clusterAddress, ipAddress, accountID, capabilities)
}
// Disconnect mocks base method.
@@ -136,19 +136,17 @@ func (mr *MockManagerMockRecorder) GetActiveClusterAddresses(ctx interface{}) *g
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusterAddresses", reflect.TypeOf((*MockManager)(nil).GetActiveClusterAddresses), ctx)
}
// GetActiveClusters mocks base method.
func (m *MockManager) GetActiveClusters(ctx context.Context) ([]Cluster, error) {
func (m *MockManager) GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetActiveClusters", ctx)
ret0, _ := ret[0].([]Cluster)
ret := m.ctrl.Call(m, "GetActiveClusterAddressesForAccount", ctx, accountID)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetActiveClusters indicates an expected call of GetActiveClusters.
func (mr *MockManagerMockRecorder) GetActiveClusters(ctx interface{}) *gomock.Call {
func (mr *MockManagerMockRecorder) GetActiveClusterAddressesForAccount(ctx, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusters", reflect.TypeOf((*MockManager)(nil).GetActiveClusters), ctx)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusterAddressesForAccount", reflect.TypeOf((*MockManager)(nil).GetActiveClusterAddressesForAccount), ctx, accountID)
}
// Heartbeat mocks base method.
@@ -165,6 +163,65 @@ func (mr *MockManagerMockRecorder) Heartbeat(ctx, p interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, p)
}
// GetAccountProxy mocks base method.
func (m *MockManager) GetAccountProxy(ctx context.Context, accountID string) (*Proxy, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAccountProxy", ctx, accountID)
ret0, _ := ret[0].(*Proxy)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetAccountProxy indicates an expected call of GetAccountProxy.
func (mr *MockManagerMockRecorder) GetAccountProxy(ctx, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountProxy", reflect.TypeOf((*MockManager)(nil).GetAccountProxy), ctx, accountID)
}
// CountAccountProxies mocks base method.
func (m *MockManager) CountAccountProxies(ctx context.Context, accountID string) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CountAccountProxies", ctx, accountID)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CountAccountProxies indicates an expected call of CountAccountProxies.
func (mr *MockManagerMockRecorder) CountAccountProxies(ctx, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAccountProxies", reflect.TypeOf((*MockManager)(nil).CountAccountProxies), ctx, accountID)
}
// IsClusterAddressAvailable mocks base method.
func (m *MockManager) IsClusterAddressAvailable(ctx context.Context, clusterAddress, accountID string) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IsClusterAddressAvailable", ctx, clusterAddress, accountID)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// IsClusterAddressAvailable indicates an expected call of IsClusterAddressAvailable.
func (mr *MockManagerMockRecorder) IsClusterAddressAvailable(ctx, clusterAddress, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsClusterAddressAvailable", reflect.TypeOf((*MockManager)(nil).IsClusterAddressAvailable), ctx, clusterAddress, accountID)
}
// DeleteAccountCluster mocks base method.
func (m *MockManager) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteAccountCluster", ctx, clusterAddress, accountID)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteAccountCluster indicates an expected call of DeleteAccountCluster.
func (mr *MockManagerMockRecorder) DeleteAccountCluster(ctx, clusterAddress, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAccountCluster", reflect.TypeOf((*MockManager)(nil).DeleteAccountCluster), ctx, clusterAddress, accountID)
}
// MockController is a mock of Controller interface.
type MockController struct {
ctrl *gomock.Controller

View File

@@ -1,6 +1,13 @@
package proxy
import "time"
import (
"time"
)
const (
StatusConnected = "connected"
StatusDisconnected = "disconnected"
)
// Capabilities describes what a proxy can handle, as reported via gRPC.
// Nil fields mean the proxy never reported this capability.
@@ -21,6 +28,7 @@ type Proxy struct {
SessionID string `gorm:"type:varchar(36)"`
ClusterAddress string `gorm:"type:varchar(255);not null;index:idx_proxy_cluster_status"`
IPAddress string `gorm:"type:varchar(45)"`
AccountID *string `gorm:"type:varchar(255);index:idx_proxy_account_id"`
LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"`
ConnectedAt *time.Time
DisconnectedAt *time.Time
@@ -36,6 +44,8 @@ func (Proxy) TableName() string {
// Cluster represents a group of proxy nodes serving the same address.
type Cluster struct {
ID string
Address string
ConnectedProxies int
SelfHosted bool
}

View File

@@ -0,0 +1,195 @@
package proxytoken
import (
"encoding/json"
"net/http"
"time"
"github.com/gorilla/mux"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
)
type handler struct {
store store.Store
permissionsManager permissions.Manager
}
func RegisterEndpoints(s store.Store, permissionsManager permissions.Manager, router *mux.Router) {
h := &handler{store: s, permissionsManager: permissionsManager}
router.HandleFunc("/reverse-proxies/proxy-tokens", h.listTokens).Methods("GET", "OPTIONS")
router.HandleFunc("/reverse-proxies/proxy-tokens", h.createToken).Methods("POST", "OPTIONS")
router.HandleFunc("/reverse-proxies/proxy-tokens/{tokenId}", h.revokeToken).Methods("DELETE", "OPTIONS")
}
func (h *handler) createToken(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Create)
if err != nil {
util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w)
return
}
if !ok {
util.WriteErrorResponse("permission denied", http.StatusForbidden, w)
return
}
var req api.ProxyTokenRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
if req.Name == "" || len(req.Name) > 255 {
util.WriteErrorResponse("name is required and must be at most 255 characters", http.StatusBadRequest, w)
return
}
var expiresIn time.Duration
if req.ExpiresIn != nil {
if *req.ExpiresIn < 0 {
util.WriteErrorResponse("expires_in must be non-negative", http.StatusBadRequest, w)
return
}
if *req.ExpiresIn > 0 {
expiresIn = time.Duration(*req.ExpiresIn) * time.Second
}
}
accountID := userAuth.AccountId
generated, err := types.CreateNewProxyAccessToken(req.Name, expiresIn, &accountID, userAuth.UserId)
if err != nil {
util.WriteErrorResponse("failed to generate token", http.StatusInternalServerError, w)
return
}
if err := h.store.SaveProxyAccessToken(r.Context(), &generated.ProxyAccessToken); err != nil {
util.WriteErrorResponse("failed to save token", http.StatusInternalServerError, w)
return
}
resp := toProxyTokenCreatedResponse(generated)
util.WriteJSONObject(r.Context(), w, resp)
}
func (h *handler) listTokens(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Read)
if err != nil {
util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w)
return
}
if !ok {
util.WriteErrorResponse("permission denied", http.StatusForbidden, w)
return
}
tokens, err := h.store.GetProxyAccessTokensByAccountID(r.Context(), store.LockingStrengthNone, userAuth.AccountId)
if err != nil {
util.WriteErrorResponse("failed to list tokens", http.StatusInternalServerError, w)
return
}
resp := make([]api.ProxyToken, 0, len(tokens))
for _, token := range tokens {
resp = append(resp, toProxyTokenResponse(token))
}
util.WriteJSONObject(r.Context(), w, resp)
}
func (h *handler) revokeToken(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Delete)
if err != nil {
util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w)
return
}
if !ok {
util.WriteErrorResponse("permission denied", http.StatusForbidden, w)
return
}
tokenID := mux.Vars(r)["tokenId"]
if tokenID == "" {
util.WriteErrorResponse("token ID is required", http.StatusBadRequest, w)
return
}
token, err := h.store.GetProxyAccessTokenByID(r.Context(), store.LockingStrengthNone, tokenID)
if err != nil {
if s, ok := status.FromError(err); ok && s.ErrorType == status.NotFound {
util.WriteErrorResponse("token not found", http.StatusNotFound, w)
} else {
util.WriteErrorResponse("failed to retrieve token", http.StatusInternalServerError, w)
}
return
}
if token.AccountID == nil || *token.AccountID != userAuth.AccountId {
util.WriteErrorResponse("token not found", http.StatusNotFound, w)
return
}
if err := h.store.RevokeProxyAccessToken(r.Context(), tokenID); err != nil {
util.WriteErrorResponse("failed to revoke token", http.StatusInternalServerError, w)
return
}
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}
func toProxyTokenResponse(token *types.ProxyAccessToken) api.ProxyToken {
resp := api.ProxyToken{
Id: token.ID,
Name: token.Name,
Revoked: token.Revoked,
}
if !token.CreatedAt.IsZero() {
resp.CreatedAt = token.CreatedAt
}
if token.ExpiresAt != nil {
resp.ExpiresAt = token.ExpiresAt
}
if token.LastUsed != nil {
resp.LastUsed = token.LastUsed
}
return resp
}
func toProxyTokenCreatedResponse(generated *types.ProxyAccessTokenGenerated) api.ProxyTokenCreated {
base := toProxyTokenResponse(&generated.ProxyAccessToken)
plainToken := string(generated.PlainToken)
return api.ProxyTokenCreated{
Id: base.Id,
Name: base.Name,
CreatedAt: base.CreatedAt,
ExpiresAt: base.ExpiresAt,
LastUsed: base.LastUsed,
Revoked: base.Revoked,
PlainToken: plainToken,
}
}

View File

@@ -0,0 +1,275 @@
package proxytoken
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
)
func authContext(accountID, userID string) context.Context {
return nbcontext.SetUserAuthInContext(context.Background(), auth.UserAuth{
AccountId: accountID,
UserId: userID,
})
}
func TestCreateToken_AccountScoped(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
accountID := "acc-123"
var savedToken *types.ProxyAccessToken
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().SaveProxyAccessToken(gomock.Any(), gomock.Any()).DoAndReturn(
func(_ context.Context, token *types.ProxyAccessToken) error {
savedToken = token
return nil
},
)
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Create).Return(true, nil)
h := &handler{
store: mockStore,
permissionsManager: permsMgr,
}
body := `{"name": "my-token"}`
req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body))
req = req.WithContext(authContext(accountID, "user-1"))
w := httptest.NewRecorder()
h.createToken(w, req)
assert.Equal(t, http.StatusOK, w.Code)
var resp api.ProxyTokenCreated
require.NoError(t, json.NewDecoder(w.Body).Decode(&resp))
assert.NotEmpty(t, resp.PlainToken)
assert.Equal(t, "my-token", resp.Name)
assert.False(t, resp.Revoked)
require.NotNil(t, savedToken)
require.NotNil(t, savedToken.AccountID)
assert.Equal(t, accountID, *savedToken.AccountID)
assert.Equal(t, "user-1", savedToken.CreatedBy)
}
func TestCreateToken_WithExpiration(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
var savedToken *types.ProxyAccessToken
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().SaveProxyAccessToken(gomock.Any(), gomock.Any()).DoAndReturn(
func(_ context.Context, token *types.ProxyAccessToken) error {
savedToken = token
return nil
},
)
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(true, nil)
h := &handler{
store: mockStore,
permissionsManager: permsMgr,
}
body := `{"name": "expiring-token", "expires_in": 3600}`
req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body))
req = req.WithContext(authContext("acc-123", "user-1"))
w := httptest.NewRecorder()
h.createToken(w, req)
assert.Equal(t, http.StatusOK, w.Code)
require.NotNil(t, savedToken)
require.NotNil(t, savedToken.ExpiresAt)
assert.True(t, savedToken.ExpiresAt.After(time.Now()))
}
func TestCreateToken_EmptyName(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(true, nil)
h := &handler{
permissionsManager: permsMgr,
}
body := `{"name": ""}`
req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body))
req = req.WithContext(authContext("acc-123", "user-1"))
w := httptest.NewRecorder()
h.createToken(w, req)
assert.Equal(t, http.StatusBadRequest, w.Code)
}
func TestCreateToken_PermissionDenied(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(false, nil)
h := &handler{
permissionsManager: permsMgr,
}
body := `{"name": "test"}`
req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body))
req = req.WithContext(authContext("acc-123", "user-1"))
w := httptest.NewRecorder()
h.createToken(w, req)
assert.Equal(t, http.StatusForbidden, w.Code)
}
func TestListTokens(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
accountID := "acc-123"
now := time.Now()
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().GetProxyAccessTokensByAccountID(gomock.Any(), store.LockingStrengthNone, accountID).Return([]*types.ProxyAccessToken{
{ID: "tok-1", Name: "token-1", AccountID: &accountID, CreatedAt: now, Revoked: false},
{ID: "tok-2", Name: "token-2", AccountID: &accountID, CreatedAt: now, Revoked: true},
}, nil)
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Read).Return(true, nil)
h := &handler{
store: mockStore,
permissionsManager: permsMgr,
}
req := httptest.NewRequest("GET", "/reverse-proxies/proxy-tokens", nil)
req = req.WithContext(authContext(accountID, "user-1"))
w := httptest.NewRecorder()
h.listTokens(w, req)
assert.Equal(t, http.StatusOK, w.Code)
var resp []api.ProxyToken
require.NoError(t, json.NewDecoder(w.Body).Decode(&resp))
require.Len(t, resp, 2)
assert.Equal(t, "tok-1", resp[0].Id)
assert.False(t, resp[0].Revoked)
assert.Equal(t, "tok-2", resp[1].Id)
assert.True(t, resp[1].Revoked)
}
func TestRevokeToken_Success(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
accountID := "acc-123"
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().GetProxyAccessTokenByID(gomock.Any(), store.LockingStrengthNone, "tok-1").Return(&types.ProxyAccessToken{
ID: "tok-1",
Name: "test-token",
AccountID: &accountID,
}, nil)
mockStore.EXPECT().RevokeProxyAccessToken(gomock.Any(), "tok-1").Return(nil)
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Delete).Return(true, nil)
h := &handler{
store: mockStore,
permissionsManager: permsMgr,
}
req := httptest.NewRequest("DELETE", "/reverse-proxies/proxy-tokens/tok-1", nil)
req = req.WithContext(authContext(accountID, "user-1"))
req = mux.SetURLVars(req, map[string]string{"tokenId": "tok-1"})
w := httptest.NewRecorder()
h.revokeToken(w, req)
assert.Equal(t, http.StatusOK, w.Code)
}
func TestRevokeToken_WrongAccount(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
otherAccount := "acc-other"
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().GetProxyAccessTokenByID(gomock.Any(), store.LockingStrengthNone, "tok-1").Return(&types.ProxyAccessToken{
ID: "tok-1",
AccountID: &otherAccount,
}, nil)
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Delete).Return(true, nil)
h := &handler{
store: mockStore,
permissionsManager: permsMgr,
}
req := httptest.NewRequest("DELETE", "/reverse-proxies/proxy-tokens/tok-1", nil)
req = req.WithContext(authContext("acc-123", "user-1"))
req = mux.SetURLVars(req, map[string]string{"tokenId": "tok-1"})
w := httptest.NewRecorder()
h.revokeToken(w, req)
assert.Equal(t, http.StatusNotFound, w.Code)
}
func TestRevokeToken_ManagementWideToken(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().GetProxyAccessTokenByID(gomock.Any(), store.LockingStrengthNone, "tok-1").Return(&types.ProxyAccessToken{
ID: "tok-1",
AccountID: nil,
}, nil)
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Delete).Return(true, nil)
h := &handler{
store: mockStore,
permissionsManager: permsMgr,
}
req := httptest.NewRequest("DELETE", "/reverse-proxies/proxy-tokens/tok-1", nil)
req = req.WithContext(authContext("acc-123", "user-1"))
req = mux.SetURLVars(req, map[string]string{"tokenId": "tok-1"})
w := httptest.NewRecorder()
h.revokeToken(w, req)
assert.Equal(t, http.StatusNotFound, w.Code)
}

View File

@@ -10,6 +10,7 @@ import (
type Manager interface {
GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error)
DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error
GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error)
GetService(ctx context.Context, accountID, userID, serviceID string) (*Service, error)
CreateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error)
@@ -28,4 +29,5 @@ type Manager interface {
RenewServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error
StopServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error
StartExposeReaper(ctx context.Context)
GetServiceByDomain(ctx context.Context, domain string) (*Service, error)
}

View File

@@ -79,6 +79,20 @@ func (mr *MockManagerMockRecorder) DeleteAllServices(ctx, accountID, userID inte
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAllServices", reflect.TypeOf((*MockManager)(nil).DeleteAllServices), ctx, accountID, userID)
}
// DeleteAccountCluster mocks base method.
func (m *MockManager) DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteAccountCluster", ctx, accountID, userID, clusterAddress)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteAccountCluster indicates an expected call of DeleteAccountCluster.
func (mr *MockManagerMockRecorder) DeleteAccountCluster(ctx, accountID, userID, clusterAddress interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAccountCluster", reflect.TypeOf((*MockManager)(nil).DeleteAccountCluster), ctx, accountID, userID, clusterAddress)
}
// DeleteService mocks base method.
func (m *MockManager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
m.ctrl.T.Helper()
@@ -138,6 +152,21 @@ func (mr *MockManagerMockRecorder) GetAllServices(ctx, accountID, userID interfa
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllServices", reflect.TypeOf((*MockManager)(nil).GetAllServices), ctx, accountID, userID)
}
// GetServiceByDomain mocks base method.
func (m *MockManager) GetServiceByDomain(ctx context.Context, domain string) (*Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, domain)
ret0, _ := ret[0].(*Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetServiceByDomain indicates an expected call of GetServiceByDomain.
func (mr *MockManagerMockRecorder) GetServiceByDomain(ctx, domain interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByDomain", reflect.TypeOf((*MockManager)(nil).GetServiceByDomain), ctx, domain)
}
// GetGlobalServices mocks base method.
func (m *MockManager) GetGlobalServices(ctx context.Context) ([]*Service, error) {
m.ctrl.T.Helper()

View File

@@ -35,6 +35,7 @@ func RegisterEndpoints(manager rpservice.Manager, domainManager domainmanager.Ma
accesslogsmanager.RegisterEndpoints(router, accessLogsManager)
router.HandleFunc("/reverse-proxies/clusters", h.getClusters).Methods("GET", "OPTIONS")
router.HandleFunc("/reverse-proxies/clusters/{clusterAddress}", h.deleteCluster).Methods("DELETE", "OPTIONS")
router.HandleFunc("/reverse-proxies/services", h.getAllServices).Methods("GET", "OPTIONS")
router.HandleFunc("/reverse-proxies/services", h.createService).Methods("POST", "OPTIONS")
router.HandleFunc("/reverse-proxies/services/{serviceId}", h.getService).Methods("GET", "OPTIONS")
@@ -195,10 +196,33 @@ func (h *handler) getClusters(w http.ResponseWriter, r *http.Request) {
apiClusters := make([]api.ProxyCluster, 0, len(clusters))
for _, c := range clusters {
apiClusters = append(apiClusters, api.ProxyCluster{
Id: c.ID,
Address: c.Address,
ConnectedProxies: c.ConnectedProxies,
SelfHosted: c.SelfHosted,
})
}
util.WriteJSONObject(r.Context(), w, apiClusters)
}
func (h *handler) deleteCluster(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
clusterAddress := mux.Vars(r)["clusterAddress"]
if clusterAddress == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "cluster address is required"), w)
return
}
if err := h.manager.DeleteAccountCluster(r.Context(), userAuth.AccountId, userAuth.UserId, clusterAddress); err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}

View File

@@ -122,7 +122,21 @@ func (m *Manager) GetActiveClusters(ctx context.Context, accountID, userID strin
return nil, status.NewPermissionDeniedError()
}
return m.store.GetActiveProxyClusters(ctx)
return m.store.GetActiveProxyClusters(ctx, accountID)
}
// DeleteAccountCluster removes all proxy registrations for the given cluster address
// owned by the account.
func (m *Manager) DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !ok {
return status.NewPermissionDeniedError()
}
return m.store.DeleteAccountCluster(ctx, clusterAddress, accountID)
}
func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) {
@@ -292,6 +306,10 @@ func (m *Manager) validateSubdomainRequirement(ctx context.Context, domain, clus
func (m *Manager) persistNewService(ctx context.Context, accountID string, svc *service.Service) error {
customPorts := m.clusterCustomPorts(ctx, svc)
if err := validateTargetReferences(ctx, m.store, accountID, svc.Targets); err != nil {
return err
}
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if svc.Domain != "" {
if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, ""); err != nil {
@@ -307,10 +325,6 @@ func (m *Manager) persistNewService(ctx context.Context, accountID string, svc *
return err
}
if err := validateTargetReferences(ctx, transaction, accountID, svc.Targets); err != nil {
return err
}
if err := transaction.CreateService(ctx, svc); err != nil {
return fmt.Errorf("create service: %w", err)
}
@@ -421,6 +435,10 @@ func (m *Manager) assignPort(ctx context.Context, tx store.Store, cluster string
func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, peerID string, svc *service.Service) error {
customPorts := m.clusterCustomPorts(ctx, svc)
if err := validateTargetReferences(ctx, m.store, accountID, svc.Targets); err != nil {
return err
}
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err := m.validateEphemeralPreconditions(ctx, transaction, accountID, peerID, svc); err != nil {
return err
@@ -434,10 +452,6 @@ func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, pee
return err
}
if err := validateTargetReferences(ctx, transaction, accountID, svc.Targets); err != nil {
return err
}
if err := transaction.CreateService(ctx, svc); err != nil {
return fmt.Errorf("create service: %w", err)
}
@@ -538,10 +552,22 @@ func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, se
svcForCaps.ProxyCluster = effectiveCluster
customPorts := m.clusterCustomPorts(ctx, &svcForCaps)
if err := validateTargetReferences(ctx, m.store, accountID, service.Targets); err != nil {
return nil, err
}
// Validate subdomain requirement *before* the transaction: the underlying
// capability lookup talks to the main DB pool, and SQLite's single-connection
// pool would self-deadlock if this ran while the tx already held the only
// connection.
if err := m.validateSubdomainRequirement(ctx, service.Domain, effectiveCluster); err != nil {
return nil, err
}
var updateInfo serviceUpdateInfo
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
return m.executeServiceUpdate(ctx, transaction, accountID, service, &updateInfo, customPorts)
return m.executeServiceUpdate(ctx, transaction, accountID, service, &updateInfo, customPorts, effectiveCluster)
})
return &updateInfo, err
@@ -571,7 +597,7 @@ func (m *Manager) resolveEffectiveCluster(ctx context.Context, accountID string,
return existing.ProxyCluster, nil
}
func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.Store, accountID string, service *service.Service, updateInfo *serviceUpdateInfo, customPorts *bool) error {
func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.Store, accountID string, service *service.Service, updateInfo *serviceUpdateInfo, customPorts *bool, effectiveCluster string) error {
existingService, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, service.ID)
if err != nil {
return err
@@ -589,17 +615,13 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St
updateInfo.domainChanged = existingService.Domain != service.Domain
if updateInfo.domainChanged {
if err := m.handleDomainChange(ctx, transaction, accountID, service); err != nil {
if err := m.handleDomainChange(ctx, transaction, service, effectiveCluster); err != nil {
return err
}
} else {
service.ProxyCluster = existingService.ProxyCluster
}
if err := m.validateSubdomainRequirement(ctx, service.Domain, service.ProxyCluster); err != nil {
return err
}
m.preserveExistingAuthSecrets(service, existingService)
if err := validateHeaderAuthValues(service.Auth.HeaderAuths); err != nil {
return err
@@ -614,9 +636,6 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St
if err := m.checkPortConflict(ctx, transaction, service); err != nil {
return err
}
if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
return err
}
if err := transaction.UpdateService(ctx, service); err != nil {
return fmt.Errorf("update service: %w", err)
}
@@ -624,20 +643,18 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St
return nil
}
func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, svc *service.Service) error {
// handleDomainChange validates the new domain is free inside the transaction
// and applies the pre-resolved cluster (computed outside the tx by
// resolveEffectiveCluster). It must NOT call clusterDeriver here: that talks
// to the main DB pool and would self-deadlock under SQLite (max_open_conns=1)
// because the transaction already holds the only connection.
func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, svc *service.Service, effectiveCluster string) error {
if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, svc.ID); err != nil {
return err
}
if m.clusterDeriver != nil {
newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, svc.Domain)
if err != nil {
log.WithError(err).Warnf("could not derive cluster from domain %s", svc.Domain)
} else {
svc.ProxyCluster = newCluster
}
if effectiveCluster != "" {
svc.ProxyCluster = effectiveCluster
}
return nil
}
@@ -986,6 +1003,10 @@ func (m *Manager) GetAccountServices(ctx context.Context, accountID string) ([]*
return services, nil
}
func (m *Manager) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) {
return m.store.GetServiceByDomain(ctx, domain)
}
func (m *Manager) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) {
target, err := m.store.GetServiceTargetByTargetID(ctx, store.LockingStrengthNone, accountID, resourceID)
if err != nil {

View File

@@ -434,7 +434,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
t.Helper()
tokenStore := nbgrpc.NewOneTimeTokenStore(context.Background(), testCacheStore(t))
pkceStore := nbgrpc.NewPKCEVerifierStore(context.Background(), testCacheStore(t))
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
return srv
}
@@ -714,7 +714,7 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t))
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
require.NoError(t, err)
@@ -1138,7 +1138,7 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t))
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
require.NoError(t, err)

View File

@@ -381,13 +381,14 @@ func (s *Service) buildPathMappings() []*proto.PathMapping {
}
// HTTP/HTTPS: build full URL
hostNoBrackets := strings.TrimSuffix(strings.TrimPrefix(target.Host, "["), "]")
targetURL := url.URL{
Scheme: target.Protocol,
Host: target.Host,
Host: bracketIPv6Host(hostNoBrackets),
Path: "/",
}
if target.Port > 0 && !isDefaultPort(target.Protocol, target.Port) {
targetURL.Host = net.JoinHostPort(targetURL.Host, strconv.FormatUint(uint64(target.Port), 10))
targetURL.Host = net.JoinHostPort(hostNoBrackets, strconv.FormatUint(uint64(target.Port), 10))
}
path := "/"
@@ -405,6 +406,19 @@ func (s *Service) buildPathMappings() []*proto.PathMapping {
return pathMappings
}
// bracketIPv6Host wraps host in square brackets when it is an IPv6 literal, as
// required for the Host field of net/url.URL (RFC 3986 §3.2.2). v4-mapped IPv6
// addresses are bracketed too since their textual form contains colons.
func bracketIPv6Host(host string) string {
if strings.HasPrefix(host, "[") {
return host
}
if addr, err := netip.ParseAddr(host); err == nil && addr.Is6() {
return "[" + host + "]"
}
return host
}
func operationToProtoType(op Operation) proto.ProxyMappingUpdateType {
switch op {
case Create:

View File

@@ -351,6 +351,83 @@ func TestToProtoMapping_PortInTargetURL(t *testing.T) {
port: 80,
wantTarget: "https://10.0.0.1:80/",
},
{
name: "domain host without port is unchanged",
protocol: "http",
host: "example.com",
port: 0,
wantTarget: "http://example.com/",
},
{
name: "domain host with non-default port is unchanged",
protocol: "http",
host: "example.com",
port: 8080,
wantTarget: "http://example.com:8080/",
},
{
name: "ipv6 host without port is bracketed",
protocol: "http",
host: "fb00:cafe:1::3",
port: 0,
wantTarget: "http://[fb00:cafe:1::3]/",
},
{
name: "ipv6 host with default port omits port and brackets host",
protocol: "http",
host: "fb00:cafe:1::3",
port: 80,
wantTarget: "http://[fb00:cafe:1::3]/",
},
{
name: "ipv6 host with non-default port is bracketed",
protocol: "http",
host: "fb00:cafe:1::3",
port: 8080,
wantTarget: "http://[fb00:cafe:1::3]:8080/",
},
{
name: "ipv6 loopback without port is bracketed",
protocol: "http",
host: "::1",
port: 0,
wantTarget: "http://[::1]/",
},
{
name: "ipv6 host with 5-digit port is bracketed",
protocol: "http",
host: "fb00:cafe::1",
port: 18080,
wantTarget: "http://[fb00:cafe::1]:18080/",
},
{
name: "pre-bracketed ipv6 without port stays single-bracketed",
protocol: "http",
host: "[fb00:cafe::1]",
port: 0,
wantTarget: "http://[fb00:cafe::1]/",
},
{
name: "pre-bracketed ipv6 with port is not double-bracketed",
protocol: "http",
host: "[fb00:cafe::1]",
port: 8080,
wantTarget: "http://[fb00:cafe::1]:8080/",
},
{
name: "v4-mapped ipv6 host without port is bracketed",
protocol: "http",
host: "::ffff:10.0.0.1",
port: 0,
wantTarget: "http://[::ffff:10.0.0.1]/",
},
{
name: "full-form 8-group ipv6 without port is bracketed",
protocol: "http",
host: "fb00:cafe:1:0:0:0:0:3",
port: 0,
wantTarget: "http://[fb00:cafe:1:0:0:0:0:3]/",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

View File

@@ -193,7 +193,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
return Create(s, func() *nbgrpc.ProxyServiceServer {
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager())
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager(), s.Store())
s.AfterInit(func(s *BaseServer) {
proxyService.SetServiceManager(s.ServiceManager())
proxyService.SetProxyController(s.ServiceProxyController())

View File

@@ -112,7 +112,11 @@ func (s *BaseServer) AuthManager() auth.Manager {
func (s *BaseServer) EphemeralManager() ephemeral.Manager {
return Create(s, func() ephemeral.Manager {
return manager.NewEphemeralManager(s.Store(), s.PeersManager())
em := manager.NewEphemeralManager(s.Store(), s.PeersManager())
if metrics := s.Metrics(); metrics != nil {
em.SetMetrics(metrics.EphemeralPeersMetrics())
}
return em
})
}

View File

@@ -26,6 +26,7 @@ import (
"github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/routers"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
@@ -113,30 +114,47 @@ func (s *BaseServer) AccountManager() account.Manager {
})
}
func isMFAEnabledForAccount(accounts []*types.Account) bool {
if len(accounts) != 1 {
return false
}
settings := accounts[0].Settings
return settings != nil && settings.LocalMfaEnabled
}
func (s *BaseServer) IdpManager() idp.Manager {
return Create(s, func() idp.Manager {
var idpManager idp.Manager
var err error
// Use embedded IdP service if embedded Dex is configured and enabled.
// Legacy IdpManager won't be used anymore even if configured.
embeddedEnabled := s.Config.EmbeddedIdP != nil && s.Config.EmbeddedIdP.Enabled
if embeddedEnabled {
idpManager, err = idp.NewEmbeddedIdPManager(context.Background(), s.Config.EmbeddedIdP, s.Metrics())
embeddedMgr, err := idp.NewEmbeddedIdPManager(context.Background(), s.Config.EmbeddedIdP, s.Metrics())
if err != nil {
log.Fatalf("failed to create embedded IDP service: %v", err)
}
return idpManager
if val := isMFAEnabledForAccount(s.Store().GetAllAccounts(context.Background())); val {
if err := embeddedMgr.SetMFAEnabled(context.Background(), val); err != nil {
log.Errorf("failed to set MFA enabled on embedded IDP: %v", err)
}
}
return embeddedMgr
}
// Fall back to external IdP service
if s.Config.IdpManagerConfig != nil {
idpManager, err = idp.NewManager(context.Background(), *s.Config.IdpManagerConfig, s.Metrics())
idpManager, err := idp.NewManager(context.Background(), *s.Config.IdpManagerConfig, s.Metrics())
if err != nil {
log.Fatalf("failed to create IDP service: %v", err)
}
return idpManager
}
return idpManager
return nil
})
}

View File

@@ -0,0 +1,815 @@
package grpc
import (
"encoding/base64"
"strconv"
nbdns "github.com/netbirdio/netbird/dns"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
nbroute "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/networkmap"
"github.com/netbirdio/netbird/shared/management/proto"
)
// wgKeyRawLen is the raw byte length of a WireGuard public key.
const wgKeyRawLen = 32
// ComponentsEnvelopeInput bundles the data the component-format encoder needs.
// In Step 2 the envelope is fully self-contained — every field needed by the
// client's local Calculate() comes from the components struct itself. The
// only externally-supplied data is the receiving peer's PeerConfig (which is
// computed alongside the components in the network_map controller and reused
// from the legacy proto path) and the dns_domain string.
type ComponentsEnvelopeInput struct {
Components *types.NetworkMapComponents
PeerConfig *proto.PeerConfig
DNSDomain string
DNSForwarderPort int64
// UserIDClaim is the OIDC claim name the client should embed in
// SshAuth.UserIDClaim when reconstructing the NetworkMap. Empty value
// is OK — client treats empty as "no SshAuth to build".
UserIDClaim string
// ProxyPatch carries pre-expanded NetworkMap fragments injected by
// external controllers (BYOP/port-forwarding). Nil when no proxy data
// is present; encoder skips the field in that case.
ProxyPatch *proto.ProxyPatch
}
// EncodeNetworkMapEnvelope converts NetworkMapComponents into the component
// wire envelope. The encoder is intentionally non-deterministic: it iterates
// Go maps in their native (random) order. Indexes inside the envelope
// (peer_indexes, source_group_ids, agent_version_idx, router_peer_indexes)
// are self-consistent within a single encode, so the decoder reconstructs
// the same typed objects regardless of emit order. Tests that need to
// compare envelopes do so semantically via proto round-trip + canonicalize,
// not byte-equal.
//
// Callers must NOT concatenate or merge envelopes from different encodes —
// index spaces are local to a single envelope. Delta sync (Step 3+) will
// use a different shape for the same reason.
func EncodeNetworkMapEnvelope(in ComponentsEnvelopeInput) *proto.NetworkMapEnvelope {
c := in.Components
// Graceful degrade when components is nil — matches the legacy path's
// account_components.go:43 behaviour for missing/unvalidated peers
// (return a NetworkMap with only Network populated). The receiver gets
// an envelope it can decode without crashing; AccountSettings stays
// non-nil so client-side dereferences are safe.
if c == nil {
// Match legacy missing-peer minimum: a NetworkMap with only Network
// populated (account_components.go:43). The receiver gets enough to
// bootstrap (Network identifier, dns_domain, account_settings) and
// nothing else.
return &proto.NetworkMapEnvelope{
Payload: &proto.NetworkMapEnvelope_Full{
Full: &proto.NetworkMapComponentsFull{
PeerConfig: in.PeerConfig,
DnsDomain: in.DNSDomain,
DnsForwarderPort: in.DNSForwarderPort,
UserIdClaim: in.UserIDClaim,
AccountSettings: &proto.AccountSettingsCompact{},
ProxyPatch: in.ProxyPatch,
},
},
}
}
// Phase 1: build dedup tables. Every routing peer (in c.RouterPeers) and
// every regular peer (in c.Peers) must be indexed before any encoder
// looks up indexes via e.peerOrder — otherwise routes / routers_map for
// peers that exist only in c.RouterPeers would silently lose their
// peer_index reference.
enc := newComponentEncoder(c)
enc.indexAllPeers()
routerIdxs := enc.indexRouterPeers(c.RouterPeers)
// Phase 2: gather every policy that any consumer references (peer-pair
// policies + resource-only policies) so encodeResourcePoliciesMap can
// translate every *Policy pointer to a wire index.
allPolicies := unionPolicies(c.Policies, c.ResourcePoliciesMap)
policies, policyToIdxs := enc.encodePolicies(allPolicies)
// Phase 3: emit. Order of struct field expressions no longer matters:
// every encoder either reads from the dedup tables or works on
// independent input.
full := &proto.NetworkMapComponentsFull{
Serial: networkSerial(c.Network),
PeerConfig: in.PeerConfig,
Network: toAccountNetwork(c.Network),
AccountSettings: toAccountSettingsCompact(c.AccountSettings),
DnsForwarderPort: in.DNSForwarderPort,
UserIdClaim: in.UserIDClaim,
ProxyPatch: in.ProxyPatch,
DnsSettings: enc.encodeDNSSettings(c.DNSSettings),
DnsDomain: in.DNSDomain,
CustomZoneDomain: c.CustomZoneDomain,
AgentVersions: enc.agentVersions,
Peers: enc.peers,
RouterPeerIndexes: routerIdxs,
Policies: policies,
Groups: enc.encodeGroups(),
Routes: enc.encodeRoutes(c.Routes),
NameserverGroups: enc.encodeNameServerGroups(c.NameServerGroups),
AllDnsRecords: encodeSimpleRecords(c.AllDNSRecords),
AccountZones: encodeCustomZones(c.AccountZones),
NetworkResources: enc.encodeNetworkResources(c.NetworkResources),
RoutersMap: enc.encodeRoutersMap(c.RoutersMap),
ResourcePoliciesMap: enc.encodeResourcePoliciesMap(c.ResourcePoliciesMap, policyToIdxs),
GroupIdToUserIds: enc.encodeGroupIDToUserIDs(c.GroupIDToUserIDs),
AllowedUserIds: stringSetToSlice(c.AllowedUserIDs),
PostureFailedPeers: enc.encodePostureFailedPeers(c.PostureFailedPeers),
}
return &proto.NetworkMapEnvelope{
Payload: &proto.NetworkMapEnvelope_Full{Full: full},
}
}
// networkSerial returns c.Network.CurrentSerial() with a nil guard. The
// production path always populates c.Network (account_components.go:86), but
// the encoder is exported and a hand-built components struct may omit it.
func networkSerial(n *types.Network) uint64 {
if n == nil {
return 0
}
return n.CurrentSerial()
}
type componentEncoder struct {
components *types.NetworkMapComponents
peerOrder map[string]uint32
peers []*proto.PeerCompact
agentVersionOrder map[string]uint32
agentVersions []string
}
func newComponentEncoder(c *types.NetworkMapComponents) *componentEncoder {
return &componentEncoder{
components: c,
peerOrder: make(map[string]uint32, len(c.Peers)),
peers: make([]*proto.PeerCompact, 0, len(c.Peers)),
agentVersionOrder: make(map[string]uint32),
}
}
func (e *componentEncoder) indexAllPeers() {
for _, p := range e.components.Peers {
if p == nil {
continue
}
e.appendPeer(p)
}
}
func (e *componentEncoder) appendPeer(p *nbpeer.Peer) uint32 {
if idx, ok := e.peerOrder[p.ID]; ok {
return idx
}
idx := uint32(len(e.peers))
e.peerOrder[p.ID] = idx
e.peers = append(e.peers, toPeerCompact(p, e.agentVersionIndex(p.Meta.WtVersion)))
return idx
}
func (e *componentEncoder) agentVersionIndex(v string) uint32 {
if idx, ok := e.agentVersionOrder[v]; ok {
return idx
}
// Lazy-initialise the table with "" at index 0 so the empty string
// stays interchangeable with proto3's default uint32=0 — peers without
// a WtVersion don't force the table to materialise.
if v == "" {
idx := uint32(len(e.agentVersions))
if idx == 0 {
e.agentVersions = append(e.agentVersions, "")
}
e.agentVersionOrder[""] = idx
return idx
}
if len(e.agentVersions) == 0 {
e.agentVersions = append(e.agentVersions, "")
e.agentVersionOrder[""] = 0
}
idx := uint32(len(e.agentVersions))
e.agentVersionOrder[v] = idx
e.agentVersions = append(e.agentVersions, v)
return idx
}
// indexRouterPeers ensures every router peer is in the peer dedup table
// (c.RouterPeers may contain peers not in c.Peers when validation rules drop
// them) and returns their wire indexes for the RouterPeerIndexes field. Must
// run before any encoder that resolves peer ids via e.peerOrder.
func (e *componentEncoder) indexRouterPeers(routers map[string]*nbpeer.Peer) []uint32 {
if len(routers) == 0 {
return nil
}
out := make([]uint32, 0, len(routers))
for _, p := range routers {
if p == nil {
continue
}
out = append(out, e.appendPeer(p))
}
return out
}
func (e *componentEncoder) encodeGroups() []*proto.GroupCompact {
if len(e.components.Groups) == 0 {
return nil
}
out := make([]*proto.GroupCompact, 0, len(e.components.Groups))
for _, g := range e.components.Groups {
if !g.HasSeqID() {
continue
}
peerIdxs := make([]uint32, 0, len(g.Peers))
for _, peerID := range g.Peers {
if idx, ok := e.peerOrder[peerID]; ok {
peerIdxs = append(peerIdxs, idx)
}
}
out = append(out, &proto.GroupCompact{
Id: g.AccountSeqID,
Name: g.Name,
PeerIndexes: peerIdxs,
})
}
return out
}
// encodePolicies flattens Policy{Rules} → []PolicyCompact. Returns the wire
// list and a map from policy pointer to the indexes of its emitted rules in
// that list — used by encodeResourcePoliciesMap to translate
// ResourcePoliciesMap[resourceID][]*Policy into wire-side indexes.
func (e *componentEncoder) encodePolicies(policies []*types.Policy) ([]*proto.PolicyCompact, map[*types.Policy][]uint32) {
if len(policies) == 0 {
return nil, nil
}
out := make([]*proto.PolicyCompact, 0, len(policies))
idxByPolicy := make(map[*types.Policy][]uint32, len(policies))
for _, pol := range policies {
if !pol.HasSeqID() || !pol.Enabled {
continue
}
for _, r := range pol.Rules {
if r == nil || !r.Enabled {
continue
}
idxByPolicy[pol] = append(idxByPolicy[pol], uint32(len(out)))
out = append(out, e.encodePolicyRule(pol, r))
}
}
return out, idxByPolicy
}
// encodePolicyRule maps a single PolicyRule under pol to a PolicyCompact entry.
func (e *componentEncoder) encodePolicyRule(pol *types.Policy, r *types.PolicyRule) *proto.PolicyCompact {
return &proto.PolicyCompact{
Id: pol.AccountSeqID,
Action: networkmap.GetProtoAction(string(r.Action)),
Protocol: networkmap.GetProtoProtocol(string(r.Protocol)),
Bidirectional: r.Bidirectional,
Ports: portsToUint32(r.Ports),
PortRanges: portRangesToProto(r.PortRanges),
SourceGroupIds: e.groupSeqIDs(r.Sources),
DestinationGroupIds: e.groupSeqIDs(r.Destinations),
AuthorizedUser: r.AuthorizedUser,
AuthorizedGroups: e.encodeAuthorizedGroups(r.AuthorizedGroups),
SourceResource: e.resourceToProto(r.SourceResource),
DestinationResource: e.resourceToProto(r.DestinationResource),
SourcePostureCheckSeqIds: e.postureCheckSeqs(pol.SourcePostureChecks),
}
}
// groupSeqIDs maps the xid group IDs in src to their per-account seq ids,
// dropping any group that has no seq id assigned.
func (e *componentEncoder) groupSeqIDs(src []string) []uint32 {
if len(src) == 0 {
return nil
}
out := make([]uint32, 0, len(src))
for _, gid := range src {
if seq, ok := e.groupSeq(gid); ok {
out = append(out, seq)
}
}
return out
}
// unionPolicies merges c.Policies with every policy referenced by
// c.ResourcePoliciesMap, deduplicating by pointer identity. Resource-only
// policies (relevant to a NetworkResource but not to peer-pair traffic)
// only live in ResourcePoliciesMap; without this union step they'd be lost
// from the wire and the client's resource-policy lookup would come back
// empty.
func unionPolicies(policies []*types.Policy, resourcePolicies map[string][]*types.Policy) []*types.Policy {
// Fast path: non-router peers have no resource-only policies, so the
// "union" is identical to `policies`. Skip the dedup map allocation.
if len(resourcePolicies) == 0 {
return policies
}
seen := make(map[*types.Policy]struct{}, len(policies))
out := make([]*types.Policy, 0, len(policies))
for _, p := range policies {
if p == nil {
continue
}
if _, ok := seen[p]; ok {
continue
}
seen[p] = struct{}{}
out = append(out, p)
}
for _, list := range resourcePolicies {
for _, p := range list {
if p == nil {
continue
}
if _, ok := seen[p]; ok {
continue
}
seen[p] = struct{}{}
out = append(out, p)
}
}
return out
}
// encodeAuthorizedGroups translates rule.AuthorizedGroups (map keyed by
// group xid → local-user names) to the wire form (map keyed by group
// account_seq_id → UserNameList). Groups without a seq id are dropped —
// matches how source/destination group references handle the same case.
func (e *componentEncoder) encodeAuthorizedGroups(m map[string][]string) map[uint32]*proto.UserNameList {
if len(m) == 0 {
return nil
}
out := make(map[uint32]*proto.UserNameList, len(m))
for groupID, names := range m {
seq, ok := e.groupSeq(groupID)
if !ok {
continue
}
out[seq] = &proto.UserNameList{Names: append([]string(nil), names...)}
}
return out
}
func (e *componentEncoder) groupSeq(groupID string) (uint32, bool) {
g, ok := e.components.Groups[groupID]
if !ok || !g.HasSeqID() {
return 0, false
}
return g.AccountSeqID, true
}
// resourceToProto translates types.Resource for the wire. For peer-typed
// resources the peer id is converted to a peer index into the envelope's
// peers array. For other resource types only the type string is shipped
// today (Calculate's resource-typed rule path consults SourceResource only
// for "peer" — other types fall through to group-based lookup).
func (e *componentEncoder) resourceToProto(r types.Resource) *proto.ResourceCompact {
if r.ID == "" && r.Type == "" {
return nil
}
out := &proto.ResourceCompact{Type: string(r.Type)}
if r.Type == types.ResourceTypePeer && r.ID != "" {
if idx, ok := e.peerOrder[r.ID]; ok {
out.PeerIndexSet = true
out.PeerIndex = idx
}
}
return out
}
// postureCheckSeqs translates a slice of posture-check xids to their
// per-account integer ids using the NetworkMapComponents.PostureCheckXIDToSeq
// lookup. Unresolvable xids are silently dropped — matches how group/peer
// references handle the same case.
func (e *componentEncoder) postureCheckSeqs(xids []string) []uint32 {
if len(xids) == 0 || len(e.components.PostureCheckXIDToSeq) == 0 {
return nil
}
out := make([]uint32, 0, len(xids))
for _, xid := range xids {
if seq, ok := e.components.PostureCheckXIDToSeq[xid]; ok {
out = append(out, seq)
}
}
return out
}
// networkSeq translates a Network xid to its per-account integer id using
// the NetworkMapComponents.NetworkXIDToSeq lookup. Returns (0,false) when
// the xid isn't known — callers decide whether to skip the parent record.
func (e *componentEncoder) networkSeq(xid string) (uint32, bool) {
if xid == "" {
return 0, false
}
seq, ok := e.components.NetworkXIDToSeq[xid]
if !ok || seq == 0 {
return 0, false
}
return seq, true
}
func (e *componentEncoder) encodeDNSSettings(s *types.DNSSettings) *proto.DNSSettingsCompact {
if s == nil || len(s.DisabledManagementGroups) == 0 {
return nil
}
out := &proto.DNSSettingsCompact{
DisabledManagementGroupIds: make([]uint32, 0, len(s.DisabledManagementGroups)),
}
for _, gid := range s.DisabledManagementGroups {
if seq, ok := e.groupSeq(gid); ok {
out.DisabledManagementGroupIds = append(out.DisabledManagementGroupIds, seq)
}
}
return out
}
func (e *componentEncoder) encodeRoutes(routes []*nbroute.Route) []*proto.RouteRaw {
if len(routes) == 0 {
return nil
}
out := make([]*proto.RouteRaw, 0, len(routes))
for _, r := range routes {
if r == nil {
continue
}
rr := &proto.RouteRaw{
Id: r.AccountSeqID,
NetId: string(r.NetID),
Description: r.Description,
KeepRoute: r.KeepRoute,
NetworkType: int32(r.NetworkType),
Masquerade: r.Masquerade,
Metric: int32(r.Metric),
Enabled: r.Enabled,
SkipAutoApply: r.SkipAutoApply,
Domains: r.Domains.ToPunycodeList(),
GroupIds: e.groupIDsToSeq(r.Groups),
AccessControlGroupIds: e.groupIDsToSeq(r.AccessControlGroups),
PeerGroupIds: e.groupIDsToSeq(r.PeerGroups),
}
if r.Network.IsValid() {
rr.NetworkCidr = r.Network.String()
}
if r.Peer != "" {
if idx, ok := e.peerOrder[r.Peer]; ok {
rr.PeerIndexSet = true
rr.PeerIndex = idx
}
}
out = append(out, rr)
}
return out
}
func (e *componentEncoder) groupIDsToSeq(groupIDs []string) []uint32 {
if len(groupIDs) == 0 {
return nil
}
out := make([]uint32, 0, len(groupIDs))
for _, gid := range groupIDs {
if seq, ok := e.groupSeq(gid); ok {
out = append(out, seq)
}
}
return out
}
func (e *componentEncoder) encodeNameServerGroups(nsgs []*nbdns.NameServerGroup) []*proto.NameServerGroupRaw {
if len(nsgs) == 0 {
return nil
}
out := make([]*proto.NameServerGroupRaw, 0, len(nsgs))
for _, nsg := range nsgs {
if nsg == nil {
continue
}
entry := &proto.NameServerGroupRaw{
Id: nsg.AccountSeqID,
Name: nsg.Name,
Description: nsg.Description,
Nameservers: encodeNameServers(nsg.NameServers),
GroupIds: e.groupIDsToSeq(nsg.Groups),
Primary: nsg.Primary,
Domains: nsg.Domains,
Enabled: nsg.Enabled,
SearchDomainsEnabled: nsg.SearchDomainsEnabled,
}
out = append(out, entry)
}
return out
}
func encodeNameServers(servers []nbdns.NameServer) []*proto.NameServer {
if len(servers) == 0 {
return nil
}
out := make([]*proto.NameServer, 0, len(servers))
for _, s := range servers {
out = append(out, &proto.NameServer{
IP: s.IP.String(),
NSType: int64(s.NSType),
Port: int64(s.Port),
})
}
return out
}
func encodeSimpleRecords(records []nbdns.SimpleRecord) []*proto.SimpleRecord {
if len(records) == 0 {
return nil
}
out := make([]*proto.SimpleRecord, 0, len(records))
for _, r := range records {
out = append(out, &proto.SimpleRecord{
Name: r.Name,
Type: int64(r.Type),
Class: r.Class,
TTL: int64(r.TTL),
RData: r.RData,
})
}
return out
}
func encodeCustomZones(zones []nbdns.CustomZone) []*proto.CustomZone {
if len(zones) == 0 {
return nil
}
out := make([]*proto.CustomZone, 0, len(zones))
for _, z := range zones {
out = append(out, &proto.CustomZone{
Domain: z.Domain,
Records: encodeSimpleRecords(z.Records),
SearchDomainDisabled: z.SearchDomainDisabled,
NonAuthoritative: z.NonAuthoritative,
})
}
return out
}
func (e *componentEncoder) encodeNetworkResources(resources []*resourceTypes.NetworkResource) []*proto.NetworkResourceRaw {
if len(resources) == 0 {
return nil
}
out := make([]*proto.NetworkResourceRaw, 0, len(resources))
for _, r := range resources {
if r == nil {
continue
}
entry := &proto.NetworkResourceRaw{
Id: r.AccountSeqID,
Name: r.Name,
Description: r.Description,
Type: string(r.Type),
Address: r.Address,
DomainValue: r.Domain,
Enabled: r.Enabled,
}
if seq, ok := e.networkSeq(r.NetworkID); ok {
entry.NetworkSeq = seq
}
if r.Prefix.IsValid() {
entry.PrefixCidr = r.Prefix.String()
}
out = append(out, entry)
}
return out
}
func (e *componentEncoder) encodeRoutersMap(routersMap map[string]map[string]*routerTypes.NetworkRouter) map[uint32]*proto.NetworkRouterList {
if len(routersMap) == 0 {
return nil
}
out := make(map[uint32]*proto.NetworkRouterList, len(routersMap))
for networkXID, routers := range routersMap {
if len(routers) == 0 {
continue
}
netSeq, ok := e.networkSeq(networkXID)
if !ok {
continue
}
entries := make([]*proto.NetworkRouterEntry, 0, len(routers))
for peerID, r := range routers {
if r == nil {
continue
}
entry := &proto.NetworkRouterEntry{
Id: r.AccountSeqID,
PeerGroupIds: e.groupIDsToSeq(r.PeerGroups),
Masquerade: r.Masquerade,
Metric: int32(r.Metric),
Enabled: r.Enabled,
}
if idx, ok := e.peerOrder[peerID]; ok {
entry.PeerIndexSet = true
entry.PeerIndex = idx
}
entries = append(entries, entry)
}
out[netSeq] = &proto.NetworkRouterList{Entries: entries}
}
return out
}
func (e *componentEncoder) encodeResourcePoliciesMap(rpm map[string][]*types.Policy, policyToIdxs map[*types.Policy][]uint32) map[uint32]*proto.PolicyIndexes {
if len(rpm) == 0 {
return nil
}
// resourceXIDToSeq is local to one encode — built from components.NetworkResources
// (small slice). Network resources without seq id are dropped, matching how
// other components-without-seq are silently filtered.
resourceXIDToSeq := make(map[string]uint32, len(e.components.NetworkResources))
for _, r := range e.components.NetworkResources {
if r != nil && r.AccountSeqID != 0 {
resourceXIDToSeq[r.ID] = r.AccountSeqID
}
}
out := make(map[uint32]*proto.PolicyIndexes, len(rpm))
for resourceXID, policies := range rpm {
seq, ok := resourceXIDToSeq[resourceXID]
if !ok {
continue
}
idxs := make([]uint32, 0, len(policies)*2)
for _, pol := range policies {
idxs = append(idxs, policyToIdxs[pol]...)
}
if len(idxs) == 0 {
continue
}
out[seq] = &proto.PolicyIndexes{Indexes: idxs}
}
return out
}
func (e *componentEncoder) encodeGroupIDToUserIDs(m map[string][]string) map[uint32]*proto.UserIDList {
if len(m) == 0 {
return nil
}
out := make(map[uint32]*proto.UserIDList, len(m))
for groupID, userIDs := range m {
seq, ok := e.groupSeq(groupID)
if !ok || len(userIDs) == 0 {
continue
}
out[seq] = &proto.UserIDList{UserIds: userIDs}
}
return out
}
func stringSetToSlice(s map[string]struct{}) []string {
if len(s) == 0 {
return nil
}
out := make([]string, 0, len(s))
for k := range s {
out = append(out, k)
}
return out
}
func (e *componentEncoder) encodePostureFailedPeers(m map[string]map[string]struct{}) map[uint32]*proto.PeerIndexSet {
if len(m) == 0 {
return nil
}
out := make(map[uint32]*proto.PeerIndexSet, len(m))
for checkXID, failedPeerIDs := range m {
seq, ok := e.components.PostureCheckXIDToSeq[checkXID]
if !ok || seq == 0 {
continue
}
idxs := make([]uint32, 0, len(failedPeerIDs))
for peerID := range failedPeerIDs {
if idx, ok := e.peerOrder[peerID]; ok {
idxs = append(idxs, idx)
}
}
if len(idxs) == 0 {
continue
}
out[seq] = &proto.PeerIndexSet{PeerIndexes: idxs}
}
return out
}
// toAccountSettingsCompact always returns a non-nil message — the client
// dereferences it unconditionally during Calculate(), so a nil here would
// crash the receiver. A missing types.AccountSettingsInfo on the server
// (which shouldn't happen in production but the encoder is exported)
// degrades to login_expiration_enabled = false, which makes
// LoginExpired() return false for every peer.
func toAccountSettingsCompact(s *types.AccountSettingsInfo) *proto.AccountSettingsCompact {
if s == nil {
return &proto.AccountSettingsCompact{}
}
return &proto.AccountSettingsCompact{
PeerLoginExpirationEnabled: s.PeerLoginExpirationEnabled,
PeerLoginExpirationNs: int64(s.PeerLoginExpiration),
}
}
func toAccountNetwork(n *types.Network) *proto.AccountNetwork {
if n == nil {
return nil
}
out := &proto.AccountNetwork{
Identifier: n.Identifier,
NetCidr: n.Net.String(),
Dns: n.Dns,
Serial: n.CurrentSerial(),
}
if len(n.NetV6.IP) > 0 {
out.NetV6Cidr = n.NetV6.String()
}
return out
}
func toPeerCompact(p *nbpeer.Peer, agentVersionIdx uint32) *proto.PeerCompact {
pc := &proto.PeerCompact{
WgPubKey: decodeWgKey(p.Key),
SshPubKey: []byte(p.SSHKey),
DnsLabel: p.DNSLabel,
AgentVersionIdx: agentVersionIdx,
AddedWithSsoLogin: p.UserID != "",
LoginExpirationEnabled: p.LoginExpirationEnabled,
SshEnabled: p.SSHEnabled,
SupportsIpv6: p.SupportsIPv6(),
SupportsSourcePrefixes: p.SupportsSourcePrefixes(),
ServerSshAllowed: p.Meta.Flags.ServerSSHAllowed,
}
if p.LastLogin != nil {
pc.LastLoginUnixNano = p.LastLogin.UnixNano()
}
switch {
case !p.IP.IsValid():
// leave Ip nil
case p.IP.Is4() || p.IP.Is4In6():
ip := p.IP.Unmap().As4()
pc.Ip = ip[:]
default:
ip := p.IP.As16()
pc.Ip = ip[:]
}
if p.IPv6.IsValid() {
ip := p.IPv6.As16()
pc.Ipv6 = ip[:]
}
return pc
}
// decodeWgKey returns the raw 32 bytes of a base64-encoded WireGuard public
// key, or nil for an empty / malformed key.
func decodeWgKey(s string) []byte {
if s == "" {
return nil
}
out := make([]byte, wgKeyRawLen)
n, err := base64.StdEncoding.Decode(out, []byte(s))
if err != nil || n != wgKeyRawLen {
return nil
}
return out
}
func portsToUint32(ports []string) []uint32 {
if len(ports) == 0 {
return nil
}
out := make([]uint32, 0, len(ports))
for _, p := range ports {
v, err := strconv.ParseUint(p, 10, 16)
if err != nil {
continue
}
out = append(out, uint32(v))
}
return out
}
func portRangesToProto(ranges []types.RulePortRange) []*proto.PortInfo_Range {
if len(ranges) == 0 {
return nil
}
out := make([]*proto.PortInfo_Range, 0, len(ranges))
for _, r := range ranges {
out = append(out, &proto.PortInfo_Range{
Start: uint32(r.Start),
End: uint32(r.End),
})
}
return out
}

View File

@@ -0,0 +1,879 @@
package grpc
import (
"bytes"
"cmp"
"net"
"net/netip"
"slices"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
goproto "google.golang.org/protobuf/proto"
nbdns "github.com/netbirdio/netbird/dns"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
nbroute "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/proto"
)
const testWgKeyA = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopq="
const testWgKeyB = "BBCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopq="
const testWgKeyC = "CBCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopq="
// canonicalize rewrites a NetworkMapComponentsFull in place into a canonical
// form: peers reordered by wg_pub_key, with the rest of the message rewritten
// to reference the new peer indexes. Groups, policies, and router indexes are
// also sorted. After canonicalize, two envelopes built from the same logical
// input compare byte-equal via proto.Equal.
//
// This lives on the test side — the encoder itself emits in map-iteration
// order. Test-side normalization is the contract for "two encodes are
// equivalent".
func canonicalize(full *proto.NetworkMapComponentsFull) {
if full == nil {
return
}
// Canonicalize agent_versions first: sort the slice and rewrite each
// peer's AgentVersionIdx accordingly. The empty placeholder stays at
// index 0 by convention.
avRemap := make(map[uint32]uint32, len(full.AgentVersions))
if len(full.AgentVersions) > 0 {
// Pair version → original index, sort, rebuild.
type avEntry struct {
version string
oldIdx uint32
}
entries := make([]avEntry, len(full.AgentVersions))
for i, v := range full.AgentVersions {
entries[i] = avEntry{version: v, oldIdx: uint32(i)}
}
// Empty stays at 0; sort the rest by string. Tiebreaker on oldIdx
// keeps the canonicalize output stable when two entries compare
// equal (the encoder dedups, but defending against future inputs).
slices.SortFunc(entries, func(a, b avEntry) int {
if a.version == "" && b.version != "" {
return -1
}
if b.version == "" && a.version != "" {
return 1
}
if c := cmp.Compare(a.version, b.version); c != 0 {
return c
}
return cmp.Compare(a.oldIdx, b.oldIdx)
})
newVersions := make([]string, len(entries))
for newIdx, e := range entries {
avRemap[e.oldIdx] = uint32(newIdx)
newVersions[newIdx] = e.version
}
full.AgentVersions = newVersions
}
for _, p := range full.Peers {
if newIdx, ok := avRemap[p.AgentVersionIdx]; ok {
p.AgentVersionIdx = newIdx
}
}
type peerEntry struct {
peer *proto.PeerCompact
oldIdx uint32
}
entries := make([]peerEntry, len(full.Peers))
for i, p := range full.Peers {
entries[i] = peerEntry{peer: p, oldIdx: uint32(i)}
}
// DnsLabel is unique per peer; it tiebreaks on equal WgPubKey (e.g. both
// nil from malformed keys, or both empty for placeholders).
slices.SortFunc(entries, func(a, b peerEntry) int {
if c := bytes.Compare(a.peer.WgPubKey, b.peer.WgPubKey); c != 0 {
return c
}
return cmp.Compare(a.peer.DnsLabel, b.peer.DnsLabel)
})
remap := make(map[uint32]uint32, len(entries))
newPeers := make([]*proto.PeerCompact, len(entries))
for newIdx, e := range entries {
remap[e.oldIdx] = uint32(newIdx)
newPeers[newIdx] = e.peer
}
full.Peers = newPeers
full.RouterPeerIndexes = remapAndSort(full.RouterPeerIndexes, remap)
for _, g := range full.Groups {
g.PeerIndexes = remapAndSort(g.PeerIndexes, remap)
}
slices.SortFunc(full.Groups, func(a, b *proto.GroupCompact) int { return cmp.Compare(a.Id, b.Id) })
for _, r := range full.Routes {
if r.PeerIndexSet {
if newIdx, ok := remap[r.PeerIndex]; ok {
r.PeerIndex = newIdx
}
}
slices.Sort(r.GroupIds)
slices.Sort(r.AccessControlGroupIds)
slices.Sort(r.PeerGroupIds)
}
slices.SortFunc(full.Routes, func(a, b *proto.RouteRaw) int { return cmp.Compare(a.Id, b.Id) })
for _, list := range full.RoutersMap {
for _, entry := range list.Entries {
if entry.PeerIndexSet {
if newIdx, ok := remap[entry.PeerIndex]; ok {
entry.PeerIndex = newIdx
}
}
slices.Sort(entry.PeerGroupIds)
}
slices.SortFunc(list.Entries, func(a, b *proto.NetworkRouterEntry) int { return cmp.Compare(a.Id, b.Id) })
}
for _, set := range full.PostureFailedPeers {
set.PeerIndexes = remapAndSort(set.PeerIndexes, remap)
}
for _, p := range full.Policies {
slices.Sort(p.SourceGroupIds)
slices.Sort(p.DestinationGroupIds)
}
// Sort policies by (Id, source_group_ids, destination_group_ids) so that
// multiple PolicyCompact entries sharing the same Id (one per rule, when
// a Policy has multiple rules) still get a deterministic order. After
// sorting we remap indexes in ResourcePoliciesMap.
policyOldOrder := make(map[*proto.PolicyCompact]uint32, len(full.Policies))
for i, p := range full.Policies {
policyOldOrder[p] = uint32(i)
}
slices.SortFunc(full.Policies, func(a, b *proto.PolicyCompact) int {
if c := cmp.Compare(a.Id, b.Id); c != 0 {
return c
}
if c := slices.Compare(a.SourceGroupIds, b.SourceGroupIds); c != 0 {
return c
}
return slices.Compare(a.DestinationGroupIds, b.DestinationGroupIds)
})
policyRemap := make(map[uint32]uint32, len(full.Policies))
for newIdx, p := range full.Policies {
policyRemap[policyOldOrder[p]] = uint32(newIdx)
}
for _, idxs := range full.ResourcePoliciesMap {
idxs.Indexes = remapAndSort(idxs.Indexes, policyRemap)
}
for _, list := range full.GroupIdToUserIds {
slices.Sort(list.UserIds)
}
slices.Sort(full.AllowedUserIds)
}
func remapAndSort(idxs []uint32, remap map[uint32]uint32) []uint32 {
out := make([]uint32, 0, len(idxs))
for _, i := range idxs {
if newIdx, ok := remap[i]; ok {
out = append(out, newIdx)
}
}
slices.Sort(out)
return out
}
// envelopesEquivalent decodes both envelopes, canonicalizes them, and reports
// whether they're proto.Equal. Use instead of byte-comparing marshaled output:
// the encoder is intentionally non-deterministic.
func envelopesEquivalent(a, b *proto.NetworkMapEnvelope) bool {
canonicalize(a.GetFull())
canonicalize(b.GetFull())
return goproto.Equal(a, b)
}
func newTestComponents() *types.NetworkMapComponents {
peerA := &nbpeer.Peer{
ID: "peer-a",
Key: testWgKeyA,
IP: netip.AddrFrom4([4]byte{100, 64, 0, 1}),
DNSLabel: "peera",
SSHKey: "ssh-a",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
}
peerB := &nbpeer.Peer{
ID: "peer-b",
Key: testWgKeyB,
IP: netip.AddrFrom4([4]byte{100, 64, 0, 2}),
IPv6: netip.AddrFrom16([16]byte{0xfd, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2}),
DNSLabel: "peerb",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.25.0"},
}
peerC := &nbpeer.Peer{
ID: "peer-c",
Key: testWgKeyC,
IP: netip.AddrFrom4([4]byte{100, 64, 0, 3}),
DNSLabel: "peerc",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
}
return &types.NetworkMapComponents{
PeerID: "peer-a",
Network: &types.Network{
Identifier: "net-test",
Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(10, 32)},
Serial: 7,
},
AccountSettings: &types.AccountSettingsInfo{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: 2 * time.Hour,
},
Peers: map[string]*nbpeer.Peer{
"peer-a": peerA,
"peer-b": peerB,
"peer-c": peerC,
},
Groups: map[string]*types.Group{
"group-src": {ID: "group-src", AccountSeqID: 1, Name: "Src", Peers: []string{"peer-a"}},
"group-dst": {ID: "group-dst", AccountSeqID: 2, Name: "Dst", Peers: []string{"peer-b", "peer-c"}},
},
Policies: []*types.Policy{
{
ID: "pol-1",
AccountSeqID: 10,
Enabled: true,
Rules: []*types.PolicyRule{{
ID: "rule-1", Enabled: true, Action: types.PolicyTrafficActionAccept,
Protocol: types.PolicyRuleProtocolTCP, Bidirectional: true,
Ports: []string{"22", "80"},
PortRanges: []types.RulePortRange{{Start: 8000, End: 8100}},
Sources: []string{"group-src"},
Destinations: []string{"group-dst"},
}},
},
},
RouterPeers: map[string]*nbpeer.Peer{"peer-c": peerC},
}
}
func TestEncodeNetworkMapEnvelope_Basic(t *testing.T) {
c := newTestComponents()
env := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{
Components: c,
DNSDomain: "netbird.cloud",
})
require.NotNil(t, env)
full := env.GetFull()
require.NotNil(t, full, "envelope must contain Full payload")
assert.EqualValues(t, 7, full.Serial)
assert.Equal(t, "netbird.cloud", full.DnsDomain)
require.NotNil(t, full.Network)
assert.Equal(t, "net-test", full.Network.Identifier)
assert.Equal(t, "100.64.0.0/10", full.Network.NetCidr)
require.NotNil(t, full.AccountSettings)
assert.True(t, full.AccountSettings.PeerLoginExpirationEnabled)
assert.EqualValues(t, (2 * time.Hour).Nanoseconds(), full.AccountSettings.PeerLoginExpirationNs)
require.Len(t, full.Peers, 3)
byLabel := map[string]*proto.PeerCompact{}
for _, p := range full.Peers {
assert.Len(t, p.WgPubKey, 32, "wg key must be raw 32 bytes")
assert.Len(t, p.Ip, 4, "ipv4 must be raw 4 bytes")
byLabel[p.DnsLabel] = p
}
assert.Len(t, byLabel["peerb"].Ipv6, 16, "peer-b has ipv6 → 16 bytes")
}
func TestEncodeNetworkMapEnvelope_RepeatEncodesEquivalent(t *testing.T) {
c := newTestComponents()
expected := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
// Hammer it 100 times — Go map iteration is randomized per call, so each
// run produces different wire bytes, but the canonicalized form must
// match.
for i := 0; i < 100; i++ {
got := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
require.True(t, envelopesEquivalent(expected, got),
"encode #%d must be semantically equivalent to first encode", i)
}
}
func TestEncodeNetworkMapEnvelope_ConcurrentEncodesEquivalent(t *testing.T) {
c := newTestComponents()
expected := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
const goroutines = 50
var wg sync.WaitGroup
wg.Add(goroutines)
results := make([]*proto.NetworkMapEnvelope, goroutines)
for i := 0; i < goroutines; i++ {
i := i
go func() {
defer wg.Done()
results[i] = EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
}()
}
wg.Wait()
for i, got := range results {
require.NotNil(t, got, "goroutine %d returned nil", i)
require.True(t, envelopesEquivalent(expected, got),
"goroutine %d produced inequivalent envelope", i)
}
}
func TestEncodeNetworkMapEnvelope_GroupsByAccountSeqID(t *testing.T) {
c := newTestComponents()
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Len(t, full.Groups, 2)
groupByID := map[uint32]*proto.GroupCompact{}
for _, g := range full.Groups {
groupByID[g.Id] = g
}
require.Contains(t, groupByID, uint32(1))
require.Contains(t, groupByID, uint32(2))
assert.Equal(t, "Src", groupByID[1].Name)
assert.Equal(t, "Dst", groupByID[2].Name)
assert.Len(t, groupByID[1].PeerIndexes, 1)
assert.Len(t, groupByID[2].PeerIndexes, 2)
}
func TestEncodeNetworkMapEnvelope_PolicyExpansion(t *testing.T) {
c := newTestComponents()
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Len(t, full.Policies, 1)
pc := full.Policies[0]
assert.EqualValues(t, 10, pc.Id)
assert.Equal(t, proto.RuleAction_ACCEPT, pc.Action)
assert.Equal(t, proto.RuleProtocol_TCP, pc.Protocol)
assert.True(t, pc.Bidirectional)
assert.Equal(t, []uint32{22, 80}, pc.Ports)
require.Len(t, pc.PortRanges, 1)
assert.EqualValues(t, 8000, pc.PortRanges[0].Start)
assert.EqualValues(t, 8100, pc.PortRanges[0].End)
assert.Equal(t, []uint32{1}, pc.SourceGroupIds)
assert.Equal(t, []uint32{2}, pc.DestinationGroupIds)
}
func TestEncodeNetworkMapEnvelope_RouterIndexes(t *testing.T) {
c := newTestComponents()
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Len(t, full.RouterPeerIndexes, 1)
idx := full.RouterPeerIndexes[0]
require.Less(t, int(idx), len(full.Peers))
assert.Equal(t, "peerc", full.Peers[idx].DnsLabel)
}
func TestEncodeNetworkMapEnvelope_AgentVersionDedup(t *testing.T) {
c := newTestComponents()
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Len(t, full.AgentVersions, 3, "empty placeholder + 2 distinct versions")
assert.Equal(t, "", full.AgentVersions[0], "index 0 reserved for empty version")
assert.ElementsMatch(t, []string{"0.40.0", "0.25.0"}, full.AgentVersions[1:],
"two distinct versions, order depends on map iteration")
idxByLabel := map[string]uint32{}
for _, p := range full.Peers {
idxByLabel[p.DnsLabel] = p.AgentVersionIdx
}
assert.Equal(t, idxByLabel["peera"], idxByLabel["peerc"], "peers with the same agent version share an index")
assert.NotEqual(t, idxByLabel["peera"], idxByLabel["peerb"])
}
func TestEncodeNetworkMapEnvelope_DisabledPolicySkipped(t *testing.T) {
c := newTestComponents()
c.Policies[0].Enabled = false
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
assert.Empty(t, full.Policies)
}
func TestEncodeNetworkMapEnvelope_GroupZeroSeqIDSkipped(t *testing.T) {
c := newTestComponents()
c.Groups["group-src"].AccountSeqID = 0
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Len(t, full.Groups, 1, "groups with AccountSeqID=0 are not yet persisted and must be skipped")
assert.EqualValues(t, 2, full.Groups[0].Id)
require.Len(t, full.Policies, 1)
pc := full.Policies[0]
assert.Empty(t, pc.SourceGroupIds, "rule references a group that was filtered out → no group id on wire")
assert.Equal(t, []uint32{2}, pc.DestinationGroupIds)
}
func TestEncodeNetworkMapEnvelope_TwoPeersSameMalformedKey(t *testing.T) {
// Both peers have nil WgPubKey after decode; canonicalize must still
// produce a stable order using DnsLabel as a tiebreaker, so 100 encodes
// canonicalize identically.
c := newTestComponents()
c.Peers["peer-a"].Key = "garbage-a-!!!"
c.Peers["peer-b"].Key = "garbage-b-!!!"
expected := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
for i := 0; i < 100; i++ {
got := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
require.True(t, envelopesEquivalent(expected, got),
"encode #%d with two same-key peers must canonicalize equivalently", i)
}
}
func TestEncodeNetworkMapEnvelope_MalformedWgKey(t *testing.T) {
c := newTestComponents()
c.Peers["peer-a"].Key = "not-base64-!!!"
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Len(t, full.Peers, 3)
var byLabel = map[string]*proto.PeerCompact{}
for _, p := range full.Peers {
byLabel[p.DnsLabel] = p
}
assert.Nil(t, byLabel["peera"].WgPubKey, "peer with malformed key encodes nil WgPubKey")
assert.Len(t, byLabel["peerb"].WgPubKey, 32, "other peers retain their key")
}
func TestEncodeNetworkMapEnvelope_IPv6OnlyPeer(t *testing.T) {
c := newTestComponents()
v6Only := &nbpeer.Peer{
ID: "peer-v6",
Key: testWgKeyA,
IPv6: netip.AddrFrom16([16]byte{0xfd, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9}),
DNSLabel: "peerv6",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
}
c.Peers["peer-v6"] = v6Only
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
var found *proto.PeerCompact
for _, p := range full.Peers {
if p.DnsLabel == "peerv6" {
found = p
}
}
require.NotNil(t, found, "ipv6-only peer must be present")
assert.Empty(t, found.Ip, "no IPv4 address → empty Ip")
assert.Len(t, found.Ipv6, 16)
}
func TestEncodeNetworkMapEnvelope_PeerWithoutIP(t *testing.T) {
c := newTestComponents()
c.Peers["peer-noip"] = &nbpeer.Peer{
ID: "peer-noip",
Key: testWgKeyA,
DNSLabel: "peernoip",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
}
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
var found *proto.PeerCompact
for _, p := range full.Peers {
if p.DnsLabel == "peernoip" {
found = p
}
}
require.NotNil(t, found)
assert.Empty(t, found.Ip)
assert.Empty(t, found.Ipv6)
}
func TestEncodeNetworkMapEnvelope_EmptyInput(t *testing.T) {
c := &types.NetworkMapComponents{
Network: &types.Network{Identifier: "x", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(10, 32)}},
}
env := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
full := env.GetFull()
require.NotNil(t, full)
assert.Empty(t, full.Peers)
assert.Empty(t, full.Groups)
assert.Empty(t, full.Policies)
assert.Empty(t, full.RouterPeerIndexes)
require.NotNil(t, full.AccountSettings, "AccountSettingsCompact must always be emitted (client dereferences it unconditionally)")
}
func TestEncodeNetworkMapEnvelope_PeerLoginExpirationFields(t *testing.T) {
c := newTestComponents()
now := time.Date(2024, 1, 2, 3, 4, 5, 0, time.UTC)
c.Peers["peer-a"].UserID = "user-1"
c.Peers["peer-a"].LoginExpirationEnabled = true
c.Peers["peer-a"].LastLogin = &now
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
var pa *proto.PeerCompact
for _, p := range full.Peers {
if p.DnsLabel == "peera" {
pa = p
}
}
require.NotNil(t, pa)
assert.True(t, pa.AddedWithSsoLogin)
assert.True(t, pa.LoginExpirationEnabled)
assert.Equal(t, now.UnixNano(), pa.LastLoginUnixNano)
// peer-b has no UserID and no LastLogin → all fields zero-value.
var pb *proto.PeerCompact
for _, p := range full.Peers {
if p.DnsLabel == "peerb" {
pb = p
}
}
require.NotNil(t, pb)
assert.False(t, pb.AddedWithSsoLogin)
assert.False(t, pb.LoginExpirationEnabled)
assert.Zero(t, pb.LastLoginUnixNano)
}
func TestEncodeNetworkMapEnvelope_RoutesRoundTrip(t *testing.T) {
c := newTestComponents()
c.Routes = []*nbroute.Route{
{
ID: "route-peer",
AccountSeqID: 100,
NetID: "net-A",
Description: "via peer-c",
Network: netip.MustParsePrefix("10.0.0.0/16"),
Peer: "peer-c", // peer ID, not WG key
Groups: []string{"group-src"},
AccessControlGroups: []string{"group-dst"},
Enabled: true,
},
{
ID: "route-peergroup",
AccountSeqID: 101,
NetID: "net-B",
Network: netip.MustParsePrefix("10.1.0.0/16"),
PeerGroups: []string{"group-src", "group-dst"},
Enabled: true,
},
{
ID: "route-no-seq",
AccountSeqID: 0, // unset — should still ship (no group seq filter on routes)
Network: netip.MustParsePrefix("10.2.0.0/16"),
Enabled: true,
},
}
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Len(t, full.Routes, 3)
byNetID := map[string]*proto.RouteRaw{}
for _, r := range full.Routes {
byNetID[r.NetId] = r
}
r1 := byNetID["net-A"]
require.NotNil(t, r1)
assert.True(t, r1.PeerIndexSet, "route with peer must set peer_index_set")
require.Less(t, int(r1.PeerIndex), len(full.Peers))
assert.Equal(t, "peerc", full.Peers[r1.PeerIndex].DnsLabel)
assert.Equal(t, []uint32{1}, r1.GroupIds, "group-src has AccountSeqID 1")
assert.Equal(t, []uint32{2}, r1.AccessControlGroupIds, "group-dst has AccountSeqID 2")
assert.Empty(t, r1.PeerGroupIds)
r2 := byNetID["net-B"]
require.NotNil(t, r2)
assert.False(t, r2.PeerIndexSet, "route with peer_groups must NOT set peer_index_set")
assert.ElementsMatch(t, []uint32{1, 2}, r2.PeerGroupIds)
}
func TestEncodeNetworkMapEnvelope_RouteWithMissingPeerLeavesIndexUnset(t *testing.T) {
c := newTestComponents()
c.Routes = []*nbroute.Route{{
ID: "route-x",
AccountSeqID: 100,
Peer: "peer-not-in-components",
Network: netip.MustParsePrefix("10.0.0.0/16"),
Enabled: true,
}}
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Len(t, full.Routes, 1)
assert.False(t, full.Routes[0].PeerIndexSet,
"missing peer reference must not pretend to point at peer index 0")
}
func TestEncodeNetworkMapEnvelope_ResourceOnlyPolicyShippedAndIndexed(t *testing.T) {
c := newTestComponents()
// Policy that exists ONLY in ResourcePoliciesMap, not in c.Policies. This
// is the I1 case — without unionPolicies the encoder would silently
// drop it from the wire.
resourceOnlyPolicy := &types.Policy{
ID: "pol-resource", AccountSeqID: 99, Enabled: true,
Rules: []*types.PolicyRule{{
ID: "rule-r", Enabled: true, Action: types.PolicyTrafficActionAccept,
Protocol: types.PolicyRuleProtocolTCP,
Sources: []string{"group-src"},
Destinations: []string{"group-dst"},
}},
}
c.ResourcePoliciesMap = map[string][]*types.Policy{
"resource-x": {c.Policies[0], resourceOnlyPolicy}, // shared + resource-only
}
// Resource must appear in components.NetworkResources with a seq id —
// encoder uses that to translate the xid map key to uint32.
c.NetworkResources = []*resourceTypes.NetworkResource{
{ID: "resource-x", AccountSeqID: 77, Name: "res-x", Enabled: true},
}
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Len(t, full.Policies, 2, "encoded policies must include both peer-traffic and resource-only")
policyByID := map[uint32]*proto.PolicyCompact{}
policyIdxByID := map[uint32]uint32{}
for i, p := range full.Policies {
policyByID[p.Id] = p
policyIdxByID[p.Id] = uint32(i)
}
require.Contains(t, policyByID, uint32(10), "original peer-traffic policy id 10")
require.Contains(t, policyByID, uint32(99), "resource-only policy id 99")
require.Contains(t, full.ResourcePoliciesMap, uint32(77))
idxs := full.ResourcePoliciesMap[77].Indexes
require.Len(t, idxs, 2)
assert.ElementsMatch(t, []uint32{policyIdxByID[10], policyIdxByID[99]}, idxs,
"resource policies map must reference both wire policy indexes")
}
func TestEncodeNetworkMapEnvelope_NameServerGroups(t *testing.T) {
c := newTestComponents()
c.NameServerGroups = []*nbdns.NameServerGroup{{
ID: "nsg-1", AccountSeqID: 50, Name: "Main", Description: "primary",
NameServers: []nbdns.NameServer{{
IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53,
}},
Groups: []string{"group-src", "group-not-persisted"},
Primary: true, Enabled: true,
Domains: []string{"corp.example"},
}}
c.Groups["group-not-persisted"] = &types.Group{ID: "group-not-persisted", AccountSeqID: 0, Peers: []string{}}
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Len(t, full.NameserverGroups, 1)
nsg := full.NameserverGroups[0]
assert.EqualValues(t, 50, nsg.Id)
assert.Equal(t, "Main", nsg.Name)
assert.True(t, nsg.Primary)
require.Len(t, nsg.Nameservers, 1)
assert.Equal(t, "8.8.8.8", nsg.Nameservers[0].IP)
assert.Equal(t, []uint32{1}, nsg.GroupIds, "group-not-persisted is filtered out (AccountSeqID=0)")
}
func TestEncodeNetworkMapEnvelope_PostureFailedPeers(t *testing.T) {
c := newTestComponents()
c.PostureCheckXIDToSeq = map[string]uint32{"check-1": 33}
c.PostureFailedPeers = map[string]map[string]struct{}{
"check-1": {
"peer-a": {},
"peer-b": {},
"peer-not-in-account": {},
},
}
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Contains(t, full.PostureFailedPeers, uint32(33))
idxs := full.PostureFailedPeers[33].PeerIndexes
assert.Len(t, idxs, 2, "missing peer is silently dropped (filterPostureFailedPeers guarantees presence in real data)")
}
func TestEncodeNetworkMapEnvelope_RoutersMap(t *testing.T) {
c := newTestComponents()
c.NetworkXIDToSeq = map[string]uint32{"net-1": 5}
c.RoutersMap = map[string]map[string]*routerTypes.NetworkRouter{
"net-1": {
"peer-c": {
ID: "router-1", AccountSeqID: 200,
Peer: "peer-c", Masquerade: true, Metric: 10, Enabled: true,
},
},
}
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Contains(t, full.RoutersMap, uint32(5))
entries := full.RoutersMap[5].Entries
require.Len(t, entries, 1)
e := entries[0]
assert.EqualValues(t, 200, e.Id)
assert.True(t, e.PeerIndexSet)
require.Less(t, int(e.PeerIndex), len(full.Peers))
assert.Equal(t, "peerc", full.Peers[e.PeerIndex].DnsLabel)
assert.True(t, e.Masquerade)
assert.EqualValues(t, 10, e.Metric)
assert.True(t, e.Enabled)
}
func TestEncodeNetworkMapEnvelope_RouterPeerNotInComponentsPeers(t *testing.T) {
// Router peer in c.RouterPeers but NOT in c.Peers (validation may have
// filtered it). indexRouterPeers runs before encodeRoutersMap, so the
// peer_index reference must still resolve.
c := newTestComponents()
delete(c.Peers, "peer-c")
routerPeer := &nbpeer.Peer{
ID: "peer-c", Key: testWgKeyC, IP: netip.AddrFrom4([4]byte{100, 64, 0, 3}),
DNSLabel: "peerc", Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
}
c.RouterPeers = map[string]*nbpeer.Peer{"peer-c": routerPeer}
c.NetworkXIDToSeq = map[string]uint32{"net-1": 5}
c.RoutersMap = map[string]map[string]*routerTypes.NetworkRouter{
"net-1": {"peer-c": {ID: "r-1", AccountSeqID: 1, Peer: "peer-c", Enabled: true}},
}
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Contains(t, full.RoutersMap, uint32(5))
require.Len(t, full.RoutersMap[5].Entries, 1)
e := full.RoutersMap[5].Entries[0]
assert.True(t, e.PeerIndexSet, "router peer must be indexed even when not in c.Peers")
}
func TestEncodeNetworkMapEnvelope_DNSSettingsFiltersUnpersistedGroups(t *testing.T) {
c := newTestComponents()
c.DNSSettings = &types.DNSSettings{
DisabledManagementGroups: []string{"group-src", "group-missing", "group-no-seq"},
}
c.Groups["group-no-seq"] = &types.Group{ID: "group-no-seq", AccountSeqID: 0}
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.NotNil(t, full.DnsSettings)
assert.Equal(t, []uint32{1}, full.DnsSettings.DisabledManagementGroupIds,
"only group-src (AccountSeqID=1) survives — missing and unpersisted are dropped")
}
func TestEncodeNetworkMapEnvelope_GroupIDToUserIDs(t *testing.T) {
c := newTestComponents()
c.GroupIDToUserIDs = map[string][]string{
"group-src": {"user-1", "user-2"},
"group-no-seq": {"user-3"}, // group not persisted → drop
"group-missing": {"user-4"}, // group not in components → drop
}
c.Groups["group-no-seq"] = &types.Group{ID: "group-no-seq", AccountSeqID: 0}
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Len(t, full.GroupIdToUserIds, 1, "only persisted+present groups survive")
require.Contains(t, full.GroupIdToUserIds, uint32(1))
assert.ElementsMatch(t, []string{"user-1", "user-2"}, full.GroupIdToUserIds[1].UserIds)
}
func TestToProxyPatch_EmptyInputReturnsNil(t *testing.T) {
assert.Nil(t, toProxyPatch(nil, "netbird.cloud", false, false))
assert.Nil(t, toProxyPatch(&types.NetworkMap{}, "netbird.cloud", false, false),
"empty NetworkMap (no peers, rules, routes etc) → nil patch so proto3 omits the field")
}
func TestToProxyPatch_PopulatesAllFields(t *testing.T) {
nm := &types.NetworkMap{
Peers: []*nbpeer.Peer{{
ID: "ext-peer", Key: testWgKeyA, IP: netip.AddrFrom4([4]byte{100, 64, 0, 9}),
DNSLabel: "extpeer", Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
}},
FirewallRules: []*types.FirewallRule{{
PeerIP: "100.64.0.9", Action: "accept", Direction: 0, Protocol: "tcp",
}},
}
patch := toProxyPatch(nm, "netbird.cloud", false, false)
require.NotNil(t, patch)
assert.Len(t, patch.Peers, 1)
assert.Len(t, patch.FirewallRules, 1)
}
// TestEncodeNetworkMapEnvelope_ProxyPatchPropagated covers the ProxyPatch
// pass-through in both encoder branches (normal path + nil-Components
// graceful-degrade). Without this test a regression that drops `ProxyPatch:`
// from one of the struct literals in components_encoder.go would slip past CI.
func TestEncodeNetworkMapEnvelope_ProxyPatchPropagated(t *testing.T) {
patch := &proto.ProxyPatch{
ForwardingRules: []*proto.ForwardingRule{{
Protocol: proto.RuleProtocol_TCP,
DestinationPort: &proto.PortInfo{PortSelection: &proto.PortInfo_Port{Port: 80}},
TranslatedAddress: net.IPv4(10, 0, 0, 1).To4(),
TranslatedPort: &proto.PortInfo{PortSelection: &proto.PortInfo_Port{Port: 8080}},
}},
}
t.Run("normal_path", func(t *testing.T) {
c := newTestComponents()
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{
Components: c,
ProxyPatch: patch,
}).GetFull()
require.NotNil(t, full.ProxyPatch, "ProxyPatch must propagate through the normal encode path")
assert.Len(t, full.ProxyPatch.ForwardingRules, 1)
})
t.Run("nil_components_graceful_degrade", func(t *testing.T) {
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{
Components: nil,
ProxyPatch: patch,
}).GetFull()
require.NotNil(t, full.ProxyPatch, "ProxyPatch must propagate through the nil-Components branch too")
assert.Len(t, full.ProxyPatch.ForwardingRules, 1)
})
}
func TestEncodeNetworkMapEnvelope_NilComponentsGracefulDegrade(t *testing.T) {
// nil Components → minimal envelope, no crash. Matches the legacy
// account_components.go:43 behaviour for missing/unvalidated peers.
env := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{
Components: nil,
DNSDomain: "netbird.cloud",
})
require.NotNil(t, env)
full := env.GetFull()
require.NotNil(t, full)
require.NotNil(t, full.AccountSettings, "AccountSettings must always be non-nil")
assert.Equal(t, "netbird.cloud", full.DnsDomain)
assert.Empty(t, full.Peers)
assert.Empty(t, full.Policies)
}
func TestEncodeNetworkMapEnvelope_AccountSettingsAlwaysEmitted(t *testing.T) {
c := &types.NetworkMapComponents{
Network: &types.Network{Identifier: "x", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(10, 32)}},
// AccountSettings deliberately nil
}
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.NotNil(t, full.AccountSettings, "client dereferences AccountSettings unconditionally during Calculate(); a nil here would crash the receiver")
assert.False(t, full.AccountSettings.PeerLoginExpirationEnabled)
assert.Zero(t, full.AccountSettings.PeerLoginExpirationNs)
}

View File

@@ -0,0 +1,193 @@
package grpc
import (
"context"
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
"github.com/netbirdio/netbird/client/ssh/auth"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/networkmap"
"github.com/netbirdio/netbird/shared/management/proto"
)
// ToComponentSyncResponse builds a SyncResponse carrying the compact
// NetworkMapEnvelope for capability-aware peers. The legacy proto.NetworkMap
// field is intentionally left empty — capable peers ignore it and the
// envelope alone is the authoritative wire shape.
//
// PeerConfig is computed once server-side using the receiving peer's own
// account-level network metadata. EnableSSH inside PeerConfig is left at
// peer.SSHEnabled (the peer's local setting); account-policy-driven SSH is
// computed by the client from the envelope's GroupIDToUserIDs / AllowedUserIDs
// inside Calculate(), so the SshConfig.SshEnabled bit may flip true on the
// client even though the server-side PeerConfig reports false.
func ToComponentSyncResponse(
ctx context.Context,
config *nbconfig.Config,
httpConfig *nbconfig.HttpServerConfig,
deviceFlowConfig *nbconfig.DeviceAuthorizationFlow,
peer *nbpeer.Peer,
turnCredentials *Token,
relayCredentials *Token,
components *types.NetworkMapComponents,
proxyPatch *types.NetworkMap,
dnsName string,
checks []*posture.Checks,
settings *types.Settings,
extraSettings *types.ExtraSettings,
peerGroups []string,
dnsFwdPort int64,
) *proto.SyncResponse {
network := networkOrZero(components)
enableSSH := computeSSHEnabledForPeer(components, peer)
peerConfig := toPeerConfig(peer, network, dnsName, settings, httpConfig, deviceFlowConfig, enableSSH)
includeIPv6 := peer.SupportsIPv6() && peer.IPv6.IsValid()
useSourcePrefixes := peer.SupportsSourcePrefixes()
userIDClaim := auth.DefaultUserIDClaim
if httpConfig != nil && httpConfig.AuthUserIDClaim != "" {
userIDClaim = httpConfig.AuthUserIDClaim
}
envelope := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{
Components: components,
PeerConfig: peerConfig,
DNSDomain: dnsName,
DNSForwarderPort: dnsFwdPort,
UserIDClaim: userIDClaim,
ProxyPatch: toProxyPatch(proxyPatch, dnsName, includeIPv6, useSourcePrefixes),
})
resp := &proto.SyncResponse{
PeerConfig: peerConfig,
NetworkMapEnvelope: envelope,
Checks: toProtocolChecks(ctx, checks),
}
nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings)
resp.NetbirdConfig = integrationsConfig.ExtendNetBirdConfig(peer.ID, peerGroups, nbConfig, extraSettings)
return resp
}
// networkOrZero returns components.Network or a zero Network — toPeerConfig
// dereferences network.Net which would panic on nil.
func networkOrZero(c *types.NetworkMapComponents) *types.Network {
if c == nil || c.Network == nil {
return &types.Network{}
}
return c.Network
}
// toProxyPatch converts a proxy-injected *types.NetworkMap into the wire
// patch the components envelope ships alongside. Returns nil when there are
// no fragments to merge — proto3 omits a nil message field, so the receiver
// sees no patch and skips the merge step entirely.
//
// We reuse the legacy proto-conversion helpers (toProtocolRoutes,
// toProtocolFirewallRules, toProtocolRoutesFirewallRules,
// appendRemotePeerConfig, ForwardingRule.ToProto) because the proxy
// delivers fragments pre-expanded — there's no raw component shape to
// derive them from. Components purity isn't violated: proxy data isn't
// policy-graph-derived, it's externally injected post-Calculate, so the
// client merges it on top of its locally-computed NetworkMap.
func toProxyPatch(nm *types.NetworkMap, dnsName string, includeIPv6, useSourcePrefixes bool) *proto.ProxyPatch {
if nm == nil {
return nil
}
if len(nm.Peers) == 0 && len(nm.OfflinePeers) == 0 && len(nm.FirewallRules) == 0 &&
len(nm.Routes) == 0 && len(nm.RoutesFirewallRules) == 0 && len(nm.ForwardingRules) == 0 {
return nil
}
patch := &proto.ProxyPatch{
Peers: networkmap.AppendRemotePeerConfig(nil, nm.Peers, dnsName, includeIPv6),
OfflinePeers: networkmap.AppendRemotePeerConfig(nil, nm.OfflinePeers, dnsName, includeIPv6),
FirewallRules: networkmap.ToProtocolFirewallRules(nm.FirewallRules, includeIPv6, useSourcePrefixes),
Routes: networkmap.ToProtocolRoutes(nm.Routes),
RouteFirewallRules: networkmap.ToProtocolRoutesFirewallRules(nm.RoutesFirewallRules),
}
if len(nm.ForwardingRules) > 0 {
patch.ForwardingRules = make([]*proto.ForwardingRule, 0, len(nm.ForwardingRules))
for _, r := range nm.ForwardingRules {
patch.ForwardingRules = append(patch.ForwardingRules, r.ToProto())
}
}
return patch
}
// computeSSHEnabledForPeer mirrors the SSH-server-activation bit that
// Calculate() folds into NetworkMap.EnableSSH. Components-format peers
// receive a freshly-computed PeerConfig.SshConfig.SshEnabled at sync time;
// without this helper the field would be incorrectly false for any peer
// that's the destination of an SSH-enabling policy without having
// peer.SSHEnabled set locally.
//
// Mirrors the two activation paths in Calculate() (`networkmap_components.go`
// `getPeerConnectionResources`):
// 1. Explicit: rule.Protocol == NetbirdSSH and peer is in the rule's
// destinations.
// 2. Legacy implicit: rule covers TCP/22 or TCP/22022 (or ALL), peer is in
// destinations, AND the peer has SSHEnabled set locally — this is the
// "allow-all/TCP-22 implies SSH activation for SSH-capable peers" path.
//
// The full SSH AuthorizedUsers map is still produced by the client when it
// runs Calculate() over the envelope.
func computeSSHEnabledForPeer(c *types.NetworkMapComponents, peer *nbpeer.Peer) bool {
if c == nil || peer == nil {
return false
}
// Mirror Calculate's `getAllPeersFromGroups` invariant: target peer must
// exist in c.Peers, otherwise no rule applies to it.
if _, ok := c.Peers[peer.ID]; !ok {
return false
}
for _, policy := range c.Policies {
if policy == nil || !policy.Enabled {
continue
}
for _, rule := range policy.Rules {
if ruleEnablesSSHForPeer(c, rule, peer) {
return true
}
}
}
return false
}
// ruleEnablesSSHForPeer returns true when rule is active, targets peer, and
// either explicitly authorises SSH or covers the legacy TCP/22 path while the
// peer itself has SSH enabled locally.
func ruleEnablesSSHForPeer(c *types.NetworkMapComponents, rule *types.PolicyRule, peer *nbpeer.Peer) bool {
if rule == nil || !rule.Enabled {
return false
}
if !peerInDestinations(c, rule, peer.ID) {
return false
}
if rule.Protocol == types.PolicyRuleProtocolNetbirdSSH {
return true
}
return peer.SSHEnabled && types.PolicyRuleImpliesLegacySSH(rule)
}
// peerInDestinations reports whether peerID is in any of rule.Destinations'
// groups (or matches DestinationResource if it's a peer-typed resource —
// for non-peer types Calculate falls through to group lookup, so we mirror
// that exactly to avoid silent divergence).
func peerInDestinations(c *types.NetworkMapComponents, rule *types.PolicyRule, peerID string) bool {
if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" {
return rule.DestinationResource.ID == peerID
}
for _, groupID := range rule.Destinations {
if c.IsPeerInGroup(peerID, groupID) {
return true
}
}
return false
}

View File

@@ -0,0 +1,186 @@
package grpc
import (
"testing"
"github.com/stretchr/testify/assert"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
)
// TestComputeSSHEnabledForPeer covers both Calculate-mirroring branches:
// explicit NetbirdSSH protocol, and the legacy implicit case where a
// TCP/22 (or 22022 / ALL / port-range-covering-22) rule activates SSH when
// the destination peer has SSHEnabled=true locally. Belt-and-suspenders for
// the B1 fix that the prod-DB equivalence test alone wouldn't have caught
// if no account had this combination.
func TestComputeSSHEnabledForPeer(t *testing.T) {
const targetPeerID = "target"
const targetGroupID = "g_dst"
mkComponents := func(rule *types.PolicyRule, sshEnabled bool) (*types.NetworkMapComponents, *nbpeer.Peer) {
peer := &nbpeer.Peer{ID: targetPeerID, SSHEnabled: sshEnabled}
group := &types.Group{ID: targetGroupID, Name: "dst", Peers: []string{targetPeerID}}
return &types.NetworkMapComponents{
Peers: map[string]*nbpeer.Peer{targetPeerID: peer},
Groups: map[string]*types.Group{targetGroupID: group},
Policies: []*types.Policy{{
ID: "p",
Enabled: true,
Rules: []*types.PolicyRule{rule},
}},
}, peer
}
cases := []struct {
name string
peerSSH bool
rule types.PolicyRule
wantEnabled bool
}{
{
name: "explicit-netbird-ssh-activates-regardless-of-peer-ssh",
peerSSH: false,
rule: types.PolicyRule{
Enabled: true, Protocol: types.PolicyRuleProtocolNetbirdSSH,
Destinations: []string{targetGroupID},
},
wantEnabled: true,
},
{
name: "implicit-tcp-22-with-peer-ssh",
peerSSH: true,
rule: types.PolicyRule{
Enabled: true, Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"22"},
Destinations: []string{targetGroupID},
},
wantEnabled: true,
},
{
name: "implicit-tcp-22-without-peer-ssh-disabled",
peerSSH: false,
rule: types.PolicyRule{
Enabled: true, Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"22"},
Destinations: []string{targetGroupID},
},
wantEnabled: false,
},
{
name: "implicit-tcp-22022-with-peer-ssh",
peerSSH: true,
rule: types.PolicyRule{
Enabled: true, Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"22022"},
Destinations: []string{targetGroupID},
},
wantEnabled: true,
},
{
name: "implicit-all-protocol-with-peer-ssh",
peerSSH: true,
rule: types.PolicyRule{
Enabled: true, Protocol: types.PolicyRuleProtocolALL,
Destinations: []string{targetGroupID},
},
wantEnabled: true,
},
{
name: "implicit-port-range-covers-22",
peerSSH: true,
rule: types.PolicyRule{
Enabled: true,
Protocol: types.PolicyRuleProtocolTCP,
PortRanges: []types.RulePortRange{{Start: 20, End: 30}},
Destinations: []string{targetGroupID},
},
wantEnabled: true,
},
{
name: "tcp-80-no-ssh",
peerSSH: true,
rule: types.PolicyRule{
Enabled: true, Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"80"},
Destinations: []string{targetGroupID},
},
wantEnabled: false,
},
{
name: "disabled-rule-skipped",
peerSSH: true,
rule: types.PolicyRule{
Enabled: false, Protocol: types.PolicyRuleProtocolNetbirdSSH,
Destinations: []string{targetGroupID},
},
wantEnabled: false,
},
{
name: "peer-not-in-destinations",
peerSSH: true,
rule: types.PolicyRule{
Enabled: true, Protocol: types.PolicyRuleProtocolNetbirdSSH,
Destinations: []string{"g_other"}, // target not in this group
},
wantEnabled: false,
},
{
name: "peer-typed-destination-resource-matches",
peerSSH: false,
rule: types.PolicyRule{
Enabled: true,
Protocol: types.PolicyRuleProtocolNetbirdSSH,
DestinationResource: types.Resource{ID: targetPeerID, Type: types.ResourceTypePeer},
},
wantEnabled: true,
},
{
name: "non-peer-destination-resource-falls-through-to-groups",
peerSSH: false,
rule: types.PolicyRule{
Enabled: true,
Protocol: types.PolicyRuleProtocolNetbirdSSH,
DestinationResource: types.Resource{ID: targetPeerID, Type: "host"}, // wrong type
Destinations: []string{targetGroupID}, // saved by group fallback
},
wantEnabled: true,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
c, peer := mkComponents(&tc.rule, tc.peerSSH)
got := computeSSHEnabledForPeer(c, peer)
assert.Equal(t, tc.wantEnabled, got)
})
}
}
// TestComputeSSHEnabledForPeer_TargetMissingFromComponents covers the
// belt-and-suspenders presence guard mirroring Calculate's
// getAllPeersFromGroups invariant.
func TestComputeSSHEnabledForPeer_TargetMissingFromComponents(t *testing.T) {
peer := &nbpeer.Peer{ID: "missing", SSHEnabled: true}
c := &types.NetworkMapComponents{
Peers: map[string]*nbpeer.Peer{}, // target peer NOT present
Groups: map[string]*types.Group{
"g": {ID: "g", Peers: []string{"missing"}},
},
Policies: []*types.Policy{{
ID: "p", Enabled: true,
Rules: []*types.PolicyRule{{
Enabled: true, Protocol: types.PolicyRuleProtocolNetbirdSSH,
Destinations: []string{"g"},
}},
}},
}
assert.False(t, computeSSHEnabledForPeer(c, peer),
"missing target peer must short-circuit to false, not consult policies")
}
// TestComputeSSHEnabledForPeer_NilInputs guards the cheap nil-checks at
// function entry — Calculate doesn't accept nil either, but the helper is
// exported indirectly via ToComponentSyncResponse and may receive nil
// components on graceful-degrade paths.
func TestComputeSSHEnabledForPeer_NilInputs(t *testing.T) {
assert.False(t, computeSSHEnabledForPeer(nil, &nbpeer.Peer{ID: "x"}))
assert.False(t, computeSSHEnabledForPeer(&types.NetworkMapComponents{}, nil))
}

View File

@@ -7,23 +7,18 @@ import (
"net/url"
"strings"
log "github.com/sirupsen/logrus"
goproto "google.golang.org/protobuf/proto"
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
"github.com/netbirdio/netbird/client/ssh/auth"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/types"
nbroute "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/networkmap"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/netiputil"
"github.com/netbirdio/netbird/shared/sshauth"
)
func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig {
@@ -138,8 +133,8 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH),
NetworkMap: &proto.NetworkMap{
Serial: networkMap.Network.CurrentSerial(),
Routes: toProtocolRoutes(networkMap.Routes),
DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort),
Routes: networkmap.ToProtocolRoutes(networkMap.Routes),
DNSConfig: networkmap.ToProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort),
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH),
},
Checks: toProtocolChecks(ctx, checks),
@@ -152,19 +147,19 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
response.NetworkMap.PeerConfig = response.PeerConfig
remotePeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers))
remotePeers = appendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName, includeIPv6)
remotePeers = networkmap.AppendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName, includeIPv6)
response.RemotePeers = remotePeers
response.NetworkMap.RemotePeers = remotePeers
response.RemotePeersIsEmpty = len(remotePeers) == 0
response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty
response.NetworkMap.OfflinePeers = appendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName, includeIPv6)
response.NetworkMap.OfflinePeers = networkmap.AppendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName, includeIPv6)
firewallRules := toProtocolFirewallRules(networkMap.FirewallRules, includeIPv6, useSourcePrefixes)
firewallRules := networkmap.ToProtocolFirewallRules(networkMap.FirewallRules, includeIPv6, useSourcePrefixes)
response.NetworkMap.FirewallRules = firewallRules
response.NetworkMap.FirewallRulesIsEmpty = len(firewallRules) == 0
routesFirewallRules := toProtocolRoutesFirewallRules(networkMap.RoutesFirewallRules)
routesFirewallRules := networkmap.ToProtocolRoutesFirewallRules(networkMap.RoutesFirewallRules)
response.NetworkMap.RoutesFirewallRules = routesFirewallRules
response.NetworkMap.RoutesFirewallRulesIsEmpty = len(routesFirewallRules) == 0
@@ -177,7 +172,7 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
}
if networkMap.AuthorizedUsers != nil {
hashedUsers, machineUsers := buildAuthorizedUsersProto(ctx, networkMap.AuthorizedUsers)
hashedUsers, machineUsers := networkmap.BuildAuthorizedUsersProto(ctx, networkMap.AuthorizedUsers)
userIDClaim := auth.DefaultUserIDClaim
if httpConfig != nil && httpConfig.AuthUserIDClaim != "" {
userIDClaim = httpConfig.AuthUserIDClaim
@@ -188,78 +183,6 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
return response
}
func buildAuthorizedUsersProto(ctx context.Context, authorizedUsers map[string]map[string]struct{}) ([][]byte, map[string]*proto.MachineUserIndexes) {
userIDToIndex := make(map[string]uint32)
var hashedUsers [][]byte
machineUsers := make(map[string]*proto.MachineUserIndexes, len(authorizedUsers))
for machineUser, users := range authorizedUsers {
indexes := make([]uint32, 0, len(users))
for userID := range users {
idx, exists := userIDToIndex[userID]
if !exists {
hash, err := sshauth.HashUserID(userID)
if err != nil {
log.WithContext(ctx).Errorf("failed to hash user id %s: %v", userID, err)
continue
}
idx = uint32(len(hashedUsers))
userIDToIndex[userID] = idx
hashedUsers = append(hashedUsers, hash[:])
}
indexes = append(indexes, idx)
}
machineUsers[machineUser] = &proto.MachineUserIndexes{Indexes: indexes}
}
return hashedUsers, machineUsers
}
func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string, includeIPv6 bool) []*proto.RemotePeerConfig {
for _, rPeer := range peers {
allowedIPs := []string{rPeer.IP.String() + "/32"}
if includeIPv6 && rPeer.IPv6.IsValid() {
allowedIPs = append(allowedIPs, rPeer.IPv6.String()+"/128")
}
dst = append(dst, &proto.RemotePeerConfig{
WgPubKey: rPeer.Key,
AllowedIps: allowedIPs,
SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
Fqdn: rPeer.FQDN(dnsName),
AgentVersion: rPeer.Meta.WtVersion,
})
}
return dst
}
// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache
func toProtocolDNSConfig(update nbdns.Config, cache *cache.DNSConfigCache, forwardPort int64) *proto.DNSConfig {
protoUpdate := &proto.DNSConfig{
ServiceEnable: update.ServiceEnable,
CustomZones: make([]*proto.CustomZone, 0, len(update.CustomZones)),
NameServerGroups: make([]*proto.NameServerGroup, 0, len(update.NameServerGroups)),
ForwarderPort: forwardPort,
}
for _, zone := range update.CustomZones {
protoZone := convertToProtoCustomZone(zone)
protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone)
}
for _, nsGroup := range update.NameServerGroups {
cacheKey := nsGroup.ID
if cachedGroup, exists := cache.GetNameServerGroup(cacheKey); exists {
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, cachedGroup)
} else {
protoGroup := convertToProtoNameServerGroup(nsGroup)
cache.SetNameServerGroup(cacheKey, protoGroup)
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup)
}
}
return protoUpdate
}
func ToResponseProto(configProto nbconfig.Protocol) proto.HostConfig_Protocol {
switch configProto {
case nbconfig.UDP:
@@ -277,204 +200,6 @@ func ToResponseProto(configProto nbconfig.Protocol) proto.HostConfig_Protocol {
}
}
func toProtocolRoutes(routes []*nbroute.Route) []*proto.Route {
protoRoutes := make([]*proto.Route, 0, len(routes))
for _, r := range routes {
protoRoutes = append(protoRoutes, toProtocolRoute(r))
}
return protoRoutes
}
func toProtocolRoute(route *nbroute.Route) *proto.Route {
return &proto.Route{
ID: string(route.ID),
NetID: string(route.NetID),
Network: route.Network.String(),
Domains: route.Domains.ToPunycodeList(),
NetworkType: int64(route.NetworkType),
Peer: route.Peer,
Metric: int64(route.Metric),
Masquerade: route.Masquerade,
KeepRoute: route.KeepRoute,
SkipAutoApply: route.SkipAutoApply,
}
}
// toProtocolFirewallRules converts the firewall rules to the protocol firewall rules.
// When useSourcePrefixes is true, the compact SourcePrefixes field is populated
// alongside the deprecated PeerIP for forward compatibility.
// Wildcard rules ("0.0.0.0") are expanded into separate v4 and v6 SourcePrefixes
// when includeIPv6 is true.
func toProtocolFirewallRules(rules []*types.FirewallRule, includeIPv6, useSourcePrefixes bool) []*proto.FirewallRule {
result := make([]*proto.FirewallRule, 0, len(rules))
for i := range rules {
rule := rules[i]
fwRule := &proto.FirewallRule{
PolicyID: []byte(rule.PolicyID),
PeerIP: rule.PeerIP, //nolint:staticcheck // populated for backward compatibility
Direction: getProtoDirection(rule.Direction),
Action: getProtoAction(rule.Action),
Protocol: getProtoProtocol(rule.Protocol),
Port: rule.Port,
}
if useSourcePrefixes && rule.PeerIP != "" {
result = append(result, populateSourcePrefixes(fwRule, rule, includeIPv6)...)
}
if shouldUsePortRange(fwRule) {
fwRule.PortInfo = rule.PortRange.ToProto()
}
result = append(result, fwRule)
}
return result
}
// populateSourcePrefixes sets SourcePrefixes on fwRule and returns any
// additional rules needed (e.g. a v6 wildcard clone when the peer IP is unspecified).
func populateSourcePrefixes(fwRule *proto.FirewallRule, rule *types.FirewallRule, includeIPv6 bool) []*proto.FirewallRule {
addr, err := netip.ParseAddr(rule.PeerIP)
if err != nil {
return nil
}
if !addr.IsUnspecified() {
fwRule.SourcePrefixes = [][]byte{netiputil.EncodeAddr(addr.Unmap())}
return nil
}
// IPv4Unspecified/0 is always valid, error is impossible.
v4Wildcard, _ := netiputil.EncodePrefix(netip.PrefixFrom(netip.IPv4Unspecified(), 0))
fwRule.SourcePrefixes = [][]byte{v4Wildcard}
if !includeIPv6 {
return nil
}
v6Rule := goproto.Clone(fwRule).(*proto.FirewallRule)
v6Rule.PeerIP = "::" //nolint:staticcheck // populated for backward compatibility
// IPv6Unspecified/0 is always valid, error is impossible.
v6Wildcard, _ := netiputil.EncodePrefix(netip.PrefixFrom(netip.IPv6Unspecified(), 0))
v6Rule.SourcePrefixes = [][]byte{v6Wildcard}
if shouldUsePortRange(v6Rule) {
v6Rule.PortInfo = rule.PortRange.ToProto()
}
return []*proto.FirewallRule{v6Rule}
}
// getProtoDirection converts the direction to proto.RuleDirection.
func getProtoDirection(direction int) proto.RuleDirection {
if direction == types.FirewallRuleDirectionOUT {
return proto.RuleDirection_OUT
}
return proto.RuleDirection_IN
}
func toProtocolRoutesFirewallRules(rules []*types.RouteFirewallRule) []*proto.RouteFirewallRule {
result := make([]*proto.RouteFirewallRule, len(rules))
for i := range rules {
rule := rules[i]
result[i] = &proto.RouteFirewallRule{
SourceRanges: rule.SourceRanges,
Action: getProtoAction(rule.Action),
Destination: rule.Destination,
Protocol: getProtoProtocol(rule.Protocol),
PortInfo: getProtoPortInfo(rule),
IsDynamic: rule.IsDynamic,
Domains: rule.Domains.ToPunycodeList(),
PolicyID: []byte(rule.PolicyID),
RouteID: string(rule.RouteID),
}
}
return result
}
// getProtoAction converts the action to proto.RuleAction.
func getProtoAction(action string) proto.RuleAction {
if action == string(types.PolicyTrafficActionDrop) {
return proto.RuleAction_DROP
}
return proto.RuleAction_ACCEPT
}
// getProtoProtocol converts the protocol to proto.RuleProtocol.
func getProtoProtocol(protocol string) proto.RuleProtocol {
switch types.PolicyRuleProtocolType(protocol) {
case types.PolicyRuleProtocolALL:
return proto.RuleProtocol_ALL
case types.PolicyRuleProtocolTCP:
return proto.RuleProtocol_TCP
case types.PolicyRuleProtocolUDP:
return proto.RuleProtocol_UDP
case types.PolicyRuleProtocolICMP:
return proto.RuleProtocol_ICMP
default:
return proto.RuleProtocol_UNKNOWN
}
}
// getProtoPortInfo converts the port info to proto.PortInfo.
func getProtoPortInfo(rule *types.RouteFirewallRule) *proto.PortInfo {
var portInfo proto.PortInfo
if rule.Port != 0 {
portInfo.PortSelection = &proto.PortInfo_Port{Port: uint32(rule.Port)}
} else if portRange := rule.PortRange; portRange.Start != 0 && portRange.End != 0 {
portInfo.PortSelection = &proto.PortInfo_Range_{
Range: &proto.PortInfo_Range{
Start: uint32(portRange.Start),
End: uint32(portRange.End),
},
}
}
return &portInfo
}
func shouldUsePortRange(rule *proto.FirewallRule) bool {
return rule.Port == "" && (rule.Protocol == proto.RuleProtocol_UDP || rule.Protocol == proto.RuleProtocol_TCP)
}
// Helper function to convert nbdns.CustomZone to proto.CustomZone
func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone {
protoZone := &proto.CustomZone{
Domain: zone.Domain,
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
SearchDomainDisabled: zone.SearchDomainDisabled,
NonAuthoritative: zone.NonAuthoritative,
}
for _, record := range zone.Records {
protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{
Name: record.Name,
Type: int64(record.Type),
Class: record.Class,
TTL: int64(record.TTL),
RData: record.RData,
})
}
return protoZone
}
// Helper function to convert nbdns.NameServerGroup to proto.NameServerGroup
func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameServerGroup {
protoGroup := &proto.NameServerGroup{
Primary: nsGroup.Primary,
Domains: nsGroup.Domains,
SearchDomainsEnabled: nsGroup.SearchDomainsEnabled,
NameServers: make([]*proto.NameServer, 0, len(nsGroup.NameServers)),
}
for _, ns := range nsGroup.NameServers {
protoGroup.NameServers = append(protoGroup.NameServers, &proto.NameServer{
IP: ns.IP.String(),
Port: int64(ns.Port),
NSType: int64(ns.NSType),
})
}
return protoGroup
}
// buildJWTConfig constructs JWT configuration for SSH servers from management server config
func buildJWTConfig(config *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow) *proto.JWTConfig {
if config == nil || config.AuthAudience == "" {

View File

@@ -12,6 +12,7 @@ import (
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/shared/management/networkmap"
)
func TestToProtocolDNSConfigWithCache(t *testing.T) {
@@ -61,13 +62,13 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) {
}
// First run with config1
result1 := toProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
result1 := networkmap.ToProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
// Second run with config2
result2 := toProtocolDNSConfig(config2, &cache, int64(network_map.DnsForwarderPort))
result2 := networkmap.ToProtocolDNSConfig(config2, &cache, int64(network_map.DnsForwarderPort))
// Third run with config1 again
result3 := toProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
result3 := networkmap.ToProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
// Verify that result1 and result3 are identical
if !reflect.DeepEqual(result1, result3) {
@@ -99,7 +100,7 @@ func BenchmarkToProtocolDNSConfig(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
toProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
networkmap.ToProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
}
})
@@ -107,7 +108,7 @@ func BenchmarkToProtocolDNSConfig(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
cache := &cache.DNSConfigCache{}
toProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
networkmap.ToProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
}
})
}

Some files were not shown because too many files have changed in this diff Show More