diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index 62dfe9bce..7b7b32ec0 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -19,7 +19,7 @@ jobs: - name: codespell uses: codespell-project/actions-codespell@v2 with: - ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA + ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA,ede,additionals skip: go.mod,go.sum,**/proxy/web/** golangci: strategy: diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index efc7d9460..960cd30e9 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -8,7 +8,7 @@ There are many ways that you can contribute: - Sharing use cases in slack or Reddit - Bug fix or feature enhancement -If you haven't already, join our slack workspace [here](https://join.slack.com/t/netbirdio/shared_invite/zt-vrahf41g-ik1v7fV8du6t0RwxSrJ96A), we would love to discuss topics that need community contribution and enhancements to existing features. +If you haven't already, join our slack workspace [here](https://docs.netbird.io/slack-url), we would love to discuss topics that need community contribution and enhancements to existing features. ## Contents diff --git a/client/firewall/uspfilter/conntrack/cap_test.go b/client/firewall/uspfilter/conntrack/cap_test.go new file mode 100644 index 000000000..ee6b72e7f --- /dev/null +++ b/client/firewall/uspfilter/conntrack/cap_test.go @@ -0,0 +1,125 @@ +package conntrack + +import ( + "net/netip" + "testing" + + "github.com/google/gopacket/layers" + "github.com/stretchr/testify/require" +) + +func TestTCPCapEvicts(t *testing.T) { + t.Setenv(EnvTCPMaxEntries, "4") + + tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) + defer tracker.Close() + require.Equal(t, 4, tracker.maxEntries) + + src := netip.MustParseAddr("100.64.0.1") + dst := netip.MustParseAddr("100.64.0.2") + + for i := 0; i < 10; i++ { + tracker.TrackOutbound(src, dst, uint16(10000+i), 80, TCPSyn, 0) + } + require.LessOrEqual(t, len(tracker.connections), 4, + "TCP table must not exceed the configured cap") + require.Greater(t, len(tracker.connections), 0, + "some entries must remain after eviction") + + // The most recently admitted flow must be present: eviction must make + // room for new entries, not silently drop them. + require.Contains(t, tracker.connections, + ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(10009), DstPort: 80}, + "newest TCP flow must be admitted after eviction") + // A pre-cap flow must have been evicted to fit the last one. + require.NotContains(t, tracker.connections, + ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(10000), DstPort: 80}, + "oldest TCP flow should have been evicted") +} + +func TestTCPCapPrefersTombstonedForEviction(t *testing.T) { + t.Setenv(EnvTCPMaxEntries, "3") + + tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) + defer tracker.Close() + + src := netip.MustParseAddr("100.64.0.1") + dst := netip.MustParseAddr("100.64.0.2") + + // Fill to cap with 3 live connections. + for i := 0; i < 3; i++ { + tracker.TrackOutbound(src, dst, uint16(20000+i), 80, TCPSyn, 0) + } + require.Len(t, tracker.connections, 3) + + // Tombstone one by sending RST through IsValidInbound. + tombstonedKey := ConnKey{SrcIP: src, DstIP: dst, SrcPort: 20001, DstPort: 80} + require.True(t, tracker.IsValidInbound(dst, src, 80, 20001, TCPRst|TCPAck, 0)) + require.True(t, tracker.connections[tombstonedKey].IsTombstone()) + + // Another live connection forces eviction. The tombstone must go first. + tracker.TrackOutbound(src, dst, uint16(29999), 80, TCPSyn, 0) + + _, tombstonedStillPresent := tracker.connections[tombstonedKey] + require.False(t, tombstonedStillPresent, + "tombstoned entry should be evicted before live entries") + require.LessOrEqual(t, len(tracker.connections), 3) + + // Both live pre-cap entries must survive: eviction must prefer the + // tombstone, not just satisfy the size bound by dropping any entry. + require.Contains(t, tracker.connections, + ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(20000), DstPort: 80}, + "live entries must not be evicted while a tombstone exists") + require.Contains(t, tracker.connections, + ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(20002), DstPort: 80}, + "live entries must not be evicted while a tombstone exists") +} + +func TestUDPCapEvicts(t *testing.T) { + t.Setenv(EnvUDPMaxEntries, "5") + + tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger) + defer tracker.Close() + require.Equal(t, 5, tracker.maxEntries) + + src := netip.MustParseAddr("100.64.0.1") + dst := netip.MustParseAddr("100.64.0.2") + + for i := 0; i < 12; i++ { + tracker.TrackOutbound(src, dst, uint16(30000+i), 53, 0) + } + require.LessOrEqual(t, len(tracker.connections), 5) + require.Greater(t, len(tracker.connections), 0) + + require.Contains(t, tracker.connections, + ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(30011), DstPort: 53}, + "newest UDP flow must be admitted after eviction") + require.NotContains(t, tracker.connections, + ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(30000), DstPort: 53}, + "oldest UDP flow should have been evicted") +} + +func TestICMPCapEvicts(t *testing.T) { + t.Setenv(EnvICMPMaxEntries, "3") + + tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger) + defer tracker.Close() + require.Equal(t, 3, tracker.maxEntries) + + src := netip.MustParseAddr("100.64.0.1") + dst := netip.MustParseAddr("100.64.0.2") + + echoReq := layers.CreateICMPv4TypeCode(uint8(layers.ICMPv4TypeEchoRequest), 0) + for i := 0; i < 8; i++ { + tracker.TrackOutbound(src, dst, uint16(i), echoReq, nil, 64) + } + require.LessOrEqual(t, len(tracker.connections), 3) + require.Greater(t, len(tracker.connections), 0) + + require.Contains(t, tracker.connections, + ICMPConnKey{SrcIP: src, DstIP: dst, ID: uint16(7)}, + "newest ICMP flow must be admitted after eviction") + require.NotContains(t, tracker.connections, + ICMPConnKey{SrcIP: src, DstIP: dst, ID: uint16(0)}, + "oldest ICMP flow should have been evicted") +} diff --git a/client/firewall/uspfilter/conntrack/common.go b/client/firewall/uspfilter/conntrack/common.go index 88e90317c..e497a0bff 100644 --- a/client/firewall/uspfilter/conntrack/common.go +++ b/client/firewall/uspfilter/conntrack/common.go @@ -3,15 +3,61 @@ package conntrack import ( "net" "net/netip" + "os" "strconv" "sync/atomic" "time" "github.com/google/uuid" + nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" ) +// evictSampleSize bounds how many map entries we scan per eviction call. +// Keeps eviction O(1) even at cap under sustained load; the sampled-LRU +// heuristic is good enough for a conntrack table that only overflows under +// abuse. +const evictSampleSize = 8 + +// envDuration parses an os.Getenv(name) as a time.Duration. Falls back to +// def on empty or invalid; logs a warning on invalid. +func envDuration(logger *nblog.Logger, name string, def time.Duration) time.Duration { + v := os.Getenv(name) + if v == "" { + return def + } + d, err := time.ParseDuration(v) + if err != nil { + logger.Warn3("invalid %s=%q: %v, using default", name, v, err) + return def + } + if d <= 0 { + logger.Warn2("invalid %s=%q: must be positive, using default", name, v) + return def + } + return d +} + +// envInt parses an os.Getenv(name) as an int. Falls back to def on empty, +// invalid, or non-positive. Logs a warning on invalid input. +func envInt(logger *nblog.Logger, name string, def int) int { + v := os.Getenv(name) + if v == "" { + return def + } + n, err := strconv.Atoi(v) + switch { + case err != nil: + logger.Warn3("invalid %s=%q: %v, using default", name, v, err) + return def + case n <= 0: + logger.Warn2("invalid %s=%q: must be positive, using default", name, v) + return def + } + return n +} + // BaseConnTrack provides common fields and locking for all connection types type BaseConnTrack struct { FlowId uuid.UUID diff --git a/client/firewall/uspfilter/conntrack/defaults_desktop.go b/client/firewall/uspfilter/conntrack/defaults_desktop.go new file mode 100644 index 000000000..2f07f5984 --- /dev/null +++ b/client/firewall/uspfilter/conntrack/defaults_desktop.go @@ -0,0 +1,11 @@ +//go:build !ios && !android + +package conntrack + +// Default per-tracker entry caps on desktop/server platforms. These mirror +// typical Linux netfilter nf_conntrack_max territory with ample headroom. +const ( + DefaultMaxTCPEntries = 65536 + DefaultMaxUDPEntries = 16384 + DefaultMaxICMPEntries = 2048 +) diff --git a/client/firewall/uspfilter/conntrack/defaults_mobile.go b/client/firewall/uspfilter/conntrack/defaults_mobile.go new file mode 100644 index 000000000..c9e05d229 --- /dev/null +++ b/client/firewall/uspfilter/conntrack/defaults_mobile.go @@ -0,0 +1,13 @@ +//go:build ios || android + +package conntrack + +// Default per-tracker entry caps on mobile platforms. iOS network extensions +// are capped at ~50 MB; Android runs under aggressive memory pressure. These +// values keep conntrack footprint well under 5 MB worst case (TCPConnTrack +// is ~200 B plus map overhead). +const ( + DefaultMaxTCPEntries = 4096 + DefaultMaxUDPEntries = 2048 + DefaultMaxICMPEntries = 512 +) diff --git a/client/firewall/uspfilter/conntrack/icmp.go b/client/firewall/uspfilter/conntrack/icmp.go index a48215ca9..3c96548b5 100644 --- a/client/firewall/uspfilter/conntrack/icmp.go +++ b/client/firewall/uspfilter/conntrack/icmp.go @@ -50,6 +50,9 @@ type ICMPConnTrack struct { ICMPCode uint8 } +// EnvICMPMaxEntries caps the ICMP conntrack table size. +const EnvICMPMaxEntries = "NB_CONNTRACK_ICMP_MAX" + // ICMPTracker manages ICMP connection states type ICMPTracker struct { logger *nblog.Logger @@ -58,6 +61,7 @@ type ICMPTracker struct { cleanupTicker *time.Ticker tickerCancel context.CancelFunc mutex sync.RWMutex + maxEntries int flowLogger nftypes.FlowLogger } @@ -171,6 +175,7 @@ func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nfty timeout: timeout, cleanupTicker: time.NewTicker(ICMPCleanupInterval), tickerCancel: cancel, + maxEntries: envInt(logger, EnvICMPMaxEntries, DefaultMaxICMPEntries), flowLogger: flowLogger, } @@ -257,7 +262,9 @@ func (t *ICMPTracker) track( // non echo requests don't need tracking if typ != uint8(layers.ICMPv4TypeEchoRequest) { - t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo) + if t.logger.Enabled(nblog.LevelTrace) { + t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo) + } t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size) return } @@ -276,10 +283,15 @@ func (t *ICMPTracker) track( conn.UpdateCounters(direction, size) t.mutex.Lock() + if t.maxEntries > 0 && len(t.connections) >= t.maxEntries { + t.evictOneLocked() + } t.connections[key] = conn t.mutex.Unlock() - t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo) + if t.logger.Enabled(nblog.LevelTrace) { + t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo) + } t.sendEvent(nftypes.TypeStart, conn, ruleId) } @@ -323,6 +335,34 @@ func (t *ICMPTracker) cleanupRoutine(ctx context.Context) { } } +// evictOneLocked removes one entry to make room. Caller must hold t.mutex. +// Bounded sample scan: picks the oldest among up to evictSampleSize entries. +func (t *ICMPTracker) evictOneLocked() { + var candKey ICMPConnKey + var candSeen int64 + haveCand := false + sampled := 0 + + for k, c := range t.connections { + seen := c.lastSeen.Load() + if !haveCand || seen < candSeen { + candKey = k + candSeen = seen + haveCand = true + } + sampled++ + if sampled >= evictSampleSize { + break + } + } + if haveCand { + if evicted := t.connections[candKey]; evicted != nil { + t.sendEvent(nftypes.TypeEnd, evicted, nil) + } + delete(t.connections, candKey) + } +} + func (t *ICMPTracker) cleanup() { t.mutex.Lock() defer t.mutex.Unlock() @@ -331,8 +371,10 @@ func (t *ICMPTracker) cleanup() { if conn.timeoutExceeded(t.timeout) { delete(t.connections, key) - t.logger.Trace5("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]", - key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) + if t.logger.Enabled(nblog.LevelTrace) { + t.logger.Trace5("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]", + key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) + } t.sendEvent(nftypes.TypeEnd, conn, nil) } } diff --git a/client/firewall/uspfilter/conntrack/tcp.go b/client/firewall/uspfilter/conntrack/tcp.go index 335a3abab..9edc9af22 100644 --- a/client/firewall/uspfilter/conntrack/tcp.go +++ b/client/firewall/uspfilter/conntrack/tcp.go @@ -38,6 +38,27 @@ const ( TCPHandshakeTimeout = 60 * time.Second // TCPCleanupInterval is how often we check for stale connections TCPCleanupInterval = 5 * time.Minute + // FinWaitTimeout bounds FIN_WAIT_1 / FIN_WAIT_2 / CLOSING states. + // Matches Linux netfilter nf_conntrack_tcp_timeout_fin_wait. + FinWaitTimeout = 60 * time.Second + // CloseWaitTimeout bounds CLOSE_WAIT. Matches Linux default; apps + // holding CloseWait longer than this should bump the env var. + CloseWaitTimeout = 60 * time.Second + // LastAckTimeout bounds LAST_ACK. Matches Linux default. + LastAckTimeout = 30 * time.Second +) + +// Env vars to override per-state teardown timeouts. Values parsed by +// time.ParseDuration (e.g. "120s", "2m"). Invalid values fall back to the +// defaults above with a warning. +const ( + EnvTCPFinWaitTimeout = "NB_CONNTRACK_TCP_FIN_WAIT_TIMEOUT" + EnvTCPCloseWaitTimeout = "NB_CONNTRACK_TCP_CLOSE_WAIT_TIMEOUT" + EnvTCPLastAckTimeout = "NB_CONNTRACK_TCP_LAST_ACK_TIMEOUT" + + // EnvTCPMaxEntries caps the TCP conntrack table size. Oldest entries + // (tombstones first) are evicted when the cap is reached. + EnvTCPMaxEntries = "NB_CONNTRACK_TCP_MAX" ) // TCPState represents the state of a TCP connection @@ -133,14 +154,18 @@ func (t *TCPConnTrack) SetTombstone() { // TCPTracker manages TCP connection states type TCPTracker struct { - logger *nblog.Logger - connections map[ConnKey]*TCPConnTrack - mutex sync.RWMutex - cleanupTicker *time.Ticker - tickerCancel context.CancelFunc - timeout time.Duration - waitTimeout time.Duration - flowLogger nftypes.FlowLogger + logger *nblog.Logger + connections map[ConnKey]*TCPConnTrack + mutex sync.RWMutex + cleanupTicker *time.Ticker + tickerCancel context.CancelFunc + timeout time.Duration + waitTimeout time.Duration + finWaitTimeout time.Duration + closeWaitTimeout time.Duration + lastAckTimeout time.Duration + maxEntries int + flowLogger nftypes.FlowLogger } // NewTCPTracker creates a new TCP connection tracker @@ -155,13 +180,17 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp ctx, cancel := context.WithCancel(context.Background()) tracker := &TCPTracker{ - logger: logger, - connections: make(map[ConnKey]*TCPConnTrack), - cleanupTicker: time.NewTicker(TCPCleanupInterval), - tickerCancel: cancel, - timeout: timeout, - waitTimeout: waitTimeout, - flowLogger: flowLogger, + logger: logger, + connections: make(map[ConnKey]*TCPConnTrack), + cleanupTicker: time.NewTicker(TCPCleanupInterval), + tickerCancel: cancel, + timeout: timeout, + waitTimeout: waitTimeout, + finWaitTimeout: envDuration(logger, EnvTCPFinWaitTimeout, FinWaitTimeout), + closeWaitTimeout: envDuration(logger, EnvTCPCloseWaitTimeout, CloseWaitTimeout), + lastAckTimeout: envDuration(logger, EnvTCPLastAckTimeout, LastAckTimeout), + maxEntries: envInt(logger, EnvTCPMaxEntries, DefaultMaxTCPEntries), + flowLogger: flowLogger, } go tracker.cleanupRoutine(ctx) @@ -209,6 +238,12 @@ func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, fla if exists || flags&TCPSyn == 0 { return } + // Reject illegal SYN combinations (SYN+FIN, SYN+RST, …) so they don't + // create spurious conntrack entries. Not mandated by RFC 9293 but a + // common hardening (Linux netfilter/nftables rejects these too). + if !isValidFlagCombination(flags) { + return + } conn := &TCPConnTrack{ BaseConnTrack: BaseConnTrack{ @@ -225,20 +260,65 @@ func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, fla conn.state.Store(int32(TCPStateNew)) conn.DNATOrigPort.Store(uint32(origPort)) - if origPort != 0 { - t.logger.Trace4("New %s TCP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort) - } else { - t.logger.Trace2("New %s TCP connection: %s", direction, key) + if t.logger.Enabled(nblog.LevelTrace) { + if origPort != 0 { + t.logger.Trace4("New %s TCP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort) + } else { + t.logger.Trace2("New %s TCP connection: %s", direction, key) + } } t.updateState(key, conn, flags, direction, size) t.mutex.Lock() + if t.maxEntries > 0 && len(t.connections) >= t.maxEntries { + t.evictOneLocked() + } t.connections[key] = conn t.mutex.Unlock() t.sendEvent(nftypes.TypeStart, conn, ruleID) } +// evictOneLocked removes one entry to make room. Caller must hold t.mutex. +// Bounded scan: samples up to evictSampleSize pseudo-random entries (Go map +// iteration order is randomized), preferring a tombstone. If no tombstone +// found in the sample, evicts the oldest among the sampled entries. O(1) +// worst case — cheap enough to run on every insert at cap during abuse. +func (t *TCPTracker) evictOneLocked() { + var candKey ConnKey + var candSeen int64 + haveCand := false + sampled := 0 + + for k, c := range t.connections { + if c.IsTombstone() { + delete(t.connections, k) + return + } + seen := c.lastSeen.Load() + if !haveCand || seen < candSeen { + candKey = k + candSeen = seen + haveCand = true + } + sampled++ + if sampled >= evictSampleSize { + break + } + } + if haveCand { + if evicted := t.connections[candKey]; evicted != nil { + // TypeEnd is already emitted at the state transition to + // TimeWait and when a connection is tombstoned. Only emit + // here when we're reaping a still-active flow. + if evicted.GetState() != TCPStateTimeWait && !evicted.IsTombstone() { + t.sendEvent(nftypes.TypeEnd, evicted, nil) + } + } + delete(t.connections, candKey) + } +} + // IsValidInbound checks if an inbound TCP packet matches a tracked connection func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) bool { key := ConnKey{ @@ -256,12 +336,19 @@ func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort ui return false } + // Reject illegal flag combinations regardless of state. These never belong + // to a legitimate flow and must not advance or tear down state. + if !isValidFlagCombination(flags) { + if t.logger.Enabled(nblog.LevelWarn) { + t.logger.Warn3("TCP illegal flag combination %x for connection %s (state %s)", flags, key, conn.GetState()) + } + return false + } + currentState := conn.GetState() if !t.isValidStateForFlags(currentState, flags) { - t.logger.Warn3("TCP state %s is not valid with flags %x for connection %s", currentState, flags, key) - // allow all flags for established for now - if currentState == TCPStateEstablished { - return true + if t.logger.Enabled(nblog.LevelWarn) { + t.logger.Warn3("TCP state %s is not valid with flags %x for connection %s", currentState, flags, key) } return false } @@ -270,116 +357,208 @@ func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort ui return true } -// updateState updates the TCP connection state based on flags +// updateState updates the TCP connection state based on flags. func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, packetDir nftypes.Direction, size int) { - conn.UpdateLastSeen() conn.UpdateCounters(packetDir, size) + // Malformed flag combinations must not refresh lastSeen or drive state, + // otherwise spoofed packets keep a dead flow alive past its timeout. + if !isValidFlagCombination(flags) { + return + } + + conn.UpdateLastSeen() + currentState := conn.GetState() if flags&TCPRst != 0 { - if conn.CompareAndSwapState(currentState, TCPStateClosed) { - conn.SetTombstone() - t.logger.Trace6("TCP connection reset: %s (dir: %s) [in: %d Pkts/%d B, out: %d Pkts/%d B]", - key, packetDir, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) - t.sendEvent(nftypes.TypeEnd, conn, nil) - } + // Hardening beyond RFC 9293 §3.10.7.4: without sequence tracking we + // cannot apply the RFC 5961 in-window RST check, so we conservatively + // reject RSTs that the spec would accept (TIME-WAIT with in-window + // SEQ, SynSent from same direction as own SYN, etc.). + t.handleRst(key, conn, currentState, packetDir) return } - var newState TCPState - switch currentState { - case TCPStateNew: - if flags&TCPSyn != 0 && flags&TCPAck == 0 { - if conn.Direction == nftypes.Egress { - newState = TCPStateSynSent - } else { - newState = TCPStateSynReceived - } - } + newState := nextState(currentState, conn.Direction, packetDir, flags) + if newState == 0 || !conn.CompareAndSwapState(currentState, newState) { + return + } + t.onTransition(key, conn, currentState, newState, packetDir) +} - case TCPStateSynSent: - if flags&TCPSyn != 0 && flags&TCPAck != 0 { - if packetDir != conn.Direction { - newState = TCPStateEstablished - } else { - // Simultaneous open - newState = TCPStateSynReceived - } - } +// handleRst processes a RST segment. Late RSTs in TimeWait and spoofed RSTs +// from the SYN direction are ignored; otherwise the flow is tombstoned. +func (t *TCPTracker) handleRst(key ConnKey, conn *TCPConnTrack, currentState TCPState, packetDir nftypes.Direction) { + // TimeWait exists to absorb late segments; don't let a late RST + // tombstone the entry and break same-4-tuple reuse. + if currentState == TCPStateTimeWait { + return + } + // A RST from the same direction as the SYN cannot be a legitimate + // response and must not tear down a half-open connection. + if currentState == TCPStateSynSent && packetDir == conn.Direction { + return + } + if !conn.CompareAndSwapState(currentState, TCPStateClosed) { + return + } + conn.SetTombstone() + if t.logger.Enabled(nblog.LevelTrace) { + t.logger.Trace6("TCP connection reset: %s (dir: %s) [in: %d Pkts/%d B, out: %d Pkts/%d B]", + key, packetDir, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) + } + t.sendEvent(nftypes.TypeEnd, conn, nil) +} - case TCPStateSynReceived: - if flags&TCPAck != 0 && flags&TCPSyn == 0 { - if packetDir == conn.Direction { - newState = TCPStateEstablished - } - } +// stateTransition describes one state's transition logic. It receives the +// packet's flags plus whether the packet direction matches the connection's +// origin direction (same=true means same side as the SYN initiator). Return 0 +// for no transition. +type stateTransition func(flags uint8, connDir nftypes.Direction, same bool) TCPState - case TCPStateEstablished: - if flags&TCPFin != 0 { - if packetDir == conn.Direction { - newState = TCPStateFinWait1 - } else { - newState = TCPStateCloseWait - } - } +// stateTable maps each state to its transition function. Centralized here so +// nextState stays trivial and each rule is easy to read in isolation. +var stateTable = map[TCPState]stateTransition{ + TCPStateNew: transNew, + TCPStateSynSent: transSynSent, + TCPStateSynReceived: transSynReceived, + TCPStateEstablished: transEstablished, + TCPStateFinWait1: transFinWait1, + TCPStateFinWait2: transFinWait2, + TCPStateClosing: transClosing, + TCPStateCloseWait: transCloseWait, + TCPStateLastAck: transLastAck, +} - case TCPStateFinWait1: - if packetDir != conn.Direction { - switch { - case flags&TCPFin != 0 && flags&TCPAck != 0: - newState = TCPStateClosing - case flags&TCPFin != 0: - newState = TCPStateClosing - case flags&TCPAck != 0: - newState = TCPStateFinWait2 - } - } +// nextState returns the target TCP state for the given current state and +// packet, or 0 if the packet does not trigger a transition. +func nextState(currentState TCPState, connDir, packetDir nftypes.Direction, flags uint8) TCPState { + fn, ok := stateTable[currentState] + if !ok { + return 0 + } + return fn(flags, connDir, packetDir == connDir) +} - case TCPStateFinWait2: - if flags&TCPFin != 0 { - newState = TCPStateTimeWait +func transNew(flags uint8, connDir nftypes.Direction, _ bool) TCPState { + if flags&TCPSyn != 0 && flags&TCPAck == 0 { + if connDir == nftypes.Egress { + return TCPStateSynSent } + return TCPStateSynReceived + } + return 0 +} - case TCPStateClosing: - if flags&TCPAck != 0 { - newState = TCPStateTimeWait +func transSynSent(flags uint8, _ nftypes.Direction, same bool) TCPState { + if flags&TCPSyn != 0 && flags&TCPAck != 0 { + if same { + return TCPStateSynReceived // simultaneous open } + return TCPStateEstablished + } + return 0 +} - case TCPStateCloseWait: - if flags&TCPFin != 0 { - newState = TCPStateLastAck - } +func transSynReceived(flags uint8, _ nftypes.Direction, same bool) TCPState { + if flags&TCPAck != 0 && flags&TCPSyn == 0 && same { + return TCPStateEstablished + } + return 0 +} - case TCPStateLastAck: - if flags&TCPAck != 0 { - newState = TCPStateClosed - } +func transEstablished(flags uint8, _ nftypes.Direction, same bool) TCPState { + if flags&TCPFin == 0 { + return 0 + } + if same { + return TCPStateFinWait1 + } + return TCPStateCloseWait +} + +// transFinWait1 handles the active-close peer response. A FIN carrying our +// ACK piggybacked goes straight to TIME-WAIT (RFC 9293 §3.10.7.4, FIN-WAIT-1: +// "if our FIN has been ACKed... enter the TIME-WAIT state"); a lone FIN moves +// to CLOSING; a pure ACK of our FIN moves to FIN-WAIT-2. +func transFinWait1(flags uint8, _ nftypes.Direction, same bool) TCPState { + if same { + return 0 + } + if flags&TCPFin != 0 && flags&TCPAck != 0 { + return TCPStateTimeWait + } + switch { + case flags&TCPFin != 0: + return TCPStateClosing + case flags&TCPAck != 0: + return TCPStateFinWait2 + } + return 0 +} + +// transFinWait2 ignores own-side FIN retransmits; only the peer's FIN advances. +func transFinWait2(flags uint8, _ nftypes.Direction, same bool) TCPState { + if flags&TCPFin != 0 && !same { + return TCPStateTimeWait + } + return 0 +} + +// transClosing completes a simultaneous close on the peer's ACK. +func transClosing(flags uint8, _ nftypes.Direction, same bool) TCPState { + if flags&TCPAck != 0 && !same { + return TCPStateTimeWait + } + return 0 +} + +// transCloseWait only advances to LastAck when WE send FIN, ignoring peer retransmits. +func transCloseWait(flags uint8, _ nftypes.Direction, same bool) TCPState { + if flags&TCPFin != 0 && same { + return TCPStateLastAck + } + return 0 +} + +// transLastAck closes the flow only on the peer's ACK (not our own ACK retransmits). +func transLastAck(flags uint8, _ nftypes.Direction, same bool) TCPState { + if flags&TCPAck != 0 && !same { + return TCPStateClosed + } + return 0 +} + +// onTransition handles logging and flow-event emission after a successful +// state transition. TimeWait and Closed are terminal for flow accounting. +func (t *TCPTracker) onTransition(key ConnKey, conn *TCPConnTrack, from, to TCPState, packetDir nftypes.Direction) { + traceOn := t.logger.Enabled(nblog.LevelTrace) + if traceOn { + t.logger.Trace4("TCP connection %s transitioned from %s to %s (dir: %s)", key, from, to, packetDir) } - if newState != 0 && conn.CompareAndSwapState(currentState, newState) { - t.logger.Trace4("TCP connection %s transitioned from %s to %s (dir: %s)", key, currentState, newState, packetDir) - - switch newState { - case TCPStateTimeWait: + switch to { + case TCPStateTimeWait: + if traceOn { t.logger.Trace5("TCP connection %s completed [in: %d Pkts/%d B, out: %d Pkts/%d B]", key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) - t.sendEvent(nftypes.TypeEnd, conn, nil) - - case TCPStateClosed: - conn.SetTombstone() + } + t.sendEvent(nftypes.TypeEnd, conn, nil) + case TCPStateClosed: + conn.SetTombstone() + if traceOn { t.logger.Trace5("TCP connection %s closed gracefully [in: %d Pkts/%d, B out: %d Pkts/%d B]", key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) - t.sendEvent(nftypes.TypeEnd, conn, nil) } + t.sendEvent(nftypes.TypeEnd, conn, nil) } } -// isValidStateForFlags checks if the TCP flags are valid for the current connection state +// isValidStateForFlags checks if the TCP flags are valid for the current +// connection state. Caller must have already verified the flag combination is +// legal via isValidFlagCombination. func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool { - if !isValidFlagCombination(flags) { - return false - } if flags&TCPRst != 0 { if state == TCPStateSynSent { return flags&TCPAck != 0 @@ -449,15 +628,24 @@ func (t *TCPTracker) cleanup() { timeout = t.waitTimeout case TCPStateEstablished: timeout = t.timeout + case TCPStateFinWait1, TCPStateFinWait2, TCPStateClosing: + timeout = t.finWaitTimeout + case TCPStateCloseWait: + timeout = t.closeWaitTimeout + case TCPStateLastAck: + timeout = t.lastAckTimeout default: + // SynSent / SynReceived / New timeout = TCPHandshakeTimeout } if conn.timeoutExceeded(timeout) { delete(t.connections, key) - t.logger.Trace6("Cleaned up timed-out TCP connection %s (%s) [in: %d Pkts/%d, B out: %d Pkts/%d B]", - key, conn.GetState(), conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) + if t.logger.Enabled(nblog.LevelTrace) { + t.logger.Trace6("Cleaned up timed-out TCP connection %s (%s) [in: %d Pkts/%d, B out: %d Pkts/%d B]", + key, conn.GetState(), conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) + } // event already handled by state change if currentState != TCPStateTimeWait { diff --git a/client/firewall/uspfilter/conntrack/tcp_rst_bugs_test.go b/client/firewall/uspfilter/conntrack/tcp_rst_bugs_test.go new file mode 100644 index 000000000..81d4f5710 --- /dev/null +++ b/client/firewall/uspfilter/conntrack/tcp_rst_bugs_test.go @@ -0,0 +1,100 @@ +package conntrack + +import ( + "net/netip" + "testing" + + "github.com/stretchr/testify/require" +) + +// RST hygiene tests: the tracker currently closes the flow on any RST that +// matches the 4-tuple, regardless of direction or state. These tests cover +// the minimum checks we want (no SEQ tracking). + +func TestTCPRstInSynSentWrongDirection(t *testing.T) { + tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) + defer tracker.Close() + + srcIP := netip.MustParseAddr("100.64.0.1") + dstIP := netip.MustParseAddr("100.64.0.2") + srcPort := uint16(12345) + dstPort := uint16(80) + key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort} + + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0) + conn := tracker.connections[key] + require.Equal(t, TCPStateSynSent, conn.GetState()) + + // A RST arriving in the same direction as the SYN (i.e. TrackOutbound) + // cannot be a legitimate response. It must not close the connection. + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPRst|TCPAck, 0) + require.Equal(t, TCPStateSynSent, conn.GetState(), + "RST in same direction as SYN must not close connection") + require.False(t, conn.IsTombstone()) +} + +func TestTCPRstInTimeWaitIgnored(t *testing.T) { + tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) + defer tracker.Close() + + srcIP := netip.MustParseAddr("100.64.0.1") + dstIP := netip.MustParseAddr("100.64.0.2") + srcPort := uint16(12345) + dstPort := uint16(80) + key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort} + + // Drive to TIME-WAIT via active close. + establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) + require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)) + require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)) + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0) + + conn := tracker.connections[key] + require.Equal(t, TCPStateTimeWait, conn.GetState()) + require.False(t, conn.IsTombstone(), "TIME-WAIT must not be tombstoned") + + // Late RST during TIME-WAIT must not tombstone the entry (TIME-WAIT + // exists to absorb late segments). + tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0) + require.Equal(t, TCPStateTimeWait, conn.GetState(), + "RST in TIME-WAIT must not transition state") + require.False(t, conn.IsTombstone(), + "RST in TIME-WAIT must not tombstone the entry") +} + +func TestTCPIllegalFlagCombos(t *testing.T) { + tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) + defer tracker.Close() + + srcIP := netip.MustParseAddr("100.64.0.1") + dstIP := netip.MustParseAddr("100.64.0.2") + srcPort := uint16(12345) + dstPort := uint16(80) + key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort} + + establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) + conn := tracker.connections[key] + + // Illegal combos must be rejected and must not change state. + combos := []struct { + name string + flags uint8 + }{ + {"SYN+RST", TCPSyn | TCPRst}, + {"FIN+RST", TCPFin | TCPRst}, + {"SYN+FIN", TCPSyn | TCPFin}, + {"SYN+FIN+RST", TCPSyn | TCPFin | TCPRst}, + } + + for _, c := range combos { + t.Run(c.name, func(t *testing.T) { + before := conn.GetState() + valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, c.flags, 0) + require.False(t, valid, "illegal flag combo must be rejected: %s", c.name) + require.Equal(t, before, conn.GetState(), + "illegal flag combo must not change state") + require.False(t, conn.IsTombstone()) + }) + } +} diff --git a/client/firewall/uspfilter/conntrack/tcp_state_bugs_test.go b/client/firewall/uspfilter/conntrack/tcp_state_bugs_test.go new file mode 100644 index 000000000..32112cd58 --- /dev/null +++ b/client/firewall/uspfilter/conntrack/tcp_state_bugs_test.go @@ -0,0 +1,235 @@ +package conntrack + +import ( + "net/netip" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +// These tests exercise cases where the TCP state machine currently advances +// on retransmitted or wrong-direction segments and tears the flow down +// prematurely. They are expected to fail until the direction checks are added. + +func TestTCPCloseWaitRetransmittedPeerFIN(t *testing.T) { + tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) + defer tracker.Close() + + srcIP := netip.MustParseAddr("100.64.0.1") + dstIP := netip.MustParseAddr("100.64.0.2") + srcPort := uint16(12345) + dstPort := uint16(80) + key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort} + + establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) + + // Peer sends FIN -> CloseWait (our app has not yet closed). + valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0) + require.True(t, valid) + conn := tracker.connections[key] + require.Equal(t, TCPStateCloseWait, conn.GetState()) + + // Peer retransmits their FIN (ACK may have been delayed). We have NOT + // sent our FIN yet, so state must remain CloseWait. + valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0) + require.True(t, valid, "retransmitted peer FIN must still be accepted") + require.Equal(t, TCPStateCloseWait, conn.GetState(), + "retransmitted peer FIN must not advance CloseWait to LastAck") + + // Our app finally closes -> LastAck. + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) + require.Equal(t, TCPStateLastAck, conn.GetState()) + + // Peer ACK closes. + valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0) + require.True(t, valid) + require.Equal(t, TCPStateClosed, conn.GetState()) +} + +func TestTCPFinWait2RetransmittedOwnFIN(t *testing.T) { + tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) + defer tracker.Close() + + srcIP := netip.MustParseAddr("100.64.0.1") + dstIP := netip.MustParseAddr("100.64.0.2") + srcPort := uint16(12345) + dstPort := uint16(80) + key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort} + + establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) + + // We initiate close. + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) + valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0) + require.True(t, valid) + conn := tracker.connections[key] + require.Equal(t, TCPStateFinWait2, conn.GetState()) + + // Stray retransmit of our own FIN (same direction as originator) must + // NOT advance FinWait2 to TimeWait; only the peer's FIN should. + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) + require.Equal(t, TCPStateFinWait2, conn.GetState(), + "own FIN retransmit must not advance FinWait2 to TimeWait") + + // Peer FIN -> TimeWait. + valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0) + require.True(t, valid) + require.Equal(t, TCPStateTimeWait, conn.GetState()) +} + +func TestTCPLastAckDirectionCheck(t *testing.T) { + tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) + defer tracker.Close() + + srcIP := netip.MustParseAddr("100.64.0.1") + dstIP := netip.MustParseAddr("100.64.0.2") + srcPort := uint16(12345) + dstPort := uint16(80) + key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort} + + establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) + + // Drive to LastAck: peer FIN -> CloseWait, our FIN -> LastAck. + require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)) + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) + conn := tracker.connections[key] + require.Equal(t, TCPStateLastAck, conn.GetState()) + + // Our own ACK retransmit (same direction as originator) must NOT close. + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0) + require.Equal(t, TCPStateLastAck, conn.GetState(), + "own ACK retransmit in LastAck must not transition to Closed") + + // Peer's ACK -> Closed. + require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)) + require.Equal(t, TCPStateClosed, conn.GetState()) +} + +func TestTCPFinWait1OwnAckDoesNotAdvance(t *testing.T) { + tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) + defer tracker.Close() + + srcIP := netip.MustParseAddr("100.64.0.1") + dstIP := netip.MustParseAddr("100.64.0.2") + srcPort := uint16(12345) + dstPort := uint16(80) + key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort} + + establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) + + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) + conn := tracker.connections[key] + require.Equal(t, TCPStateFinWait1, conn.GetState()) + + // Our own ACK retransmit (same direction as originator) must not advance. + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0) + require.Equal(t, TCPStateFinWait1, conn.GetState(), + "own ACK in FinWait1 must not advance to FinWait2") +} + +func TestTCPPerStateTeardownTimeouts(t *testing.T) { + // Verify cleanup reaps entries in each teardown state at the configured + // per-state timeout, not at the single handshake timeout. + t.Setenv(EnvTCPFinWaitTimeout, "50ms") + t.Setenv(EnvTCPCloseWaitTimeout, "80ms") + t.Setenv(EnvTCPLastAckTimeout, "30ms") + + dstIP := netip.MustParseAddr("100.64.0.2") + dstPort := uint16(80) + + // Drives a connection to the target state, forces its lastSeen well + // beyond the configured timeout, runs cleanup, and asserts reaping. + cases := []struct { + name string + // drive takes a fresh tracker and returns the conn key after + // transitioning the flow into the intended teardown state. + drive func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState) + }{ + { + name: "FinWait1", + drive: func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState) { + establishConnection(t, tr, srcIP, dstIP, srcPort, dstPort) + tr.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // → FinWait1 + return ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}, TCPStateFinWait1 + }, + }, + { + name: "FinWait2", + drive: func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState) { + establishConnection(t, tr, srcIP, dstIP, srcPort, dstPort) + tr.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // FinWait1 + require.True(t, tr.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)) // → FinWait2 + return ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}, TCPStateFinWait2 + }, + }, + { + name: "CloseWait", + drive: func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState) { + establishConnection(t, tr, srcIP, dstIP, srcPort, dstPort) + require.True(t, tr.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)) // → CloseWait + return ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}, TCPStateCloseWait + }, + }, + { + name: "LastAck", + drive: func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState) { + establishConnection(t, tr, srcIP, dstIP, srcPort, dstPort) + require.True(t, tr.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)) // CloseWait + tr.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // → LastAck + return ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}, TCPStateLastAck + }, + }, + } + + // Use a unique source port per subtest so nothing aliases. + port := uint16(12345) + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) + defer tracker.Close() + + require.Equal(t, 50*time.Millisecond, tracker.finWaitTimeout) + require.Equal(t, 80*time.Millisecond, tracker.closeWaitTimeout) + require.Equal(t, 30*time.Millisecond, tracker.lastAckTimeout) + + srcIP := netip.MustParseAddr("100.64.0.1") + port++ + key, wantState := c.drive(t, tracker, srcIP, port) + conn := tracker.connections[key] + require.NotNil(t, conn) + require.Equal(t, wantState, conn.GetState()) + + // Age the entry past the largest per-state timeout. + conn.lastSeen.Store(time.Now().Add(-500 * time.Millisecond).UnixNano()) + tracker.cleanup() + _, exists := tracker.connections[key] + require.False(t, exists, "%s entry should be reaped", c.name) + }) + } +} + +func TestTCPEstablishedPSHACKInFinStates(t *testing.T) { + // Verifies FIN|PSH|ACK and bare ACK keepalives are not dropped in FIN + // teardown states, which some stacks emit during close. + tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) + defer tracker.Close() + + srcIP := netip.MustParseAddr("100.64.0.1") + dstIP := netip.MustParseAddr("100.64.0.2") + srcPort := uint16(12345) + dstPort := uint16(80) + + establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) + + // Peer FIN -> CloseWait. + require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)) + + // Peer pushes trailing data + FIN|PSH|ACK (legal). + require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPPush|TCPAck, 100), + "FIN|PSH|ACK in CloseWait must be accepted") + + // Bare ACK keepalive from peer in CloseWait must be accepted. + require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0), + "bare ACK in CloseWait must be accepted") +} diff --git a/client/firewall/uspfilter/conntrack/udp.go b/client/firewall/uspfilter/conntrack/udp.go index a3b6a418b..335c5832a 100644 --- a/client/firewall/uspfilter/conntrack/udp.go +++ b/client/firewall/uspfilter/conntrack/udp.go @@ -17,6 +17,9 @@ const ( DefaultUDPTimeout = 30 * time.Second // UDPCleanupInterval is how often we check for stale connections UDPCleanupInterval = 15 * time.Second + + // EnvUDPMaxEntries caps the UDP conntrack table size. + EnvUDPMaxEntries = "NB_CONNTRACK_UDP_MAX" ) // UDPConnTrack represents a UDP connection state @@ -34,6 +37,7 @@ type UDPTracker struct { cleanupTicker *time.Ticker tickerCancel context.CancelFunc mutex sync.RWMutex + maxEntries int flowLogger nftypes.FlowLogger } @@ -51,6 +55,7 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp timeout: timeout, cleanupTicker: time.NewTicker(UDPCleanupInterval), tickerCancel: cancel, + maxEntries: envInt(logger, EnvUDPMaxEntries, DefaultMaxUDPEntries), flowLogger: flowLogger, } @@ -117,13 +122,18 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d conn.UpdateCounters(direction, size) t.mutex.Lock() + if t.maxEntries > 0 && len(t.connections) >= t.maxEntries { + t.evictOneLocked() + } t.connections[key] = conn t.mutex.Unlock() - if origPort != 0 { - t.logger.Trace4("New %s UDP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort) - } else { - t.logger.Trace2("New %s UDP connection: %s", direction, key) + if t.logger.Enabled(nblog.LevelTrace) { + if origPort != 0 { + t.logger.Trace4("New %s UDP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort) + } else { + t.logger.Trace2("New %s UDP connection: %s", direction, key) + } } t.sendEvent(nftypes.TypeStart, conn, ruleID) } @@ -151,6 +161,34 @@ func (t *UDPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort return true } +// evictOneLocked removes one entry to make room. Caller must hold t.mutex. +// Bounded sample: picks the oldest among up to evictSampleSize entries. +func (t *UDPTracker) evictOneLocked() { + var candKey ConnKey + var candSeen int64 + haveCand := false + sampled := 0 + + for k, c := range t.connections { + seen := c.lastSeen.Load() + if !haveCand || seen < candSeen { + candKey = k + candSeen = seen + haveCand = true + } + sampled++ + if sampled >= evictSampleSize { + break + } + } + if haveCand { + if evicted := t.connections[candKey]; evicted != nil { + t.sendEvent(nftypes.TypeEnd, evicted, nil) + } + delete(t.connections, candKey) + } +} + // cleanupRoutine periodically removes stale connections func (t *UDPTracker) cleanupRoutine(ctx context.Context) { defer t.cleanupTicker.Stop() @@ -173,8 +211,10 @@ func (t *UDPTracker) cleanup() { if conn.timeoutExceeded(t.timeout) { delete(t.connections, key) - t.logger.Trace5("Removed UDP connection %s (timeout) [in: %d Pkts/%d B, out: %d Pkts/%d B]", - key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) + if t.logger.Enabled(nblog.LevelTrace) { + t.logger.Trace5("Removed UDP connection %s (timeout) [in: %d Pkts/%d B, out: %d Pkts/%d B]", + key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) + } t.sendEvent(nftypes.TypeEnd, conn, nil) } } diff --git a/client/firewall/uspfilter/filter.go b/client/firewall/uspfilter/filter.go index 5ecd08abf..91866dcab 100644 --- a/client/firewall/uspfilter/filter.go +++ b/client/firewall/uspfilter/filter.go @@ -787,7 +787,9 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool { srcIP, dstIP := m.extractIPs(d) if !srcIP.IsValid() { - m.logger.Error1("Unknown network layer: %v", d.decoded[0]) + if m.logger.Enabled(nblog.LevelError) { + m.logger.Error1("Unknown network layer: %v", d.decoded[0]) + } return false } @@ -901,7 +903,9 @@ func (m *Manager) clampTCPMSS(packetData []byte, d *decoder) bool { return false } - m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, mssClampValue) + if m.logger.Enabled(nblog.LevelTrace) { + m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, mssClampValue) + } return true } @@ -1044,11 +1048,13 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool { // TODO: pass fragments of routed packets to forwarder if fragment { - if d.decoded[0] == layers.LayerTypeIPv4 { - m.logger.Trace4("packet is a fragment: src=%v dst=%v id=%v flags=%v", - srcIP, dstIP, d.ip4.Id, d.ip4.Flags) - } else { - m.logger.Trace2("packet is an IPv6 fragment: src=%v dst=%v", srcIP, dstIP) + if m.logger.Enabled(nblog.LevelTrace) { + if d.decoded[0] == layers.LayerTypeIPv4 { + m.logger.Trace4("packet is a fragment: src=%v dst=%v id=%v flags=%v", + srcIP, dstIP, d.ip4.Id, d.ip4.Flags) + } else { + m.logger.Trace2("packet is an IPv6 fragment: src=%v dst=%v", srcIP, dstIP) + } } return false } @@ -1091,8 +1097,10 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet pnum := getProtocolFromPacket(d) srcPort, dstPort := getPortsFromPacket(d) - m.logger.Trace6("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d", - ruleID, pnum, srcIP, srcPort, dstIP, dstPort) + if m.logger.Enabled(nblog.LevelTrace) { + m.logger.Trace6("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d", + ruleID, pnum, srcIP, srcPort, dstIP, dstPort) + } m.flowLogger.StoreEvent(nftypes.EventFields{ FlowID: uuid.New(), @@ -1142,8 +1150,10 @@ func (m *Manager) handleForwardedLocalTraffic(packetData []byte) bool { func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool { // Drop if routing is disabled if !m.routingEnabled.Load() { - m.logger.Trace2("Dropping routed packet (routing disabled): src=%s dst=%s", - srcIP, dstIP) + if m.logger.Enabled(nblog.LevelTrace) { + m.logger.Trace2("Dropping routed packet (routing disabled): src=%s dst=%s", + srcIP, dstIP) + } return true } @@ -1160,8 +1170,10 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe if !pass { proto := getProtocolFromPacket(d) - m.logger.Trace6("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d", - ruleID, proto, srcIP, srcPort, dstIP, dstPort) + if m.logger.Enabled(nblog.LevelTrace) { + m.logger.Trace6("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d", + ruleID, proto, srcIP, srcPort, dstIP, dstPort) + } m.flowLogger.StoreEvent(nftypes.EventFields{ FlowID: uuid.New(), @@ -1287,7 +1299,9 @@ func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) { // It returns true, true if the packet is a fragment and valid. func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) { if err := d.decodePacket(packetData); err != nil { - m.logger.Trace1("couldn't decode packet, err: %s", err) + if m.logger.Enabled(nblog.LevelTrace) { + m.logger.Trace1("couldn't decode packet, err: %s", err) + } return false, false } diff --git a/client/firewall/uspfilter/forwarder/icmp.go b/client/firewall/uspfilter/forwarder/icmp.go index 3922c2052..d6d4e705e 100644 --- a/client/firewall/uspfilter/forwarder/icmp.go +++ b/client/firewall/uspfilter/forwarder/icmp.go @@ -13,6 +13,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" + nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" ) @@ -97,8 +98,10 @@ func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []by return nil, fmt.Errorf("write ICMP packet: %w", err) } - f.logger.Trace3("forwarder: Forwarded ICMP packet %v type %v code %v", - epID(id), icmpType, icmpCode) + if f.logger.Enabled(nblog.LevelTrace) { + f.logger.Trace3("forwarder: Forwarded ICMP packet %v type %v code %v", + epID(id), icmpType, icmpCode) + } return conn, nil } @@ -121,12 +124,14 @@ func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndp txBytes := f.handleEchoResponse(conn, id, v6) rtt := time.Since(sendTime).Round(10 * time.Microsecond) - proto := "ICMP" - if v6 { - proto = "ICMPv6" + if f.logger.Enabled(nblog.LevelTrace) { + proto := "ICMP" + if v6 { + proto = "ICMPv6" + } + f.logger.Trace5("forwarder: Forwarded %s echo reply %v type %v code %v (rtt=%v, raw socket)", + proto, epID(id), icmpType, icmpCode, rtt) } - f.logger.Trace5("forwarder: Forwarded %s echo reply %v type %v code %v (rtt=%v, raw socket)", - proto, epID(id), icmpType, icmpCode, rtt) f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes)) } @@ -224,13 +229,17 @@ func (f *Forwarder) handleICMPViaPing(flowID uuid.UUID, id stack.TransportEndpoi } rtt := time.Since(pingStart).Round(10 * time.Microsecond) - f.logger.Trace3("forwarder: Forwarded ICMP echo request %v type %v code %v", - epID(id), icmpType, icmpCode) + if f.logger.Enabled(nblog.LevelTrace) { + f.logger.Trace3("forwarder: Forwarded ICMP echo request %v type %v code %v", + epID(id), icmpType, icmpCode) + } txBytes := f.synthesizeEchoReply(id, icmpData) - f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, ping binary)", - epID(id), icmpType, icmpCode, rtt) + if f.logger.Enabled(nblog.LevelTrace) { + f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, ping binary)", + epID(id), icmpType, icmpCode, rtt) + } f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes)) } diff --git a/client/firewall/uspfilter/forwarder/tcp.go b/client/firewall/uspfilter/forwarder/tcp.go index 8844463f5..c65ebcde0 100644 --- a/client/firewall/uspfilter/forwarder/tcp.go +++ b/client/firewall/uspfilter/forwarder/tcp.go @@ -1,11 +1,8 @@ package forwarder import ( - "context" - "io" "net" "strconv" - "sync" "github.com/google/uuid" @@ -15,7 +12,9 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/waiter" + nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" + "github.com/netbirdio/netbird/util/netrelay" ) // handleTCP is called by the TCP forwarder for new connections. @@ -37,7 +36,9 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) { outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr) if err != nil { r.Complete(true) - f.logger.Trace2("forwarder: dial error for %v: %v", epID(id), err) + if f.logger.Enabled(nblog.LevelTrace) { + f.logger.Trace2("forwarder: dial error for %v: %v", epID(id), err) + } return } @@ -60,64 +61,22 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) { inConn := gonet.NewTCPConn(&wq, ep) success = true - f.logger.Trace1("forwarder: established TCP connection %v", epID(id)) + if f.logger.Enabled(nblog.LevelTrace) { + f.logger.Trace1("forwarder: established TCP connection %v", epID(id)) + } go f.proxyTCP(id, inConn, outConn, ep, flowID) } func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) { + // netrelay.Relay copies bidirectionally with proper half-close propagation + // and fully closes both conns before returning. + bytesFromInToOut, bytesFromOutToIn := netrelay.Relay(f.ctx, inConn, outConn, netrelay.Options{ + Logger: f.logger, + }) - ctx, cancel := context.WithCancel(f.ctx) - defer cancel() - - go func() { - <-ctx.Done() - // Close connections and endpoint. - if err := inConn.Close(); err != nil && !isClosedError(err) { - f.logger.Debug1("forwarder: inConn close error: %v", err) - } - if err := outConn.Close(); err != nil && !isClosedError(err) { - f.logger.Debug1("forwarder: outConn close error: %v", err) - } - - ep.Close() - }() - - var wg sync.WaitGroup - wg.Add(2) - - var ( - bytesFromInToOut int64 // bytes from client to server (tx for client) - bytesFromOutToIn int64 // bytes from server to client (rx for client) - errInToOut error - errOutToIn error - ) - - go func() { - bytesFromInToOut, errInToOut = io.Copy(outConn, inConn) - cancel() - wg.Done() - }() - - go func() { - - bytesFromOutToIn, errOutToIn = io.Copy(inConn, outConn) - cancel() - wg.Done() - }() - - wg.Wait() - - if errInToOut != nil { - if !isClosedError(errInToOut) { - f.logger.Error2("proxyTCP: copy error (in → out) for %s: %v", epID(id), errInToOut) - } - } - if errOutToIn != nil { - if !isClosedError(errOutToIn) { - f.logger.Error2("proxyTCP: copy error (out → in) for %s: %v", epID(id), errOutToIn) - } - } + // Close the netstack endpoint after both conns are drained. + ep.Close() var rxPackets, txPackets uint64 if tcpStats, ok := ep.Stats().(*tcp.Stats); ok { @@ -126,7 +85,9 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn txPackets = tcpStats.SegmentsReceived.Value() } - f.logger.Trace5("forwarder: Removed TCP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, bytesFromOutToIn, txPackets, bytesFromInToOut) + if f.logger.Enabled(nblog.LevelTrace) { + f.logger.Trace5("forwarder: Removed TCP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, bytesFromOutToIn, txPackets, bytesFromInToOut) + } f.sendTCPEvent(nftypes.TypeEnd, flowID, id, uint64(bytesFromOutToIn), uint64(bytesFromInToOut), rxPackets, txPackets) } diff --git a/client/firewall/uspfilter/forwarder/udp.go b/client/firewall/uspfilter/forwarder/udp.go index c92fa1f32..d840ef06b 100644 --- a/client/firewall/uspfilter/forwarder/udp.go +++ b/client/firewall/uspfilter/forwarder/udp.go @@ -125,7 +125,9 @@ func (f *udpForwarder) cleanup() { delete(f.conns, idle.id) f.Unlock() - f.logger.Trace1("forwarder: cleaned up idle UDP connection %v", epID(idle.id)) + if f.logger.Enabled(nblog.LevelTrace) { + f.logger.Trace1("forwarder: cleaned up idle UDP connection %v", epID(idle.id)) + } } } } @@ -144,7 +146,9 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool { _, exists := f.udpForwarder.conns[id] f.udpForwarder.RUnlock() if exists { - f.logger.Trace1("forwarder: existing UDP connection for %v", epID(id)) + if f.logger.Enabled(nblog.LevelTrace) { + f.logger.Trace1("forwarder: existing UDP connection for %v", epID(id)) + } return true } @@ -206,7 +210,9 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool { f.udpForwarder.Unlock() success = true - f.logger.Trace1("forwarder: established UDP connection %v", epID(id)) + if f.logger.Enabled(nblog.LevelTrace) { + f.logger.Trace1("forwarder: established UDP connection %v", epID(id)) + } go f.proxyUDP(connCtx, pConn, id, ep) return true @@ -265,7 +271,9 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack txPackets = udpStats.PacketsReceived.Value() } - f.logger.Trace5("forwarder: Removed UDP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, rxBytes, txPackets, txBytes) + if f.logger.Enabled(nblog.LevelTrace) { + f.logger.Trace5("forwarder: Removed UDP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, rxBytes, txPackets, txBytes) + } f.udpForwarder.Lock() delete(f.udpForwarder.conns, id) diff --git a/client/firewall/uspfilter/log/log.go b/client/firewall/uspfilter/log/log.go index c6ca55e70..03e7d4809 100644 --- a/client/firewall/uspfilter/log/log.go +++ b/client/firewall/uspfilter/log/log.go @@ -53,16 +53,17 @@ var levelStrings = map[Level]string{ } type logMessage struct { - level Level - format string - arg1 any - arg2 any - arg3 any - arg4 any - arg5 any - arg6 any - arg7 any - arg8 any + level Level + argCount uint8 + format string + arg1 any + arg2 any + arg3 any + arg4 any + arg5 any + arg6 any + arg7 any + arg8 any } // Logger is a high-performance, non-blocking logger @@ -107,6 +108,13 @@ func (l *Logger) SetLevel(level Level) { log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level]) } +// Enabled reports whether the given level is currently logged. Callers on the +// hot path should guard log sites with this to avoid boxing arguments into +// any when the level is off. +func (l *Logger) Enabled(level Level) bool { + return l.level.Load() >= uint32(level) +} + func (l *Logger) Error(format string) { if l.level.Load() >= uint32(LevelError) { select { @@ -155,7 +163,7 @@ func (l *Logger) Trace(format string) { func (l *Logger) Error1(format string, arg1 any) { if l.level.Load() >= uint32(LevelError) { select { - case l.msgChannel <- logMessage{level: LevelError, format: format, arg1: arg1}: + case l.msgChannel <- logMessage{level: LevelError, argCount: 1, format: format, arg1: arg1}: default: } } @@ -164,7 +172,16 @@ func (l *Logger) Error1(format string, arg1 any) { func (l *Logger) Error2(format string, arg1, arg2 any) { if l.level.Load() >= uint32(LevelError) { select { - case l.msgChannel <- logMessage{level: LevelError, format: format, arg1: arg1, arg2: arg2}: + case l.msgChannel <- logMessage{level: LevelError, argCount: 2, format: format, arg1: arg1, arg2: arg2}: + default: + } + } +} + +func (l *Logger) Warn2(format string, arg1, arg2 any) { + if l.level.Load() >= uint32(LevelWarn) { + select { + case l.msgChannel <- logMessage{level: LevelWarn, argCount: 2, format: format, arg1: arg1, arg2: arg2}: default: } } @@ -173,7 +190,7 @@ func (l *Logger) Error2(format string, arg1, arg2 any) { func (l *Logger) Warn3(format string, arg1, arg2, arg3 any) { if l.level.Load() >= uint32(LevelWarn) { select { - case l.msgChannel <- logMessage{level: LevelWarn, format: format, arg1: arg1, arg2: arg2, arg3: arg3}: + case l.msgChannel <- logMessage{level: LevelWarn, argCount: 3, format: format, arg1: arg1, arg2: arg2, arg3: arg3}: default: } } @@ -182,7 +199,7 @@ func (l *Logger) Warn3(format string, arg1, arg2, arg3 any) { func (l *Logger) Warn4(format string, arg1, arg2, arg3, arg4 any) { if l.level.Load() >= uint32(LevelWarn) { select { - case l.msgChannel <- logMessage{level: LevelWarn, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}: + case l.msgChannel <- logMessage{level: LevelWarn, argCount: 4, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}: default: } } @@ -191,7 +208,7 @@ func (l *Logger) Warn4(format string, arg1, arg2, arg3, arg4 any) { func (l *Logger) Debug1(format string, arg1 any) { if l.level.Load() >= uint32(LevelDebug) { select { - case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1}: + case l.msgChannel <- logMessage{level: LevelDebug, argCount: 1, format: format, arg1: arg1}: default: } } @@ -200,7 +217,7 @@ func (l *Logger) Debug1(format string, arg1 any) { func (l *Logger) Debug2(format string, arg1, arg2 any) { if l.level.Load() >= uint32(LevelDebug) { select { - case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1, arg2: arg2}: + case l.msgChannel <- logMessage{level: LevelDebug, argCount: 2, format: format, arg1: arg1, arg2: arg2}: default: } } @@ -209,16 +226,59 @@ func (l *Logger) Debug2(format string, arg1, arg2 any) { func (l *Logger) Debug3(format string, arg1, arg2, arg3 any) { if l.level.Load() >= uint32(LevelDebug) { select { - case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1, arg2: arg2, arg3: arg3}: + case l.msgChannel <- logMessage{level: LevelDebug, argCount: 3, format: format, arg1: arg1, arg2: arg2, arg3: arg3}: default: } } } +// Debugf is the variadic shape. Dispatches to Debug/Debug1/Debug2/Debug3 +// to avoid allocating an args slice on the fast path when the arg count is +// known (0-3). Args beyond 3 land on the general variadic path; callers on +// the hot path should prefer DebugN for known counts. +func (l *Logger) Debugf(format string, args ...any) { + if l.level.Load() < uint32(LevelDebug) { + return + } + switch len(args) { + case 0: + l.Debug(format) + case 1: + l.Debug1(format, args[0]) + case 2: + l.Debug2(format, args[0], args[1]) + case 3: + l.Debug3(format, args[0], args[1], args[2]) + default: + l.sendVariadic(LevelDebug, format, args) + } +} + +// sendVariadic packs a slice of arguments into a logMessage and non-blocking +// enqueues it. Used for arg counts beyond the fixed-arity fast paths. Args +// beyond the 8-arg slot limit are dropped so callers don't produce silently +// empty log lines via uint8 wraparound in argCount. +func (l *Logger) sendVariadic(level Level, format string, args []any) { + const maxArgs = 8 + n := len(args) + if n > maxArgs { + n = maxArgs + } + msg := logMessage{level: level, argCount: uint8(n), format: format} + slots := [maxArgs]*any{&msg.arg1, &msg.arg2, &msg.arg3, &msg.arg4, &msg.arg5, &msg.arg6, &msg.arg7, &msg.arg8} + for i := 0; i < n; i++ { + *slots[i] = args[i] + } + select { + case l.msgChannel <- msg: + default: + } +} + func (l *Logger) Trace1(format string, arg1 any) { if l.level.Load() >= uint32(LevelTrace) { select { - case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1}: + case l.msgChannel <- logMessage{level: LevelTrace, argCount: 1, format: format, arg1: arg1}: default: } } @@ -227,7 +287,7 @@ func (l *Logger) Trace1(format string, arg1 any) { func (l *Logger) Trace2(format string, arg1, arg2 any) { if l.level.Load() >= uint32(LevelTrace) { select { - case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2}: + case l.msgChannel <- logMessage{level: LevelTrace, argCount: 2, format: format, arg1: arg1, arg2: arg2}: default: } } @@ -236,7 +296,7 @@ func (l *Logger) Trace2(format string, arg1, arg2 any) { func (l *Logger) Trace3(format string, arg1, arg2, arg3 any) { if l.level.Load() >= uint32(LevelTrace) { select { - case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3}: + case l.msgChannel <- logMessage{level: LevelTrace, argCount: 3, format: format, arg1: arg1, arg2: arg2, arg3: arg3}: default: } } @@ -245,7 +305,7 @@ func (l *Logger) Trace3(format string, arg1, arg2, arg3 any) { func (l *Logger) Trace4(format string, arg1, arg2, arg3, arg4 any) { if l.level.Load() >= uint32(LevelTrace) { select { - case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}: + case l.msgChannel <- logMessage{level: LevelTrace, argCount: 4, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}: default: } } @@ -254,7 +314,7 @@ func (l *Logger) Trace4(format string, arg1, arg2, arg3, arg4 any) { func (l *Logger) Trace5(format string, arg1, arg2, arg3, arg4, arg5 any) { if l.level.Load() >= uint32(LevelTrace) { select { - case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5}: + case l.msgChannel <- logMessage{level: LevelTrace, argCount: 5, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5}: default: } } @@ -263,7 +323,7 @@ func (l *Logger) Trace5(format string, arg1, arg2, arg3, arg4, arg5 any) { func (l *Logger) Trace6(format string, arg1, arg2, arg3, arg4, arg5, arg6 any) { if l.level.Load() >= uint32(LevelTrace) { select { - case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6}: + case l.msgChannel <- logMessage{level: LevelTrace, argCount: 6, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6}: default: } } @@ -273,7 +333,7 @@ func (l *Logger) Trace6(format string, arg1, arg2, arg3, arg4, arg5, arg6 any) { func (l *Logger) Trace8(format string, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8 any) { if l.level.Load() >= uint32(LevelTrace) { select { - case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6, arg7: arg7, arg8: arg8}: + case l.msgChannel <- logMessage{level: LevelTrace, argCount: 8, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6, arg7: arg7, arg8: arg8}: default: } } @@ -286,35 +346,8 @@ func (l *Logger) formatMessage(buf *[]byte, msg logMessage) { *buf = append(*buf, levelStrings[msg.level]...) *buf = append(*buf, ' ') - // Count non-nil arguments for switch - argCount := 0 - if msg.arg1 != nil { - argCount++ - if msg.arg2 != nil { - argCount++ - if msg.arg3 != nil { - argCount++ - if msg.arg4 != nil { - argCount++ - if msg.arg5 != nil { - argCount++ - if msg.arg6 != nil { - argCount++ - if msg.arg7 != nil { - argCount++ - if msg.arg8 != nil { - argCount++ - } - } - } - } - } - } - } - } - var formatted string - switch argCount { + switch msg.argCount { case 0: formatted = msg.format case 1: diff --git a/client/firewall/uspfilter/nat.go b/client/firewall/uspfilter/nat.go index 0d411c21e..5d51c1538 100644 --- a/client/firewall/uspfilter/nat.go +++ b/client/firewall/uspfilter/nat.go @@ -11,6 +11,7 @@ import ( "github.com/google/gopacket/layers" firewall "github.com/netbirdio/netbird/client/firewall/manager" + nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" ) var ( @@ -262,11 +263,15 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool { } if err := m.rewritePacketIP(packetData, d, translatedIP, false); err != nil { - m.logger.Error1("failed to rewrite packet destination: %v", err) + if m.logger.Enabled(nblog.LevelError) { + m.logger.Error1("failed to rewrite packet destination: %v", err) + } return false } - m.logger.Trace2("DNAT: %s -> %s", dstIP, translatedIP) + if m.logger.Enabled(nblog.LevelTrace) { + m.logger.Trace2("DNAT: %s -> %s", dstIP, translatedIP) + } return true } @@ -283,11 +288,15 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool { } if err := m.rewritePacketIP(packetData, d, originalIP, true); err != nil { - m.logger.Error1("failed to rewrite packet source: %v", err) + if m.logger.Enabled(nblog.LevelError) { + m.logger.Error1("failed to rewrite packet source: %v", err) + } return false } - m.logger.Trace2("Reverse DNAT: %s -> %s", srcIP, originalIP) + if m.logger.Enabled(nblog.LevelTrace) { + m.logger.Trace2("Reverse DNAT: %s -> %s", srcIP, originalIP) + } return true } @@ -612,7 +621,9 @@ func (m *Manager) applyPortRule(packetData []byte, d *decoder, srcIP, dstIP neti } if err := rewriteFn(packetData, d, rule.targetPort, destinationPortOffset); err != nil { - m.logger.Error1("failed to rewrite port: %v", err) + if m.logger.Enabled(nblog.LevelError) { + m.logger.Error1("failed to rewrite port: %v", err) + } return false } d.dnatOrigPort = rule.origPort diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index 62a7747f1..915296069 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -57,6 +57,27 @@ import ( var currentMTU uint16 = iface.DefaultMTU +// nonRetryableEDECodes lists EDE info codes (RFC 8914) for which a SERVFAIL +// from one upstream means another upstream would return the same answer: +// DNSSEC validation outcomes and policy-based blocks. Transient errors +// (network, cached, not ready) are not included. +var nonRetryableEDECodes = map[uint16]struct{}{ + dns.ExtendedErrorCodeUnsupportedDNSKEYAlgorithm: {}, + dns.ExtendedErrorCodeUnsupportedDSDigestType: {}, + dns.ExtendedErrorCodeDNSSECIndeterminate: {}, + dns.ExtendedErrorCodeDNSBogus: {}, + dns.ExtendedErrorCodeSignatureExpired: {}, + dns.ExtendedErrorCodeSignatureNotYetValid: {}, + dns.ExtendedErrorCodeDNSKEYMissing: {}, + dns.ExtendedErrorCodeRRSIGsMissing: {}, + dns.ExtendedErrorCodeNoZoneKeyBitSet: {}, + dns.ExtendedErrorCodeNSECMissing: {}, + dns.ExtendedErrorCodeBlocked: {}, + dns.ExtendedErrorCodeCensored: {}, + dns.ExtendedErrorCodeFiltered: {}, + dns.ExtendedErrorCodeProhibited: {}, +} + // privateClientIface is the subset of the WireGuard interface needed by GetClientPrivate. type privateClientIface interface { Name() string @@ -151,6 +172,7 @@ type raceResult struct { msg *dns.Msg upstream netip.AddrPort protocol string + ede string failures []upstreamFailure } @@ -308,6 +330,9 @@ func (u *upstreamResolverBase) tryOnlyRace(ctx context.Context, w dns.ResponseWr if res.msg == nil { return false, res.failures } + if res.ede != "" { + resutil.SetMeta(w, "ede", res.ede) + } u.writeSuccessResponse(w, res.msg, res.upstream, r.Question[0].Name, res.protocol, logger) return true, res.failures } @@ -367,21 +392,33 @@ func (u *upstreamResolverBase) tryRace(ctx context.Context, r *dns.Msg, group up // options in-place, so reusing the same *dns.Msg across sequential // upstreams would carry those mutations (e.g. a reduced UDP size) // into the next attempt. - msg, proto, failure := u.queryUpstream(ctx, r.Copy(), upstream, timeout) + res, failure := u.queryUpstream(ctx, r.Copy(), upstream, timeout) if failure != nil { failures = append(failures, *failure) continue } - return raceResult{msg: msg, upstream: upstream, protocol: proto, failures: failures} + res.failures = failures + return res } return raceResult{failures: failures} } -func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration) (*dns.Msg, string, *upstreamFailure) { +func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration) (raceResult, *upstreamFailure) { ctx, cancel := context.WithTimeout(parentCtx, timeout) defer cancel() ctx, upstreamProto := contextWithUpstreamProtocolResult(ctx) + // Advertise EDNS0 so the upstream may include Extended DNS Errors + // (RFC 8914) in failure responses; we use those to short-circuit + // failover for definitive answers like DNSSEC validation failures. + // The caller already passed a per-attempt copy, so we can mutate r + // directly; hadEdns reflects the original client request's state and + // controls whether we strip the OPT from the response. + hadEdns := r.IsEdns0() != nil + if !hadEdns { + r.SetEdns0(upstreamUDPSize(), false) + } + startTime := time.Now() rm, _, err := u.upstreamClient.exchange(ctx, upstream.String(), r) @@ -391,31 +428,42 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, r *dns.M // error chain and the parent context: a transport may surface the // cancellation as a read/deadline error rather than context.Canceled. if errors.Is(err, context.Canceled) || errors.Is(parentCtx.Err(), context.Canceled) { - return nil, "", &upstreamFailure{upstream: upstream, reason: "canceled"} + return raceResult{}, &upstreamFailure{upstream: upstream, reason: "canceled"} } failure := u.handleUpstreamError(err, upstream, startTime) u.markUpstreamFail(upstream, failure.reason) - return nil, "", failure + return raceResult{}, failure } if rm == nil || !rm.Response { u.markUpstreamFail(upstream, "no response") - return nil, "", &upstreamFailure{upstream: upstream, reason: "no response"} + return raceResult{}, &upstreamFailure{upstream: upstream, reason: "no response"} } - if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused { - reason := dns.RcodeToString[rm.Rcode] - u.markUpstreamFail(upstream, reason) - return nil, "", &upstreamFailure{upstream: upstream, reason: reason} - } - - u.markUpstreamOk(upstream) - proto := "" if upstreamProto != nil { proto = upstreamProto.protocol } - return rm, proto, nil + + if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused { + if _, ok := nonRetryableEDE(rm); ok { + if !hadEdns { + stripOPT(rm) + } + u.markUpstreamOk(upstream) + return raceResult{msg: rm, upstream: upstream, protocol: proto}, nil + } + reason := dns.RcodeToString[rm.Rcode] + u.markUpstreamFail(upstream, reason) + return raceResult{}, &upstreamFailure{upstream: upstream, reason: reason} + } + + if !hadEdns { + stripOPT(rm) + } + + u.markUpstreamOk(upstream) + return raceResult{msg: rm, upstream: upstream, protocol: proto}, nil } // healthEntry returns the mutable health record for addr, lazily creating @@ -460,6 +508,31 @@ func (u *upstreamResolverBase) UpstreamHealth() map[netip.AddrPort]UpstreamHealt return out } +// upstreamUDPSize returns the EDNS0 UDP buffer size we advertise to upstreams, +// derived from the tunnel MTU and bounded against underflow. +func upstreamUDPSize() uint16 { + if currentMTU > ipUDPHeaderSize { + return currentMTU - ipUDPHeaderSize + } + return dns.MinMsgSize +} + +// stripOPT removes any OPT pseudo-RRs from the response's Extra section so +// the response complies with RFC 6891 when the client did not advertise EDNS0. +func stripOPT(rm *dns.Msg) { + if len(rm.Extra) == 0 { + return + } + out := rm.Extra[:0] + for _, rr := range rm.Extra { + if _, ok := rr.(*dns.OPT); ok { + continue + } + out = append(out, rr) + } + rm.Extra = out +} + func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, startTime time.Time) *upstreamFailure { if !errors.Is(err, context.DeadlineExceeded) && !isTimeout(err) { return &upstreamFailure{upstream: upstream, reason: err.Error()} @@ -529,6 +602,34 @@ func formatFailures(failures []upstreamFailure) string { return strings.Join(parts, ", ") } +// nonRetryableEDE returns the first non-retryable EDE code carried in the +// response, if any. +func nonRetryableEDE(rm *dns.Msg) (uint16, bool) { + opt := rm.IsEdns0() + if opt == nil { + return 0, false + } + for _, o := range opt.Option { + ede, ok := o.(*dns.EDNS0_EDE) + if !ok { + continue + } + if _, ok := nonRetryableEDECodes[ede.InfoCode]; ok { + return ede.InfoCode, true + } + } + return 0, false +} + +// edeName returns a human-readable name for an EDE code, falling back to +// the numeric code when unknown. +func edeName(code uint16) string { + if name, ok := dns.ExtendedErrorCodeToString[code]; ok { + return name + } + return fmt.Sprintf("EDE %d", code) +} + // isTimeout returns true if the given error is a network timeout error. // // Copied from k8s.io/apimachinery/pkg/util/net.IsTimeout diff --git a/client/internal/dns/upstream_test.go b/client/internal/dns/upstream_test.go index 60195b5ac..8b3c589f1 100644 --- a/client/internal/dns/upstream_test.go +++ b/client/internal/dns/upstream_test.go @@ -847,3 +847,132 @@ func TestExchangeWithFallback_TCPTruncatesToClientSize(t *testing.T) { assert.Less(t, len(rm2.Answer), 20, "small EDNS0 client should get fewer records") assert.True(t, rm2.Truncated, "response should be truncated for small buffer client") } + +func msgWithEDE(rcode int, codes ...uint16) *dns.Msg { + m := new(dns.Msg) + m.Response = true + m.Rcode = rcode + if len(codes) == 0 { + return m + } + opt := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}} + opt.SetUDPSize(dns.MinMsgSize) + for _, c := range codes { + opt.Option = append(opt.Option, &dns.EDNS0_EDE{InfoCode: c}) + } + m.Extra = append(m.Extra, opt) + return m +} + +func TestNonRetryableEDE(t *testing.T) { + tests := []struct { + name string + msg *dns.Msg + wantOK bool + wantCode uint16 + }{ + {name: "no edns0", msg: msgWithEDE(dns.RcodeServerFailure)}, + { + name: "opt without ede", + msg: func() *dns.Msg { + m := msgWithEDE(dns.RcodeServerFailure) + opt := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}} + opt.Option = append(opt.Option, &dns.EDNS0_NSID{Code: dns.EDNS0NSID}) + m.Extra = []dns.RR{opt} + return m + }(), + }, + {name: "ede dnsbogus", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeDNSBogus), wantOK: true, wantCode: dns.ExtendedErrorCodeDNSBogus}, + {name: "ede signature expired", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeSignatureExpired), wantOK: true, wantCode: dns.ExtendedErrorCodeSignatureExpired}, + {name: "ede blocked", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeBlocked), wantOK: true, wantCode: dns.ExtendedErrorCodeBlocked}, + {name: "ede prohibited", msg: msgWithEDE(dns.RcodeRefused, dns.ExtendedErrorCodeProhibited), wantOK: true, wantCode: dns.ExtendedErrorCodeProhibited}, + {name: "ede cached error retryable", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeCachedError)}, + {name: "ede network error retryable", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeNetworkError)}, + {name: "ede not ready retryable", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeNotReady)}, + { + name: "first non-retryable wins", + msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeNetworkError, dns.ExtendedErrorCodeDNSBogus), + wantOK: true, + wantCode: dns.ExtendedErrorCodeDNSBogus, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + code, ok := nonRetryableEDE(tc.msg) + assert.Equal(t, tc.wantOK, ok, "ok should match") + if tc.wantOK { + assert.Equal(t, tc.wantCode, code, "code should match") + } + }) + } +} + +func TestEDEName(t *testing.T) { + assert.Equal(t, "DNSSEC Bogus", edeName(dns.ExtendedErrorCodeDNSBogus)) + assert.Equal(t, "Signature Expired", edeName(dns.ExtendedErrorCodeSignatureExpired)) + assert.Equal(t, "EDE 9999", edeName(9999), "unknown code falls back to numeric") +} + +func TestStripOPT(t *testing.T) { + rm := &dns.Msg{ + Extra: []dns.RR{ + &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}}, + &dns.A{Hdr: dns.RR_Header{Name: "x.", Rrtype: dns.TypeA}, A: net.IPv4(1, 2, 3, 4)}, + }, + } + stripOPT(rm) + assert.Len(t, rm.Extra, 1, "OPT should be removed, A kept") + _, isOPT := rm.Extra[0].(*dns.OPT) + assert.False(t, isOPT, "remaining record must not be OPT") +} + +func TestUpstreamResolver_NonRetryableEDEShortCircuits(t *testing.T) { + upstream1 := netip.MustParseAddrPort("192.0.2.1:53") + upstream2 := netip.MustParseAddrPort("192.0.2.2:53") + + servfailWithEDE := msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeDNSBogus) + successResp := buildMockResponse(dns.RcodeSuccess, "192.0.2.100") + + var queried []string + tracking := &trackingMockClient{ + inner: &mockUpstreamResolverPerServer{ + responses: map[string]mockUpstreamResponse{ + upstream1.String(): {msg: servfailWithEDE}, + upstream2.String(): {msg: successResp}, + }, + rtt: time.Millisecond, + }, + queriedUpstreams: &queried, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + resolver := &upstreamResolverBase{ + ctx: ctx, + upstreamClient: tracking, + upstreamServers: []upstreamRace{{upstream1, upstream2}}, + upstreamTimeout: UpstreamTimeout, + } + + var written *dns.Msg + w := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + written = m + return nil + }, + } + + // Client query without EDNS0 must not see an OPT in the response. + q := new(dns.Msg).SetQuestion("example.com.", dns.TypeA) + resolver.ServeDNS(w, q) + + require.NotNil(t, written, "response must be written") + assert.Equal(t, dns.RcodeServerFailure, written.Rcode, "SERVFAIL must propagate") + assert.Len(t, queried, 1, "only first upstream should be queried") + assert.Equal(t, upstream1.String(), queried[0]) + for _, rr := range written.Extra { + _, isOPT := rr.(*dns.OPT) + assert.False(t, isOPT, "synthetic OPT must not leak to a non-EDNS0 client") + } +} diff --git a/client/ssh/client/client.go b/client/ssh/client/client.go index 7f72a72cf..ebf8eb794 100644 --- a/client/ssh/client/client.go +++ b/client/ssh/client/client.go @@ -25,6 +25,7 @@ import ( nbssh "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/ssh/detection" "github.com/netbirdio/netbird/util" + "github.com/netbirdio/netbird/util/netrelay" ) const ( @@ -536,7 +537,7 @@ func (c *Client) LocalPortForward(ctx context.Context, localAddr, remoteAddr str continue } - go c.handleLocalForward(localConn, remoteAddr) + go c.handleLocalForward(ctx, localConn, remoteAddr) } }() @@ -548,7 +549,7 @@ func (c *Client) LocalPortForward(ctx context.Context, localAddr, remoteAddr str } // handleLocalForward handles a single local port forwarding connection -func (c *Client) handleLocalForward(localConn net.Conn, remoteAddr string) { +func (c *Client) handleLocalForward(ctx context.Context, localConn net.Conn, remoteAddr string) { defer func() { if err := localConn.Close(); err != nil { log.Debugf("local port forwarding: close local connection: %v", err) @@ -571,7 +572,7 @@ func (c *Client) handleLocalForward(localConn net.Conn, remoteAddr string) { } }() - nbssh.BidirectionalCopy(log.NewEntry(log.StandardLogger()), localConn, channel) + netrelay.Relay(ctx, localConn, channel, netrelay.Options{Logger: log.NewEntry(log.StandardLogger())}) } // RemotePortForward sets up remote port forwarding, binding on remote and forwarding to localAddr @@ -653,16 +654,19 @@ func (c *Client) handleRemoteForwardChannels(ctx context.Context, localAddr stri select { case <-ctx.Done(): return - case newChan := <-channelRequests: + case newChan, ok := <-channelRequests: + if !ok { + return + } if newChan != nil { - go c.handleRemoteForwardChannel(newChan, localAddr) + go c.handleRemoteForwardChannel(ctx, newChan, localAddr) } } } } // handleRemoteForwardChannel handles a single forwarded-tcpip channel -func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr string) { +func (c *Client) handleRemoteForwardChannel(ctx context.Context, newChan ssh.NewChannel, localAddr string) { channel, reqs, err := newChan.Accept() if err != nil { return @@ -675,8 +679,14 @@ func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr st go ssh.DiscardRequests(reqs) - localConn, err := net.Dial("tcp", localAddr) + // Bound the dial so a black-holed localAddr can't pin the accepted SSH + // channel open indefinitely; the relay itself runs under the outer ctx. + dialCtx, cancelDial := context.WithTimeout(ctx, 10*time.Second) + var dialer net.Dialer + localConn, err := dialer.DialContext(dialCtx, "tcp", localAddr) + cancelDial() if err != nil { + log.Debugf("remote port forwarding: dial %s: %v", localAddr, err) return } defer func() { @@ -685,7 +695,7 @@ func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr st } }() - nbssh.BidirectionalCopy(log.NewEntry(log.StandardLogger()), localConn, channel) + netrelay.Relay(ctx, localConn, channel, netrelay.Options{Logger: log.NewEntry(log.StandardLogger())}) } // tcpipForwardMsg represents the structure for tcpip-forward requests diff --git a/client/ssh/common.go b/client/ssh/common.go index f6aec5f9c..92e647b7d 100644 --- a/client/ssh/common.go +++ b/client/ssh/common.go @@ -194,63 +194,3 @@ func buildAddressList(hostname string, remote net.Addr) []string { return addresses } -// BidirectionalCopy copies data bidirectionally between two io.ReadWriter connections. -// It waits for both directions to complete before returning. -// The caller is responsible for closing the connections. -func BidirectionalCopy(logger *log.Entry, rw1, rw2 io.ReadWriter) { - done := make(chan struct{}, 2) - - go func() { - if _, err := io.Copy(rw2, rw1); err != nil && !isExpectedCopyError(err) { - logger.Debugf("copy error (1->2): %v", err) - } - done <- struct{}{} - }() - - go func() { - if _, err := io.Copy(rw1, rw2); err != nil && !isExpectedCopyError(err) { - logger.Debugf("copy error (2->1): %v", err) - } - done <- struct{}{} - }() - - <-done - <-done -} - -func isExpectedCopyError(err error) bool { - return errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) -} - -// BidirectionalCopyWithContext copies data bidirectionally between two io.ReadWriteCloser connections. -// It waits for both directions to complete or for context cancellation before returning. -// Both connections are closed when the function returns. -func BidirectionalCopyWithContext(logger *log.Entry, ctx context.Context, conn1, conn2 io.ReadWriteCloser) { - done := make(chan struct{}, 2) - - go func() { - if _, err := io.Copy(conn2, conn1); err != nil && !isExpectedCopyError(err) { - logger.Debugf("copy error (1->2): %v", err) - } - done <- struct{}{} - }() - - go func() { - if _, err := io.Copy(conn1, conn2); err != nil && !isExpectedCopyError(err) { - logger.Debugf("copy error (2->1): %v", err) - } - done <- struct{}{} - }() - - select { - case <-ctx.Done(): - case <-done: - select { - case <-ctx.Done(): - case <-done: - } - } - - _ = conn1.Close() - _ = conn2.Close() -} diff --git a/client/ssh/config/manager.go b/client/ssh/config/manager.go index 01822ead6..b58bf2233 100644 --- a/client/ssh/config/manager.go +++ b/client/ssh/config/manager.go @@ -229,18 +229,31 @@ func (m *Manager) buildHostPatterns(peer PeerSSHInfo) []string { func (m *Manager) writeSSHConfig(sshConfig string) error { sshConfigPath := filepath.Join(m.sshConfigDir, m.sshConfigFile) - sshConfigPathTmp := sshConfigPath + ".tmp" if err := os.MkdirAll(m.sshConfigDir, 0755); err != nil { return fmt.Errorf("create SSH config directory %s: %w", m.sshConfigDir, err) } - if err := writeFileWithTimeout(sshConfigPathTmp, []byte(sshConfig), 0644); err != nil { - return fmt.Errorf("write SSH config file %s: %w", sshConfigPath, err) + tmp, err := os.CreateTemp(m.sshConfigDir, m.sshConfigFile+".*.tmp") + if err != nil { + return fmt.Errorf("create temp SSH config: %w", err) + } + tmpPath := tmp.Name() + defer func() { + if err := os.Remove(tmpPath); err != nil && !os.IsNotExist(err) { + log.Debugf("remove temp SSH config %s: %v", tmpPath, err) + } + }() + if err := tmp.Close(); err != nil { + return fmt.Errorf("close temp SSH config %s: %w", tmpPath, err) } - if err := os.Rename(sshConfigPathTmp, sshConfigPath); err != nil { - return fmt.Errorf("rename ssh config %s -> %s: %w", sshConfigPathTmp, sshConfigPath, err) + if err := writeFileWithTimeout(tmpPath, []byte(sshConfig), 0644); err != nil { + return fmt.Errorf("write SSH config file %s: %w", tmpPath, err) + } + + if err := os.Rename(tmpPath, sshConfigPath); err != nil { + return fmt.Errorf("rename SSH config %s -> %s: %w", tmpPath, sshConfigPath, err) } log.Infof("Created NetBird SSH client config: %s", sshConfigPath) diff --git a/client/ssh/proxy/proxy.go b/client/ssh/proxy/proxy.go index eb659fe21..73b50122c 100644 --- a/client/ssh/proxy/proxy.go +++ b/client/ssh/proxy/proxy.go @@ -23,6 +23,7 @@ import ( "github.com/netbirdio/netbird/client/proto" nbssh "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/ssh/detection" + "github.com/netbirdio/netbird/util/netrelay" "github.com/netbirdio/netbird/version" ) @@ -352,7 +353,7 @@ func (p *SSHProxy) directTCPIPHandler(_ *ssh.Server, _ *cryptossh.ServerConn, ne } go cryptossh.DiscardRequests(clientReqs) - nbssh.BidirectionalCopyWithContext(log.NewEntry(log.StandardLogger()), sshCtx, clientChan, backendChan) + netrelay.Relay(sshCtx, clientChan, backendChan, netrelay.Options{Logger: log.NewEntry(log.StandardLogger())}) } func (p *SSHProxy) sftpSubsystemHandler(s ssh.Session, jwtToken string) { @@ -591,7 +592,7 @@ func (p *SSHProxy) handleForwardedChannel(sshCtx ssh.Context, sshConn *cryptossh } go cryptossh.DiscardRequests(clientReqs) - nbssh.BidirectionalCopyWithContext(log.NewEntry(log.StandardLogger()), sshCtx, clientChan, backendChan) + netrelay.Relay(sshCtx, clientChan, backendChan, netrelay.Options{Logger: log.NewEntry(log.StandardLogger())}) } func (p *SSHProxy) dialBackend(ctx context.Context, addr, user, jwtToken string) (*cryptossh.Client, error) { diff --git a/client/ssh/server/port_forwarding.go b/client/ssh/server/port_forwarding.go index f5ac66fca..a47fdb48a 100644 --- a/client/ssh/server/port_forwarding.go +++ b/client/ssh/server/port_forwarding.go @@ -17,7 +17,7 @@ import ( log "github.com/sirupsen/logrus" cryptossh "golang.org/x/crypto/ssh" - nbssh "github.com/netbirdio/netbird/client/ssh" + "github.com/netbirdio/netbird/util/netrelay" ) const privilegedPortThreshold = 1024 @@ -357,7 +357,7 @@ func (s *Server) handleRemoteForwardConnection(ctx ssh.Context, conn net.Conn, h return } - nbssh.BidirectionalCopyWithContext(logger, ctx, conn, channel) + netrelay.Relay(ctx, conn, channel, netrelay.Options{Logger: logger}) } // openForwardChannel creates an SSH forwarded-tcpip channel diff --git a/client/ssh/server/server.go b/client/ssh/server/server.go index de40d3091..6735e0f3b 100644 --- a/client/ssh/server/server.go +++ b/client/ssh/server/server.go @@ -8,9 +8,9 @@ import ( "fmt" "io" "net" - "strconv" "net/netip" "slices" + "strconv" "strings" "sync" "time" @@ -27,6 +27,7 @@ import ( "github.com/netbirdio/netbird/client/ssh/detection" "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/auth/jwt" + "github.com/netbirdio/netbird/util/netrelay" "github.com/netbirdio/netbird/version" ) @@ -53,6 +54,10 @@ const ( DefaultJWTMaxTokenAge = 10 * 60 ) +// directTCPIPDialTimeout bounds how long relayDirectTCPIP waits on a dial to +// the forwarded destination before rejecting the SSH channel. +const directTCPIPDialTimeout = 30 * time.Second + var ( ErrPrivilegedUserDisabled = errors.New(msgPrivilegedUserDisabled) ErrUserNotFound = errors.New("user not found") @@ -933,5 +938,29 @@ func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn, s.addConnectionPortForward(ctx.User(), ctx.RemoteAddr(), forwardAddr) logger.Infof("local port forwarding: %s", hostPort) - ssh.DirectTCPIPHandler(srv, conn, newChan, ctx) + s.relayDirectTCPIP(ctx, newChan, payload.Host, int(payload.Port), logger) +} + +// relayDirectTCPIP is a netrelay-based replacement for gliderlabs' +// DirectTCPIPHandler. The upstream handler closes both sides on the first +// EOF; netrelay.Relay propagates CloseWrite so each direction drains on its +// own terms. +func (s *Server) relayDirectTCPIP(ctx ssh.Context, newChan cryptossh.NewChannel, host string, port int, logger *log.Entry) { + dest := net.JoinHostPort(host, strconv.Itoa(port)) + + dialer := net.Dialer{Timeout: directTCPIPDialTimeout} + dconn, err := dialer.DialContext(ctx, "tcp", dest) + if err != nil { + _ = newChan.Reject(cryptossh.ConnectionFailed, err.Error()) + return + } + + ch, reqs, err := newChan.Accept() + if err != nil { + _ = dconn.Close() + return + } + go cryptossh.DiscardRequests(reqs) + + netrelay.Relay(ctx, dconn, ch, netrelay.Options{Logger: logger}) } diff --git a/combined/cmd/config.go b/combined/cmd/config.go index 9959f7a56..fe350e52a 100644 --- a/combined/cmd/config.go +++ b/combined/cmd/config.go @@ -133,13 +133,18 @@ type ManagementConfig struct { // AuthConfig contains authentication/identity provider settings type AuthConfig struct { - Issuer string `yaml:"issuer"` - LocalAuthDisabled bool `yaml:"localAuthDisabled"` - SignKeyRefreshEnabled bool `yaml:"signKeyRefreshEnabled"` - Storage AuthStorageConfig `yaml:"storage"` - DashboardRedirectURIs []string `yaml:"dashboardRedirectURIs"` - CLIRedirectURIs []string `yaml:"cliRedirectURIs"` - Owner *AuthOwnerConfig `yaml:"owner,omitempty"` + Issuer string `yaml:"issuer"` + LocalAuthDisabled bool `yaml:"localAuthDisabled"` + SignKeyRefreshEnabled bool `yaml:"signKeyRefreshEnabled"` + MfaSessionMaxLifetime string `yaml:"mfaSessionMaxLifetime"` + MfaSessionIdleTimeout string `yaml:"mfaSessionIdleTimeout"` + MfaSessionRememberMe bool `yaml:"mfaSessionRememberMe"` + SessionCookieEncryptionKey string `yaml:"sessionCookieEncryptionKey"` + Storage AuthStorageConfig `yaml:"storage"` + DashboardRedirectURIs []string `yaml:"dashboardRedirectURIs"` + CLIRedirectURIs []string `yaml:"cliRedirectURIs"` + Owner *AuthOwnerConfig `yaml:"owner,omitempty"` + DashboardPostLogoutRedirectURIs []string `yaml:"dashboardPostLogoutRedirectURIs"` } // AuthStorageConfig contains auth storage settings @@ -581,10 +586,14 @@ func (c *CombinedConfig) buildEmbeddedIdPConfig(mgmt ManagementConfig) (*idp.Emb } cfg := &idp.EmbeddedIdPConfig{ - Enabled: true, - Issuer: mgmt.Auth.Issuer, - LocalAuthDisabled: mgmt.Auth.LocalAuthDisabled, - SignKeyRefreshEnabled: mgmt.Auth.SignKeyRefreshEnabled, + Enabled: true, + Issuer: mgmt.Auth.Issuer, + LocalAuthDisabled: mgmt.Auth.LocalAuthDisabled, + SignKeyRefreshEnabled: mgmt.Auth.SignKeyRefreshEnabled, + MfaSessionMaxLifetime: mgmt.Auth.MfaSessionMaxLifetime, + MfaSessionIdleTimeout: mgmt.Auth.MfaSessionIdleTimeout, + MfaSessionRememberMe: mgmt.Auth.MfaSessionRememberMe, + SessionCookieEncryptionKey: mgmt.Auth.SessionCookieEncryptionKey, Storage: idp.EmbeddedStorageConfig{ Type: authStorageType, Config: idp.EmbeddedStorageTypeConfig{ @@ -592,8 +601,9 @@ func (c *CombinedConfig) buildEmbeddedIdPConfig(mgmt ManagementConfig) (*idp.Emb DSN: authStorageDSN, }, }, - DashboardRedirectURIs: mgmt.Auth.DashboardRedirectURIs, - CLIRedirectURIs: mgmt.Auth.CLIRedirectURIs, + DashboardRedirectURIs: mgmt.Auth.DashboardRedirectURIs, + CLIRedirectURIs: mgmt.Auth.CLIRedirectURIs, + DashboardPostLogoutRedirectURIs: mgmt.Auth.DashboardPostLogoutRedirectURIs, } if mgmt.Auth.Owner != nil && mgmt.Auth.Owner.Email != "" { diff --git a/combined/config.yaml.example b/combined/config.yaml.example index af85b0477..66bc71703 100644 --- a/combined/config.yaml.example +++ b/combined/config.yaml.example @@ -86,6 +86,13 @@ server: issuer: "https://example.com/oauth2" localAuthDisabled: false signKeyRefreshEnabled: false + # MFA session settings (applies when TOTP is enabled for an account) + # mfaSessionMaxLifetime: "24h" # Max duration for an MFA session from creation + # mfaSessionIdleTimeout: "1h" # MFA session expires after this idle period + # mfaSessionRememberMe: false # Pre-check "remember me" on login so the MFA session persists across tabs/restarts + # Optional AES key for encrypting embedded IdP session cookies. Can also be set via NB_IDP_SESSION_COOKIE_ENCRYPTION_KEY. + # Must be 16/24/32 raw bytes or base64-encoded to one of those lengths (for example: openssl rand -hex 16). + # sessionCookieEncryptionKey: "" # OAuth2 redirect URIs for dashboard dashboardRedirectURIs: - "https://app.example.com/nb-auth" @@ -93,6 +100,9 @@ server: # OAuth2 redirect URIs for CLI cliRedirectURIs: - "http://localhost:53000/" + # OAuth2 post-logout redirect URIs for dashboard (RP-initiated logout) + # dashboardPostLogoutRedirectURIs: + # - "https://app.example.com/" # Optional initial admin user # owner: # email: "admin@example.com" diff --git a/go.mod b/go.mod index bc4e8af15..5704887ce 100644 --- a/go.mod +++ b/go.mod @@ -14,7 +14,7 @@ require ( github.com/onsi/gomega v1.27.6 github.com/rs/cors v1.8.0 github.com/sirupsen/logrus v1.9.4 - github.com/spf13/cobra v1.10.1 + github.com/spf13/cobra v1.10.2 github.com/spf13/pflag v1.0.9 github.com/vishvananda/netlink v1.3.1 golang.org/x/crypto v0.50.0 @@ -41,11 +41,11 @@ require ( github.com/cilium/ebpf v0.15.0 github.com/coder/websocket v1.8.14 github.com/coreos/go-iptables v0.7.0 - github.com/coreos/go-oidc/v3 v3.14.1 + github.com/coreos/go-oidc/v3 v3.18.0 github.com/creack/pty v1.1.24 github.com/crowdsecurity/crowdsec v1.7.7 github.com/crowdsecurity/go-cs-bouncer v0.0.21 - github.com/dexidp/dex v0.0.0-00010101000000-000000000000 + github.com/dexidp/dex v2.13.0+incompatible github.com/dexidp/dex/api/v2 v2.4.0 github.com/ebitengine/purego v0.8.4 github.com/eko/gocache/lib/v4 v4.2.0 @@ -53,9 +53,9 @@ require ( github.com/eko/gocache/store/redis/v4 v4.2.2 github.com/fsnotify/fsnotify v1.9.0 github.com/gliderlabs/ssh v0.3.8 - github.com/go-jose/go-jose/v4 v4.1.3 + github.com/go-jose/go-jose/v4 v4.1.4 github.com/godbus/dbus/v5 v5.1.0 - github.com/golang-jwt/jwt/v5 v5.3.0 + github.com/golang-jwt/jwt/v5 v5.3.1 github.com/golang/mock v1.6.0 github.com/google/go-cmp v0.7.0 github.com/google/gopacket v1.1.19 @@ -72,7 +72,7 @@ require ( github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81 github.com/mdlayher/socket v0.5.1 github.com/mdp/qrterminal/v3 v3.2.1 - github.com/miekg/dns v1.1.59 + github.com/miekg/dns v1.1.72 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42 github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 @@ -113,7 +113,7 @@ require ( go.opentelemetry.io/otel/exporters/prometheus v0.64.0 go.opentelemetry.io/otel/metric v1.43.0 go.opentelemetry.io/otel/sdk/metric v1.43.0 - go.uber.org/mock v0.5.2 + go.uber.org/mock v0.6.0 go.uber.org/zap v1.27.0 goauthentik.io/api/v3 v3.2023051.3 golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b @@ -141,7 +141,7 @@ require ( filippo.io/edwards25519 v1.1.1 // indirect github.com/AppsFlyer/go-sundheit v0.6.0 // indirect github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect - github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 // indirect + github.com/Azure/go-ntlmssp v0.1.0 // indirect github.com/BurntSushi/toml v1.5.0 // indirect github.com/Masterminds/goutils v1.1.1 // indirect github.com/Masterminds/semver/v3 v3.3.0 // indirect @@ -168,6 +168,7 @@ require ( github.com/aws/smithy-go v1.23.0 // indirect github.com/beevik/etree v1.6.0 // indirect github.com/beorn7/perks v1.0.1 // indirect + github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect github.com/caddyserver/zerossl v0.1.3 // indirect github.com/cenkalti/backoff/v5 v5.0.3 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect @@ -183,6 +184,7 @@ require ( github.com/docker/go-units v0.5.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/fredbi/uri v1.1.1 // indirect + github.com/fxamacker/cbor/v2 v2.9.1 // indirect github.com/fyne-io/gl-js v0.2.0 // indirect github.com/fyne-io/glfw-js v0.3.0 // indirect github.com/fyne-io/image v0.1.1 // indirect @@ -190,7 +192,7 @@ require ( github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 // indirect github.com/go-gl/gl v0.0.0-20231021071112-07e5d0ea2e71 // indirect github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a // indirect - github.com/go-ldap/ldap/v3 v3.4.12 // indirect + github.com/go-ldap/ldap/v3 v3.4.13 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-ole/go-ole v1.3.0 // indirect @@ -206,11 +208,15 @@ require ( github.com/go-sql-driver/mysql v1.9.3 // indirect github.com/go-text/render v0.2.0 // indirect github.com/go-text/typesetting v0.2.1 // indirect + github.com/go-viper/mapstructure/v2 v2.5.0 // indirect + github.com/go-webauthn/webauthn v0.16.4 // indirect + github.com/go-webauthn/x v0.2.3 // indirect github.com/goccy/go-yaml v1.18.0 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang-jwt/jwt/v4 v4.5.2 // indirect github.com/google/btree v1.1.2 // indirect github.com/google/go-querystring v1.1.0 // indirect + github.com/google/go-tpm v0.9.8 // indirect github.com/google/s2a-go v0.1.9 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.14 // indirect github.com/googleapis/gax-go/v2 v2.21.0 // indirect @@ -218,7 +224,13 @@ require ( github.com/hack-pad/go-indexeddb v0.3.2 // indirect github.com/hack-pad/safejs v0.1.0 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect + github.com/hashicorp/go-cleanhttp v0.5.2 // indirect + github.com/hashicorp/go-retryablehttp v0.7.8 // indirect + github.com/hashicorp/go-secure-stdlib/parseutil v0.2.0 // indirect + github.com/hashicorp/go-secure-stdlib/strutil v0.1.2 // indirect + github.com/hashicorp/go-sockaddr v1.0.7 // indirect github.com/hashicorp/go-uuid v1.0.3 // indirect + github.com/hashicorp/hcl v1.0.1-vault-7 // indirect github.com/huandu/xstrings v1.5.0 // indirect github.com/huin/goupnp v1.2.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect @@ -238,13 +250,13 @@ require ( github.com/klauspost/cpuid/v2 v2.2.10 // indirect github.com/koron/go-ssdp v0.0.4 // indirect github.com/kr/fs v0.1.0 // indirect - github.com/lib/pq v1.10.9 // indirect + github.com/lib/pq v1.12.3 // indirect github.com/libdns/libdns v0.2.2 // indirect github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae // indirect github.com/magiconair/properties v1.8.10 // indirect github.com/mailru/easyjson v0.9.0 // indirect github.com/mattermost/xml-roundtrip-validator v0.1.0 // indirect - github.com/mattn/go-sqlite3 v1.14.32 // indirect + github.com/mattn/go-sqlite3 v1.14.42 // indirect github.com/mdelapenya/tlscert v0.2.0 // indirect github.com/mdlayher/genetlink v1.3.2 // indirect github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 // indirect @@ -265,8 +277,10 @@ require ( github.com/nxadm/tail v1.4.11 // indirect github.com/oklog/ulid v1.3.1 // indirect github.com/onsi/ginkgo/v2 v2.9.5 // indirect + github.com/openbao/openbao/api/v2 v2.5.1 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.1 // indirect + github.com/philhofer/fwd v1.2.0 // indirect github.com/pion/dtls/v2 v2.2.10 // indirect github.com/pion/dtls/v3 v3.0.9 // indirect github.com/pion/mdns/v2 v2.0.7 // indirect @@ -275,11 +289,13 @@ require ( github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect + github.com/pquerna/otp v1.5.0 // indirect github.com/prometheus/client_model v0.6.2 // indirect github.com/prometheus/common v0.67.5 // indirect github.com/prometheus/otlptranslator v1.0.0 // indirect github.com/prometheus/procfs v0.19.2 // indirect github.com/russellhaering/goxmldsig v1.6.0 // indirect + github.com/ryanuber/go-glob v1.0.0 // indirect github.com/rymdport/portal v0.4.2 // indirect github.com/shirou/gopsutil/v4 v4.25.8 // indirect github.com/shoenig/go-m1cpu v0.2.1 // indirect @@ -288,11 +304,13 @@ require ( github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c // indirect github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef // indirect github.com/stretchr/objx v0.5.2 // indirect + github.com/tinylib/msgp v1.6.3 // indirect github.com/tklauser/go-sysconf v0.3.15 // indirect github.com/tklauser/numcpus v0.10.0 // indirect github.com/vishvananda/netns v0.0.5 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect github.com/wlynxg/anet v0.0.5 // indirect + github.com/x448/float16 v0.8.4 // indirect github.com/yuin/goldmark v1.7.8 // indirect github.com/zeebo/blake3 v0.2.3 // indirect go.mongodb.org/mongo-driver v1.17.9 // indirect @@ -319,10 +337,12 @@ replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-202 replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0 -replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 +replace github.com/cloudflare/circl => codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 replace github.com/pion/ice/v4 => github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 -replace github.com/dexidp/dex => github.com/netbirdio/dex v0.244.0 +replace github.com/dexidp/dex => github.com/netbirdio/dex v0.244.1-0.20260415145816-a0c6b40ff9f2 + +replace github.com/dexidp/dex/api/v2 => github.com/netbirdio/dex/api/v2 v2.0.0-20260415145816-a0c6b40ff9f2 replace github.com/mailru/easyjson => github.com/netbirdio/easyjson v0.9.0 diff --git a/go.sum b/go.sum index d54dc01e6..42652169c 100644 --- a/go.sum +++ b/go.sum @@ -5,6 +5,8 @@ cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3R cloud.google.com/go/compute/metadata v0.2.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs= cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10= +codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 h1:b8xUw3004wk+3ipBhu0VU4RtUJsegMIiqjxSK4++lzA= +codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6/go.mod h1:+CauBF6R70Jqcyl8N2hC8pAXYbWkGIezuSbuGLtRhnw= cunicu.li/go-rosenpass v0.4.0 h1:LtPtBgFWY/9emfgC4glKLEqS0MJTylzV6+ChRhiZERw= cunicu.li/go-rosenpass v0.4.0/go.mod h1:MPbjH9nxV4l3vEagKVdFNwHOketqgS5/To1VYJplf/M= dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s= @@ -23,8 +25,8 @@ github.com/AppsFlyer/go-sundheit v0.6.0 h1:d2hBvCjBSb2lUsEWGfPigr4MCOt04sxB+Rppl github.com/AppsFlyer/go-sundheit v0.6.0/go.mod h1:LDdBHD6tQBtmHsdW+i1GwdTt6Wqc0qazf5ZEJVTbTME= github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0= github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= -github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 h1:mFRzDkZVAjdal+s7s0MwaRv9igoPqLRdzOLzw/8Xvq8= -github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358/go.mod h1:chxPXzSsl7ZWRAuOIE23GDNzjWuZquvFlgA8xmpunjU= +github.com/Azure/go-ntlmssp v0.1.0 h1:DjFo6YtWzNqNvQdrwEyr/e4nhU3vRiwenz5QX7sFz+A= +github.com/Azure/go-ntlmssp v0.1.0/go.mod h1:NYqdhxd/8aAct/s4qSYZEerdPuH1liG2/X9DiVTbhpk= github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg= github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= github.com/Masterminds/goutils v1.1.1 h1:5nUrii3FMTL5diU80unEVvNevw1nH4+ZV4DSLVJLSYI= @@ -91,6 +93,8 @@ github.com/beevik/etree v1.6.0/go.mod h1:bh4zJxiIr62SOf9pRzN7UUYaEDa9HEKafK25+sL github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/bmatcuk/doublestar v1.1.1/go.mod h1:UD6OnuiIn0yFxxA2le/rnRU1G4RaI4UvFv1sNto9p6w= +github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI= +github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= @@ -117,8 +121,8 @@ github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpS github.com/containerd/platforms v0.2.1/go.mod h1:XHCb+2/hzowdiut9rkudds9bE5yJ7npe7dG/wG+uFPw= github.com/coreos/go-iptables v0.7.0 h1:XWM3V+MPRr5/q51NuWSgU0fqMad64Zyxs8ZUoMsamr8= github.com/coreos/go-iptables v0.7.0/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q= -github.com/coreos/go-oidc/v3 v3.14.1 h1:9ePWwfdwC4QKRlCXsJGou56adA/owXczOzwKdOumLqk= -github.com/coreos/go-oidc/v3 v3.14.1/go.mod h1:HaZ3szPaZ0e4r6ebqvsLWlk2Tn+aejfmrfah6hnSYEU= +github.com/coreos/go-oidc/v3 v3.18.0 h1:V9orjXynvu5wiC9SemFTWnG4F45v403aIcjWo0d41+A= +github.com/coreos/go-oidc/v3 v3.18.0/go.mod h1:DYCf24+ncYi+XkIH97GY1+dqoRlbaSI26KVTCI9SrY4= github.com/cpuguy83/dockercfg v0.3.2 h1:DlJTyZGBDlXqUZ2Dk2Q3xHs/FtnooJJVaad2S9GKorA= github.com/cpuguy83/dockercfg v0.3.2/go.mod h1:sugsbF4//dDlL/i+S+rtpIWp+5h0BHJHfjj5/jFyUJc= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= @@ -130,14 +134,10 @@ github.com/crowdsecurity/go-cs-bouncer v0.0.21 h1:arPz0VtdVSaz+auOSfHythzkZVLyy1 github.com/crowdsecurity/go-cs-bouncer v0.0.21/go.mod h1:4JiH0XXA4KKnnWThItUpe5+heJHWzsLOSA2IWJqUDBA= github.com/crowdsecurity/go-cs-lib v0.0.25 h1:Ov6VPW9yV+OPsbAIQk1iTkEWhwkpaG0v3lrBzeqjzj4= github.com/crowdsecurity/go-cs-lib v0.0.25/go.mod h1:X0GMJY2CxdA1S09SpuqIKaWQsvRGxXmecUp9cP599dE= -github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 h1:/DS5cDX3FJdl+XaN2D7XAwFpuanTxnp52DBLZAaJKx0= -github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6/go.mod h1:+CauBF6R70Jqcyl8N2hC8pAXYbWkGIezuSbuGLtRhnw= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dexidp/dex/api/v2 v2.4.0 h1:gNba7n6BKVp8X4Jp24cxYn5rIIGhM6kDOXcZoL6tr9A= -github.com/dexidp/dex/api/v2 v2.4.0/go.mod h1:/p550ADvFFh7K95VmhUD+jgm15VdaNnab9td8DHOpyI= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk= @@ -156,6 +156,8 @@ github.com/eko/gocache/store/go_cache/v4 v4.2.2 h1:tAI9nl6TLoJyKG1ujF0CS0n/IgTEM github.com/eko/gocache/store/go_cache/v4 v4.2.2/go.mod h1:T9zkHokzr8K9EiC7RfMbDg6HSwaV6rv3UdcNu13SGcA= github.com/eko/gocache/store/redis/v4 v4.2.2 h1:Thw31fzGuH3WzJywsdbMivOmP550D6JS7GDHhvCJPA0= github.com/eko/gocache/store/redis/v4 v4.2.2/go.mod h1:LaTxLKx9TG/YUEybQvPMij++D7PBTIJ4+pzvk0ykz0w= +github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= +github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= github.com/felixge/fgprof v0.9.3 h1:VvyZxILNuCiUCSXtPtYmmtGvb65nqXh2QFWc0Wpf2/g= github.com/felixge/fgprof v0.9.3/go.mod h1:RdbpDgzqYVh/T9fPELJyV7EYJuHB55UTEULNun8eiPw= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= @@ -171,6 +173,8 @@ github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4 github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw= github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= +github.com/fxamacker/cbor/v2 v2.9.1 h1:2rWm8B193Ll4VdjsJY28jxs70IdDsHRWgQYAI80+rMQ= +github.com/fxamacker/cbor/v2 v2.9.1/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ= github.com/fyne-io/gl-js v0.2.0 h1:+EXMLVEa18EfkXBVKhifYB6OGs3HwKO3lUElA0LlAjs= github.com/fyne-io/gl-js v0.2.0/go.mod h1:ZcepK8vmOYLu96JoxbCKJy2ybr+g1pTnaBDdl7c3ajI= github.com/fyne-io/glfw-js v0.3.0 h1:d8k2+Y7l+zy2pc7wlGRyPfTgZoqDf3AI4G+2zOWhWUk= @@ -189,10 +193,10 @@ github.com/go-gl/gl v0.0.0-20231021071112-07e5d0ea2e71 h1:5BVwOaUSBTlVZowGO6VZGw github.com/go-gl/gl v0.0.0-20231021071112-07e5d0ea2e71/go.mod h1:9YTyiznxEY1fVinfM7RvRcjRHbw2xLBJ3AAGIT0I4Nw= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a h1:vxnBhFDDT+xzxf1jTJKMKZw3H0swfWk9RpWbBbDK5+0= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= -github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs= -github.com/go-jose/go-jose/v4 v4.1.3/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08= -github.com/go-ldap/ldap/v3 v3.4.12 h1:1b81mv7MagXZ7+1r7cLTWmyuTqVqdwbtJSjC0DAp9s4= -github.com/go-ldap/ldap/v3 v3.4.12/go.mod h1:+SPAGcTtOfmGsCb3h1RFiq4xpp4N636G75OEace8lNo= +github.com/go-jose/go-jose/v4 v4.1.4 h1:moDMcTHmvE6Groj34emNPLs/qtYXRVcd6S7NHbHz3kA= +github.com/go-jose/go-jose/v4 v4.1.4/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08= +github.com/go-ldap/ldap/v3 v3.4.13 h1:+x1nG9h+MZN7h/lUi5Q3UZ0fJ1GyDQYbPvbuH38baDQ= +github.com/go-ldap/ldap/v3 v3.4.13/go.mod h1:LxsGZV6vbaK0sIvYfsv47rfh4ca0JXokCoKjZxsszv0= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= @@ -229,12 +233,20 @@ github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI6 github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= +github.com/go-test/deep v1.1.1 h1:0r/53hagsehfO4bzD2Pgr/+RgHqhmf+k1Bpse2cTu1U= +github.com/go-test/deep v1.1.1/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= github.com/go-text/render v0.2.0 h1:LBYoTmp5jYiJ4NPqDc2pz17MLmA3wHw1dZSVGcOdeAc= github.com/go-text/render v0.2.0/go.mod h1:CkiqfukRGKJA5vZZISkjSYrcdtgKQWRa2HIzvwNN5SU= github.com/go-text/typesetting v0.2.1 h1:x0jMOGyO3d1qFAPI0j4GSsh7M0Q3Ypjzr4+CEVg82V8= github.com/go-text/typesetting v0.2.1/go.mod h1:mTOxEwasOFpAMBjEQDhdWRckoLLeI/+qrQeBCTGEt6M= github.com/go-text/typesetting-utils v0.0.0-20241103174707-87a29e9e6066 h1:qCuYC+94v2xrb1PoS4NIDe7DGYtLnU2wWiQe9a1B1c0= github.com/go-text/typesetting-utils v0.0.0-20241103174707-87a29e9e6066/go.mod h1:DDxDdQEnB70R8owOx3LVpEFvpMK9eeH1o2r0yZhFI9o= +github.com/go-viper/mapstructure/v2 v2.5.0 h1:vM5IJoUAy3d7zRSVtIwQgBj7BiWtMPfmPEgAXnvj1Ro= +github.com/go-viper/mapstructure/v2 v2.5.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= +github.com/go-webauthn/webauthn v0.16.4 h1:R9jqR/cYZa7hRquFF7Za/8qoH/K/TIs1/Q/4CyGN+1Q= +github.com/go-webauthn/webauthn v0.16.4/go.mod h1:SU2ljAgToTV/YLPI0C05QS4qn+e04WpB5g1RMfcZfS4= +github.com/go-webauthn/x v0.2.3 h1:8oArS+Rc1SWFLXhE17KZNx258Z4kUSyaDgsSncCO5RA= +github.com/go-webauthn/x v0.2.3/go.mod h1:tM04GF3V6VYq79AZMl7vbj4q6pz9r7L2criWRzbWhPk= github.com/goccy/go-yaml v1.18.0 h1:8W7wMFS12Pcas7KU+VVkaiCng+kG8QiFeFwzFb+rwuw= github.com/goccy/go-yaml v1.18.0/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk= @@ -243,8 +255,8 @@ github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI= github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= -github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= -github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= +github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY= +github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -276,6 +288,10 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8= github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU= +github.com/google/go-tpm v0.9.8 h1:slArAR9Ft+1ybZu0lBwpSmpwhRXaa85hWtMinMyRAWo= +github.com/google/go-tpm v0.9.8/go.mod h1:h9jEsEECg7gtLis0upRBQU+GhYVH6jMjrFxI8u6bVUY= +github.com/google/go-tpm-tools v0.3.13-0.20230620182252-4639ecce2aba h1:qJEJcuLzH5KDR0gKc0zcktin6KSAwL7+jWKBYceddTc= +github.com/google/go-tpm-tools v0.3.13-0.20230620182252-4639ecce2aba/go.mod h1:EFYHy8/1y2KfgTAsx7Luu7NGhoxtuVHnNo8jE7FikKc= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= @@ -308,15 +324,29 @@ github.com/hack-pad/safejs v0.1.0/go.mod h1:HdS+bKF1NrE72VoXZeWzxFOVQVUSqZJAG0xN github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= +github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= +github.com/hashicorp/go-hclog v1.6.3 h1:Qr2kF+eVWjTiYmU7Y31tYlP1h0q/X3Nl3tPGdaB11/k= +github.com/hashicorp/go-hclog v1.6.3/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M= github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= +github.com/hashicorp/go-retryablehttp v0.7.8 h1:ylXZWnqa7Lhqpk0L1P1LzDtGcCR0rPVUrx/c8Unxc48= +github.com/hashicorp/go-retryablehttp v0.7.8/go.mod h1:rjiScheydd+CxvumBsIrFKlx3iS0jrZ7LvzFGFmuKbw= github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 h1:ET4pqyjiGmY09R5y+rSd70J2w45CtbWDNvGqWp/R3Ng= github.com/hashicorp/go-secure-stdlib/base62 v0.1.2/go.mod h1:EdWO6czbmthiwZ3/PUsDV+UD1D5IRU4ActiaWGwt0Yw= +github.com/hashicorp/go-secure-stdlib/parseutil v0.2.0 h1:U+kC2dOhMFQctRfhK0gRctKAPTloZdMU5ZJxaesJ/VM= +github.com/hashicorp/go-secure-stdlib/parseutil v0.2.0/go.mod h1:Ll013mhdmsVDuoIXVfBtvgGJsXDYkTw1kooNcoCXuE0= +github.com/hashicorp/go-secure-stdlib/strutil v0.1.2 h1:kes8mmyCpxJsI7FTwtzRqEy9CdjCtrXrXGuOpxEA7Ts= +github.com/hashicorp/go-secure-stdlib/strutil v0.1.2/go.mod h1:Gou2R9+il93BqX25LAKCLuM+y9U2T4hlwvT1yprcna4= +github.com/hashicorp/go-sockaddr v1.0.7 h1:G+pTkSO01HpR5qCxg7lxfsFEZaG+C0VssTy/9dbT+Fw= +github.com/hashicorp/go-sockaddr v1.0.7/go.mod h1:FZQbEYa1pxkQ7WLpyXJ6cbjpT8q0YgQaK/JakXqGyWw= github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8= github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/go-version v1.7.0 h1:5tqGy27NaOTB8yJKUZELlFAS/LTKJkrmONwQKeRZfjY= github.com/hashicorp/go-version v1.7.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= +github.com/hashicorp/hcl v1.0.1-vault-7 h1:ag5OxFVy3QYTFTJODRzTKVZ6xvdfLLCA1cy/Y6xGI0I= +github.com/hashicorp/hcl v1.0.1-vault-7/go.mod h1:XYhtn6ijBSAj6n4YqAaf7RBPS4I06AItNorpy+MoQNM= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI= github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= @@ -387,8 +417,8 @@ github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/leodido/go-urn v1.1.0/go.mod h1:+cyI34gQWZcE1eQU7NVgKkkzdXDQHr1dBMtdAPozLkw= -github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= -github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/lib/pq v1.12.3 h1:tTWxr2YLKwIvK90ZXEw8GP7UFHtcbTtty8zsI+YjrfQ= +github.com/lib/pq v1.12.3/go.mod h1:/p+8NSbOcwzAEI7wiMXFlgydTwcgTr3OSKMsD2BitpA= github.com/libdns/libdns v0.2.2 h1:O6ws7bAfRPaBsgAYt8MDe2HcNBGC29hkZ9MX2eUSX3s= github.com/libdns/libdns v0.2.2/go.mod h1:4Bj9+5CQiNMVGf87wjX4CY3HQJypUHRuLvlsfsZqLWQ= github.com/libdns/route53 v1.5.0 h1:2SKdpPFl/qgWsXQvsLNJJAoX7rSxlk7zgoL4jnWdXVA= @@ -406,9 +436,13 @@ github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8S github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= github.com/mattermost/xml-roundtrip-validator v0.1.0 h1:RXbVD2UAl7A7nOTR4u7E3ILa4IbtvKBHw64LDsmu9hU= github.com/mattermost/xml-roundtrip-validator v0.1.0/go.mod h1:qccnGMcpgwcNaBnxqpJpWWUiPNr5H3O8eDgGV9gT5To= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ= -github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs= -github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-sqlite3 v1.14.42 h1:MigqEP4ZmHw3aIdIT7T+9TLa90Z6smwcthx+Azv4Cgo= +github.com/mattn/go-sqlite3 v1.14.42/go.mod h1:pjEuOr8IwzLJP2MfGeTb0A35jauH+C2kbHKBr7yXKVQ= github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI= github.com/mdelapenya/tlscert v0.2.0/go.mod h1:O4njj3ELLnJjGdkN7M/vIVCpZ+Cf0L6muqOG4tLSl8o= github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw= @@ -421,8 +455,8 @@ github.com/mdp/qrterminal/v3 v3.2.1 h1:6+yQjiiOsSuXT5n9/m60E54vdgFsw0zhADHhHLrFe github.com/mdp/qrterminal/v3 v3.2.1/go.mod h1:jOTmXvnBsMy5xqLniO0R++Jmjs2sTm9dFSuQ5kpz/SU= github.com/mholt/acmez/v2 v2.0.1 h1:3/3N0u1pLjMK4sNEAFSI+bcvzbPhRpY383sy1kLHJ6k= github.com/mholt/acmez/v2 v2.0.1/go.mod h1:fX4c9r5jYwMyMsC+7tkYRxHibkOTgta5DIFGoe67e1U= -github.com/miekg/dns v1.1.59 h1:C9EXc/UToRwKLhK5wKU/I4QVsBUc8kE6MkHBkeypWZs= -github.com/miekg/dns v1.1.59/go.mod h1:nZpewl5p6IvctfgrckopVx2OlSEHPRO/U4SYkRklrEk= +github.com/miekg/dns v1.1.72 h1:vhmr+TF2A3tuoGNkLDFK9zi36F2LS+hKTRW0Uf8kbzI= +github.com/miekg/dns v1.1.72/go.mod h1:+EuEPhdHOsfk6Wk5TT2CzssZdqkmFhf8r+aVyDEToIs= github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws= github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc= github.com/mitchellh/copystructure v1.2.0 h1:vpKXTN4ewci03Vljg/q9QvCGUDttBOGBIa15WveJJGw= @@ -451,8 +485,10 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= -github.com/netbirdio/dex v0.244.0 h1:1GOvi8wnXYassnKGildzNqRHq0RbcfEUw7LKYpKIN7U= -github.com/netbirdio/dex v0.244.0/go.mod h1:STGInJhPcAflrHmDO7vyit2kSq03PdL+8zQPoGALtcU= +github.com/netbirdio/dex v0.244.1-0.20260415145816-a0c6b40ff9f2 h1:AP7OM/JnTogod3rVcLsMuilSG94kWQCr3z6R4rfVXnc= +github.com/netbirdio/dex v0.244.1-0.20260415145816-a0c6b40ff9f2/go.mod h1:+trSlzHNmdJGvz0oLEyyiuaPstUeD7YO6B3Fx9nyziY= +github.com/netbirdio/dex/api/v2 v2.0.0-20260415145816-a0c6b40ff9f2 h1:HEEGJPsVw7/p7SEL3HWP4vaInxHo8OJSEaOkHpUAk+M= +github.com/netbirdio/dex/api/v2 v2.0.0-20260415145816-a0c6b40ff9f2/go.mod h1:awuTyT29CYALpEyET0S307EgNlPWrc7fFKRAyhsO45M= github.com/netbirdio/easyjson v0.9.0 h1:6Nw2lghSVuy8RSkAYDhDv1thBVEmfVbKZnV7T7Z6Aus= github.com/netbirdio/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI= @@ -489,6 +525,8 @@ github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7J github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE= github.com/onsi/gomega v1.27.6/go.mod h1:PIQNjfQwkP3aQAH7lf7j87O/5FiNr+ZR8+ipb+qQlhg= +github.com/openbao/openbao/api/v2 v2.5.1 h1:Br79D6L20SbAa5P7xqENxmvv8LyI4HoKosPy7klhn4o= +github.com/openbao/openbao/api/v2 v2.5.1/go.mod h1:Dh5un77tqGgMbmlVEqjqN+8/dMyUohnkaQVg/wXW0Ig= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= @@ -501,6 +539,8 @@ github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0 github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203 h1:E7Kmf11E4K7B5hDti2K2NqPb1nlYlGYsu02S1JNd/Bs= github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= +github.com/philhofer/fwd v1.2.0 h1:e6DnBTl7vGY+Gz322/ASL4Gyp1FspeMvx1RNDoToZuM= +github.com/philhofer/fwd v1.2.0/go.mod h1:RqIHx9QI14HlwKwm98g9Re5prTQ6LdeRQn+gXJFxsJM= github.com/pion/dtls/v2 v2.2.7/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ/F9s= github.com/pion/dtls/v2 v2.2.10 h1:u2Axk+FyIR1VFTPurktB+1zoEPGIW3bmyj3LEFrXjAA= github.com/pion/dtls/v2 v2.2.10/go.mod h1:d9SYc9fch0CqK90mRk1dC7AkzzpwJj6u2GU3u+9pqFE= @@ -542,6 +582,8 @@ github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 h1:o4JXh1EVt9k/+g42oCprj/FisM4qX9L3sZB3upGN2ZU= github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= +github.com/pquerna/otp v1.5.0 h1:NMMR+WrmaqXU4EzdGJEE1aUUI0AMRzsp96fFFWNPwxs= +github.com/pquerna/otp v1.5.0/go.mod h1:dkJfzwRKNiegxyNb54X/3fLwhCynbMspSyWKnvi1AEg= github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o= github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg= github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk= @@ -565,6 +607,8 @@ github.com/rs/xid v1.3.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/russellhaering/goxmldsig v1.6.0 h1:8fdWXEPh2k/NZNQBPFNoVfS3JmzS4ZprY/sAOpKQLks= github.com/russellhaering/goxmldsig v1.6.0/go.mod h1:TrnaquDcYxWXfJrOjeMBTX4mLBeYAqaHEyUeWPxZlBM= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/ryanuber/go-glob v1.0.0 h1:iQh3xXAumdQ+4Ufa5b25cRpC5TYKlno6hsv6Cb3pkBk= +github.com/ryanuber/go-glob v1.0.0/go.mod h1:807d1WSdnB0XRJzKNil9Om6lcp/3a0v4qIHxIXzX/Yc= github.com/rymdport/portal v0.4.2 h1:7jKRSemwlTyVHHrTGgQg7gmNPJs88xkbKcIL3NlcmSU= github.com/rymdport/portal v0.4.2/go.mod h1:kFF4jslnJ8pD5uCi17brj/ODlfIidOxlgUDTO5ncnC4= github.com/shirou/gopsutil/v3 v3.24.4 h1:dEHgzZXt4LMNm+oYELpzl9YCqV65Yr/6SfrvgRBtXeU= @@ -587,8 +631,8 @@ github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 h1:TG/diQgUe0pntT/2D github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E= github.com/spf13/cast v1.7.0 h1:ntdiHjuueXFgm5nzDRdOS4yfT43P5Fnud6DH50rz/7w= github.com/spf13/cast v1.7.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= -github.com/spf13/cobra v1.10.1 h1:lJeBwCfmrnXthfAupyUTzJ/J4Nc1RsHC/mSRU2dll/s= -github.com/spf13/cobra v1.10.1/go.mod h1:7SmJGaTHFVBY0jW4NXGluQoLvhqFQM+6XSKD+P4XaB0= +github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU= +github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4= github.com/spf13/pflag v1.0.9 h1:9exaQaMOCwffKiiiYk6/BndUBv+iRViNW+4lEMi0PvY= github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spkg/bom v0.0.0-20160624110644-59b7046e48ad/go.mod h1:qLr4V1qq6nMqFKkMo8ZTx3f+BZEkzsRUY10Xsm2mwU0= @@ -628,6 +672,8 @@ github.com/ti-mo/conntrack v0.5.1 h1:opEwkFICnDbQc0BUXl73PHBK0h23jEIFVjXsqvF4GY0 github.com/ti-mo/conntrack v0.5.1/go.mod h1:T6NCbkMdVU4qEIgwL0njA6lw/iCAbzchlnwm1Sa314o= github.com/ti-mo/netfilter v0.5.2 h1:CTjOwFuNNeZ9QPdRXt1MZFLFUf84cKtiQutNauHWd40= github.com/ti-mo/netfilter v0.5.2/go.mod h1:Btx3AtFiOVdHReTDmP9AE+hlkOcvIy403u7BXXbWZKo= +github.com/tinylib/msgp v1.6.3 h1:bCSxiTz386UTgyT1i0MSCvdbWjVW+8sG3PjkGsZQt4s= +github.com/tinylib/msgp v1.6.3/go.mod h1:RSp0LW9oSxFut3KzESt5Voq4GVWyS+PSulT77roAqEA= github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI= github.com/tklauser/go-sysconf v0.3.15 h1:VE89k0criAymJ/Os65CSn1IXaol+1wrsFHEB8Ol49K4= github.com/tklauser/go-sysconf v0.3.15/go.mod h1:Dmjwr6tYFIseJw7a3dRLJfsHAMXZ3nEnL/aZY+0IuI4= @@ -646,6 +692,8 @@ github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAh github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= github.com/wlynxg/anet v0.0.5 h1:J3VJGi1gvo0JwZ/P1/Yc/8p63SoW98B5dHkYDmpgvvU= github.com/wlynxg/anet v0.0.5/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA= +github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= +github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= @@ -690,14 +738,15 @@ go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lI go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= -go.uber.org/mock v0.5.2 h1:LbtPTcP8A5k9WPXj54PPPbjcI4Y6lhyOZXn+VS7wNko= -go.uber.org/mock v0.5.2/go.mod h1:wLlUxC2vVTPTaE3UD51E0BGOAElKrILxhVSDYQLld5o= +go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y= +go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0= go.yaml.in/yaml/v2 v2.4.3/go.mod h1:zSxWcmIDjOzPXpjlTTbAsKokqkDNAVtZO0WOMiT90s8= +go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= goauthentik.io/api/v3 v3.2023051.3 h1:NebAhD/TeTWNo/9X3/Uj+rM5fG1HaiLOlKTNLQv9Qq4= goauthentik.io/api/v3 v3.2023051.3/go.mod h1:nYECml4jGbp/541hj8GcylKQG1gVBsKppHy4+7G8u4U= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= diff --git a/idp/dex/config.go b/idp/dex/config.go index e686233ad..56ed998c2 100644 --- a/idp/dex/config.go +++ b/idp/dex/config.go @@ -51,6 +51,70 @@ type YAMLConfig struct { // StaticPasswords cause the server use this list of passwords rather than // querying the storage. StaticPasswords []Password `yaml:"staticPasswords" json:"staticPasswords"` + + // Sessions holds authentication session configuration. + // Requires DEX_SESSIONS_ENABLED=true feature flag. + Sessions *Sessions `yaml:"sessions" json:"sessions"` + + // MFA holds multi-factor authentication configuration. + MFA MFAConfig `yaml:"mfa" json:"mfa"` +} + +type Sessions struct { + // CookieName is the name of the session cookie. Defaults to "dex_session". + CookieName string `yaml:"cookieName" json:"cookieName"` + // AbsoluteLifetime is the maximum session lifetime from creation. Defaults to "24h". + AbsoluteLifetime string `yaml:"absoluteLifetime" json:"absoluteLifetime"` + // ValidIfNotUsedFor is the idle timeout. Defaults to "1h". + ValidIfNotUsedFor string `yaml:"validIfNotUsedFor" json:"validIfNotUsedFor"` + // RememberMeCheckedByDefault controls the default state of the "remember me" checkbox. + RememberMeCheckedByDefault *bool `yaml:"rememberMeCheckedByDefault" json:"rememberMeCheckedByDefault"` + // CookieEncryptionKey is the AES key for encrypting session cookies. + // Must be 16, 24, or 32 bytes for AES-128, AES-192, or AES-256. + // If empty, cookies are not encrypted. + CookieEncryptionKey string `yaml:"cookieEncryptionKey" json:"cookieEncryptionKey"` + // SSOSharedWithDefault is the default SSO sharing policy for clients without explicit ssoSharedWith. + // "all" = share with all clients, "none" = share with no one (default: "none"). + SSOSharedWithDefault string `yaml:"ssoSharedWithDefault" json:"ssoSharedWithDefault"` +} + +type MFAConfig struct { + Authenticators []MFAAuthenticator `yaml:"authenticators" json:"authenticators"` +} + +type MFAAuthenticator struct { + ID string `yaml:"id" json:"id"` + Type string `yaml:"type" json:"type"` + Config map[string]interface{} `yaml:"config" json:"config"` + + ConnectorTypes []string `yaml:"connectorTypes" json:"connectorTypes"` +} + +type TOTPConfig struct { + Issuer string `yaml:"issuer" json:"issuer"` +} + +// WebAuthnConfig holds configuration for a WebAuthn authenticator. +type WebAuthnConfig struct { + // RPDisplayName is the human-readable relying party name shown in the browser + // dialog during key registration and authentication (e.g., "My Company SSO"). + RPDisplayName string `yaml:"rpDisplayName" json:"rpDisplayName"` + // RPID is the relying party identifier — must match the domain in the browser + // address bar. If empty, derived from the issuer URL hostname. + // Example: "auth.example.com" + RPID string `yaml:"rpID" json:"rpID"` + // RPOrigins is the list of allowed origins for WebAuthn ceremonies. + // If empty, derived from the issuer URL (scheme + host). + // Example: ["https://auth.example.com"] + RPOrigins []string `yaml:"rpOrigins" json:"rpOrigins"` + // AttestationPreference controls what attestation data the authenticator should provide: + // "none" — don't request attestation (simpler, more private) + // "indirect" — authenticator may anonymize attestation (default) + // "direct" — request full attestation (for enterprise key model verification) + AttestationPreference string `yaml:"attestationPreference" json:"attestationPreference"` + // Timeout is the duration allowed for the browser WebAuthn ceremony + // (registration or login). Defaults to "60s". + Timeout string `yaml:"timeout" json:"timeout"` } // Web is the config format for the HTTP server. @@ -116,7 +180,6 @@ type Storage struct { Config map[string]interface{} `yaml:"config" json:"config"` } -// Password represents a static user configuration type Password storage.Password func (p *Password) UnmarshalYAML(node *yaml.Node) error { @@ -429,9 +492,98 @@ func (c *YAMLConfig) Validate() error { if !c.EnablePasswordDB && len(c.StaticPasswords) != 0 { return fmt.Errorf("cannot specify static passwords without enabling password db") } + return nil } +func buildTotpConfig(auth MFAAuthenticator) (*server.TOTPProvider, error) { + data, err := json.Marshal(auth.Config) + if err != nil { + return nil, fmt.Errorf("failed to marshal TOTP config id: %s - %w", auth.ID, err) + } + + var cfg TOTPConfig + if err := json.Unmarshal(data, &cfg); err != nil { + return nil, fmt.Errorf("failed to parse TOTP config id: %s - %w", auth.ID, err) + } + + return server.NewTOTPProvider(cfg.Issuer, auth.ConnectorTypes), nil +} + +func buildWebAuthnConfig(auth MFAAuthenticator, issuerURL string) (*server.WebAuthnProvider, error) { + data, err := json.Marshal(auth.Config) + if err != nil { + return nil, fmt.Errorf("failed to marshal WebAuthn config id: %s - %w", auth.ID, err) + } + + var cfg WebAuthnConfig + if err := json.Unmarshal(data, &cfg); err != nil { + return nil, fmt.Errorf("failed to parse WebAuthn config id: %s - %w", auth.ID, err) + } + + provider, err := server.NewWebAuthnProvider(cfg.RPDisplayName, cfg.RPID, cfg.RPOrigins, + cfg.AttestationPreference, cfg.Timeout, issuerURL, auth.ConnectorTypes) + if err != nil { + return nil, fmt.Errorf("failed to create WebAuthn provider id: %s - err: %w", auth.ID, err) + } + + return provider, nil +} + +func buildMFAProviders(authenticators []MFAAuthenticator, issuerURL string, logger *slog.Logger) map[string]server.MFAProvider { + if len(authenticators) == 0 { + return nil + } + + providers := make(map[string]server.MFAProvider, len(authenticators)) + for _, auth := range authenticators { + switch auth.Type { + case "TOTP": + provider, err := buildTotpConfig(auth) + if err != nil { + logger.Error("failed to parse TOTP config", "id", auth.ID, "err", err) + continue + } + providers[auth.ID] = provider + logger.Info("MFA authenticator configured", "id", auth.ID, "type", auth.Type) + case "WebAuthn": + provider, err := buildWebAuthnConfig(auth, issuerURL) + if err != nil { + logger.Error("failed to parse WebAuthn config", "id", auth.ID, "err", err) + continue + } + providers[auth.ID] = provider + logger.Info("MFA authenticator configured", "id", auth.ID, "type", auth.Type) + default: + logger.Error("unknown MFA authenticator type, skipping", "id", auth.ID, "type", auth.Type) + } + } + return providers +} + +func buildSessionsConfig(sessions *Sessions) *server.SessionConfig { + if sessions == nil { + return nil + } + + if sessions.RememberMeCheckedByDefault == nil { + defaultRememberMeCheckedByDefault := false + sessions.RememberMeCheckedByDefault = &defaultRememberMeCheckedByDefault + } + + absoluteLifetime, _ := parseDuration(sessions.AbsoluteLifetime) + validIfNotUsedFor, _ := parseDuration(sessions.ValidIfNotUsedFor) + + return &server.SessionConfig{ + CookieEncryptionKey: []byte(sessions.CookieEncryptionKey), + CookieName: sessions.CookieName, + AbsoluteLifetime: absoluteLifetime, + ValidIfNotUsedFor: validIfNotUsedFor, + RememberMeCheckedByDefault: *sessions.RememberMeCheckedByDefault, + SSOSharedWithDefault: sessions.SSOSharedWithDefault, + } +} + // ToServerConfig converts YAMLConfig to dex server.Config func (c *YAMLConfig) ToServerConfig(stor storage.Storage, logger *slog.Logger) server.Config { cfg := server.Config{ @@ -448,6 +600,8 @@ func (c *YAMLConfig) ToServerConfig(stor storage.Storage, logger *slog.Logger) s Dir: c.Frontend.Dir, Extra: c.Frontend.Extra, }, + SessionConfig: buildSessionsConfig(c.Sessions), + MFAProviders: buildMFAProviders(c.MFA.Authenticators, c.Issuer, logger), } // Use embedded NetBird-styled templates if no custom dir specified @@ -460,11 +614,6 @@ func (c *YAMLConfig) ToServerConfig(stor storage.Storage, logger *slog.Logger) s } // Apply expiry settings - if c.Expiry.SigningKeys != "" { - if d, err := parseDuration(c.Expiry.SigningKeys); err == nil { - cfg.RotateKeysAfter = d - } - } if c.Expiry.IDTokens != "" { if d, err := parseDuration(c.Expiry.IDTokens); err == nil { cfg.IDTokensValidFor = d diff --git a/idp/dex/provider.go b/idp/dex/provider.go index 24aed1b99..526d6a17a 100644 --- a/idp/dex/provider.go +++ b/idp/dex/provider.go @@ -18,6 +18,7 @@ import ( dexapi "github.com/dexidp/dex/api/v2" "github.com/dexidp/dex/server" + "github.com/dexidp/dex/server/signer" "github.com/dexidp/dex/storage" "github.com/dexidp/dex/storage/sql" jose "github.com/go-jose/go-jose/v4" @@ -70,7 +71,7 @@ func NewProvider(ctx context.Context, config *Config) (*Provider, error) { logger := slog.New(slog.NewTextHandler(os.Stderr, nil)) // Ensure data directory exists - if err := os.MkdirAll(config.DataDir, 0700); err != nil { + if err := os.MkdirAll(config.DataDir, 0o700); err != nil { return nil, fmt.Errorf("failed to create data directory: %w", err) } @@ -101,6 +102,15 @@ func NewProvider(ctx context.Context, config *Config) (*Provider, error) { return nil, fmt.Errorf("failed to create refresh token policy: %w", err) } + localSignerConfig := signer.LocalConfig{ + KeysRotationPeriod: "6h", + } + + localSigner, err := localSignerConfig.Open(ctx, stor, 24*time.Hour, time.Now, logger) + if err != nil { + return nil, fmt.Errorf("failed to create local signer: %w", err) + } + // Build Dex server config - use Dex's types directly dexConfig := server.Config{ Issuer: issuer, @@ -110,12 +120,12 @@ func NewProvider(ctx context.Context, config *Config) (*Provider, error) { ContinueOnConnectorFailure: true, Logger: logger, PrometheusRegistry: prometheus.NewRegistry(), - RotateKeysAfter: 6 * time.Hour, IDTokensValidFor: 24 * time.Hour, RefreshTokenPolicy: refreshPolicy, Web: server.WebConfig{ Issuer: "NetBird", }, + Signer: localSigner, } dexSrv, err := server.NewServer(ctx, dexConfig) @@ -167,6 +177,14 @@ func NewProviderFromYAML(ctx context.Context, yamlConfig *YAMLConfig) (*Provider return nil, fmt.Errorf("failed to create refresh token policy: %w", err) } + localSigner, err := getSigner(ctx, stor, yamlConfig, logger) + if err != nil { + stor.Close() + return nil, fmt.Errorf("failed to create local signer: %w", err) + } + + dexConfig.Signer = localSigner + dexSrv, err := server.NewServer(ctx, dexConfig) if err != nil { stor.Close() @@ -182,6 +200,32 @@ func NewProviderFromYAML(ctx context.Context, yamlConfig *YAMLConfig) (*Provider }, nil } +func getSigner(ctx context.Context, stor storage.Storage, yamlConfig *YAMLConfig, logger *slog.Logger) (signer.Signer, error) { + // Parse expiry durations + idTokensValidFor := 24 * time.Hour // default + if yamlConfig.Expiry.IDTokens != "" { + var err error + idTokensValidFor, err = parseDuration(yamlConfig.Expiry.IDTokens) + if err != nil { + return nil, fmt.Errorf("invalid config value %q for id token expiry: %v", yamlConfig.Expiry.IDTokens, err) + } + } + + localSignerConfig := &signer.LocalConfig{ + KeysRotationPeriod: "720h", // 30 Days + } + + if yamlConfig.Expiry.SigningKeys != "" { + if _, err := parseDuration(yamlConfig.Expiry.SigningKeys); err != nil { + return nil, fmt.Errorf("invalid config value %q for signing key expiry: %v", yamlConfig.Expiry.SigningKeys, err) + } + + localSignerConfig.KeysRotationPeriod = yamlConfig.Expiry.SigningKeys + } + + return localSignerConfig.Open(ctx, stor, idTokensValidFor, time.Now, logger) +} + // initializeStorage sets up connectors, passwords, and clients in storage func initializeStorage(ctx context.Context, stor storage.Storage, cfg *YAMLConfig) error { if cfg.EnablePasswordDB { @@ -241,6 +285,8 @@ func ensureStaticClients(ctx context.Context, stor storage.Storage, clients []st old.RedirectURIs = client.RedirectURIs old.Name = client.Name old.Public = client.Public + old.PostLogoutRedirectURIs = client.PostLogoutRedirectURIs + old.MFAChain = client.MFAChain return old, nil }); err != nil { return fmt.Errorf("failed to update client %s: %w", client.ID, err) @@ -253,9 +299,6 @@ func ensureStaticClients(ctx context.Context, stor storage.Storage, clients []st func buildDexConfig(yamlConfig *YAMLConfig, stor storage.Storage, logger *slog.Logger) server.Config { cfg := yamlConfig.ToServerConfig(stor, logger) cfg.PrometheusRegistry = prometheus.NewRegistry() - if cfg.RotateKeysAfter == 0 { - cfg.RotateKeysAfter = 24 * 30 * time.Hour - } if cfg.IDTokensValidFor == 0 { cfg.IDTokensValidFor = 24 * time.Hour } @@ -450,10 +493,34 @@ func (p *Provider) Storage() storage.Storage { return p.storage } +// SetClientsMFAChain updates the MFAChain field on the dashboard and CLI OAuth2 clients. +// Pass a non-empty slice (e.g. []string{"default-totp"}) to enable MFA, or nil to disable it. +func (p *Provider) SetClientsMFAChain(ctx context.Context, clientIDs []string, mfaChain []string) error { + for _, clientID := range clientIDs { + if err := p.storage.UpdateClient(ctx, clientID, func(old storage.Client) (storage.Client, error) { + old.MFAChain = mfaChain + return old, nil + }); err != nil { + return fmt.Errorf("failed to update MFA chain on client %s: %w", clientID, err) + } + } + return nil +} + // Handler returns the Dex server as an http.Handler for embedding in another server. // The handler expects requests with path prefix "/oauth2/". func (p *Provider) Handler() http.Handler { - return p.dexServer + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Dex's /logout endpoint requires id_token_hint for RP-initiated logout with + // post_logout_redirect_uri. If the dashboard calls logout without one, avoid + // rendering Dex's non-actionable Bad Request page and send the user home. + if strings.HasSuffix(r.URL.Path, "/logout") && r.FormValue("id_token_hint") == "" { + http.Redirect(w, r, "/", http.StatusSeeOther) + return + } + + p.dexServer.ServeHTTP(w, r) + }) } // CreateUser creates a new user with the given email, username, and password. diff --git a/idp/dex/provider_test.go b/idp/dex/provider_test.go index 4ed89fd2e..88828fbbb 100644 --- a/idp/dex/provider_test.go +++ b/idp/dex/provider_test.go @@ -4,6 +4,8 @@ import ( "context" "encoding/json" "log/slog" + "net/http" + "net/http/httptest" "os" "path/filepath" "testing" @@ -144,6 +146,30 @@ func TestEncodeDexUserID_MatchesDexFormat(t *testing.T) { assert.Equal(t, knownEncodedID, reEncoded) } +func TestHandlerRedirectsLogoutWithoutIDTokenHint(t *testing.T) { + ctx := context.Background() + + tmpDir, err := os.MkdirTemp("", "dex-logout-handler-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + provider, err := NewProvider(ctx, &Config{ + Issuer: "http://localhost:5556/oauth2", + Port: 5556, + DataDir: tmpDir, + }) + require.NoError(t, err) + defer func() { _ = provider.Stop(ctx) }() + + req := httptest.NewRequest(http.MethodGet, "/oauth2/logout?post_logout_redirect_uri=https://example.com", nil) + rec := httptest.NewRecorder() + + provider.Handler().ServeHTTP(rec, req) + + require.Equal(t, http.StatusSeeOther, rec.Code) + require.Equal(t, "/", rec.Header().Get("Location")) +} + func TestCreateUserInTempDB(t *testing.T) { ctx := context.Background() diff --git a/idp/dex/web/templates/home.html b/idp/dex/web/templates/home.html new file mode 100644 index 000000000..be7c938ae --- /dev/null +++ b/idp/dex/web/templates/home.html @@ -0,0 +1,12 @@ +{{ template "header.html" . }} + + + + +{{ template "footer.html" . }} diff --git a/idp/dex/web/templates/logout.html b/idp/dex/web/templates/logout.html new file mode 100644 index 000000000..b623d35af --- /dev/null +++ b/idp/dex/web/templates/logout.html @@ -0,0 +1,14 @@ +{{ template "header.html" . }} + +
You have been successfully logged out.
+ + {{ if .BackURL }} + + {{ end }} +Scan the QR code below using your authenticator app, then enter the code.
+Enter the code from your authenticator app.
+ {{ end }} + + +