mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-25 10:09:56 +00:00
Compare commits
10 Commits
refactor-c
...
reduce-emb
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2f028779a7 | ||
|
|
e31be52ff3 | ||
|
|
39ee9c5986 | ||
|
|
cfc100a9d7 | ||
|
|
f4e42a8929 | ||
|
|
8f7fd69813 | ||
|
|
7720484c85 | ||
|
|
8520bbbe57 | ||
|
|
533dd9577b | ||
|
|
e81ce81494 |
2
.github/workflows/golangci-lint.yml
vendored
2
.github/workflows/golangci-lint.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
- name: codespell
|
||||
uses: codespell-project/actions-codespell@v2
|
||||
with:
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA,ede,additionals
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA
|
||||
skip: go.mod,go.sum,**/proxy/web/**
|
||||
golangci:
|
||||
strategy:
|
||||
|
||||
@@ -8,7 +8,7 @@ There are many ways that you can contribute:
|
||||
- Sharing use cases in slack or Reddit
|
||||
- Bug fix or feature enhancement
|
||||
|
||||
If you haven't already, join our slack workspace [here](https://docs.netbird.io/slack-url), we would love to discuss topics that need community contribution and enhancements to existing features.
|
||||
If you haven't already, join our slack workspace [here](https://join.slack.com/t/netbirdio/shared_invite/zt-vrahf41g-ik1v7fV8du6t0RwxSrJ96A), we would love to discuss topics that need community contribution and enhancements to existing features.
|
||||
|
||||
## Contents
|
||||
|
||||
|
||||
@@ -143,7 +143,7 @@ func init() {
|
||||
rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets WireGuard PreSharedKey property. If set, then only peers that have the same key can communicate.")
|
||||
rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device")
|
||||
rootCmd.PersistentFlags().BoolVarP(&anonymizeFlag, "anonymize", "A", false, "anonymize IP addresses and non-netbird.io domains in logs and status output")
|
||||
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", profilemanager.DefaultConfigPath, "Overrides the default profile file location")
|
||||
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "Overrides the default profile file location")
|
||||
|
||||
rootCmd.AddCommand(upCmd)
|
||||
rootCmd.AddCommand(downCmd)
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
wgnetstack "golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
@@ -94,6 +95,26 @@ type Options struct {
|
||||
MTU *uint16
|
||||
// DNSLabels defines additional DNS labels configured in the peer.
|
||||
DNSLabels []string
|
||||
// Performance configures the tunnel's buffer pool cap and batch size.
|
||||
Performance Performance
|
||||
}
|
||||
|
||||
// Performance configures the embedded client's tunnel memory/throughput knobs.
|
||||
//
|
||||
// These settings are process-global: any non-nil field also becomes the
|
||||
// default for Clients constructed by later embed.New calls in the same
|
||||
// process. Nil fields are ignored.
|
||||
type Performance struct {
|
||||
// PreallocatedBuffersPerPool caps the per-tunnel buffer pool. Zero
|
||||
// leaves the pool unbounded. Lower values trade throughput for a
|
||||
// tighter memory ceiling. May also be changed on a running Client via
|
||||
// Client.SetPerformance, provided this field was nonzero at construction.
|
||||
PreallocatedBuffersPerPool *uint32
|
||||
// MaxBatchSize overrides the number of packets the tunnel reads or
|
||||
// writes per syscall, which also bounds eager buffer allocation per
|
||||
// worker. Zero uses the platform default. Applied at construction
|
||||
// only; ignored by Client.SetPerformance.
|
||||
MaxBatchSize *uint32
|
||||
}
|
||||
|
||||
// validateCredentials checks that exactly one credential type is provided
|
||||
@@ -192,6 +213,13 @@ func New(opts Options) (*Client, error) {
|
||||
config.PrivateKey = opts.PrivateKey
|
||||
}
|
||||
|
||||
if opts.Performance.PreallocatedBuffersPerPool != nil {
|
||||
wgdevice.SetPreallocatedBuffersPerPool(*opts.Performance.PreallocatedBuffersPerPool)
|
||||
}
|
||||
if opts.Performance.MaxBatchSize != nil {
|
||||
wgdevice.SetMaxBatchSizeOverride(*opts.Performance.MaxBatchSize)
|
||||
}
|
||||
|
||||
return &Client{
|
||||
deviceName: opts.DeviceName,
|
||||
setupKey: opts.SetupKey,
|
||||
@@ -473,6 +501,25 @@ func (c *Client) VerifySSHHostKey(peerAddress string, key []byte) error {
|
||||
return sshcommon.VerifyHostKey(storedKey, key, peerAddress)
|
||||
}
|
||||
|
||||
// SetPerformance retunes a running Client. Only PreallocatedBuffersPerPool
|
||||
// takes effect, and only when it was nonzero at construction;
|
||||
// MaxBatchSize is construction-only and returns an error if set here.
|
||||
//
|
||||
// Returns ErrClientNotStarted / ErrEngineNotStarted if the Client is not
|
||||
// running yet.
|
||||
func (c *Client) SetPerformance(t Performance) error {
|
||||
if t.MaxBatchSize != nil {
|
||||
return errors.New("MaxBatchSize is construction-only and cannot be changed at runtime")
|
||||
}
|
||||
engine, err := c.getEngine()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return engine.SetPerformance(internal.Performance{
|
||||
PreallocatedBuffersPerPool: t.PreallocatedBuffersPerPool,
|
||||
})
|
||||
}
|
||||
|
||||
// StartCapture begins capturing packets on this client's tunnel device.
|
||||
// Only one capture can be active at a time; starting a new one stops the previous.
|
||||
// Call StopCapture (or CaptureSession.Stop) to end it.
|
||||
|
||||
@@ -1,125 +0,0 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestTCPCapEvicts(t *testing.T) {
|
||||
t.Setenv(EnvTCPMaxEntries, "4")
|
||||
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
require.Equal(t, 4, tracker.maxEntries)
|
||||
|
||||
src := netip.MustParseAddr("100.64.0.1")
|
||||
dst := netip.MustParseAddr("100.64.0.2")
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
tracker.TrackOutbound(src, dst, uint16(10000+i), 80, TCPSyn, 0)
|
||||
}
|
||||
require.LessOrEqual(t, len(tracker.connections), 4,
|
||||
"TCP table must not exceed the configured cap")
|
||||
require.Greater(t, len(tracker.connections), 0,
|
||||
"some entries must remain after eviction")
|
||||
|
||||
// The most recently admitted flow must be present: eviction must make
|
||||
// room for new entries, not silently drop them.
|
||||
require.Contains(t, tracker.connections,
|
||||
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(10009), DstPort: 80},
|
||||
"newest TCP flow must be admitted after eviction")
|
||||
// A pre-cap flow must have been evicted to fit the last one.
|
||||
require.NotContains(t, tracker.connections,
|
||||
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(10000), DstPort: 80},
|
||||
"oldest TCP flow should have been evicted")
|
||||
}
|
||||
|
||||
func TestTCPCapPrefersTombstonedForEviction(t *testing.T) {
|
||||
t.Setenv(EnvTCPMaxEntries, "3")
|
||||
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
src := netip.MustParseAddr("100.64.0.1")
|
||||
dst := netip.MustParseAddr("100.64.0.2")
|
||||
|
||||
// Fill to cap with 3 live connections.
|
||||
for i := 0; i < 3; i++ {
|
||||
tracker.TrackOutbound(src, dst, uint16(20000+i), 80, TCPSyn, 0)
|
||||
}
|
||||
require.Len(t, tracker.connections, 3)
|
||||
|
||||
// Tombstone one by sending RST through IsValidInbound.
|
||||
tombstonedKey := ConnKey{SrcIP: src, DstIP: dst, SrcPort: 20001, DstPort: 80}
|
||||
require.True(t, tracker.IsValidInbound(dst, src, 80, 20001, TCPRst|TCPAck, 0))
|
||||
require.True(t, tracker.connections[tombstonedKey].IsTombstone())
|
||||
|
||||
// Another live connection forces eviction. The tombstone must go first.
|
||||
tracker.TrackOutbound(src, dst, uint16(29999), 80, TCPSyn, 0)
|
||||
|
||||
_, tombstonedStillPresent := tracker.connections[tombstonedKey]
|
||||
require.False(t, tombstonedStillPresent,
|
||||
"tombstoned entry should be evicted before live entries")
|
||||
require.LessOrEqual(t, len(tracker.connections), 3)
|
||||
|
||||
// Both live pre-cap entries must survive: eviction must prefer the
|
||||
// tombstone, not just satisfy the size bound by dropping any entry.
|
||||
require.Contains(t, tracker.connections,
|
||||
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(20000), DstPort: 80},
|
||||
"live entries must not be evicted while a tombstone exists")
|
||||
require.Contains(t, tracker.connections,
|
||||
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(20002), DstPort: 80},
|
||||
"live entries must not be evicted while a tombstone exists")
|
||||
}
|
||||
|
||||
func TestUDPCapEvicts(t *testing.T) {
|
||||
t.Setenv(EnvUDPMaxEntries, "5")
|
||||
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
require.Equal(t, 5, tracker.maxEntries)
|
||||
|
||||
src := netip.MustParseAddr("100.64.0.1")
|
||||
dst := netip.MustParseAddr("100.64.0.2")
|
||||
|
||||
for i := 0; i < 12; i++ {
|
||||
tracker.TrackOutbound(src, dst, uint16(30000+i), 53, 0)
|
||||
}
|
||||
require.LessOrEqual(t, len(tracker.connections), 5)
|
||||
require.Greater(t, len(tracker.connections), 0)
|
||||
|
||||
require.Contains(t, tracker.connections,
|
||||
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(30011), DstPort: 53},
|
||||
"newest UDP flow must be admitted after eviction")
|
||||
require.NotContains(t, tracker.connections,
|
||||
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(30000), DstPort: 53},
|
||||
"oldest UDP flow should have been evicted")
|
||||
}
|
||||
|
||||
func TestICMPCapEvicts(t *testing.T) {
|
||||
t.Setenv(EnvICMPMaxEntries, "3")
|
||||
|
||||
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
require.Equal(t, 3, tracker.maxEntries)
|
||||
|
||||
src := netip.MustParseAddr("100.64.0.1")
|
||||
dst := netip.MustParseAddr("100.64.0.2")
|
||||
|
||||
echoReq := layers.CreateICMPv4TypeCode(uint8(layers.ICMPv4TypeEchoRequest), 0)
|
||||
for i := 0; i < 8; i++ {
|
||||
tracker.TrackOutbound(src, dst, uint16(i), echoReq, nil, 64)
|
||||
}
|
||||
require.LessOrEqual(t, len(tracker.connections), 3)
|
||||
require.Greater(t, len(tracker.connections), 0)
|
||||
|
||||
require.Contains(t, tracker.connections,
|
||||
ICMPConnKey{SrcIP: src, DstIP: dst, ID: uint16(7)},
|
||||
"newest ICMP flow must be admitted after eviction")
|
||||
require.NotContains(t, tracker.connections,
|
||||
ICMPConnKey{SrcIP: src, DstIP: dst, ID: uint16(0)},
|
||||
"oldest ICMP flow should have been evicted")
|
||||
}
|
||||
@@ -3,61 +3,15 @@ package conntrack
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
)
|
||||
|
||||
// evictSampleSize bounds how many map entries we scan per eviction call.
|
||||
// Keeps eviction O(1) even at cap under sustained load; the sampled-LRU
|
||||
// heuristic is good enough for a conntrack table that only overflows under
|
||||
// abuse.
|
||||
const evictSampleSize = 8
|
||||
|
||||
// envDuration parses an os.Getenv(name) as a time.Duration. Falls back to
|
||||
// def on empty or invalid; logs a warning on invalid.
|
||||
func envDuration(logger *nblog.Logger, name string, def time.Duration) time.Duration {
|
||||
v := os.Getenv(name)
|
||||
if v == "" {
|
||||
return def
|
||||
}
|
||||
d, err := time.ParseDuration(v)
|
||||
if err != nil {
|
||||
logger.Warn3("invalid %s=%q: %v, using default", name, v, err)
|
||||
return def
|
||||
}
|
||||
if d <= 0 {
|
||||
logger.Warn2("invalid %s=%q: must be positive, using default", name, v)
|
||||
return def
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
// envInt parses an os.Getenv(name) as an int. Falls back to def on empty,
|
||||
// invalid, or non-positive. Logs a warning on invalid input.
|
||||
func envInt(logger *nblog.Logger, name string, def int) int {
|
||||
v := os.Getenv(name)
|
||||
if v == "" {
|
||||
return def
|
||||
}
|
||||
n, err := strconv.Atoi(v)
|
||||
switch {
|
||||
case err != nil:
|
||||
logger.Warn3("invalid %s=%q: %v, using default", name, v, err)
|
||||
return def
|
||||
case n <= 0:
|
||||
logger.Warn2("invalid %s=%q: must be positive, using default", name, v)
|
||||
return def
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// BaseConnTrack provides common fields and locking for all connection types
|
||||
type BaseConnTrack struct {
|
||||
FlowId uuid.UUID
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
//go:build !ios && !android
|
||||
|
||||
package conntrack
|
||||
|
||||
// Default per-tracker entry caps on desktop/server platforms. These mirror
|
||||
// typical Linux netfilter nf_conntrack_max territory with ample headroom.
|
||||
const (
|
||||
DefaultMaxTCPEntries = 65536
|
||||
DefaultMaxUDPEntries = 16384
|
||||
DefaultMaxICMPEntries = 2048
|
||||
)
|
||||
@@ -1,13 +0,0 @@
|
||||
//go:build ios || android
|
||||
|
||||
package conntrack
|
||||
|
||||
// Default per-tracker entry caps on mobile platforms. iOS network extensions
|
||||
// are capped at ~50 MB; Android runs under aggressive memory pressure. These
|
||||
// values keep conntrack footprint well under 5 MB worst case (TCPConnTrack
|
||||
// is ~200 B plus map overhead).
|
||||
const (
|
||||
DefaultMaxTCPEntries = 4096
|
||||
DefaultMaxUDPEntries = 2048
|
||||
DefaultMaxICMPEntries = 512
|
||||
)
|
||||
@@ -50,9 +50,6 @@ type ICMPConnTrack struct {
|
||||
ICMPCode uint8
|
||||
}
|
||||
|
||||
// EnvICMPMaxEntries caps the ICMP conntrack table size.
|
||||
const EnvICMPMaxEntries = "NB_CONNTRACK_ICMP_MAX"
|
||||
|
||||
// ICMPTracker manages ICMP connection states
|
||||
type ICMPTracker struct {
|
||||
logger *nblog.Logger
|
||||
@@ -61,7 +58,6 @@ type ICMPTracker struct {
|
||||
cleanupTicker *time.Ticker
|
||||
tickerCancel context.CancelFunc
|
||||
mutex sync.RWMutex
|
||||
maxEntries int
|
||||
flowLogger nftypes.FlowLogger
|
||||
}
|
||||
|
||||
@@ -175,7 +171,6 @@ func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nfty
|
||||
timeout: timeout,
|
||||
cleanupTicker: time.NewTicker(ICMPCleanupInterval),
|
||||
tickerCancel: cancel,
|
||||
maxEntries: envInt(logger, EnvICMPMaxEntries, DefaultMaxICMPEntries),
|
||||
flowLogger: flowLogger,
|
||||
}
|
||||
|
||||
@@ -262,9 +257,7 @@ func (t *ICMPTracker) track(
|
||||
|
||||
// non echo requests don't need tracking
|
||||
if typ != uint8(layers.ICMPv4TypeEchoRequest) {
|
||||
if t.logger.Enabled(nblog.LevelTrace) {
|
||||
t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo)
|
||||
}
|
||||
t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo)
|
||||
t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size)
|
||||
return
|
||||
}
|
||||
@@ -283,15 +276,10 @@ func (t *ICMPTracker) track(
|
||||
conn.UpdateCounters(direction, size)
|
||||
|
||||
t.mutex.Lock()
|
||||
if t.maxEntries > 0 && len(t.connections) >= t.maxEntries {
|
||||
t.evictOneLocked()
|
||||
}
|
||||
t.connections[key] = conn
|
||||
t.mutex.Unlock()
|
||||
|
||||
if t.logger.Enabled(nblog.LevelTrace) {
|
||||
t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo)
|
||||
}
|
||||
t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo)
|
||||
t.sendEvent(nftypes.TypeStart, conn, ruleId)
|
||||
}
|
||||
|
||||
@@ -335,34 +323,6 @@ func (t *ICMPTracker) cleanupRoutine(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// evictOneLocked removes one entry to make room. Caller must hold t.mutex.
|
||||
// Bounded sample scan: picks the oldest among up to evictSampleSize entries.
|
||||
func (t *ICMPTracker) evictOneLocked() {
|
||||
var candKey ICMPConnKey
|
||||
var candSeen int64
|
||||
haveCand := false
|
||||
sampled := 0
|
||||
|
||||
for k, c := range t.connections {
|
||||
seen := c.lastSeen.Load()
|
||||
if !haveCand || seen < candSeen {
|
||||
candKey = k
|
||||
candSeen = seen
|
||||
haveCand = true
|
||||
}
|
||||
sampled++
|
||||
if sampled >= evictSampleSize {
|
||||
break
|
||||
}
|
||||
}
|
||||
if haveCand {
|
||||
if evicted := t.connections[candKey]; evicted != nil {
|
||||
t.sendEvent(nftypes.TypeEnd, evicted, nil)
|
||||
}
|
||||
delete(t.connections, candKey)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ICMPTracker) cleanup() {
|
||||
t.mutex.Lock()
|
||||
defer t.mutex.Unlock()
|
||||
@@ -371,10 +331,8 @@ func (t *ICMPTracker) cleanup() {
|
||||
if conn.timeoutExceeded(t.timeout) {
|
||||
delete(t.connections, key)
|
||||
|
||||
if t.logger.Enabled(nblog.LevelTrace) {
|
||||
t.logger.Trace5("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
|
||||
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
}
|
||||
t.logger.Trace5("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
|
||||
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -38,27 +38,6 @@ const (
|
||||
TCPHandshakeTimeout = 60 * time.Second
|
||||
// TCPCleanupInterval is how often we check for stale connections
|
||||
TCPCleanupInterval = 5 * time.Minute
|
||||
// FinWaitTimeout bounds FIN_WAIT_1 / FIN_WAIT_2 / CLOSING states.
|
||||
// Matches Linux netfilter nf_conntrack_tcp_timeout_fin_wait.
|
||||
FinWaitTimeout = 60 * time.Second
|
||||
// CloseWaitTimeout bounds CLOSE_WAIT. Matches Linux default; apps
|
||||
// holding CloseWait longer than this should bump the env var.
|
||||
CloseWaitTimeout = 60 * time.Second
|
||||
// LastAckTimeout bounds LAST_ACK. Matches Linux default.
|
||||
LastAckTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
// Env vars to override per-state teardown timeouts. Values parsed by
|
||||
// time.ParseDuration (e.g. "120s", "2m"). Invalid values fall back to the
|
||||
// defaults above with a warning.
|
||||
const (
|
||||
EnvTCPFinWaitTimeout = "NB_CONNTRACK_TCP_FIN_WAIT_TIMEOUT"
|
||||
EnvTCPCloseWaitTimeout = "NB_CONNTRACK_TCP_CLOSE_WAIT_TIMEOUT"
|
||||
EnvTCPLastAckTimeout = "NB_CONNTRACK_TCP_LAST_ACK_TIMEOUT"
|
||||
|
||||
// EnvTCPMaxEntries caps the TCP conntrack table size. Oldest entries
|
||||
// (tombstones first) are evicted when the cap is reached.
|
||||
EnvTCPMaxEntries = "NB_CONNTRACK_TCP_MAX"
|
||||
)
|
||||
|
||||
// TCPState represents the state of a TCP connection
|
||||
@@ -154,18 +133,14 @@ func (t *TCPConnTrack) SetTombstone() {
|
||||
|
||||
// TCPTracker manages TCP connection states
|
||||
type TCPTracker struct {
|
||||
logger *nblog.Logger
|
||||
connections map[ConnKey]*TCPConnTrack
|
||||
mutex sync.RWMutex
|
||||
cleanupTicker *time.Ticker
|
||||
tickerCancel context.CancelFunc
|
||||
timeout time.Duration
|
||||
waitTimeout time.Duration
|
||||
finWaitTimeout time.Duration
|
||||
closeWaitTimeout time.Duration
|
||||
lastAckTimeout time.Duration
|
||||
maxEntries int
|
||||
flowLogger nftypes.FlowLogger
|
||||
logger *nblog.Logger
|
||||
connections map[ConnKey]*TCPConnTrack
|
||||
mutex sync.RWMutex
|
||||
cleanupTicker *time.Ticker
|
||||
tickerCancel context.CancelFunc
|
||||
timeout time.Duration
|
||||
waitTimeout time.Duration
|
||||
flowLogger nftypes.FlowLogger
|
||||
}
|
||||
|
||||
// NewTCPTracker creates a new TCP connection tracker
|
||||
@@ -180,17 +155,13 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
tracker := &TCPTracker{
|
||||
logger: logger,
|
||||
connections: make(map[ConnKey]*TCPConnTrack),
|
||||
cleanupTicker: time.NewTicker(TCPCleanupInterval),
|
||||
tickerCancel: cancel,
|
||||
timeout: timeout,
|
||||
waitTimeout: waitTimeout,
|
||||
finWaitTimeout: envDuration(logger, EnvTCPFinWaitTimeout, FinWaitTimeout),
|
||||
closeWaitTimeout: envDuration(logger, EnvTCPCloseWaitTimeout, CloseWaitTimeout),
|
||||
lastAckTimeout: envDuration(logger, EnvTCPLastAckTimeout, LastAckTimeout),
|
||||
maxEntries: envInt(logger, EnvTCPMaxEntries, DefaultMaxTCPEntries),
|
||||
flowLogger: flowLogger,
|
||||
logger: logger,
|
||||
connections: make(map[ConnKey]*TCPConnTrack),
|
||||
cleanupTicker: time.NewTicker(TCPCleanupInterval),
|
||||
tickerCancel: cancel,
|
||||
timeout: timeout,
|
||||
waitTimeout: waitTimeout,
|
||||
flowLogger: flowLogger,
|
||||
}
|
||||
|
||||
go tracker.cleanupRoutine(ctx)
|
||||
@@ -238,12 +209,6 @@ func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, fla
|
||||
if exists || flags&TCPSyn == 0 {
|
||||
return
|
||||
}
|
||||
// Reject illegal SYN combinations (SYN+FIN, SYN+RST, …) so they don't
|
||||
// create spurious conntrack entries. Not mandated by RFC 9293 but a
|
||||
// common hardening (Linux netfilter/nftables rejects these too).
|
||||
if !isValidFlagCombination(flags) {
|
||||
return
|
||||
}
|
||||
|
||||
conn := &TCPConnTrack{
|
||||
BaseConnTrack: BaseConnTrack{
|
||||
@@ -260,65 +225,20 @@ func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, fla
|
||||
conn.state.Store(int32(TCPStateNew))
|
||||
conn.DNATOrigPort.Store(uint32(origPort))
|
||||
|
||||
if t.logger.Enabled(nblog.LevelTrace) {
|
||||
if origPort != 0 {
|
||||
t.logger.Trace4("New %s TCP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
|
||||
} else {
|
||||
t.logger.Trace2("New %s TCP connection: %s", direction, key)
|
||||
}
|
||||
if origPort != 0 {
|
||||
t.logger.Trace4("New %s TCP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
|
||||
} else {
|
||||
t.logger.Trace2("New %s TCP connection: %s", direction, key)
|
||||
}
|
||||
t.updateState(key, conn, flags, direction, size)
|
||||
|
||||
t.mutex.Lock()
|
||||
if t.maxEntries > 0 && len(t.connections) >= t.maxEntries {
|
||||
t.evictOneLocked()
|
||||
}
|
||||
t.connections[key] = conn
|
||||
t.mutex.Unlock()
|
||||
|
||||
t.sendEvent(nftypes.TypeStart, conn, ruleID)
|
||||
}
|
||||
|
||||
// evictOneLocked removes one entry to make room. Caller must hold t.mutex.
|
||||
// Bounded scan: samples up to evictSampleSize pseudo-random entries (Go map
|
||||
// iteration order is randomized), preferring a tombstone. If no tombstone
|
||||
// found in the sample, evicts the oldest among the sampled entries. O(1)
|
||||
// worst case — cheap enough to run on every insert at cap during abuse.
|
||||
func (t *TCPTracker) evictOneLocked() {
|
||||
var candKey ConnKey
|
||||
var candSeen int64
|
||||
haveCand := false
|
||||
sampled := 0
|
||||
|
||||
for k, c := range t.connections {
|
||||
if c.IsTombstone() {
|
||||
delete(t.connections, k)
|
||||
return
|
||||
}
|
||||
seen := c.lastSeen.Load()
|
||||
if !haveCand || seen < candSeen {
|
||||
candKey = k
|
||||
candSeen = seen
|
||||
haveCand = true
|
||||
}
|
||||
sampled++
|
||||
if sampled >= evictSampleSize {
|
||||
break
|
||||
}
|
||||
}
|
||||
if haveCand {
|
||||
if evicted := t.connections[candKey]; evicted != nil {
|
||||
// TypeEnd is already emitted at the state transition to
|
||||
// TimeWait and when a connection is tombstoned. Only emit
|
||||
// here when we're reaping a still-active flow.
|
||||
if evicted.GetState() != TCPStateTimeWait && !evicted.IsTombstone() {
|
||||
t.sendEvent(nftypes.TypeEnd, evicted, nil)
|
||||
}
|
||||
}
|
||||
delete(t.connections, candKey)
|
||||
}
|
||||
}
|
||||
|
||||
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
|
||||
func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) bool {
|
||||
key := ConnKey{
|
||||
@@ -336,19 +256,12 @@ func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort ui
|
||||
return false
|
||||
}
|
||||
|
||||
// Reject illegal flag combinations regardless of state. These never belong
|
||||
// to a legitimate flow and must not advance or tear down state.
|
||||
if !isValidFlagCombination(flags) {
|
||||
if t.logger.Enabled(nblog.LevelWarn) {
|
||||
t.logger.Warn3("TCP illegal flag combination %x for connection %s (state %s)", flags, key, conn.GetState())
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
currentState := conn.GetState()
|
||||
if !t.isValidStateForFlags(currentState, flags) {
|
||||
if t.logger.Enabled(nblog.LevelWarn) {
|
||||
t.logger.Warn3("TCP state %s is not valid with flags %x for connection %s", currentState, flags, key)
|
||||
t.logger.Warn3("TCP state %s is not valid with flags %x for connection %s", currentState, flags, key)
|
||||
// allow all flags for established for now
|
||||
if currentState == TCPStateEstablished {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -357,208 +270,116 @@ func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort ui
|
||||
return true
|
||||
}
|
||||
|
||||
// updateState updates the TCP connection state based on flags.
|
||||
// updateState updates the TCP connection state based on flags
|
||||
func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, packetDir nftypes.Direction, size int) {
|
||||
conn.UpdateCounters(packetDir, size)
|
||||
|
||||
// Malformed flag combinations must not refresh lastSeen or drive state,
|
||||
// otherwise spoofed packets keep a dead flow alive past its timeout.
|
||||
if !isValidFlagCombination(flags) {
|
||||
return
|
||||
}
|
||||
|
||||
conn.UpdateLastSeen()
|
||||
conn.UpdateCounters(packetDir, size)
|
||||
|
||||
currentState := conn.GetState()
|
||||
|
||||
if flags&TCPRst != 0 {
|
||||
// Hardening beyond RFC 9293 §3.10.7.4: without sequence tracking we
|
||||
// cannot apply the RFC 5961 in-window RST check, so we conservatively
|
||||
// reject RSTs that the spec would accept (TIME-WAIT with in-window
|
||||
// SEQ, SynSent from same direction as own SYN, etc.).
|
||||
t.handleRst(key, conn, currentState, packetDir)
|
||||
return
|
||||
}
|
||||
|
||||
newState := nextState(currentState, conn.Direction, packetDir, flags)
|
||||
if newState == 0 || !conn.CompareAndSwapState(currentState, newState) {
|
||||
return
|
||||
}
|
||||
t.onTransition(key, conn, currentState, newState, packetDir)
|
||||
}
|
||||
|
||||
// handleRst processes a RST segment. Late RSTs in TimeWait and spoofed RSTs
|
||||
// from the SYN direction are ignored; otherwise the flow is tombstoned.
|
||||
func (t *TCPTracker) handleRst(key ConnKey, conn *TCPConnTrack, currentState TCPState, packetDir nftypes.Direction) {
|
||||
// TimeWait exists to absorb late segments; don't let a late RST
|
||||
// tombstone the entry and break same-4-tuple reuse.
|
||||
if currentState == TCPStateTimeWait {
|
||||
return
|
||||
}
|
||||
// A RST from the same direction as the SYN cannot be a legitimate
|
||||
// response and must not tear down a half-open connection.
|
||||
if currentState == TCPStateSynSent && packetDir == conn.Direction {
|
||||
return
|
||||
}
|
||||
if !conn.CompareAndSwapState(currentState, TCPStateClosed) {
|
||||
return
|
||||
}
|
||||
conn.SetTombstone()
|
||||
if t.logger.Enabled(nblog.LevelTrace) {
|
||||
t.logger.Trace6("TCP connection reset: %s (dir: %s) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
|
||||
key, packetDir, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
}
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
|
||||
// stateTransition describes one state's transition logic. It receives the
|
||||
// packet's flags plus whether the packet direction matches the connection's
|
||||
// origin direction (same=true means same side as the SYN initiator). Return 0
|
||||
// for no transition.
|
||||
type stateTransition func(flags uint8, connDir nftypes.Direction, same bool) TCPState
|
||||
|
||||
// stateTable maps each state to its transition function. Centralized here so
|
||||
// nextState stays trivial and each rule is easy to read in isolation.
|
||||
var stateTable = map[TCPState]stateTransition{
|
||||
TCPStateNew: transNew,
|
||||
TCPStateSynSent: transSynSent,
|
||||
TCPStateSynReceived: transSynReceived,
|
||||
TCPStateEstablished: transEstablished,
|
||||
TCPStateFinWait1: transFinWait1,
|
||||
TCPStateFinWait2: transFinWait2,
|
||||
TCPStateClosing: transClosing,
|
||||
TCPStateCloseWait: transCloseWait,
|
||||
TCPStateLastAck: transLastAck,
|
||||
}
|
||||
|
||||
// nextState returns the target TCP state for the given current state and
|
||||
// packet, or 0 if the packet does not trigger a transition.
|
||||
func nextState(currentState TCPState, connDir, packetDir nftypes.Direction, flags uint8) TCPState {
|
||||
fn, ok := stateTable[currentState]
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
return fn(flags, connDir, packetDir == connDir)
|
||||
}
|
||||
|
||||
func transNew(flags uint8, connDir nftypes.Direction, _ bool) TCPState {
|
||||
if flags&TCPSyn != 0 && flags&TCPAck == 0 {
|
||||
if connDir == nftypes.Egress {
|
||||
return TCPStateSynSent
|
||||
if conn.CompareAndSwapState(currentState, TCPStateClosed) {
|
||||
conn.SetTombstone()
|
||||
t.logger.Trace6("TCP connection reset: %s (dir: %s) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
|
||||
key, packetDir, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
return TCPStateSynReceived
|
||||
return
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func transSynSent(flags uint8, _ nftypes.Direction, same bool) TCPState {
|
||||
if flags&TCPSyn != 0 && flags&TCPAck != 0 {
|
||||
if same {
|
||||
return TCPStateSynReceived // simultaneous open
|
||||
var newState TCPState
|
||||
switch currentState {
|
||||
case TCPStateNew:
|
||||
if flags&TCPSyn != 0 && flags&TCPAck == 0 {
|
||||
if conn.Direction == nftypes.Egress {
|
||||
newState = TCPStateSynSent
|
||||
} else {
|
||||
newState = TCPStateSynReceived
|
||||
}
|
||||
}
|
||||
return TCPStateEstablished
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func transSynReceived(flags uint8, _ nftypes.Direction, same bool) TCPState {
|
||||
if flags&TCPAck != 0 && flags&TCPSyn == 0 && same {
|
||||
return TCPStateEstablished
|
||||
}
|
||||
return 0
|
||||
}
|
||||
case TCPStateSynSent:
|
||||
if flags&TCPSyn != 0 && flags&TCPAck != 0 {
|
||||
if packetDir != conn.Direction {
|
||||
newState = TCPStateEstablished
|
||||
} else {
|
||||
// Simultaneous open
|
||||
newState = TCPStateSynReceived
|
||||
}
|
||||
}
|
||||
|
||||
func transEstablished(flags uint8, _ nftypes.Direction, same bool) TCPState {
|
||||
if flags&TCPFin == 0 {
|
||||
return 0
|
||||
}
|
||||
if same {
|
||||
return TCPStateFinWait1
|
||||
}
|
||||
return TCPStateCloseWait
|
||||
}
|
||||
case TCPStateSynReceived:
|
||||
if flags&TCPAck != 0 && flags&TCPSyn == 0 {
|
||||
if packetDir == conn.Direction {
|
||||
newState = TCPStateEstablished
|
||||
}
|
||||
}
|
||||
|
||||
// transFinWait1 handles the active-close peer response. A FIN carrying our
|
||||
// ACK piggybacked goes straight to TIME-WAIT (RFC 9293 §3.10.7.4, FIN-WAIT-1:
|
||||
// "if our FIN has been ACKed... enter the TIME-WAIT state"); a lone FIN moves
|
||||
// to CLOSING; a pure ACK of our FIN moves to FIN-WAIT-2.
|
||||
func transFinWait1(flags uint8, _ nftypes.Direction, same bool) TCPState {
|
||||
if same {
|
||||
return 0
|
||||
}
|
||||
if flags&TCPFin != 0 && flags&TCPAck != 0 {
|
||||
return TCPStateTimeWait
|
||||
}
|
||||
switch {
|
||||
case flags&TCPFin != 0:
|
||||
return TCPStateClosing
|
||||
case flags&TCPAck != 0:
|
||||
return TCPStateFinWait2
|
||||
}
|
||||
return 0
|
||||
}
|
||||
case TCPStateEstablished:
|
||||
if flags&TCPFin != 0 {
|
||||
if packetDir == conn.Direction {
|
||||
newState = TCPStateFinWait1
|
||||
} else {
|
||||
newState = TCPStateCloseWait
|
||||
}
|
||||
}
|
||||
|
||||
// transFinWait2 ignores own-side FIN retransmits; only the peer's FIN advances.
|
||||
func transFinWait2(flags uint8, _ nftypes.Direction, same bool) TCPState {
|
||||
if flags&TCPFin != 0 && !same {
|
||||
return TCPStateTimeWait
|
||||
}
|
||||
return 0
|
||||
}
|
||||
case TCPStateFinWait1:
|
||||
if packetDir != conn.Direction {
|
||||
switch {
|
||||
case flags&TCPFin != 0 && flags&TCPAck != 0:
|
||||
newState = TCPStateClosing
|
||||
case flags&TCPFin != 0:
|
||||
newState = TCPStateClosing
|
||||
case flags&TCPAck != 0:
|
||||
newState = TCPStateFinWait2
|
||||
}
|
||||
}
|
||||
|
||||
// transClosing completes a simultaneous close on the peer's ACK.
|
||||
func transClosing(flags uint8, _ nftypes.Direction, same bool) TCPState {
|
||||
if flags&TCPAck != 0 && !same {
|
||||
return TCPStateTimeWait
|
||||
}
|
||||
return 0
|
||||
}
|
||||
case TCPStateFinWait2:
|
||||
if flags&TCPFin != 0 {
|
||||
newState = TCPStateTimeWait
|
||||
}
|
||||
|
||||
// transCloseWait only advances to LastAck when WE send FIN, ignoring peer retransmits.
|
||||
func transCloseWait(flags uint8, _ nftypes.Direction, same bool) TCPState {
|
||||
if flags&TCPFin != 0 && same {
|
||||
return TCPStateLastAck
|
||||
}
|
||||
return 0
|
||||
}
|
||||
case TCPStateClosing:
|
||||
if flags&TCPAck != 0 {
|
||||
newState = TCPStateTimeWait
|
||||
}
|
||||
|
||||
// transLastAck closes the flow only on the peer's ACK (not our own ACK retransmits).
|
||||
func transLastAck(flags uint8, _ nftypes.Direction, same bool) TCPState {
|
||||
if flags&TCPAck != 0 && !same {
|
||||
return TCPStateClosed
|
||||
}
|
||||
return 0
|
||||
}
|
||||
case TCPStateCloseWait:
|
||||
if flags&TCPFin != 0 {
|
||||
newState = TCPStateLastAck
|
||||
}
|
||||
|
||||
// onTransition handles logging and flow-event emission after a successful
|
||||
// state transition. TimeWait and Closed are terminal for flow accounting.
|
||||
func (t *TCPTracker) onTransition(key ConnKey, conn *TCPConnTrack, from, to TCPState, packetDir nftypes.Direction) {
|
||||
traceOn := t.logger.Enabled(nblog.LevelTrace)
|
||||
if traceOn {
|
||||
t.logger.Trace4("TCP connection %s transitioned from %s to %s (dir: %s)", key, from, to, packetDir)
|
||||
case TCPStateLastAck:
|
||||
if flags&TCPAck != 0 {
|
||||
newState = TCPStateClosed
|
||||
}
|
||||
}
|
||||
|
||||
switch to {
|
||||
case TCPStateTimeWait:
|
||||
if traceOn {
|
||||
if newState != 0 && conn.CompareAndSwapState(currentState, newState) {
|
||||
t.logger.Trace4("TCP connection %s transitioned from %s to %s (dir: %s)", key, currentState, newState, packetDir)
|
||||
|
||||
switch newState {
|
||||
case TCPStateTimeWait:
|
||||
t.logger.Trace5("TCP connection %s completed [in: %d Pkts/%d B, out: %d Pkts/%d B]",
|
||||
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
}
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
case TCPStateClosed:
|
||||
conn.SetTombstone()
|
||||
if traceOn {
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
|
||||
case TCPStateClosed:
|
||||
conn.SetTombstone()
|
||||
t.logger.Trace5("TCP connection %s closed gracefully [in: %d Pkts/%d, B out: %d Pkts/%d B]",
|
||||
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
}
|
||||
|
||||
// isValidStateForFlags checks if the TCP flags are valid for the current
|
||||
// connection state. Caller must have already verified the flag combination is
|
||||
// legal via isValidFlagCombination.
|
||||
// isValidStateForFlags checks if the TCP flags are valid for the current connection state
|
||||
func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
|
||||
if !isValidFlagCombination(flags) {
|
||||
return false
|
||||
}
|
||||
if flags&TCPRst != 0 {
|
||||
if state == TCPStateSynSent {
|
||||
return flags&TCPAck != 0
|
||||
@@ -628,24 +449,15 @@ func (t *TCPTracker) cleanup() {
|
||||
timeout = t.waitTimeout
|
||||
case TCPStateEstablished:
|
||||
timeout = t.timeout
|
||||
case TCPStateFinWait1, TCPStateFinWait2, TCPStateClosing:
|
||||
timeout = t.finWaitTimeout
|
||||
case TCPStateCloseWait:
|
||||
timeout = t.closeWaitTimeout
|
||||
case TCPStateLastAck:
|
||||
timeout = t.lastAckTimeout
|
||||
default:
|
||||
// SynSent / SynReceived / New
|
||||
timeout = TCPHandshakeTimeout
|
||||
}
|
||||
|
||||
if conn.timeoutExceeded(timeout) {
|
||||
delete(t.connections, key)
|
||||
|
||||
if t.logger.Enabled(nblog.LevelTrace) {
|
||||
t.logger.Trace6("Cleaned up timed-out TCP connection %s (%s) [in: %d Pkts/%d, B out: %d Pkts/%d B]",
|
||||
key, conn.GetState(), conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
}
|
||||
t.logger.Trace6("Cleaned up timed-out TCP connection %s (%s) [in: %d Pkts/%d, B out: %d Pkts/%d B]",
|
||||
key, conn.GetState(), conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
|
||||
// event already handled by state change
|
||||
if currentState != TCPStateTimeWait {
|
||||
|
||||
@@ -1,100 +0,0 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// RST hygiene tests: the tracker currently closes the flow on any RST that
|
||||
// matches the 4-tuple, regardless of direction or state. These tests cover
|
||||
// the minimum checks we want (no SEQ tracking).
|
||||
|
||||
func TestTCPRstInSynSentWrongDirection(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
|
||||
conn := tracker.connections[key]
|
||||
require.Equal(t, TCPStateSynSent, conn.GetState())
|
||||
|
||||
// A RST arriving in the same direction as the SYN (i.e. TrackOutbound)
|
||||
// cannot be a legitimate response. It must not close the connection.
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPRst|TCPAck, 0)
|
||||
require.Equal(t, TCPStateSynSent, conn.GetState(),
|
||||
"RST in same direction as SYN must not close connection")
|
||||
require.False(t, conn.IsTombstone())
|
||||
}
|
||||
|
||||
func TestTCPRstInTimeWaitIgnored(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||
|
||||
// Drive to TIME-WAIT via active close.
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0))
|
||||
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0))
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
|
||||
conn := tracker.connections[key]
|
||||
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||
require.False(t, conn.IsTombstone(), "TIME-WAIT must not be tombstoned")
|
||||
|
||||
// Late RST during TIME-WAIT must not tombstone the entry (TIME-WAIT
|
||||
// exists to absorb late segments).
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
|
||||
require.Equal(t, TCPStateTimeWait, conn.GetState(),
|
||||
"RST in TIME-WAIT must not transition state")
|
||||
require.False(t, conn.IsTombstone(),
|
||||
"RST in TIME-WAIT must not tombstone the entry")
|
||||
}
|
||||
|
||||
func TestTCPIllegalFlagCombos(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
conn := tracker.connections[key]
|
||||
|
||||
// Illegal combos must be rejected and must not change state.
|
||||
combos := []struct {
|
||||
name string
|
||||
flags uint8
|
||||
}{
|
||||
{"SYN+RST", TCPSyn | TCPRst},
|
||||
{"FIN+RST", TCPFin | TCPRst},
|
||||
{"SYN+FIN", TCPSyn | TCPFin},
|
||||
{"SYN+FIN+RST", TCPSyn | TCPFin | TCPRst},
|
||||
}
|
||||
|
||||
for _, c := range combos {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
before := conn.GetState()
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, c.flags, 0)
|
||||
require.False(t, valid, "illegal flag combo must be rejected: %s", c.name)
|
||||
require.Equal(t, before, conn.GetState(),
|
||||
"illegal flag combo must not change state")
|
||||
require.False(t, conn.IsTombstone())
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,235 +0,0 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// These tests exercise cases where the TCP state machine currently advances
|
||||
// on retransmitted or wrong-direction segments and tears the flow down
|
||||
// prematurely. They are expected to fail until the direction checks are added.
|
||||
|
||||
func TestTCPCloseWaitRetransmittedPeerFIN(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Peer sends FIN -> CloseWait (our app has not yet closed).
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
conn := tracker.connections[key]
|
||||
require.Equal(t, TCPStateCloseWait, conn.GetState())
|
||||
|
||||
// Peer retransmits their FIN (ACK may have been delayed). We have NOT
|
||||
// sent our FIN yet, so state must remain CloseWait.
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||
require.True(t, valid, "retransmitted peer FIN must still be accepted")
|
||||
require.Equal(t, TCPStateCloseWait, conn.GetState(),
|
||||
"retransmitted peer FIN must not advance CloseWait to LastAck")
|
||||
|
||||
// Our app finally closes -> LastAck.
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
require.Equal(t, TCPStateLastAck, conn.GetState())
|
||||
|
||||
// Peer ACK closes.
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
require.Equal(t, TCPStateClosed, conn.GetState())
|
||||
}
|
||||
|
||||
func TestTCPFinWait2RetransmittedOwnFIN(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// We initiate close.
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
conn := tracker.connections[key]
|
||||
require.Equal(t, TCPStateFinWait2, conn.GetState())
|
||||
|
||||
// Stray retransmit of our own FIN (same direction as originator) must
|
||||
// NOT advance FinWait2 to TimeWait; only the peer's FIN should.
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
require.Equal(t, TCPStateFinWait2, conn.GetState(),
|
||||
"own FIN retransmit must not advance FinWait2 to TimeWait")
|
||||
|
||||
// Peer FIN -> TimeWait.
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||
}
|
||||
|
||||
func TestTCPLastAckDirectionCheck(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Drive to LastAck: peer FIN -> CloseWait, our FIN -> LastAck.
|
||||
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0))
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
conn := tracker.connections[key]
|
||||
require.Equal(t, TCPStateLastAck, conn.GetState())
|
||||
|
||||
// Our own ACK retransmit (same direction as originator) must NOT close.
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
require.Equal(t, TCPStateLastAck, conn.GetState(),
|
||||
"own ACK retransmit in LastAck must not transition to Closed")
|
||||
|
||||
// Peer's ACK -> Closed.
|
||||
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0))
|
||||
require.Equal(t, TCPStateClosed, conn.GetState())
|
||||
}
|
||||
|
||||
func TestTCPFinWait1OwnAckDoesNotAdvance(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
conn := tracker.connections[key]
|
||||
require.Equal(t, TCPStateFinWait1, conn.GetState())
|
||||
|
||||
// Our own ACK retransmit (same direction as originator) must not advance.
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
require.Equal(t, TCPStateFinWait1, conn.GetState(),
|
||||
"own ACK in FinWait1 must not advance to FinWait2")
|
||||
}
|
||||
|
||||
func TestTCPPerStateTeardownTimeouts(t *testing.T) {
|
||||
// Verify cleanup reaps entries in each teardown state at the configured
|
||||
// per-state timeout, not at the single handshake timeout.
|
||||
t.Setenv(EnvTCPFinWaitTimeout, "50ms")
|
||||
t.Setenv(EnvTCPCloseWaitTimeout, "80ms")
|
||||
t.Setenv(EnvTCPLastAckTimeout, "30ms")
|
||||
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
dstPort := uint16(80)
|
||||
|
||||
// Drives a connection to the target state, forces its lastSeen well
|
||||
// beyond the configured timeout, runs cleanup, and asserts reaping.
|
||||
cases := []struct {
|
||||
name string
|
||||
// drive takes a fresh tracker and returns the conn key after
|
||||
// transitioning the flow into the intended teardown state.
|
||||
drive func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState)
|
||||
}{
|
||||
{
|
||||
name: "FinWait1",
|
||||
drive: func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState) {
|
||||
establishConnection(t, tr, srcIP, dstIP, srcPort, dstPort)
|
||||
tr.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // → FinWait1
|
||||
return ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}, TCPStateFinWait1
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "FinWait2",
|
||||
drive: func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState) {
|
||||
establishConnection(t, tr, srcIP, dstIP, srcPort, dstPort)
|
||||
tr.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // FinWait1
|
||||
require.True(t, tr.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)) // → FinWait2
|
||||
return ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}, TCPStateFinWait2
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "CloseWait",
|
||||
drive: func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState) {
|
||||
establishConnection(t, tr, srcIP, dstIP, srcPort, dstPort)
|
||||
require.True(t, tr.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)) // → CloseWait
|
||||
return ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}, TCPStateCloseWait
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "LastAck",
|
||||
drive: func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState) {
|
||||
establishConnection(t, tr, srcIP, dstIP, srcPort, dstPort)
|
||||
require.True(t, tr.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)) // CloseWait
|
||||
tr.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // → LastAck
|
||||
return ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}, TCPStateLastAck
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Use a unique source port per subtest so nothing aliases.
|
||||
port := uint16(12345)
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
require.Equal(t, 50*time.Millisecond, tracker.finWaitTimeout)
|
||||
require.Equal(t, 80*time.Millisecond, tracker.closeWaitTimeout)
|
||||
require.Equal(t, 30*time.Millisecond, tracker.lastAckTimeout)
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
port++
|
||||
key, wantState := c.drive(t, tracker, srcIP, port)
|
||||
conn := tracker.connections[key]
|
||||
require.NotNil(t, conn)
|
||||
require.Equal(t, wantState, conn.GetState())
|
||||
|
||||
// Age the entry past the largest per-state timeout.
|
||||
conn.lastSeen.Store(time.Now().Add(-500 * time.Millisecond).UnixNano())
|
||||
tracker.cleanup()
|
||||
_, exists := tracker.connections[key]
|
||||
require.False(t, exists, "%s entry should be reaped", c.name)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTCPEstablishedPSHACKInFinStates(t *testing.T) {
|
||||
// Verifies FIN|PSH|ACK and bare ACK keepalives are not dropped in FIN
|
||||
// teardown states, which some stacks emit during close.
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Peer FIN -> CloseWait.
|
||||
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0))
|
||||
|
||||
// Peer pushes trailing data + FIN|PSH|ACK (legal).
|
||||
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPPush|TCPAck, 100),
|
||||
"FIN|PSH|ACK in CloseWait must be accepted")
|
||||
|
||||
// Bare ACK keepalive from peer in CloseWait must be accepted.
|
||||
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0),
|
||||
"bare ACK in CloseWait must be accepted")
|
||||
}
|
||||
@@ -17,9 +17,6 @@ const (
|
||||
DefaultUDPTimeout = 30 * time.Second
|
||||
// UDPCleanupInterval is how often we check for stale connections
|
||||
UDPCleanupInterval = 15 * time.Second
|
||||
|
||||
// EnvUDPMaxEntries caps the UDP conntrack table size.
|
||||
EnvUDPMaxEntries = "NB_CONNTRACK_UDP_MAX"
|
||||
)
|
||||
|
||||
// UDPConnTrack represents a UDP connection state
|
||||
@@ -37,7 +34,6 @@ type UDPTracker struct {
|
||||
cleanupTicker *time.Ticker
|
||||
tickerCancel context.CancelFunc
|
||||
mutex sync.RWMutex
|
||||
maxEntries int
|
||||
flowLogger nftypes.FlowLogger
|
||||
}
|
||||
|
||||
@@ -55,7 +51,6 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
|
||||
timeout: timeout,
|
||||
cleanupTicker: time.NewTicker(UDPCleanupInterval),
|
||||
tickerCancel: cancel,
|
||||
maxEntries: envInt(logger, EnvUDPMaxEntries, DefaultMaxUDPEntries),
|
||||
flowLogger: flowLogger,
|
||||
}
|
||||
|
||||
@@ -122,18 +117,13 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
|
||||
conn.UpdateCounters(direction, size)
|
||||
|
||||
t.mutex.Lock()
|
||||
if t.maxEntries > 0 && len(t.connections) >= t.maxEntries {
|
||||
t.evictOneLocked()
|
||||
}
|
||||
t.connections[key] = conn
|
||||
t.mutex.Unlock()
|
||||
|
||||
if t.logger.Enabled(nblog.LevelTrace) {
|
||||
if origPort != 0 {
|
||||
t.logger.Trace4("New %s UDP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
|
||||
} else {
|
||||
t.logger.Trace2("New %s UDP connection: %s", direction, key)
|
||||
}
|
||||
if origPort != 0 {
|
||||
t.logger.Trace4("New %s UDP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
|
||||
} else {
|
||||
t.logger.Trace2("New %s UDP connection: %s", direction, key)
|
||||
}
|
||||
t.sendEvent(nftypes.TypeStart, conn, ruleID)
|
||||
}
|
||||
@@ -161,34 +151,6 @@ func (t *UDPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort
|
||||
return true
|
||||
}
|
||||
|
||||
// evictOneLocked removes one entry to make room. Caller must hold t.mutex.
|
||||
// Bounded sample: picks the oldest among up to evictSampleSize entries.
|
||||
func (t *UDPTracker) evictOneLocked() {
|
||||
var candKey ConnKey
|
||||
var candSeen int64
|
||||
haveCand := false
|
||||
sampled := 0
|
||||
|
||||
for k, c := range t.connections {
|
||||
seen := c.lastSeen.Load()
|
||||
if !haveCand || seen < candSeen {
|
||||
candKey = k
|
||||
candSeen = seen
|
||||
haveCand = true
|
||||
}
|
||||
sampled++
|
||||
if sampled >= evictSampleSize {
|
||||
break
|
||||
}
|
||||
}
|
||||
if haveCand {
|
||||
if evicted := t.connections[candKey]; evicted != nil {
|
||||
t.sendEvent(nftypes.TypeEnd, evicted, nil)
|
||||
}
|
||||
delete(t.connections, candKey)
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupRoutine periodically removes stale connections
|
||||
func (t *UDPTracker) cleanupRoutine(ctx context.Context) {
|
||||
defer t.cleanupTicker.Stop()
|
||||
@@ -211,10 +173,8 @@ func (t *UDPTracker) cleanup() {
|
||||
if conn.timeoutExceeded(t.timeout) {
|
||||
delete(t.connections, key)
|
||||
|
||||
if t.logger.Enabled(nblog.LevelTrace) {
|
||||
t.logger.Trace5("Removed UDP connection %s (timeout) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
|
||||
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
}
|
||||
t.logger.Trace5("Removed UDP connection %s (timeout) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
|
||||
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -787,9 +787,7 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
|
||||
|
||||
srcIP, dstIP := m.extractIPs(d)
|
||||
if !srcIP.IsValid() {
|
||||
if m.logger.Enabled(nblog.LevelError) {
|
||||
m.logger.Error1("Unknown network layer: %v", d.decoded[0])
|
||||
}
|
||||
m.logger.Error1("Unknown network layer: %v", d.decoded[0])
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -903,9 +901,7 @@ func (m *Manager) clampTCPMSS(packetData []byte, d *decoder) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
if m.logger.Enabled(nblog.LevelTrace) {
|
||||
m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, mssClampValue)
|
||||
}
|
||||
m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, mssClampValue)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -1048,13 +1044,11 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
|
||||
|
||||
// TODO: pass fragments of routed packets to forwarder
|
||||
if fragment {
|
||||
if m.logger.Enabled(nblog.LevelTrace) {
|
||||
if d.decoded[0] == layers.LayerTypeIPv4 {
|
||||
m.logger.Trace4("packet is a fragment: src=%v dst=%v id=%v flags=%v",
|
||||
srcIP, dstIP, d.ip4.Id, d.ip4.Flags)
|
||||
} else {
|
||||
m.logger.Trace2("packet is an IPv6 fragment: src=%v dst=%v", srcIP, dstIP)
|
||||
}
|
||||
if d.decoded[0] == layers.LayerTypeIPv4 {
|
||||
m.logger.Trace4("packet is a fragment: src=%v dst=%v id=%v flags=%v",
|
||||
srcIP, dstIP, d.ip4.Id, d.ip4.Flags)
|
||||
} else {
|
||||
m.logger.Trace2("packet is an IPv6 fragment: src=%v dst=%v", srcIP, dstIP)
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -1097,10 +1091,8 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
|
||||
pnum := getProtocolFromPacket(d)
|
||||
srcPort, dstPort := getPortsFromPacket(d)
|
||||
|
||||
if m.logger.Enabled(nblog.LevelTrace) {
|
||||
m.logger.Trace6("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
|
||||
ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
|
||||
}
|
||||
m.logger.Trace6("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
|
||||
ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
|
||||
|
||||
m.flowLogger.StoreEvent(nftypes.EventFields{
|
||||
FlowID: uuid.New(),
|
||||
@@ -1150,10 +1142,8 @@ func (m *Manager) handleForwardedLocalTraffic(packetData []byte) bool {
|
||||
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool {
|
||||
// Drop if routing is disabled
|
||||
if !m.routingEnabled.Load() {
|
||||
if m.logger.Enabled(nblog.LevelTrace) {
|
||||
m.logger.Trace2("Dropping routed packet (routing disabled): src=%s dst=%s",
|
||||
srcIP, dstIP)
|
||||
}
|
||||
m.logger.Trace2("Dropping routed packet (routing disabled): src=%s dst=%s",
|
||||
srcIP, dstIP)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -1170,10 +1160,8 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
|
||||
if !pass {
|
||||
proto := getProtocolFromPacket(d)
|
||||
|
||||
if m.logger.Enabled(nblog.LevelTrace) {
|
||||
m.logger.Trace6("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
|
||||
ruleID, proto, srcIP, srcPort, dstIP, dstPort)
|
||||
}
|
||||
m.logger.Trace6("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
|
||||
ruleID, proto, srcIP, srcPort, dstIP, dstPort)
|
||||
|
||||
m.flowLogger.StoreEvent(nftypes.EventFields{
|
||||
FlowID: uuid.New(),
|
||||
@@ -1299,9 +1287,7 @@ func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) {
|
||||
// It returns true, true if the packet is a fragment and valid.
|
||||
func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) {
|
||||
if err := d.decodePacket(packetData); err != nil {
|
||||
if m.logger.Enabled(nblog.LevelTrace) {
|
||||
m.logger.Trace1("couldn't decode packet, err: %s", err)
|
||||
}
|
||||
m.logger.Trace1("couldn't decode packet, err: %s", err)
|
||||
return false, false
|
||||
}
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
)
|
||||
|
||||
@@ -98,10 +97,8 @@ func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []by
|
||||
return nil, fmt.Errorf("write ICMP packet: %w", err)
|
||||
}
|
||||
|
||||
if f.logger.Enabled(nblog.LevelTrace) {
|
||||
f.logger.Trace3("forwarder: Forwarded ICMP packet %v type %v code %v",
|
||||
epID(id), icmpType, icmpCode)
|
||||
}
|
||||
f.logger.Trace3("forwarder: Forwarded ICMP packet %v type %v code %v",
|
||||
epID(id), icmpType, icmpCode)
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
@@ -124,14 +121,12 @@ func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndp
|
||||
txBytes := f.handleEchoResponse(conn, id, v6)
|
||||
rtt := time.Since(sendTime).Round(10 * time.Microsecond)
|
||||
|
||||
if f.logger.Enabled(nblog.LevelTrace) {
|
||||
proto := "ICMP"
|
||||
if v6 {
|
||||
proto = "ICMPv6"
|
||||
}
|
||||
f.logger.Trace5("forwarder: Forwarded %s echo reply %v type %v code %v (rtt=%v, raw socket)",
|
||||
proto, epID(id), icmpType, icmpCode, rtt)
|
||||
proto := "ICMP"
|
||||
if v6 {
|
||||
proto = "ICMPv6"
|
||||
}
|
||||
f.logger.Trace5("forwarder: Forwarded %s echo reply %v type %v code %v (rtt=%v, raw socket)",
|
||||
proto, epID(id), icmpType, icmpCode, rtt)
|
||||
|
||||
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
|
||||
}
|
||||
@@ -229,17 +224,13 @@ func (f *Forwarder) handleICMPViaPing(flowID uuid.UUID, id stack.TransportEndpoi
|
||||
}
|
||||
rtt := time.Since(pingStart).Round(10 * time.Microsecond)
|
||||
|
||||
if f.logger.Enabled(nblog.LevelTrace) {
|
||||
f.logger.Trace3("forwarder: Forwarded ICMP echo request %v type %v code %v",
|
||||
epID(id), icmpType, icmpCode)
|
||||
}
|
||||
f.logger.Trace3("forwarder: Forwarded ICMP echo request %v type %v code %v",
|
||||
epID(id), icmpType, icmpCode)
|
||||
|
||||
txBytes := f.synthesizeEchoReply(id, icmpData)
|
||||
|
||||
if f.logger.Enabled(nblog.LevelTrace) {
|
||||
f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, ping binary)",
|
||||
epID(id), icmpType, icmpCode, rtt)
|
||||
}
|
||||
f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, ping binary)",
|
||||
epID(id), icmpType, icmpCode, rtt)
|
||||
|
||||
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
|
||||
}
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
package forwarder
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
@@ -12,9 +15,7 @@ import (
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||
"gvisor.dev/gvisor/pkg/waiter"
|
||||
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
"github.com/netbirdio/netbird/util/netrelay"
|
||||
)
|
||||
|
||||
// handleTCP is called by the TCP forwarder for new connections.
|
||||
@@ -36,9 +37,7 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
||||
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
|
||||
if err != nil {
|
||||
r.Complete(true)
|
||||
if f.logger.Enabled(nblog.LevelTrace) {
|
||||
f.logger.Trace2("forwarder: dial error for %v: %v", epID(id), err)
|
||||
}
|
||||
f.logger.Trace2("forwarder: dial error for %v: %v", epID(id), err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -61,22 +60,64 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
||||
inConn := gonet.NewTCPConn(&wq, ep)
|
||||
|
||||
success = true
|
||||
if f.logger.Enabled(nblog.LevelTrace) {
|
||||
f.logger.Trace1("forwarder: established TCP connection %v", epID(id))
|
||||
}
|
||||
f.logger.Trace1("forwarder: established TCP connection %v", epID(id))
|
||||
|
||||
go f.proxyTCP(id, inConn, outConn, ep, flowID)
|
||||
}
|
||||
|
||||
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) {
|
||||
// netrelay.Relay copies bidirectionally with proper half-close propagation
|
||||
// and fully closes both conns before returning.
|
||||
bytesFromInToOut, bytesFromOutToIn := netrelay.Relay(f.ctx, inConn, outConn, netrelay.Options{
|
||||
Logger: f.logger,
|
||||
})
|
||||
|
||||
// Close the netstack endpoint after both conns are drained.
|
||||
ep.Close()
|
||||
ctx, cancel := context.WithCancel(f.ctx)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
// Close connections and endpoint.
|
||||
if err := inConn.Close(); err != nil && !isClosedError(err) {
|
||||
f.logger.Debug1("forwarder: inConn close error: %v", err)
|
||||
}
|
||||
if err := outConn.Close(); err != nil && !isClosedError(err) {
|
||||
f.logger.Debug1("forwarder: outConn close error: %v", err)
|
||||
}
|
||||
|
||||
ep.Close()
|
||||
}()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
var (
|
||||
bytesFromInToOut int64 // bytes from client to server (tx for client)
|
||||
bytesFromOutToIn int64 // bytes from server to client (rx for client)
|
||||
errInToOut error
|
||||
errOutToIn error
|
||||
)
|
||||
|
||||
go func() {
|
||||
bytesFromInToOut, errInToOut = io.Copy(outConn, inConn)
|
||||
cancel()
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
|
||||
bytesFromOutToIn, errOutToIn = io.Copy(inConn, outConn)
|
||||
cancel()
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if errInToOut != nil {
|
||||
if !isClosedError(errInToOut) {
|
||||
f.logger.Error2("proxyTCP: copy error (in → out) for %s: %v", epID(id), errInToOut)
|
||||
}
|
||||
}
|
||||
if errOutToIn != nil {
|
||||
if !isClosedError(errOutToIn) {
|
||||
f.logger.Error2("proxyTCP: copy error (out → in) for %s: %v", epID(id), errOutToIn)
|
||||
}
|
||||
}
|
||||
|
||||
var rxPackets, txPackets uint64
|
||||
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
|
||||
@@ -85,9 +126,7 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
|
||||
txPackets = tcpStats.SegmentsReceived.Value()
|
||||
}
|
||||
|
||||
if f.logger.Enabled(nblog.LevelTrace) {
|
||||
f.logger.Trace5("forwarder: Removed TCP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, bytesFromOutToIn, txPackets, bytesFromInToOut)
|
||||
}
|
||||
f.logger.Trace5("forwarder: Removed TCP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, bytesFromOutToIn, txPackets, bytesFromInToOut)
|
||||
|
||||
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, uint64(bytesFromOutToIn), uint64(bytesFromInToOut), rxPackets, txPackets)
|
||||
}
|
||||
|
||||
@@ -125,9 +125,7 @@ func (f *udpForwarder) cleanup() {
|
||||
delete(f.conns, idle.id)
|
||||
f.Unlock()
|
||||
|
||||
if f.logger.Enabled(nblog.LevelTrace) {
|
||||
f.logger.Trace1("forwarder: cleaned up idle UDP connection %v", epID(idle.id))
|
||||
}
|
||||
f.logger.Trace1("forwarder: cleaned up idle UDP connection %v", epID(idle.id))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -146,9 +144,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
|
||||
_, exists := f.udpForwarder.conns[id]
|
||||
f.udpForwarder.RUnlock()
|
||||
if exists {
|
||||
if f.logger.Enabled(nblog.LevelTrace) {
|
||||
f.logger.Trace1("forwarder: existing UDP connection for %v", epID(id))
|
||||
}
|
||||
f.logger.Trace1("forwarder: existing UDP connection for %v", epID(id))
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -210,9 +206,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
|
||||
f.udpForwarder.Unlock()
|
||||
|
||||
success = true
|
||||
if f.logger.Enabled(nblog.LevelTrace) {
|
||||
f.logger.Trace1("forwarder: established UDP connection %v", epID(id))
|
||||
}
|
||||
f.logger.Trace1("forwarder: established UDP connection %v", epID(id))
|
||||
|
||||
go f.proxyUDP(connCtx, pConn, id, ep)
|
||||
return true
|
||||
@@ -271,9 +265,7 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
|
||||
txPackets = udpStats.PacketsReceived.Value()
|
||||
}
|
||||
|
||||
if f.logger.Enabled(nblog.LevelTrace) {
|
||||
f.logger.Trace5("forwarder: Removed UDP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, rxBytes, txPackets, txBytes)
|
||||
}
|
||||
f.logger.Trace5("forwarder: Removed UDP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, rxBytes, txPackets, txBytes)
|
||||
|
||||
f.udpForwarder.Lock()
|
||||
delete(f.udpForwarder.conns, id)
|
||||
|
||||
@@ -53,17 +53,16 @@ var levelStrings = map[Level]string{
|
||||
}
|
||||
|
||||
type logMessage struct {
|
||||
level Level
|
||||
argCount uint8
|
||||
format string
|
||||
arg1 any
|
||||
arg2 any
|
||||
arg3 any
|
||||
arg4 any
|
||||
arg5 any
|
||||
arg6 any
|
||||
arg7 any
|
||||
arg8 any
|
||||
level Level
|
||||
format string
|
||||
arg1 any
|
||||
arg2 any
|
||||
arg3 any
|
||||
arg4 any
|
||||
arg5 any
|
||||
arg6 any
|
||||
arg7 any
|
||||
arg8 any
|
||||
}
|
||||
|
||||
// Logger is a high-performance, non-blocking logger
|
||||
@@ -108,13 +107,6 @@ func (l *Logger) SetLevel(level Level) {
|
||||
log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level])
|
||||
}
|
||||
|
||||
// Enabled reports whether the given level is currently logged. Callers on the
|
||||
// hot path should guard log sites with this to avoid boxing arguments into
|
||||
// any when the level is off.
|
||||
func (l *Logger) Enabled(level Level) bool {
|
||||
return l.level.Load() >= uint32(level)
|
||||
}
|
||||
|
||||
func (l *Logger) Error(format string) {
|
||||
if l.level.Load() >= uint32(LevelError) {
|
||||
select {
|
||||
@@ -163,7 +155,7 @@ func (l *Logger) Trace(format string) {
|
||||
func (l *Logger) Error1(format string, arg1 any) {
|
||||
if l.level.Load() >= uint32(LevelError) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelError, argCount: 1, format: format, arg1: arg1}:
|
||||
case l.msgChannel <- logMessage{level: LevelError, format: format, arg1: arg1}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -172,16 +164,7 @@ func (l *Logger) Error1(format string, arg1 any) {
|
||||
func (l *Logger) Error2(format string, arg1, arg2 any) {
|
||||
if l.level.Load() >= uint32(LevelError) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelError, argCount: 2, format: format, arg1: arg1, arg2: arg2}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Warn2(format string, arg1, arg2 any) {
|
||||
if l.level.Load() >= uint32(LevelWarn) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelWarn, argCount: 2, format: format, arg1: arg1, arg2: arg2}:
|
||||
case l.msgChannel <- logMessage{level: LevelError, format: format, arg1: arg1, arg2: arg2}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -190,7 +173,7 @@ func (l *Logger) Warn2(format string, arg1, arg2 any) {
|
||||
func (l *Logger) Warn3(format string, arg1, arg2, arg3 any) {
|
||||
if l.level.Load() >= uint32(LevelWarn) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelWarn, argCount: 3, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
|
||||
case l.msgChannel <- logMessage{level: LevelWarn, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -199,7 +182,7 @@ func (l *Logger) Warn3(format string, arg1, arg2, arg3 any) {
|
||||
func (l *Logger) Warn4(format string, arg1, arg2, arg3, arg4 any) {
|
||||
if l.level.Load() >= uint32(LevelWarn) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelWarn, argCount: 4, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}:
|
||||
case l.msgChannel <- logMessage{level: LevelWarn, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -208,7 +191,7 @@ func (l *Logger) Warn4(format string, arg1, arg2, arg3, arg4 any) {
|
||||
func (l *Logger) Debug1(format string, arg1 any) {
|
||||
if l.level.Load() >= uint32(LevelDebug) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelDebug, argCount: 1, format: format, arg1: arg1}:
|
||||
case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -217,7 +200,7 @@ func (l *Logger) Debug1(format string, arg1 any) {
|
||||
func (l *Logger) Debug2(format string, arg1, arg2 any) {
|
||||
if l.level.Load() >= uint32(LevelDebug) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelDebug, argCount: 2, format: format, arg1: arg1, arg2: arg2}:
|
||||
case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1, arg2: arg2}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -226,59 +209,16 @@ func (l *Logger) Debug2(format string, arg1, arg2 any) {
|
||||
func (l *Logger) Debug3(format string, arg1, arg2, arg3 any) {
|
||||
if l.level.Load() >= uint32(LevelDebug) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelDebug, argCount: 3, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
|
||||
case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Debugf is the variadic shape. Dispatches to Debug/Debug1/Debug2/Debug3
|
||||
// to avoid allocating an args slice on the fast path when the arg count is
|
||||
// known (0-3). Args beyond 3 land on the general variadic path; callers on
|
||||
// the hot path should prefer DebugN for known counts.
|
||||
func (l *Logger) Debugf(format string, args ...any) {
|
||||
if l.level.Load() < uint32(LevelDebug) {
|
||||
return
|
||||
}
|
||||
switch len(args) {
|
||||
case 0:
|
||||
l.Debug(format)
|
||||
case 1:
|
||||
l.Debug1(format, args[0])
|
||||
case 2:
|
||||
l.Debug2(format, args[0], args[1])
|
||||
case 3:
|
||||
l.Debug3(format, args[0], args[1], args[2])
|
||||
default:
|
||||
l.sendVariadic(LevelDebug, format, args)
|
||||
}
|
||||
}
|
||||
|
||||
// sendVariadic packs a slice of arguments into a logMessage and non-blocking
|
||||
// enqueues it. Used for arg counts beyond the fixed-arity fast paths. Args
|
||||
// beyond the 8-arg slot limit are dropped so callers don't produce silently
|
||||
// empty log lines via uint8 wraparound in argCount.
|
||||
func (l *Logger) sendVariadic(level Level, format string, args []any) {
|
||||
const maxArgs = 8
|
||||
n := len(args)
|
||||
if n > maxArgs {
|
||||
n = maxArgs
|
||||
}
|
||||
msg := logMessage{level: level, argCount: uint8(n), format: format}
|
||||
slots := [maxArgs]*any{&msg.arg1, &msg.arg2, &msg.arg3, &msg.arg4, &msg.arg5, &msg.arg6, &msg.arg7, &msg.arg8}
|
||||
for i := 0; i < n; i++ {
|
||||
*slots[i] = args[i]
|
||||
}
|
||||
select {
|
||||
case l.msgChannel <- msg:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Trace1(format string, arg1 any) {
|
||||
if l.level.Load() >= uint32(LevelTrace) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 1, format: format, arg1: arg1}:
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -287,7 +227,7 @@ func (l *Logger) Trace1(format string, arg1 any) {
|
||||
func (l *Logger) Trace2(format string, arg1, arg2 any) {
|
||||
if l.level.Load() >= uint32(LevelTrace) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 2, format: format, arg1: arg1, arg2: arg2}:
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -296,7 +236,7 @@ func (l *Logger) Trace2(format string, arg1, arg2 any) {
|
||||
func (l *Logger) Trace3(format string, arg1, arg2, arg3 any) {
|
||||
if l.level.Load() >= uint32(LevelTrace) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 3, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -305,7 +245,7 @@ func (l *Logger) Trace3(format string, arg1, arg2, arg3 any) {
|
||||
func (l *Logger) Trace4(format string, arg1, arg2, arg3, arg4 any) {
|
||||
if l.level.Load() >= uint32(LevelTrace) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 4, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}:
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -314,7 +254,7 @@ func (l *Logger) Trace4(format string, arg1, arg2, arg3, arg4 any) {
|
||||
func (l *Logger) Trace5(format string, arg1, arg2, arg3, arg4, arg5 any) {
|
||||
if l.level.Load() >= uint32(LevelTrace) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 5, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5}:
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -323,7 +263,7 @@ func (l *Logger) Trace5(format string, arg1, arg2, arg3, arg4, arg5 any) {
|
||||
func (l *Logger) Trace6(format string, arg1, arg2, arg3, arg4, arg5, arg6 any) {
|
||||
if l.level.Load() >= uint32(LevelTrace) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 6, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6}:
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -333,7 +273,7 @@ func (l *Logger) Trace6(format string, arg1, arg2, arg3, arg4, arg5, arg6 any) {
|
||||
func (l *Logger) Trace8(format string, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8 any) {
|
||||
if l.level.Load() >= uint32(LevelTrace) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 8, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6, arg7: arg7, arg8: arg8}:
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6, arg7: arg7, arg8: arg8}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -346,8 +286,35 @@ func (l *Logger) formatMessage(buf *[]byte, msg logMessage) {
|
||||
*buf = append(*buf, levelStrings[msg.level]...)
|
||||
*buf = append(*buf, ' ')
|
||||
|
||||
// Count non-nil arguments for switch
|
||||
argCount := 0
|
||||
if msg.arg1 != nil {
|
||||
argCount++
|
||||
if msg.arg2 != nil {
|
||||
argCount++
|
||||
if msg.arg3 != nil {
|
||||
argCount++
|
||||
if msg.arg4 != nil {
|
||||
argCount++
|
||||
if msg.arg5 != nil {
|
||||
argCount++
|
||||
if msg.arg6 != nil {
|
||||
argCount++
|
||||
if msg.arg7 != nil {
|
||||
argCount++
|
||||
if msg.arg8 != nil {
|
||||
argCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var formatted string
|
||||
switch msg.argCount {
|
||||
switch argCount {
|
||||
case 0:
|
||||
formatted = msg.format
|
||||
case 1:
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"github.com/google/gopacket/layers"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -263,15 +262,11 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
|
||||
}
|
||||
|
||||
if err := m.rewritePacketIP(packetData, d, translatedIP, false); err != nil {
|
||||
if m.logger.Enabled(nblog.LevelError) {
|
||||
m.logger.Error1("failed to rewrite packet destination: %v", err)
|
||||
}
|
||||
m.logger.Error1("failed to rewrite packet destination: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
if m.logger.Enabled(nblog.LevelTrace) {
|
||||
m.logger.Trace2("DNAT: %s -> %s", dstIP, translatedIP)
|
||||
}
|
||||
m.logger.Trace2("DNAT: %s -> %s", dstIP, translatedIP)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -288,15 +283,11 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
|
||||
}
|
||||
|
||||
if err := m.rewritePacketIP(packetData, d, originalIP, true); err != nil {
|
||||
if m.logger.Enabled(nblog.LevelError) {
|
||||
m.logger.Error1("failed to rewrite packet source: %v", err)
|
||||
}
|
||||
m.logger.Error1("failed to rewrite packet source: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
if m.logger.Enabled(nblog.LevelTrace) {
|
||||
m.logger.Trace2("Reverse DNAT: %s -> %s", srcIP, originalIP)
|
||||
}
|
||||
m.logger.Trace2("Reverse DNAT: %s -> %s", srcIP, originalIP)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -621,9 +612,7 @@ func (m *Manager) applyPortRule(packetData []byte, d *decoder, srcIP, dstIP neti
|
||||
}
|
||||
|
||||
if err := rewriteFn(packetData, d, rule.targetPort, destinationPortOffset); err != nil {
|
||||
if m.logger.Enabled(nblog.LevelError) {
|
||||
m.logger.Error1("failed to rewrite port: %v", err)
|
||||
}
|
||||
m.logger.Error1("failed to rewrite port: %v", err)
|
||||
return false
|
||||
}
|
||||
d.dnatOrigPort = rule.origPort
|
||||
|
||||
@@ -30,27 +30,6 @@ import (
|
||||
|
||||
var currentMTU uint16 = iface.DefaultMTU
|
||||
|
||||
// nonRetryableEDECodes lists EDE info codes (RFC 8914) for which a SERVFAIL
|
||||
// from one upstream means another upstream would return the same answer:
|
||||
// DNSSEC validation outcomes and policy-based blocks. Transient errors
|
||||
// (network, cached, not ready) are not included.
|
||||
var nonRetryableEDECodes = map[uint16]struct{}{
|
||||
dns.ExtendedErrorCodeUnsupportedDNSKEYAlgorithm: {},
|
||||
dns.ExtendedErrorCodeUnsupportedDSDigestType: {},
|
||||
dns.ExtendedErrorCodeDNSSECIndeterminate: {},
|
||||
dns.ExtendedErrorCodeDNSBogus: {},
|
||||
dns.ExtendedErrorCodeSignatureExpired: {},
|
||||
dns.ExtendedErrorCodeSignatureNotYetValid: {},
|
||||
dns.ExtendedErrorCodeDNSKEYMissing: {},
|
||||
dns.ExtendedErrorCodeRRSIGsMissing: {},
|
||||
dns.ExtendedErrorCodeNoZoneKeyBitSet: {},
|
||||
dns.ExtendedErrorCodeNSECMissing: {},
|
||||
dns.ExtendedErrorCodeBlocked: {},
|
||||
dns.ExtendedErrorCodeCensored: {},
|
||||
dns.ExtendedErrorCodeFiltered: {},
|
||||
dns.ExtendedErrorCodeProhibited: {},
|
||||
}
|
||||
|
||||
// privateClientIface is the subset of the WireGuard interface needed by GetClientPrivate.
|
||||
type privateClientIface interface {
|
||||
Name() string
|
||||
@@ -271,18 +250,6 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.Re
|
||||
var t time.Duration
|
||||
var err error
|
||||
|
||||
// Advertise EDNS0 so the upstream may include Extended DNS Errors
|
||||
// (RFC 8914) in failure responses; we use those to short-circuit
|
||||
// failover for definitive answers like DNSSEC validation failures.
|
||||
// Operate on a copy so the inbound request is unchanged: a client that
|
||||
// did not advertise EDNS0 must not see an OPT in the response.
|
||||
hadEdns := r.IsEdns0() != nil
|
||||
reqUp := r
|
||||
if !hadEdns {
|
||||
reqUp = r.Copy()
|
||||
reqUp.SetEdns0(upstreamUDPSize(), false)
|
||||
}
|
||||
|
||||
var startTime time.Time
|
||||
var upstreamProto *upstreamProtocolResult
|
||||
func() {
|
||||
@@ -290,7 +257,7 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.Re
|
||||
defer cancel()
|
||||
ctx, upstreamProto = contextWithupstreamProtocolResult(ctx)
|
||||
startTime = time.Now()
|
||||
rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), reqUp)
|
||||
rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r)
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
@@ -302,49 +269,13 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.Re
|
||||
}
|
||||
|
||||
if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused {
|
||||
if code, ok := nonRetryableEDE(rm); ok {
|
||||
resutil.SetMeta(w, "ede", edeName(code))
|
||||
if !hadEdns {
|
||||
stripOPT(rm)
|
||||
}
|
||||
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger)
|
||||
return nil
|
||||
}
|
||||
return &upstreamFailure{upstream: upstream, reason: dns.RcodeToString[rm.Rcode]}
|
||||
}
|
||||
|
||||
if !hadEdns {
|
||||
stripOPT(rm)
|
||||
}
|
||||
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger)
|
||||
return nil
|
||||
}
|
||||
|
||||
// upstreamUDPSize returns the EDNS0 UDP buffer size we advertise to upstreams,
|
||||
// derived from the tunnel MTU and bounded against underflow.
|
||||
func upstreamUDPSize() uint16 {
|
||||
if currentMTU > ipUDPHeaderSize {
|
||||
return currentMTU - ipUDPHeaderSize
|
||||
}
|
||||
return dns.MinMsgSize
|
||||
}
|
||||
|
||||
// stripOPT removes any OPT pseudo-RRs from the response's Extra section so
|
||||
// the response complies with RFC 6891 when the client did not advertise EDNS0.
|
||||
func stripOPT(rm *dns.Msg) {
|
||||
if len(rm.Extra) == 0 {
|
||||
return
|
||||
}
|
||||
out := rm.Extra[:0]
|
||||
for _, rr := range rm.Extra {
|
||||
if _, ok := rr.(*dns.OPT); ok {
|
||||
continue
|
||||
}
|
||||
out = append(out, rr)
|
||||
}
|
||||
rm.Extra = out
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, startTime time.Time) *upstreamFailure {
|
||||
if !errors.Is(err, context.DeadlineExceeded) && !isTimeout(err) {
|
||||
return &upstreamFailure{upstream: upstream, reason: err.Error()}
|
||||
@@ -406,34 +337,6 @@ func formatFailures(failures []upstreamFailure) string {
|
||||
return strings.Join(parts, ", ")
|
||||
}
|
||||
|
||||
// nonRetryableEDE returns the first non-retryable EDE code carried in the
|
||||
// response, if any.
|
||||
func nonRetryableEDE(rm *dns.Msg) (uint16, bool) {
|
||||
opt := rm.IsEdns0()
|
||||
if opt == nil {
|
||||
return 0, false
|
||||
}
|
||||
for _, o := range opt.Option {
|
||||
ede, ok := o.(*dns.EDNS0_EDE)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if _, ok := nonRetryableEDECodes[ede.InfoCode]; ok {
|
||||
return ede.InfoCode, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// edeName returns a human-readable name for an EDE code, falling back to
|
||||
// the numeric code when unknown.
|
||||
func edeName(code uint16) string {
|
||||
if name, ok := dns.ExtendedErrorCodeToString[code]; ok {
|
||||
return name
|
||||
}
|
||||
return fmt.Sprintf("EDE %d", code)
|
||||
}
|
||||
|
||||
// ProbeAvailability tests all upstream servers simultaneously and
|
||||
// disables the resolver if none work
|
||||
func (u *upstreamResolverBase) ProbeAvailability(ctx context.Context) {
|
||||
|
||||
@@ -770,132 +770,3 @@ func TestExchangeWithFallback_TCPTruncatesToClientSize(t *testing.T) {
|
||||
assert.Less(t, len(rm2.Answer), 20, "small EDNS0 client should get fewer records")
|
||||
assert.True(t, rm2.Truncated, "response should be truncated for small buffer client")
|
||||
}
|
||||
|
||||
func msgWithEDE(rcode int, codes ...uint16) *dns.Msg {
|
||||
m := new(dns.Msg)
|
||||
m.Response = true
|
||||
m.Rcode = rcode
|
||||
if len(codes) == 0 {
|
||||
return m
|
||||
}
|
||||
opt := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}}
|
||||
opt.SetUDPSize(dns.MinMsgSize)
|
||||
for _, c := range codes {
|
||||
opt.Option = append(opt.Option, &dns.EDNS0_EDE{InfoCode: c})
|
||||
}
|
||||
m.Extra = append(m.Extra, opt)
|
||||
return m
|
||||
}
|
||||
|
||||
func TestNonRetryableEDE(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msg *dns.Msg
|
||||
wantOK bool
|
||||
wantCode uint16
|
||||
}{
|
||||
{name: "no edns0", msg: msgWithEDE(dns.RcodeServerFailure)},
|
||||
{
|
||||
name: "opt without ede",
|
||||
msg: func() *dns.Msg {
|
||||
m := msgWithEDE(dns.RcodeServerFailure)
|
||||
opt := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}}
|
||||
opt.Option = append(opt.Option, &dns.EDNS0_NSID{Code: dns.EDNS0NSID})
|
||||
m.Extra = []dns.RR{opt}
|
||||
return m
|
||||
}(),
|
||||
},
|
||||
{name: "ede dnsbogus", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeDNSBogus), wantOK: true, wantCode: dns.ExtendedErrorCodeDNSBogus},
|
||||
{name: "ede signature expired", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeSignatureExpired), wantOK: true, wantCode: dns.ExtendedErrorCodeSignatureExpired},
|
||||
{name: "ede blocked", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeBlocked), wantOK: true, wantCode: dns.ExtendedErrorCodeBlocked},
|
||||
{name: "ede prohibited", msg: msgWithEDE(dns.RcodeRefused, dns.ExtendedErrorCodeProhibited), wantOK: true, wantCode: dns.ExtendedErrorCodeProhibited},
|
||||
{name: "ede cached error retryable", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeCachedError)},
|
||||
{name: "ede network error retryable", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeNetworkError)},
|
||||
{name: "ede not ready retryable", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeNotReady)},
|
||||
{
|
||||
name: "first non-retryable wins",
|
||||
msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeNetworkError, dns.ExtendedErrorCodeDNSBogus),
|
||||
wantOK: true,
|
||||
wantCode: dns.ExtendedErrorCodeDNSBogus,
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
code, ok := nonRetryableEDE(tc.msg)
|
||||
assert.Equal(t, tc.wantOK, ok, "ok should match")
|
||||
if tc.wantOK {
|
||||
assert.Equal(t, tc.wantCode, code, "code should match")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEDEName(t *testing.T) {
|
||||
assert.Equal(t, "DNSSEC Bogus", edeName(dns.ExtendedErrorCodeDNSBogus))
|
||||
assert.Equal(t, "Signature Expired", edeName(dns.ExtendedErrorCodeSignatureExpired))
|
||||
assert.Equal(t, "EDE 9999", edeName(9999), "unknown code falls back to numeric")
|
||||
}
|
||||
|
||||
func TestStripOPT(t *testing.T) {
|
||||
rm := &dns.Msg{
|
||||
Extra: []dns.RR{
|
||||
&dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}},
|
||||
&dns.A{Hdr: dns.RR_Header{Name: "x.", Rrtype: dns.TypeA}, A: net.IPv4(1, 2, 3, 4)},
|
||||
},
|
||||
}
|
||||
stripOPT(rm)
|
||||
assert.Len(t, rm.Extra, 1, "OPT should be removed, A kept")
|
||||
_, isOPT := rm.Extra[0].(*dns.OPT)
|
||||
assert.False(t, isOPT, "remaining record must not be OPT")
|
||||
}
|
||||
|
||||
func TestUpstreamResolver_NonRetryableEDEShortCircuits(t *testing.T) {
|
||||
upstream1 := netip.MustParseAddrPort("192.0.2.1:53")
|
||||
upstream2 := netip.MustParseAddrPort("192.0.2.2:53")
|
||||
|
||||
servfailWithEDE := msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeDNSBogus)
|
||||
successResp := buildMockResponse(dns.RcodeSuccess, "192.0.2.100")
|
||||
|
||||
var queried []string
|
||||
tracking := &trackingMockClient{
|
||||
inner: &mockUpstreamResolverPerServer{
|
||||
responses: map[string]mockUpstreamResponse{
|
||||
upstream1.String(): {msg: servfailWithEDE},
|
||||
upstream2.String(): {msg: successResp},
|
||||
},
|
||||
rtt: time.Millisecond,
|
||||
},
|
||||
queriedUpstreams: &queried,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
resolver := &upstreamResolverBase{
|
||||
ctx: ctx,
|
||||
upstreamClient: tracking,
|
||||
upstreamServers: []netip.AddrPort{upstream1, upstream2},
|
||||
upstreamTimeout: UpstreamTimeout,
|
||||
}
|
||||
|
||||
var written *dns.Msg
|
||||
w := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
written = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
// Client query without EDNS0 must not see an OPT in the response.
|
||||
q := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
|
||||
resolver.ServeDNS(w, q)
|
||||
|
||||
require.NotNil(t, written, "response must be written")
|
||||
assert.Equal(t, dns.RcodeServerFailure, written.Rcode, "SERVFAIL must propagate")
|
||||
assert.Len(t, queried, 1, "only first upstream should be queried")
|
||||
assert.Equal(t, upstream1.String(), queried[0])
|
||||
for _, rr := range written.Extra {
|
||||
_, isOPT := rr.(*dns.OPT)
|
||||
assert.False(t, isOPT, "synthetic OPT must not leak to a non-EDNS0 client")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1979,6 +1979,29 @@ func (e *Engine) GetClientMetrics() *metrics.ClientMetrics {
|
||||
return e.clientMetrics
|
||||
}
|
||||
|
||||
// Performance bundles runtime-adjustable tunnel pool knobs.
|
||||
// See Engine.SetPerformance. Nil fields are ignored.
|
||||
type Performance struct {
|
||||
PreallocatedBuffersPerPool *uint32
|
||||
}
|
||||
|
||||
// SetPerformance applies the given tuning to this engine's live Device.
|
||||
func (e *Engine) SetPerformance(t Performance) error {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
if e.wgInterface == nil {
|
||||
return fmt.Errorf("wg interface not initialized")
|
||||
}
|
||||
dev := e.wgInterface.GetWGDevice()
|
||||
if dev == nil {
|
||||
return fmt.Errorf("wg device not initialized")
|
||||
}
|
||||
if t.PreallocatedBuffersPerPool != nil {
|
||||
dev.SetPreallocatedBuffersPerPool(*t.PreallocatedBuffersPerPool)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
|
||||
iface, err := net.InterfaceByName(ifaceName)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,93 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
)
|
||||
|
||||
func TestPersistLoginOverrides(t *testing.T) {
|
||||
strPtr := func(s string) *string { return &s }
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
initialMgmtURL string
|
||||
initialPSK string
|
||||
newMgmtURL string
|
||||
newPSK *string
|
||||
wantMgmtURL string
|
||||
wantPSK string
|
||||
}{
|
||||
{
|
||||
name: "persist new management URL",
|
||||
initialMgmtURL: "https://old.example.com:33073",
|
||||
newMgmtURL: "https://new.example.com:33073",
|
||||
wantMgmtURL: "https://new.example.com:33073",
|
||||
},
|
||||
{
|
||||
name: "persist new pre-shared key",
|
||||
initialMgmtURL: "https://existing.example.com:33073",
|
||||
initialPSK: "old-key",
|
||||
newPSK: strPtr("new-key"),
|
||||
wantMgmtURL: "https://existing.example.com:33073",
|
||||
wantPSK: "new-key",
|
||||
},
|
||||
{
|
||||
name: "persist both",
|
||||
initialMgmtURL: "https://old.example.com:33073",
|
||||
initialPSK: "old-key",
|
||||
newMgmtURL: "https://new.example.com:33073",
|
||||
newPSK: strPtr("new-key"),
|
||||
wantMgmtURL: "https://new.example.com:33073",
|
||||
wantPSK: "new-key",
|
||||
},
|
||||
{
|
||||
name: "no inputs preserves existing",
|
||||
initialMgmtURL: "https://existing.example.com:33073",
|
||||
initialPSK: "existing-key",
|
||||
wantMgmtURL: "https://existing.example.com:33073",
|
||||
wantPSK: "existing-key",
|
||||
},
|
||||
{
|
||||
name: "empty PSK pointer is ignored",
|
||||
initialMgmtURL: "https://existing.example.com:33073",
|
||||
initialPSK: "existing-key",
|
||||
newPSK: strPtr(""),
|
||||
wantMgmtURL: "https://existing.example.com:33073",
|
||||
wantPSK: "existing-key",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
origDefault := profilemanager.DefaultConfigPath
|
||||
t.Cleanup(func() { profilemanager.DefaultConfigPath = origDefault })
|
||||
|
||||
dir := t.TempDir()
|
||||
profilemanager.DefaultConfigPath = filepath.Join(dir, "default.json")
|
||||
|
||||
seed := profilemanager.ConfigInput{
|
||||
ConfigPath: profilemanager.DefaultConfigPath,
|
||||
ManagementURL: tt.initialMgmtURL,
|
||||
}
|
||||
if tt.initialPSK != "" {
|
||||
seed.PreSharedKey = strPtr(tt.initialPSK)
|
||||
}
|
||||
_, err := profilemanager.UpdateOrCreateConfig(seed)
|
||||
require.NoError(t, err, "seed config")
|
||||
|
||||
activeProf := &profilemanager.ActiveProfileState{Name: "default"}
|
||||
err = persistLoginOverrides(activeProf, tt.newMgmtURL, tt.newPSK)
|
||||
require.NoError(t, err, "persistLoginOverrides")
|
||||
|
||||
cfg, err := profilemanager.ReadConfig(profilemanager.DefaultConfigPath)
|
||||
require.NoError(t, err, "read back config")
|
||||
|
||||
require.Equal(t, tt.wantMgmtURL, cfg.ManagementURL.String(), "management URL")
|
||||
require.Equal(t, tt.wantPSK, cfg.PreSharedKey, "pre-shared key")
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -490,11 +490,6 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
|
||||
|
||||
s.mutex.Unlock()
|
||||
|
||||
if err := persistLoginOverrides(activeProf, msg.ManagementUrl, msg.OptionalPreSharedKey); err != nil {
|
||||
log.Errorf("failed to persist login overrides: %v", err)
|
||||
return nil, fmt.Errorf("persist login overrides: %w", err)
|
||||
}
|
||||
|
||||
config, _, err := s.getConfig(activeProf)
|
||||
if err != nil {
|
||||
log.Errorf("failed to get active profile config: %v", err)
|
||||
@@ -969,7 +964,7 @@ func (s *Server) handleActiveProfileLogout(ctx context.Context) (*proto.LogoutRe
|
||||
return &proto.LogoutResponse{}, nil
|
||||
}
|
||||
|
||||
// getConfig reads config file and returns Config and whether the config file already existed. Errors out if it does not exist
|
||||
// GetConfig reads config file and returns Config and whether the config file already existed. Errors out if it does not exist
|
||||
func (s *Server) getConfig(activeProf *profilemanager.ActiveProfileState) (*profilemanager.Config, bool, error) {
|
||||
cfgPath, err := activeProf.FilePath()
|
||||
if err != nil {
|
||||
@@ -1771,29 +1766,3 @@ func sendTerminalNotification() error {
|
||||
|
||||
return wallCmd.Wait()
|
||||
}
|
||||
|
||||
// persistLoginOverrides writes management URL and pre-shared key from a LoginRequest to the
|
||||
// active profile config so that subsequent reads pick them up. Empty/nil values are ignored.
|
||||
func persistLoginOverrides(activeProf *profilemanager.ActiveProfileState, managementURL string, preSharedKey *string) error {
|
||||
if preSharedKey != nil && *preSharedKey == "" {
|
||||
preSharedKey = nil
|
||||
}
|
||||
if managementURL == "" && preSharedKey == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
cfgPath, err := activeProf.FilePath()
|
||||
if err != nil {
|
||||
return fmt.Errorf("active profile file path: %w", err)
|
||||
}
|
||||
|
||||
input := profilemanager.ConfigInput{
|
||||
ConfigPath: cfgPath,
|
||||
ManagementURL: managementURL,
|
||||
PreSharedKey: preSharedKey,
|
||||
}
|
||||
if _, err := profilemanager.UpdateOrCreateConfig(input); err != nil {
|
||||
return fmt.Errorf("update config: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -25,7 +25,6 @@ import (
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"github.com/netbirdio/netbird/util/netrelay"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -537,7 +536,7 @@ func (c *Client) LocalPortForward(ctx context.Context, localAddr, remoteAddr str
|
||||
continue
|
||||
}
|
||||
|
||||
go c.handleLocalForward(ctx, localConn, remoteAddr)
|
||||
go c.handleLocalForward(localConn, remoteAddr)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -549,7 +548,7 @@ func (c *Client) LocalPortForward(ctx context.Context, localAddr, remoteAddr str
|
||||
}
|
||||
|
||||
// handleLocalForward handles a single local port forwarding connection
|
||||
func (c *Client) handleLocalForward(ctx context.Context, localConn net.Conn, remoteAddr string) {
|
||||
func (c *Client) handleLocalForward(localConn net.Conn, remoteAddr string) {
|
||||
defer func() {
|
||||
if err := localConn.Close(); err != nil {
|
||||
log.Debugf("local port forwarding: close local connection: %v", err)
|
||||
@@ -572,7 +571,7 @@ func (c *Client) handleLocalForward(ctx context.Context, localConn net.Conn, rem
|
||||
}
|
||||
}()
|
||||
|
||||
netrelay.Relay(ctx, localConn, channel, netrelay.Options{Logger: log.NewEntry(log.StandardLogger())})
|
||||
nbssh.BidirectionalCopy(log.NewEntry(log.StandardLogger()), localConn, channel)
|
||||
}
|
||||
|
||||
// RemotePortForward sets up remote port forwarding, binding on remote and forwarding to localAddr
|
||||
@@ -654,19 +653,16 @@ func (c *Client) handleRemoteForwardChannels(ctx context.Context, localAddr stri
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case newChan, ok := <-channelRequests:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
case newChan := <-channelRequests:
|
||||
if newChan != nil {
|
||||
go c.handleRemoteForwardChannel(ctx, newChan, localAddr)
|
||||
go c.handleRemoteForwardChannel(newChan, localAddr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleRemoteForwardChannel handles a single forwarded-tcpip channel
|
||||
func (c *Client) handleRemoteForwardChannel(ctx context.Context, newChan ssh.NewChannel, localAddr string) {
|
||||
func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr string) {
|
||||
channel, reqs, err := newChan.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
@@ -679,14 +675,8 @@ func (c *Client) handleRemoteForwardChannel(ctx context.Context, newChan ssh.New
|
||||
|
||||
go ssh.DiscardRequests(reqs)
|
||||
|
||||
// Bound the dial so a black-holed localAddr can't pin the accepted SSH
|
||||
// channel open indefinitely; the relay itself runs under the outer ctx.
|
||||
dialCtx, cancelDial := context.WithTimeout(ctx, 10*time.Second)
|
||||
var dialer net.Dialer
|
||||
localConn, err := dialer.DialContext(dialCtx, "tcp", localAddr)
|
||||
cancelDial()
|
||||
localConn, err := net.Dial("tcp", localAddr)
|
||||
if err != nil {
|
||||
log.Debugf("remote port forwarding: dial %s: %v", localAddr, err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
@@ -695,7 +685,7 @@ func (c *Client) handleRemoteForwardChannel(ctx context.Context, newChan ssh.New
|
||||
}
|
||||
}()
|
||||
|
||||
netrelay.Relay(ctx, localConn, channel, netrelay.Options{Logger: log.NewEntry(log.StandardLogger())})
|
||||
nbssh.BidirectionalCopy(log.NewEntry(log.StandardLogger()), localConn, channel)
|
||||
}
|
||||
|
||||
// tcpipForwardMsg represents the structure for tcpip-forward requests
|
||||
|
||||
@@ -194,3 +194,63 @@ func buildAddressList(hostname string, remote net.Addr) []string {
|
||||
return addresses
|
||||
}
|
||||
|
||||
// BidirectionalCopy copies data bidirectionally between two io.ReadWriter connections.
|
||||
// It waits for both directions to complete before returning.
|
||||
// The caller is responsible for closing the connections.
|
||||
func BidirectionalCopy(logger *log.Entry, rw1, rw2 io.ReadWriter) {
|
||||
done := make(chan struct{}, 2)
|
||||
|
||||
go func() {
|
||||
if _, err := io.Copy(rw2, rw1); err != nil && !isExpectedCopyError(err) {
|
||||
logger.Debugf("copy error (1->2): %v", err)
|
||||
}
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
if _, err := io.Copy(rw1, rw2); err != nil && !isExpectedCopyError(err) {
|
||||
logger.Debugf("copy error (2->1): %v", err)
|
||||
}
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
<-done
|
||||
<-done
|
||||
}
|
||||
|
||||
func isExpectedCopyError(err error) bool {
|
||||
return errors.Is(err, io.EOF) || errors.Is(err, context.Canceled)
|
||||
}
|
||||
|
||||
// BidirectionalCopyWithContext copies data bidirectionally between two io.ReadWriteCloser connections.
|
||||
// It waits for both directions to complete or for context cancellation before returning.
|
||||
// Both connections are closed when the function returns.
|
||||
func BidirectionalCopyWithContext(logger *log.Entry, ctx context.Context, conn1, conn2 io.ReadWriteCloser) {
|
||||
done := make(chan struct{}, 2)
|
||||
|
||||
go func() {
|
||||
if _, err := io.Copy(conn2, conn1); err != nil && !isExpectedCopyError(err) {
|
||||
logger.Debugf("copy error (1->2): %v", err)
|
||||
}
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
if _, err := io.Copy(conn1, conn2); err != nil && !isExpectedCopyError(err) {
|
||||
logger.Debugf("copy error (2->1): %v", err)
|
||||
}
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-done:
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-done:
|
||||
}
|
||||
}
|
||||
|
||||
_ = conn1.Close()
|
||||
_ = conn2.Close()
|
||||
}
|
||||
|
||||
@@ -229,35 +229,18 @@ func (m *Manager) buildHostPatterns(peer PeerSSHInfo) []string {
|
||||
|
||||
func (m *Manager) writeSSHConfig(sshConfig string) error {
|
||||
sshConfigPath := filepath.Join(m.sshConfigDir, m.sshConfigFile)
|
||||
sshConfigPathTmp := sshConfigPath + ".tmp"
|
||||
|
||||
if err := os.MkdirAll(m.sshConfigDir, 0755); err != nil {
|
||||
return fmt.Errorf("create SSH config directory %s: %w", m.sshConfigDir, err)
|
||||
}
|
||||
|
||||
tmp, err := os.CreateTemp(m.sshConfigDir, m.sshConfigFile+".*.tmp")
|
||||
if err != nil {
|
||||
return fmt.Errorf("create temp SSH config: %w", err)
|
||||
}
|
||||
tmpPath := tmp.Name()
|
||||
defer func() {
|
||||
if err := os.Remove(tmpPath); err != nil && !os.IsNotExist(err) {
|
||||
log.Debugf("remove temp SSH config %s: %v", tmpPath, err)
|
||||
}
|
||||
}()
|
||||
if err := tmp.Close(); err != nil {
|
||||
return fmt.Errorf("close temp SSH config %s: %w", tmpPath, err)
|
||||
if err := writeFileWithTimeout(sshConfigPathTmp, []byte(sshConfig), 0644); err != nil {
|
||||
return fmt.Errorf("write SSH config file %s: %w", sshConfigPath, err)
|
||||
}
|
||||
|
||||
if err := writeFileWithTimeout(tmpPath, []byte(sshConfig), 0644); err != nil {
|
||||
return fmt.Errorf("write SSH config file %s: %w", tmpPath, err)
|
||||
}
|
||||
|
||||
if err := os.Chmod(tmpPath, 0644); err != nil {
|
||||
return fmt.Errorf("chmod SSH config file %s: %w", tmpPath, err)
|
||||
}
|
||||
|
||||
if err := os.Rename(tmpPath, sshConfigPath); err != nil {
|
||||
return fmt.Errorf("rename SSH config %s -> %s: %w", tmpPath, sshConfigPath, err)
|
||||
if err := os.Rename(sshConfigPathTmp, sshConfigPath); err != nil {
|
||||
return fmt.Errorf("rename ssh config %s -> %s: %w", sshConfigPathTmp, sshConfigPath, err)
|
||||
}
|
||||
|
||||
log.Infof("Created NetBird SSH client config: %s", sshConfigPath)
|
||||
|
||||
@@ -23,7 +23,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||
"github.com/netbirdio/netbird/util/netrelay"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
@@ -353,7 +352,7 @@ func (p *SSHProxy) directTCPIPHandler(_ *ssh.Server, _ *cryptossh.ServerConn, ne
|
||||
}
|
||||
go cryptossh.DiscardRequests(clientReqs)
|
||||
|
||||
netrelay.Relay(sshCtx, clientChan, backendChan, netrelay.Options{Logger: log.NewEntry(log.StandardLogger())})
|
||||
nbssh.BidirectionalCopyWithContext(log.NewEntry(log.StandardLogger()), sshCtx, clientChan, backendChan)
|
||||
}
|
||||
|
||||
func (p *SSHProxy) sftpSubsystemHandler(s ssh.Session, jwtToken string) {
|
||||
@@ -592,7 +591,7 @@ func (p *SSHProxy) handleForwardedChannel(sshCtx ssh.Context, sshConn *cryptossh
|
||||
}
|
||||
go cryptossh.DiscardRequests(clientReqs)
|
||||
|
||||
netrelay.Relay(sshCtx, clientChan, backendChan, netrelay.Options{Logger: log.NewEntry(log.StandardLogger())})
|
||||
nbssh.BidirectionalCopyWithContext(log.NewEntry(log.StandardLogger()), sshCtx, clientChan, backendChan)
|
||||
}
|
||||
|
||||
func (p *SSHProxy) dialBackend(ctx context.Context, addr, user, jwtToken string) (*cryptossh.Client, error) {
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
cryptossh "golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/netbirdio/netbird/util/netrelay"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
)
|
||||
|
||||
const privilegedPortThreshold = 1024
|
||||
@@ -357,7 +357,7 @@ func (s *Server) handleRemoteForwardConnection(ctx ssh.Context, conn net.Conn, h
|
||||
return
|
||||
}
|
||||
|
||||
netrelay.Relay(ctx, conn, channel, netrelay.Options{Logger: logger})
|
||||
nbssh.BidirectionalCopyWithContext(logger, ctx, conn, channel)
|
||||
}
|
||||
|
||||
// openForwardChannel creates an SSH forwarded-tcpip channel
|
||||
|
||||
@@ -8,9 +8,9 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -27,7 +27,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
"github.com/netbirdio/netbird/util/netrelay"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
@@ -54,10 +53,6 @@ const (
|
||||
DefaultJWTMaxTokenAge = 10 * 60
|
||||
)
|
||||
|
||||
// directTCPIPDialTimeout bounds how long relayDirectTCPIP waits on a dial to
|
||||
// the forwarded destination before rejecting the SSH channel.
|
||||
const directTCPIPDialTimeout = 30 * time.Second
|
||||
|
||||
var (
|
||||
ErrPrivilegedUserDisabled = errors.New(msgPrivilegedUserDisabled)
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
@@ -938,29 +933,5 @@ func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn,
|
||||
s.addConnectionPortForward(ctx.User(), ctx.RemoteAddr(), forwardAddr)
|
||||
logger.Infof("local port forwarding: %s", hostPort)
|
||||
|
||||
s.relayDirectTCPIP(ctx, newChan, payload.Host, int(payload.Port), logger)
|
||||
}
|
||||
|
||||
// relayDirectTCPIP is a netrelay-based replacement for gliderlabs'
|
||||
// DirectTCPIPHandler. The upstream handler closes both sides on the first
|
||||
// EOF; netrelay.Relay propagates CloseWrite so each direction drains on its
|
||||
// own terms.
|
||||
func (s *Server) relayDirectTCPIP(ctx ssh.Context, newChan cryptossh.NewChannel, host string, port int, logger *log.Entry) {
|
||||
dest := net.JoinHostPort(host, strconv.Itoa(port))
|
||||
|
||||
dialer := net.Dialer{Timeout: directTCPIPDialTimeout}
|
||||
dconn, err := dialer.DialContext(ctx, "tcp", dest)
|
||||
if err != nil {
|
||||
_ = newChan.Reject(cryptossh.ConnectionFailed, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
ch, reqs, err := newChan.Accept()
|
||||
if err != nil {
|
||||
_ = dconn.Close()
|
||||
return
|
||||
}
|
||||
go cryptossh.DiscardRequests(reqs)
|
||||
|
||||
netrelay.Relay(ctx, dconn, ch, netrelay.Options{Logger: logger})
|
||||
ssh.DirectTCPIPHandler(srv, conn, newChan, ctx)
|
||||
}
|
||||
|
||||
@@ -133,18 +133,13 @@ type ManagementConfig struct {
|
||||
|
||||
// AuthConfig contains authentication/identity provider settings
|
||||
type AuthConfig struct {
|
||||
Issuer string `yaml:"issuer"`
|
||||
LocalAuthDisabled bool `yaml:"localAuthDisabled"`
|
||||
SignKeyRefreshEnabled bool `yaml:"signKeyRefreshEnabled"`
|
||||
MfaSessionMaxLifetime string `yaml:"mfaSessionMaxLifetime"`
|
||||
MfaSessionIdleTimeout string `yaml:"mfaSessionIdleTimeout"`
|
||||
MfaSessionRememberMe bool `yaml:"mfaSessionRememberMe"`
|
||||
SessionCookieEncryptionKey string `yaml:"sessionCookieEncryptionKey"`
|
||||
Storage AuthStorageConfig `yaml:"storage"`
|
||||
DashboardRedirectURIs []string `yaml:"dashboardRedirectURIs"`
|
||||
CLIRedirectURIs []string `yaml:"cliRedirectURIs"`
|
||||
Owner *AuthOwnerConfig `yaml:"owner,omitempty"`
|
||||
DashboardPostLogoutRedirectURIs []string `yaml:"dashboardPostLogoutRedirectURIs"`
|
||||
Issuer string `yaml:"issuer"`
|
||||
LocalAuthDisabled bool `yaml:"localAuthDisabled"`
|
||||
SignKeyRefreshEnabled bool `yaml:"signKeyRefreshEnabled"`
|
||||
Storage AuthStorageConfig `yaml:"storage"`
|
||||
DashboardRedirectURIs []string `yaml:"dashboardRedirectURIs"`
|
||||
CLIRedirectURIs []string `yaml:"cliRedirectURIs"`
|
||||
Owner *AuthOwnerConfig `yaml:"owner,omitempty"`
|
||||
}
|
||||
|
||||
// AuthStorageConfig contains auth storage settings
|
||||
@@ -586,14 +581,10 @@ func (c *CombinedConfig) buildEmbeddedIdPConfig(mgmt ManagementConfig) (*idp.Emb
|
||||
}
|
||||
|
||||
cfg := &idp.EmbeddedIdPConfig{
|
||||
Enabled: true,
|
||||
Issuer: mgmt.Auth.Issuer,
|
||||
LocalAuthDisabled: mgmt.Auth.LocalAuthDisabled,
|
||||
SignKeyRefreshEnabled: mgmt.Auth.SignKeyRefreshEnabled,
|
||||
MfaSessionMaxLifetime: mgmt.Auth.MfaSessionMaxLifetime,
|
||||
MfaSessionIdleTimeout: mgmt.Auth.MfaSessionIdleTimeout,
|
||||
MfaSessionRememberMe: mgmt.Auth.MfaSessionRememberMe,
|
||||
SessionCookieEncryptionKey: mgmt.Auth.SessionCookieEncryptionKey,
|
||||
Enabled: true,
|
||||
Issuer: mgmt.Auth.Issuer,
|
||||
LocalAuthDisabled: mgmt.Auth.LocalAuthDisabled,
|
||||
SignKeyRefreshEnabled: mgmt.Auth.SignKeyRefreshEnabled,
|
||||
Storage: idp.EmbeddedStorageConfig{
|
||||
Type: authStorageType,
|
||||
Config: idp.EmbeddedStorageTypeConfig{
|
||||
@@ -601,9 +592,8 @@ func (c *CombinedConfig) buildEmbeddedIdPConfig(mgmt ManagementConfig) (*idp.Emb
|
||||
DSN: authStorageDSN,
|
||||
},
|
||||
},
|
||||
DashboardRedirectURIs: mgmt.Auth.DashboardRedirectURIs,
|
||||
CLIRedirectURIs: mgmt.Auth.CLIRedirectURIs,
|
||||
DashboardPostLogoutRedirectURIs: mgmt.Auth.DashboardPostLogoutRedirectURIs,
|
||||
DashboardRedirectURIs: mgmt.Auth.DashboardRedirectURIs,
|
||||
CLIRedirectURIs: mgmt.Auth.CLIRedirectURIs,
|
||||
}
|
||||
|
||||
if mgmt.Auth.Owner != nil && mgmt.Auth.Owner.Email != "" {
|
||||
|
||||
@@ -67,10 +67,6 @@ func init() {
|
||||
rootCmd.AddCommand(newTokenCommands())
|
||||
}
|
||||
|
||||
func RootCmd() *cobra.Command {
|
||||
return rootCmd
|
||||
}
|
||||
|
||||
func Execute() error {
|
||||
return rootCmd.Execute()
|
||||
}
|
||||
@@ -172,7 +168,7 @@ func initializeConfig() error {
|
||||
// serverInstances holds all server instances created during startup.
|
||||
type serverInstances struct {
|
||||
relaySrv *relayServer.Server
|
||||
mgmtSrv mgmtServer.Server
|
||||
mgmtSrv *mgmtServer.BaseServer
|
||||
signalSrv *signalServer.Server
|
||||
healthcheck *healthcheck.Server
|
||||
stunServer *stun.Server
|
||||
@@ -328,24 +324,19 @@ func setupServerHooks(servers *serverInstances, cfg *CombinedConfig) {
|
||||
return
|
||||
}
|
||||
|
||||
if s, ok := servers.mgmtSrv.GetContainer(mgmtServer.ContainerKeyBaseServer); ok {
|
||||
if baseServer, ok := s.(*mgmtServer.BaseServer); ok {
|
||||
baseServer.AfterInit(func(s *mgmtServer.BaseServer) {
|
||||
grpcSrv := s.GRPCServer()
|
||||
servers.mgmtSrv.AfterInit(func(s *mgmtServer.BaseServer) {
|
||||
grpcSrv := s.GRPCServer()
|
||||
|
||||
if servers.signalSrv != nil {
|
||||
proto.RegisterSignalExchangeServer(grpcSrv, servers.signalSrv)
|
||||
log.Infof("Signal server registered on port %s", cfg.Server.ListenAddress)
|
||||
}
|
||||
|
||||
s.SetHandlerFunc(createCombinedHandler(grpcSrv, s.APIHandler(), servers.relaySrv, servers.metricsServer.Meter, cfg))
|
||||
if servers.relaySrv != nil {
|
||||
log.Infof("Relay WebSocket handler added (path: /relay)")
|
||||
}
|
||||
})
|
||||
if servers.signalSrv != nil {
|
||||
proto.RegisterSignalExchangeServer(grpcSrv, servers.signalSrv)
|
||||
log.Infof("Signal server registered on port %s", cfg.Server.ListenAddress)
|
||||
}
|
||||
}
|
||||
|
||||
s.SetHandlerFunc(createCombinedHandler(grpcSrv, s.APIHandler(), servers.relaySrv, servers.metricsServer.Meter, cfg))
|
||||
if servers.relaySrv != nil {
|
||||
log.Infof("Relay WebSocket handler added (path: /relay)")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func startServers(wg *sync.WaitGroup, srv *relayServer.Server, httpHealthcheck *healthcheck.Server, stunServer *stun.Server, metricsServer *sharedMetrics.Metrics) {
|
||||
@@ -355,32 +346,38 @@ func startServers(wg *sync.WaitGroup, srv *relayServer.Server, httpHealthcheck *
|
||||
log.Infof("Relay WebSocket multiplexed on management port (no separate relay listener)")
|
||||
}
|
||||
|
||||
wg.Go(func() {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
log.Infof("running metrics server: %s%s", metricsServer.Addr, metricsServer.Endpoint)
|
||||
if err := metricsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
|
||||
log.Fatalf("failed to start metrics server: %v", err)
|
||||
}
|
||||
})
|
||||
}()
|
||||
|
||||
wg.Go(func() {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := httpHealthcheck.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
|
||||
log.Fatalf("failed to start healthcheck server: %v", err)
|
||||
}
|
||||
})
|
||||
}()
|
||||
|
||||
if stunServer != nil {
|
||||
wg.Go(func() {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := stunServer.Listen(); err != nil {
|
||||
if errors.Is(err, stun.ErrServerClosed) {
|
||||
return
|
||||
}
|
||||
log.Errorf("STUN server error: %v", err)
|
||||
}
|
||||
})
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func shutdownServers(ctx context.Context, srv *relayServer.Server, httpHealthcheck *healthcheck.Server, stunServer *stun.Server, mgmtSrv mgmtServer.Server, metricsServer *sharedMetrics.Metrics) error {
|
||||
func shutdownServers(ctx context.Context, srv *relayServer.Server, httpHealthcheck *healthcheck.Server, stunServer *stun.Server, mgmtSrv *mgmtServer.BaseServer, metricsServer *sharedMetrics.Metrics) error {
|
||||
var errs error
|
||||
|
||||
if err := httpHealthcheck.Shutdown(ctx); err != nil {
|
||||
@@ -494,7 +491,7 @@ func handleTLSConfig(cfg *CombinedConfig) (*tls.Config, bool, error) {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (mgmtServer.Server, error) {
|
||||
func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*mgmtServer.BaseServer, error) {
|
||||
mgmt := cfg.Management
|
||||
|
||||
// Extract port from listen address
|
||||
@@ -505,7 +502,7 @@ func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (m
|
||||
}
|
||||
mgmtPort, _ := strconv.Atoi(portStr)
|
||||
|
||||
mgmtSrv := newServer(
|
||||
mgmtSrv := mgmtServer.NewServer(
|
||||
&mgmtServer.Config{
|
||||
NbConfig: mgmtConfig,
|
||||
DNSDomain: "",
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
mgmtServer "github.com/netbirdio/netbird/management/internals/server"
|
||||
)
|
||||
|
||||
var newServer = func(cfg *mgmtServer.Config) mgmtServer.Server {
|
||||
return mgmtServer.NewServer(cfg)
|
||||
}
|
||||
|
||||
func SetNewServer(fn func(*mgmtServer.Config) mgmtServer.Server) {
|
||||
newServer = fn
|
||||
}
|
||||
@@ -86,13 +86,6 @@ server:
|
||||
issuer: "https://example.com/oauth2"
|
||||
localAuthDisabled: false
|
||||
signKeyRefreshEnabled: false
|
||||
# MFA session settings (applies when TOTP is enabled for an account)
|
||||
# mfaSessionMaxLifetime: "24h" # Max duration for an MFA session from creation
|
||||
# mfaSessionIdleTimeout: "1h" # MFA session expires after this idle period
|
||||
# mfaSessionRememberMe: false # Pre-check "remember me" on login so the MFA session persists across tabs/restarts
|
||||
# Optional AES key for encrypting embedded IdP session cookies. Can also be set via NB_IDP_SESSION_COOKIE_ENCRYPTION_KEY.
|
||||
# Must be 16/24/32 raw bytes or base64-encoded to one of those lengths (for example: openssl rand -hex 16).
|
||||
# sessionCookieEncryptionKey: ""
|
||||
# OAuth2 redirect URIs for dashboard
|
||||
dashboardRedirectURIs:
|
||||
- "https://app.example.com/nb-auth"
|
||||
@@ -100,9 +93,6 @@ server:
|
||||
# OAuth2 redirect URIs for CLI
|
||||
cliRedirectURIs:
|
||||
- "http://localhost:53000/"
|
||||
# OAuth2 post-logout redirect URIs for dashboard (RP-initiated logout)
|
||||
# dashboardPostLogoutRedirectURIs:
|
||||
# - "https://app.example.com/"
|
||||
# Optional initial admin user
|
||||
# owner:
|
||||
# email: "admin@example.com"
|
||||
|
||||
48
go.mod
48
go.mod
@@ -14,7 +14,7 @@ require (
|
||||
github.com/onsi/gomega v1.27.6
|
||||
github.com/rs/cors v1.8.0
|
||||
github.com/sirupsen/logrus v1.9.4
|
||||
github.com/spf13/cobra v1.10.2
|
||||
github.com/spf13/cobra v1.10.1
|
||||
github.com/spf13/pflag v1.0.9
|
||||
github.com/vishvananda/netlink v1.3.1
|
||||
golang.org/x/crypto v0.50.0
|
||||
@@ -41,11 +41,11 @@ require (
|
||||
github.com/cilium/ebpf v0.15.0
|
||||
github.com/coder/websocket v1.8.14
|
||||
github.com/coreos/go-iptables v0.7.0
|
||||
github.com/coreos/go-oidc/v3 v3.18.0
|
||||
github.com/coreos/go-oidc/v3 v3.14.1
|
||||
github.com/creack/pty v1.1.24
|
||||
github.com/crowdsecurity/crowdsec v1.7.7
|
||||
github.com/crowdsecurity/go-cs-bouncer v0.0.21
|
||||
github.com/dexidp/dex v2.13.0+incompatible
|
||||
github.com/dexidp/dex v0.0.0-00010101000000-000000000000
|
||||
github.com/dexidp/dex/api/v2 v2.4.0
|
||||
github.com/ebitengine/purego v0.8.4
|
||||
github.com/eko/gocache/lib/v4 v4.2.0
|
||||
@@ -53,9 +53,9 @@ require (
|
||||
github.com/eko/gocache/store/redis/v4 v4.2.2
|
||||
github.com/fsnotify/fsnotify v1.9.0
|
||||
github.com/gliderlabs/ssh v0.3.8
|
||||
github.com/go-jose/go-jose/v4 v4.1.4
|
||||
github.com/go-jose/go-jose/v4 v4.1.3
|
||||
github.com/godbus/dbus/v5 v5.1.0
|
||||
github.com/golang-jwt/jwt/v5 v5.3.1
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0
|
||||
github.com/golang/mock v1.6.0
|
||||
github.com/google/go-cmp v0.7.0
|
||||
github.com/google/gopacket v1.1.19
|
||||
@@ -72,7 +72,7 @@ require (
|
||||
github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81
|
||||
github.com/mdlayher/socket v0.5.1
|
||||
github.com/mdp/qrterminal/v3 v3.2.1
|
||||
github.com/miekg/dns v1.1.72
|
||||
github.com/miekg/dns v1.1.59
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
|
||||
@@ -113,7 +113,7 @@ require (
|
||||
go.opentelemetry.io/otel/exporters/prometheus v0.64.0
|
||||
go.opentelemetry.io/otel/metric v1.43.0
|
||||
go.opentelemetry.io/otel/sdk/metric v1.43.0
|
||||
go.uber.org/mock v0.6.0
|
||||
go.uber.org/mock v0.5.2
|
||||
go.uber.org/zap v1.27.0
|
||||
goauthentik.io/api/v3 v3.2023051.3
|
||||
golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b
|
||||
@@ -141,7 +141,7 @@ require (
|
||||
filippo.io/edwards25519 v1.1.1 // indirect
|
||||
github.com/AppsFlyer/go-sundheit v0.6.0 // indirect
|
||||
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect
|
||||
github.com/Azure/go-ntlmssp v0.1.0 // indirect
|
||||
github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 // indirect
|
||||
github.com/BurntSushi/toml v1.5.0 // indirect
|
||||
github.com/Masterminds/goutils v1.1.1 // indirect
|
||||
github.com/Masterminds/semver/v3 v3.3.0 // indirect
|
||||
@@ -168,7 +168,6 @@ require (
|
||||
github.com/aws/smithy-go v1.23.0 // indirect
|
||||
github.com/beevik/etree v1.6.0 // indirect
|
||||
github.com/beorn7/perks v1.0.1 // indirect
|
||||
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect
|
||||
github.com/caddyserver/zerossl v0.1.3 // indirect
|
||||
github.com/cenkalti/backoff/v5 v5.0.3 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
@@ -184,7 +183,6 @@ require (
|
||||
github.com/docker/go-units v0.5.0 // indirect
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/fredbi/uri v1.1.1 // indirect
|
||||
github.com/fxamacker/cbor/v2 v2.9.1 // indirect
|
||||
github.com/fyne-io/gl-js v0.2.0 // indirect
|
||||
github.com/fyne-io/glfw-js v0.3.0 // indirect
|
||||
github.com/fyne-io/image v0.1.1 // indirect
|
||||
@@ -192,7 +190,7 @@ require (
|
||||
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 // indirect
|
||||
github.com/go-gl/gl v0.0.0-20231021071112-07e5d0ea2e71 // indirect
|
||||
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a // indirect
|
||||
github.com/go-ldap/ldap/v3 v3.4.13 // indirect
|
||||
github.com/go-ldap/ldap/v3 v3.4.12 // indirect
|
||||
github.com/go-logr/logr v1.4.3 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/go-ole/go-ole v1.3.0 // indirect
|
||||
@@ -208,15 +206,11 @@ require (
|
||||
github.com/go-sql-driver/mysql v1.9.3 // indirect
|
||||
github.com/go-text/render v0.2.0 // indirect
|
||||
github.com/go-text/typesetting v0.2.1 // indirect
|
||||
github.com/go-viper/mapstructure/v2 v2.5.0 // indirect
|
||||
github.com/go-webauthn/webauthn v0.16.4 // indirect
|
||||
github.com/go-webauthn/x v0.2.3 // indirect
|
||||
github.com/goccy/go-yaml v1.18.0 // indirect
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/golang-jwt/jwt/v4 v4.5.2 // indirect
|
||||
github.com/google/btree v1.1.2 // indirect
|
||||
github.com/google/go-querystring v1.1.0 // indirect
|
||||
github.com/google/go-tpm v0.9.8 // indirect
|
||||
github.com/google/s2a-go v0.1.9 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.14 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.21.0 // indirect
|
||||
@@ -224,13 +218,7 @@ require (
|
||||
github.com/hack-pad/go-indexeddb v0.3.2 // indirect
|
||||
github.com/hack-pad/safejs v0.1.0 // indirect
|
||||
github.com/hashicorp/errwrap v1.1.0 // indirect
|
||||
github.com/hashicorp/go-cleanhttp v0.5.2 // indirect
|
||||
github.com/hashicorp/go-retryablehttp v0.7.8 // indirect
|
||||
github.com/hashicorp/go-secure-stdlib/parseutil v0.2.0 // indirect
|
||||
github.com/hashicorp/go-secure-stdlib/strutil v0.1.2 // indirect
|
||||
github.com/hashicorp/go-sockaddr v1.0.7 // indirect
|
||||
github.com/hashicorp/go-uuid v1.0.3 // indirect
|
||||
github.com/hashicorp/hcl v1.0.1-vault-7 // indirect
|
||||
github.com/huandu/xstrings v1.5.0 // indirect
|
||||
github.com/huin/goupnp v1.2.0 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
@@ -250,13 +238,13 @@ require (
|
||||
github.com/klauspost/cpuid/v2 v2.2.10 // indirect
|
||||
github.com/koron/go-ssdp v0.0.4 // indirect
|
||||
github.com/kr/fs v0.1.0 // indirect
|
||||
github.com/lib/pq v1.12.3 // indirect
|
||||
github.com/lib/pq v1.10.9 // indirect
|
||||
github.com/libdns/libdns v0.2.2 // indirect
|
||||
github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae // indirect
|
||||
github.com/magiconair/properties v1.8.10 // indirect
|
||||
github.com/mailru/easyjson v0.9.0 // indirect
|
||||
github.com/mattermost/xml-roundtrip-validator v0.1.0 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.42 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.32 // indirect
|
||||
github.com/mdelapenya/tlscert v0.2.0 // indirect
|
||||
github.com/mdlayher/genetlink v1.3.2 // indirect
|
||||
github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 // indirect
|
||||
@@ -277,10 +265,8 @@ require (
|
||||
github.com/nxadm/tail v1.4.11 // indirect
|
||||
github.com/oklog/ulid v1.3.1 // indirect
|
||||
github.com/onsi/ginkgo/v2 v2.9.5 // indirect
|
||||
github.com/openbao/openbao/api/v2 v2.5.1 // indirect
|
||||
github.com/opencontainers/go-digest v1.0.0 // indirect
|
||||
github.com/opencontainers/image-spec v1.1.1 // indirect
|
||||
github.com/philhofer/fwd v1.2.0 // indirect
|
||||
github.com/pion/dtls/v2 v2.2.10 // indirect
|
||||
github.com/pion/dtls/v3 v3.0.9 // indirect
|
||||
github.com/pion/mdns/v2 v2.0.7 // indirect
|
||||
@@ -289,13 +275,11 @@ require (
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect
|
||||
github.com/pquerna/otp v1.5.0 // indirect
|
||||
github.com/prometheus/client_model v0.6.2 // indirect
|
||||
github.com/prometheus/common v0.67.5 // indirect
|
||||
github.com/prometheus/otlptranslator v1.0.0 // indirect
|
||||
github.com/prometheus/procfs v0.19.2 // indirect
|
||||
github.com/russellhaering/goxmldsig v1.6.0 // indirect
|
||||
github.com/ryanuber/go-glob v1.0.0 // indirect
|
||||
github.com/rymdport/portal v0.4.2 // indirect
|
||||
github.com/shirou/gopsutil/v4 v4.25.8 // indirect
|
||||
github.com/shoenig/go-m1cpu v0.2.1 // indirect
|
||||
@@ -304,13 +288,11 @@ require (
|
||||
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c // indirect
|
||||
github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef // indirect
|
||||
github.com/stretchr/objx v0.5.2 // indirect
|
||||
github.com/tinylib/msgp v1.6.3 // indirect
|
||||
github.com/tklauser/go-sysconf v0.3.15 // indirect
|
||||
github.com/tklauser/numcpus v0.10.0 // indirect
|
||||
github.com/vishvananda/netns v0.0.5 // indirect
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
|
||||
github.com/wlynxg/anet v0.0.5 // indirect
|
||||
github.com/x448/float16 v0.8.4 // indirect
|
||||
github.com/yuin/goldmark v1.7.8 // indirect
|
||||
github.com/zeebo/blake3 v0.2.3 // indirect
|
||||
go.mongodb.org/mongo-driver v1.17.9 // indirect
|
||||
@@ -335,14 +317,12 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024
|
||||
|
||||
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
|
||||
|
||||
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0
|
||||
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260422100739-63c67f59bf58
|
||||
|
||||
replace github.com/cloudflare/circl => codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6
|
||||
replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6
|
||||
|
||||
replace github.com/pion/ice/v4 => github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51
|
||||
|
||||
replace github.com/dexidp/dex => github.com/netbirdio/dex v0.244.1-0.20260512110716-8d70ad8647c1
|
||||
|
||||
replace github.com/dexidp/dex/api/v2 => github.com/netbirdio/dex/api/v2 v2.0.0-20260512110716-8d70ad8647c1
|
||||
replace github.com/dexidp/dex => github.com/netbirdio/dex v0.244.0
|
||||
|
||||
replace github.com/mailru/easyjson => github.com/netbirdio/easyjson v0.9.0
|
||||
|
||||
105
go.sum
105
go.sum
@@ -5,8 +5,6 @@ cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3R
|
||||
cloud.google.com/go/compute/metadata v0.2.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
|
||||
cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs=
|
||||
cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10=
|
||||
codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 h1:b8xUw3004wk+3ipBhu0VU4RtUJsegMIiqjxSK4++lzA=
|
||||
codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6/go.mod h1:+CauBF6R70Jqcyl8N2hC8pAXYbWkGIezuSbuGLtRhnw=
|
||||
cunicu.li/go-rosenpass v0.4.0 h1:LtPtBgFWY/9emfgC4glKLEqS0MJTylzV6+ChRhiZERw=
|
||||
cunicu.li/go-rosenpass v0.4.0/go.mod h1:MPbjH9nxV4l3vEagKVdFNwHOketqgS5/To1VYJplf/M=
|
||||
dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s=
|
||||
@@ -25,8 +23,8 @@ github.com/AppsFlyer/go-sundheit v0.6.0 h1:d2hBvCjBSb2lUsEWGfPigr4MCOt04sxB+Rppl
|
||||
github.com/AppsFlyer/go-sundheit v0.6.0/go.mod h1:LDdBHD6tQBtmHsdW+i1GwdTt6Wqc0qazf5ZEJVTbTME=
|
||||
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0=
|
||||
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
|
||||
github.com/Azure/go-ntlmssp v0.1.0 h1:DjFo6YtWzNqNvQdrwEyr/e4nhU3vRiwenz5QX7sFz+A=
|
||||
github.com/Azure/go-ntlmssp v0.1.0/go.mod h1:NYqdhxd/8aAct/s4qSYZEerdPuH1liG2/X9DiVTbhpk=
|
||||
github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 h1:mFRzDkZVAjdal+s7s0MwaRv9igoPqLRdzOLzw/8Xvq8=
|
||||
github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358/go.mod h1:chxPXzSsl7ZWRAuOIE23GDNzjWuZquvFlgA8xmpunjU=
|
||||
github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg=
|
||||
github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
|
||||
github.com/Masterminds/goutils v1.1.1 h1:5nUrii3FMTL5diU80unEVvNevw1nH4+ZV4DSLVJLSYI=
|
||||
@@ -93,8 +91,6 @@ github.com/beevik/etree v1.6.0/go.mod h1:bh4zJxiIr62SOf9pRzN7UUYaEDa9HEKafK25+sL
|
||||
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
||||
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
||||
github.com/bmatcuk/doublestar v1.1.1/go.mod h1:UD6OnuiIn0yFxxA2le/rnRU1G4RaI4UvFv1sNto9p6w=
|
||||
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI=
|
||||
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
|
||||
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
|
||||
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
|
||||
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
|
||||
@@ -121,8 +117,8 @@ github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpS
|
||||
github.com/containerd/platforms v0.2.1/go.mod h1:XHCb+2/hzowdiut9rkudds9bE5yJ7npe7dG/wG+uFPw=
|
||||
github.com/coreos/go-iptables v0.7.0 h1:XWM3V+MPRr5/q51NuWSgU0fqMad64Zyxs8ZUoMsamr8=
|
||||
github.com/coreos/go-iptables v0.7.0/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q=
|
||||
github.com/coreos/go-oidc/v3 v3.18.0 h1:V9orjXynvu5wiC9SemFTWnG4F45v403aIcjWo0d41+A=
|
||||
github.com/coreos/go-oidc/v3 v3.18.0/go.mod h1:DYCf24+ncYi+XkIH97GY1+dqoRlbaSI26KVTCI9SrY4=
|
||||
github.com/coreos/go-oidc/v3 v3.14.1 h1:9ePWwfdwC4QKRlCXsJGou56adA/owXczOzwKdOumLqk=
|
||||
github.com/coreos/go-oidc/v3 v3.14.1/go.mod h1:HaZ3szPaZ0e4r6ebqvsLWlk2Tn+aejfmrfah6hnSYEU=
|
||||
github.com/cpuguy83/dockercfg v0.3.2 h1:DlJTyZGBDlXqUZ2Dk2Q3xHs/FtnooJJVaad2S9GKorA=
|
||||
github.com/cpuguy83/dockercfg v0.3.2/go.mod h1:sugsbF4//dDlL/i+S+rtpIWp+5h0BHJHfjj5/jFyUJc=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
|
||||
@@ -134,10 +130,14 @@ github.com/crowdsecurity/go-cs-bouncer v0.0.21 h1:arPz0VtdVSaz+auOSfHythzkZVLyy1
|
||||
github.com/crowdsecurity/go-cs-bouncer v0.0.21/go.mod h1:4JiH0XXA4KKnnWThItUpe5+heJHWzsLOSA2IWJqUDBA=
|
||||
github.com/crowdsecurity/go-cs-lib v0.0.25 h1:Ov6VPW9yV+OPsbAIQk1iTkEWhwkpaG0v3lrBzeqjzj4=
|
||||
github.com/crowdsecurity/go-cs-lib v0.0.25/go.mod h1:X0GMJY2CxdA1S09SpuqIKaWQsvRGxXmecUp9cP599dE=
|
||||
github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 h1:/DS5cDX3FJdl+XaN2D7XAwFpuanTxnp52DBLZAaJKx0=
|
||||
github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6/go.mod h1:+CauBF6R70Jqcyl8N2hC8pAXYbWkGIezuSbuGLtRhnw=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dexidp/dex/api/v2 v2.4.0 h1:gNba7n6BKVp8X4Jp24cxYn5rIIGhM6kDOXcZoL6tr9A=
|
||||
github.com/dexidp/dex/api/v2 v2.4.0/go.mod h1:/p550ADvFFh7K95VmhUD+jgm15VdaNnab9td8DHOpyI=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
||||
@@ -156,8 +156,6 @@ github.com/eko/gocache/store/go_cache/v4 v4.2.2 h1:tAI9nl6TLoJyKG1ujF0CS0n/IgTEM
|
||||
github.com/eko/gocache/store/go_cache/v4 v4.2.2/go.mod h1:T9zkHokzr8K9EiC7RfMbDg6HSwaV6rv3UdcNu13SGcA=
|
||||
github.com/eko/gocache/store/redis/v4 v4.2.2 h1:Thw31fzGuH3WzJywsdbMivOmP550D6JS7GDHhvCJPA0=
|
||||
github.com/eko/gocache/store/redis/v4 v4.2.2/go.mod h1:LaTxLKx9TG/YUEybQvPMij++D7PBTIJ4+pzvk0ykz0w=
|
||||
github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM=
|
||||
github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU=
|
||||
github.com/felixge/fgprof v0.9.3 h1:VvyZxILNuCiUCSXtPtYmmtGvb65nqXh2QFWc0Wpf2/g=
|
||||
github.com/felixge/fgprof v0.9.3/go.mod h1:RdbpDgzqYVh/T9fPELJyV7EYJuHB55UTEULNun8eiPw=
|
||||
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
|
||||
@@ -173,8 +171,6 @@ github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4
|
||||
github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw=
|
||||
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
|
||||
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||
github.com/fxamacker/cbor/v2 v2.9.1 h1:2rWm8B193Ll4VdjsJY28jxs70IdDsHRWgQYAI80+rMQ=
|
||||
github.com/fxamacker/cbor/v2 v2.9.1/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ=
|
||||
github.com/fyne-io/gl-js v0.2.0 h1:+EXMLVEa18EfkXBVKhifYB6OGs3HwKO3lUElA0LlAjs=
|
||||
github.com/fyne-io/gl-js v0.2.0/go.mod h1:ZcepK8vmOYLu96JoxbCKJy2ybr+g1pTnaBDdl7c3ajI=
|
||||
github.com/fyne-io/glfw-js v0.3.0 h1:d8k2+Y7l+zy2pc7wlGRyPfTgZoqDf3AI4G+2zOWhWUk=
|
||||
@@ -193,10 +189,10 @@ github.com/go-gl/gl v0.0.0-20231021071112-07e5d0ea2e71 h1:5BVwOaUSBTlVZowGO6VZGw
|
||||
github.com/go-gl/gl v0.0.0-20231021071112-07e5d0ea2e71/go.mod h1:9YTyiznxEY1fVinfM7RvRcjRHbw2xLBJ3AAGIT0I4Nw=
|
||||
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a h1:vxnBhFDDT+xzxf1jTJKMKZw3H0swfWk9RpWbBbDK5+0=
|
||||
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
|
||||
github.com/go-jose/go-jose/v4 v4.1.4 h1:moDMcTHmvE6Groj34emNPLs/qtYXRVcd6S7NHbHz3kA=
|
||||
github.com/go-jose/go-jose/v4 v4.1.4/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08=
|
||||
github.com/go-ldap/ldap/v3 v3.4.13 h1:+x1nG9h+MZN7h/lUi5Q3UZ0fJ1GyDQYbPvbuH38baDQ=
|
||||
github.com/go-ldap/ldap/v3 v3.4.13/go.mod h1:LxsGZV6vbaK0sIvYfsv47rfh4ca0JXokCoKjZxsszv0=
|
||||
github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs=
|
||||
github.com/go-jose/go-jose/v4 v4.1.3/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08=
|
||||
github.com/go-ldap/ldap/v3 v3.4.12 h1:1b81mv7MagXZ7+1r7cLTWmyuTqVqdwbtJSjC0DAp9s4=
|
||||
github.com/go-ldap/ldap/v3 v3.4.12/go.mod h1:+SPAGcTtOfmGsCb3h1RFiq4xpp4N636G75OEace8lNo=
|
||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
|
||||
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
@@ -233,20 +229,12 @@ github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI6
|
||||
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE=
|
||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
|
||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls=
|
||||
github.com/go-test/deep v1.1.1 h1:0r/53hagsehfO4bzD2Pgr/+RgHqhmf+k1Bpse2cTu1U=
|
||||
github.com/go-test/deep v1.1.1/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE=
|
||||
github.com/go-text/render v0.2.0 h1:LBYoTmp5jYiJ4NPqDc2pz17MLmA3wHw1dZSVGcOdeAc=
|
||||
github.com/go-text/render v0.2.0/go.mod h1:CkiqfukRGKJA5vZZISkjSYrcdtgKQWRa2HIzvwNN5SU=
|
||||
github.com/go-text/typesetting v0.2.1 h1:x0jMOGyO3d1qFAPI0j4GSsh7M0Q3Ypjzr4+CEVg82V8=
|
||||
github.com/go-text/typesetting v0.2.1/go.mod h1:mTOxEwasOFpAMBjEQDhdWRckoLLeI/+qrQeBCTGEt6M=
|
||||
github.com/go-text/typesetting-utils v0.0.0-20241103174707-87a29e9e6066 h1:qCuYC+94v2xrb1PoS4NIDe7DGYtLnU2wWiQe9a1B1c0=
|
||||
github.com/go-text/typesetting-utils v0.0.0-20241103174707-87a29e9e6066/go.mod h1:DDxDdQEnB70R8owOx3LVpEFvpMK9eeH1o2r0yZhFI9o=
|
||||
github.com/go-viper/mapstructure/v2 v2.5.0 h1:vM5IJoUAy3d7zRSVtIwQgBj7BiWtMPfmPEgAXnvj1Ro=
|
||||
github.com/go-viper/mapstructure/v2 v2.5.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
|
||||
github.com/go-webauthn/webauthn v0.16.4 h1:R9jqR/cYZa7hRquFF7Za/8qoH/K/TIs1/Q/4CyGN+1Q=
|
||||
github.com/go-webauthn/webauthn v0.16.4/go.mod h1:SU2ljAgToTV/YLPI0C05QS4qn+e04WpB5g1RMfcZfS4=
|
||||
github.com/go-webauthn/x v0.2.3 h1:8oArS+Rc1SWFLXhE17KZNx258Z4kUSyaDgsSncCO5RA=
|
||||
github.com/go-webauthn/x v0.2.3/go.mod h1:tM04GF3V6VYq79AZMl7vbj4q6pz9r7L2criWRzbWhPk=
|
||||
github.com/goccy/go-yaml v1.18.0 h1:8W7wMFS12Pcas7KU+VVkaiCng+kG8QiFeFwzFb+rwuw=
|
||||
github.com/goccy/go-yaml v1.18.0/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
|
||||
github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk=
|
||||
@@ -255,8 +243,8 @@ github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
|
||||
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
||||
github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI=
|
||||
github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
||||
github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
|
||||
github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs=
|
||||
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
@@ -288,10 +276,6 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8=
|
||||
github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU=
|
||||
github.com/google/go-tpm v0.9.8 h1:slArAR9Ft+1ybZu0lBwpSmpwhRXaa85hWtMinMyRAWo=
|
||||
github.com/google/go-tpm v0.9.8/go.mod h1:h9jEsEECg7gtLis0upRBQU+GhYVH6jMjrFxI8u6bVUY=
|
||||
github.com/google/go-tpm-tools v0.3.13-0.20230620182252-4639ecce2aba h1:qJEJcuLzH5KDR0gKc0zcktin6KSAwL7+jWKBYceddTc=
|
||||
github.com/google/go-tpm-tools v0.3.13-0.20230620182252-4639ecce2aba/go.mod h1:EFYHy8/1y2KfgTAsx7Luu7NGhoxtuVHnNo8jE7FikKc=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
|
||||
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
|
||||
@@ -324,29 +308,15 @@ github.com/hack-pad/safejs v0.1.0/go.mod h1:HdS+bKF1NrE72VoXZeWzxFOVQVUSqZJAG0xN
|
||||
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
|
||||
github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I=
|
||||
github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
|
||||
github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ=
|
||||
github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48=
|
||||
github.com/hashicorp/go-hclog v1.6.3 h1:Qr2kF+eVWjTiYmU7Y31tYlP1h0q/X3Nl3tPGdaB11/k=
|
||||
github.com/hashicorp/go-hclog v1.6.3/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M=
|
||||
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
|
||||
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
|
||||
github.com/hashicorp/go-retryablehttp v0.7.8 h1:ylXZWnqa7Lhqpk0L1P1LzDtGcCR0rPVUrx/c8Unxc48=
|
||||
github.com/hashicorp/go-retryablehttp v0.7.8/go.mod h1:rjiScheydd+CxvumBsIrFKlx3iS0jrZ7LvzFGFmuKbw=
|
||||
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 h1:ET4pqyjiGmY09R5y+rSd70J2w45CtbWDNvGqWp/R3Ng=
|
||||
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2/go.mod h1:EdWO6czbmthiwZ3/PUsDV+UD1D5IRU4ActiaWGwt0Yw=
|
||||
github.com/hashicorp/go-secure-stdlib/parseutil v0.2.0 h1:U+kC2dOhMFQctRfhK0gRctKAPTloZdMU5ZJxaesJ/VM=
|
||||
github.com/hashicorp/go-secure-stdlib/parseutil v0.2.0/go.mod h1:Ll013mhdmsVDuoIXVfBtvgGJsXDYkTw1kooNcoCXuE0=
|
||||
github.com/hashicorp/go-secure-stdlib/strutil v0.1.2 h1:kes8mmyCpxJsI7FTwtzRqEy9CdjCtrXrXGuOpxEA7Ts=
|
||||
github.com/hashicorp/go-secure-stdlib/strutil v0.1.2/go.mod h1:Gou2R9+il93BqX25LAKCLuM+y9U2T4hlwvT1yprcna4=
|
||||
github.com/hashicorp/go-sockaddr v1.0.7 h1:G+pTkSO01HpR5qCxg7lxfsFEZaG+C0VssTy/9dbT+Fw=
|
||||
github.com/hashicorp/go-sockaddr v1.0.7/go.mod h1:FZQbEYa1pxkQ7WLpyXJ6cbjpT8q0YgQaK/JakXqGyWw=
|
||||
github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
|
||||
github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8=
|
||||
github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
|
||||
github.com/hashicorp/go-version v1.7.0 h1:5tqGy27NaOTB8yJKUZELlFAS/LTKJkrmONwQKeRZfjY=
|
||||
github.com/hashicorp/go-version v1.7.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA=
|
||||
github.com/hashicorp/hcl v1.0.1-vault-7 h1:ag5OxFVy3QYTFTJODRzTKVZ6xvdfLLCA1cy/Y6xGI0I=
|
||||
github.com/hashicorp/hcl v1.0.1-vault-7/go.mod h1:XYhtn6ijBSAj6n4YqAaf7RBPS4I06AItNorpy+MoQNM=
|
||||
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
|
||||
github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI=
|
||||
github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE=
|
||||
@@ -417,8 +387,8 @@ github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||
github.com/leodido/go-urn v1.1.0/go.mod h1:+cyI34gQWZcE1eQU7NVgKkkzdXDQHr1dBMtdAPozLkw=
|
||||
github.com/lib/pq v1.12.3 h1:tTWxr2YLKwIvK90ZXEw8GP7UFHtcbTtty8zsI+YjrfQ=
|
||||
github.com/lib/pq v1.12.3/go.mod h1:/p+8NSbOcwzAEI7wiMXFlgydTwcgTr3OSKMsD2BitpA=
|
||||
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||
github.com/libdns/libdns v0.2.2 h1:O6ws7bAfRPaBsgAYt8MDe2HcNBGC29hkZ9MX2eUSX3s=
|
||||
github.com/libdns/libdns v0.2.2/go.mod h1:4Bj9+5CQiNMVGf87wjX4CY3HQJypUHRuLvlsfsZqLWQ=
|
||||
github.com/libdns/route53 v1.5.0 h1:2SKdpPFl/qgWsXQvsLNJJAoX7rSxlk7zgoL4jnWdXVA=
|
||||
@@ -436,13 +406,9 @@ github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8S
|
||||
github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
|
||||
github.com/mattermost/xml-roundtrip-validator v0.1.0 h1:RXbVD2UAl7A7nOTR4u7E3ILa4IbtvKBHw64LDsmu9hU=
|
||||
github.com/mattermost/xml-roundtrip-validator v0.1.0/go.mod h1:qccnGMcpgwcNaBnxqpJpWWUiPNr5H3O8eDgGV9gT5To=
|
||||
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
|
||||
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
|
||||
github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-sqlite3 v1.14.42 h1:MigqEP4ZmHw3aIdIT7T+9TLa90Z6smwcthx+Azv4Cgo=
|
||||
github.com/mattn/go-sqlite3 v1.14.42/go.mod h1:pjEuOr8IwzLJP2MfGeTb0A35jauH+C2kbHKBr7yXKVQ=
|
||||
github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs=
|
||||
github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
|
||||
github.com/mdelapenya/tlscert v0.2.0/go.mod h1:O4njj3ELLnJjGdkN7M/vIVCpZ+Cf0L6muqOG4tLSl8o=
|
||||
github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw=
|
||||
@@ -455,8 +421,8 @@ github.com/mdp/qrterminal/v3 v3.2.1 h1:6+yQjiiOsSuXT5n9/m60E54vdgFsw0zhADHhHLrFe
|
||||
github.com/mdp/qrterminal/v3 v3.2.1/go.mod h1:jOTmXvnBsMy5xqLniO0R++Jmjs2sTm9dFSuQ5kpz/SU=
|
||||
github.com/mholt/acmez/v2 v2.0.1 h1:3/3N0u1pLjMK4sNEAFSI+bcvzbPhRpY383sy1kLHJ6k=
|
||||
github.com/mholt/acmez/v2 v2.0.1/go.mod h1:fX4c9r5jYwMyMsC+7tkYRxHibkOTgta5DIFGoe67e1U=
|
||||
github.com/miekg/dns v1.1.72 h1:vhmr+TF2A3tuoGNkLDFK9zi36F2LS+hKTRW0Uf8kbzI=
|
||||
github.com/miekg/dns v1.1.72/go.mod h1:+EuEPhdHOsfk6Wk5TT2CzssZdqkmFhf8r+aVyDEToIs=
|
||||
github.com/miekg/dns v1.1.59 h1:C9EXc/UToRwKLhK5wKU/I4QVsBUc8kE6MkHBkeypWZs=
|
||||
github.com/miekg/dns v1.1.59/go.mod h1:nZpewl5p6IvctfgrckopVx2OlSEHPRO/U4SYkRklrEk=
|
||||
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws=
|
||||
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc=
|
||||
github.com/mitchellh/copystructure v1.2.0 h1:vpKXTN4ewci03Vljg/q9QvCGUDttBOGBIa15WveJJGw=
|
||||
@@ -485,10 +451,8 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
|
||||
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
|
||||
github.com/netbirdio/dex v0.244.1-0.20260512110716-8d70ad8647c1 h1:4TaYr9O4xX0D2kszeOLclTiCbA3eHq3xWV+9ILJbIYs=
|
||||
github.com/netbirdio/dex v0.244.1-0.20260512110716-8d70ad8647c1/go.mod h1:IHH+H8vK2GfqtIt5u/5OdPh18yk0oDHuj2vz5+Goetg=
|
||||
github.com/netbirdio/dex/api/v2 v2.0.0-20260512110716-8d70ad8647c1 h1:neE7z+FPUkldl3faK/Jt+hJK2L+1XfQ1W33TQhU9m88=
|
||||
github.com/netbirdio/dex/api/v2 v2.0.0-20260512110716-8d70ad8647c1/go.mod h1:awuTyT29CYALpEyET0S307EgNlPWrc7fFKRAyhsO45M=
|
||||
github.com/netbirdio/dex v0.244.0 h1:1GOvi8wnXYassnKGildzNqRHq0RbcfEUw7LKYpKIN7U=
|
||||
github.com/netbirdio/dex v0.244.0/go.mod h1:STGInJhPcAflrHmDO7vyit2kSq03PdL+8zQPoGALtcU=
|
||||
github.com/netbirdio/easyjson v0.9.0 h1:6Nw2lghSVuy8RSkAYDhDv1thBVEmfVbKZnV7T7Z6Aus=
|
||||
github.com/netbirdio/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU=
|
||||
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
|
||||
@@ -499,8 +463,8 @@ github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9ax
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0 h1:h/QnNzm7xzHPm+gajcblYUOclrW2FeNeDlUNj6tTWKQ=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20260422100739-63c67f59bf58 h1:6REpBYpJBLTTgqCcLGpTqvRDoEoLbA5r2nAXqMd2La0=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20260422100739-63c67f59bf58/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
|
||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
|
||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
|
||||
github.com/nicksnyder/go-i18n/v2 v2.5.1 h1:IxtPxYsR9Gp60cGXjfuR/llTqV8aYMsC472zD0D1vHk=
|
||||
@@ -525,8 +489,6 @@ github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7J
|
||||
github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo=
|
||||
github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE=
|
||||
github.com/onsi/gomega v1.27.6/go.mod h1:PIQNjfQwkP3aQAH7lf7j87O/5FiNr+ZR8+ipb+qQlhg=
|
||||
github.com/openbao/openbao/api/v2 v2.5.1 h1:Br79D6L20SbAa5P7xqENxmvv8LyI4HoKosPy7klhn4o=
|
||||
github.com/openbao/openbao/api/v2 v2.5.1/go.mod h1:Dh5un77tqGgMbmlVEqjqN+8/dMyUohnkaQVg/wXW0Ig=
|
||||
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
||||
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
||||
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
||||
@@ -539,8 +501,6 @@ github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0
|
||||
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||
github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203 h1:E7Kmf11E4K7B5hDti2K2NqPb1nlYlGYsu02S1JNd/Bs=
|
||||
github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4=
|
||||
github.com/philhofer/fwd v1.2.0 h1:e6DnBTl7vGY+Gz322/ASL4Gyp1FspeMvx1RNDoToZuM=
|
||||
github.com/philhofer/fwd v1.2.0/go.mod h1:RqIHx9QI14HlwKwm98g9Re5prTQ6LdeRQn+gXJFxsJM=
|
||||
github.com/pion/dtls/v2 v2.2.7/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ/F9s=
|
||||
github.com/pion/dtls/v2 v2.2.10 h1:u2Axk+FyIR1VFTPurktB+1zoEPGIW3bmyj3LEFrXjAA=
|
||||
github.com/pion/dtls/v2 v2.2.10/go.mod h1:d9SYc9fch0CqK90mRk1dC7AkzzpwJj6u2GU3u+9pqFE=
|
||||
@@ -582,8 +542,6 @@ github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
|
||||
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 h1:o4JXh1EVt9k/+g42oCprj/FisM4qX9L3sZB3upGN2ZU=
|
||||
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
|
||||
github.com/pquerna/otp v1.5.0 h1:NMMR+WrmaqXU4EzdGJEE1aUUI0AMRzsp96fFFWNPwxs=
|
||||
github.com/pquerna/otp v1.5.0/go.mod h1:dkJfzwRKNiegxyNb54X/3fLwhCynbMspSyWKnvi1AEg=
|
||||
github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o=
|
||||
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
|
||||
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
|
||||
@@ -607,8 +565,6 @@ github.com/rs/xid v1.3.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
||||
github.com/russellhaering/goxmldsig v1.6.0 h1:8fdWXEPh2k/NZNQBPFNoVfS3JmzS4ZprY/sAOpKQLks=
|
||||
github.com/russellhaering/goxmldsig v1.6.0/go.mod h1:TrnaquDcYxWXfJrOjeMBTX4mLBeYAqaHEyUeWPxZlBM=
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
github.com/ryanuber/go-glob v1.0.0 h1:iQh3xXAumdQ+4Ufa5b25cRpC5TYKlno6hsv6Cb3pkBk=
|
||||
github.com/ryanuber/go-glob v1.0.0/go.mod h1:807d1WSdnB0XRJzKNil9Om6lcp/3a0v4qIHxIXzX/Yc=
|
||||
github.com/rymdport/portal v0.4.2 h1:7jKRSemwlTyVHHrTGgQg7gmNPJs88xkbKcIL3NlcmSU=
|
||||
github.com/rymdport/portal v0.4.2/go.mod h1:kFF4jslnJ8pD5uCi17brj/ODlfIidOxlgUDTO5ncnC4=
|
||||
github.com/shirou/gopsutil/v3 v3.24.4 h1:dEHgzZXt4LMNm+oYELpzl9YCqV65Yr/6SfrvgRBtXeU=
|
||||
@@ -631,8 +587,8 @@ github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 h1:TG/diQgUe0pntT/2D
|
||||
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E=
|
||||
github.com/spf13/cast v1.7.0 h1:ntdiHjuueXFgm5nzDRdOS4yfT43P5Fnud6DH50rz/7w=
|
||||
github.com/spf13/cast v1.7.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
|
||||
github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU=
|
||||
github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4=
|
||||
github.com/spf13/cobra v1.10.1 h1:lJeBwCfmrnXthfAupyUTzJ/J4Nc1RsHC/mSRU2dll/s=
|
||||
github.com/spf13/cobra v1.10.1/go.mod h1:7SmJGaTHFVBY0jW4NXGluQoLvhqFQM+6XSKD+P4XaB0=
|
||||
github.com/spf13/pflag v1.0.9 h1:9exaQaMOCwffKiiiYk6/BndUBv+iRViNW+4lEMi0PvY=
|
||||
github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/spkg/bom v0.0.0-20160624110644-59b7046e48ad/go.mod h1:qLr4V1qq6nMqFKkMo8ZTx3f+BZEkzsRUY10Xsm2mwU0=
|
||||
@@ -672,8 +628,6 @@ github.com/ti-mo/conntrack v0.5.1 h1:opEwkFICnDbQc0BUXl73PHBK0h23jEIFVjXsqvF4GY0
|
||||
github.com/ti-mo/conntrack v0.5.1/go.mod h1:T6NCbkMdVU4qEIgwL0njA6lw/iCAbzchlnwm1Sa314o=
|
||||
github.com/ti-mo/netfilter v0.5.2 h1:CTjOwFuNNeZ9QPdRXt1MZFLFUf84cKtiQutNauHWd40=
|
||||
github.com/ti-mo/netfilter v0.5.2/go.mod h1:Btx3AtFiOVdHReTDmP9AE+hlkOcvIy403u7BXXbWZKo=
|
||||
github.com/tinylib/msgp v1.6.3 h1:bCSxiTz386UTgyT1i0MSCvdbWjVW+8sG3PjkGsZQt4s=
|
||||
github.com/tinylib/msgp v1.6.3/go.mod h1:RSp0LW9oSxFut3KzESt5Voq4GVWyS+PSulT77roAqEA=
|
||||
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
|
||||
github.com/tklauser/go-sysconf v0.3.15 h1:VE89k0criAymJ/Os65CSn1IXaol+1wrsFHEB8Ol49K4=
|
||||
github.com/tklauser/go-sysconf v0.3.15/go.mod h1:Dmjwr6tYFIseJw7a3dRLJfsHAMXZ3nEnL/aZY+0IuI4=
|
||||
@@ -692,8 +646,6 @@ github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAh
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
|
||||
github.com/wlynxg/anet v0.0.5 h1:J3VJGi1gvo0JwZ/P1/Yc/8p63SoW98B5dHkYDmpgvvU=
|
||||
github.com/wlynxg/anet v0.0.5/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA=
|
||||
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
|
||||
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
|
||||
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
|
||||
@@ -738,15 +690,14 @@ go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lI
|
||||
go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM=
|
||||
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
||||
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
|
||||
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
|
||||
go.uber.org/mock v0.5.2 h1:LbtPTcP8A5k9WPXj54PPPbjcI4Y6lhyOZXn+VS7wNko=
|
||||
go.uber.org/mock v0.5.2/go.mod h1:wLlUxC2vVTPTaE3UD51E0BGOAElKrILxhVSDYQLld5o=
|
||||
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
|
||||
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
|
||||
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
|
||||
go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0=
|
||||
go.yaml.in/yaml/v2 v2.4.3/go.mod h1:zSxWcmIDjOzPXpjlTTbAsKokqkDNAVtZO0WOMiT90s8=
|
||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||
goauthentik.io/api/v3 v3.2023051.3 h1:NebAhD/TeTWNo/9X3/Uj+rM5fG1HaiLOlKTNLQv9Qq4=
|
||||
goauthentik.io/api/v3 v3.2023051.3/go.mod h1:nYECml4jGbp/541hj8GcylKQG1gVBsKppHy4+7G8u4U=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
|
||||
@@ -51,70 +51,6 @@ type YAMLConfig struct {
|
||||
// StaticPasswords cause the server use this list of passwords rather than
|
||||
// querying the storage.
|
||||
StaticPasswords []Password `yaml:"staticPasswords" json:"staticPasswords"`
|
||||
|
||||
// Sessions holds authentication session configuration.
|
||||
// Requires DEX_SESSIONS_ENABLED=true feature flag.
|
||||
Sessions *Sessions `yaml:"sessions" json:"sessions"`
|
||||
|
||||
// MFA holds multi-factor authentication configuration.
|
||||
MFA MFAConfig `yaml:"mfa" json:"mfa"`
|
||||
}
|
||||
|
||||
type Sessions struct {
|
||||
// CookieName is the name of the session cookie. Defaults to "dex_session".
|
||||
CookieName string `yaml:"cookieName" json:"cookieName"`
|
||||
// AbsoluteLifetime is the maximum session lifetime from creation. Defaults to "24h".
|
||||
AbsoluteLifetime string `yaml:"absoluteLifetime" json:"absoluteLifetime"`
|
||||
// ValidIfNotUsedFor is the idle timeout. Defaults to "1h".
|
||||
ValidIfNotUsedFor string `yaml:"validIfNotUsedFor" json:"validIfNotUsedFor"`
|
||||
// RememberMeCheckedByDefault controls the default state of the "remember me" checkbox.
|
||||
RememberMeCheckedByDefault *bool `yaml:"rememberMeCheckedByDefault" json:"rememberMeCheckedByDefault"`
|
||||
// CookieEncryptionKey is the AES key for encrypting session cookies.
|
||||
// Must be 16, 24, or 32 bytes for AES-128, AES-192, or AES-256.
|
||||
// If empty, cookies are not encrypted.
|
||||
CookieEncryptionKey string `yaml:"cookieEncryptionKey" json:"cookieEncryptionKey"`
|
||||
// SSOSharedWithDefault is the default SSO sharing policy for clients without explicit ssoSharedWith.
|
||||
// "all" = share with all clients, "none" = share with no one (default: "none").
|
||||
SSOSharedWithDefault string `yaml:"ssoSharedWithDefault" json:"ssoSharedWithDefault"`
|
||||
}
|
||||
|
||||
type MFAConfig struct {
|
||||
Authenticators []MFAAuthenticator `yaml:"authenticators" json:"authenticators"`
|
||||
}
|
||||
|
||||
type MFAAuthenticator struct {
|
||||
ID string `yaml:"id" json:"id"`
|
||||
Type string `yaml:"type" json:"type"`
|
||||
Config map[string]interface{} `yaml:"config" json:"config"`
|
||||
|
||||
ConnectorTypes []string `yaml:"connectorTypes" json:"connectorTypes"`
|
||||
}
|
||||
|
||||
type TOTPConfig struct {
|
||||
Issuer string `yaml:"issuer" json:"issuer"`
|
||||
}
|
||||
|
||||
// WebAuthnConfig holds configuration for a WebAuthn authenticator.
|
||||
type WebAuthnConfig struct {
|
||||
// RPDisplayName is the human-readable relying party name shown in the browser
|
||||
// dialog during key registration and authentication (e.g., "My Company SSO").
|
||||
RPDisplayName string `yaml:"rpDisplayName" json:"rpDisplayName"`
|
||||
// RPID is the relying party identifier — must match the domain in the browser
|
||||
// address bar. If empty, derived from the issuer URL hostname.
|
||||
// Example: "auth.example.com"
|
||||
RPID string `yaml:"rpID" json:"rpID"`
|
||||
// RPOrigins is the list of allowed origins for WebAuthn ceremonies.
|
||||
// If empty, derived from the issuer URL (scheme + host).
|
||||
// Example: ["https://auth.example.com"]
|
||||
RPOrigins []string `yaml:"rpOrigins" json:"rpOrigins"`
|
||||
// AttestationPreference controls what attestation data the authenticator should provide:
|
||||
// "none" — don't request attestation (simpler, more private)
|
||||
// "indirect" — authenticator may anonymize attestation (default)
|
||||
// "direct" — request full attestation (for enterprise key model verification)
|
||||
AttestationPreference string `yaml:"attestationPreference" json:"attestationPreference"`
|
||||
// Timeout is the duration allowed for the browser WebAuthn ceremony
|
||||
// (registration or login). Defaults to "60s".
|
||||
Timeout string `yaml:"timeout" json:"timeout"`
|
||||
}
|
||||
|
||||
// Web is the config format for the HTTP server.
|
||||
@@ -180,6 +116,7 @@ type Storage struct {
|
||||
Config map[string]interface{} `yaml:"config" json:"config"`
|
||||
}
|
||||
|
||||
// Password represents a static user configuration
|
||||
type Password storage.Password
|
||||
|
||||
func (p *Password) UnmarshalYAML(node *yaml.Node) error {
|
||||
@@ -492,98 +429,9 @@ func (c *YAMLConfig) Validate() error {
|
||||
if !c.EnablePasswordDB && len(c.StaticPasswords) != 0 {
|
||||
return fmt.Errorf("cannot specify static passwords without enabling password db")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildTotpConfig(auth MFAAuthenticator) (*server.TOTPProvider, error) {
|
||||
data, err := json.Marshal(auth.Config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal TOTP config id: %s - %w", auth.ID, err)
|
||||
}
|
||||
|
||||
var cfg TOTPConfig
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse TOTP config id: %s - %w", auth.ID, err)
|
||||
}
|
||||
|
||||
return server.NewTOTPProvider(cfg.Issuer, auth.ConnectorTypes), nil
|
||||
}
|
||||
|
||||
func buildWebAuthnConfig(auth MFAAuthenticator, issuerURL string) (*server.WebAuthnProvider, error) {
|
||||
data, err := json.Marshal(auth.Config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal WebAuthn config id: %s - %w", auth.ID, err)
|
||||
}
|
||||
|
||||
var cfg WebAuthnConfig
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse WebAuthn config id: %s - %w", auth.ID, err)
|
||||
}
|
||||
|
||||
provider, err := server.NewWebAuthnProvider(cfg.RPDisplayName, cfg.RPID, cfg.RPOrigins,
|
||||
cfg.AttestationPreference, cfg.Timeout, issuerURL, auth.ConnectorTypes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create WebAuthn provider id: %s - err: %w", auth.ID, err)
|
||||
}
|
||||
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func buildMFAProviders(authenticators []MFAAuthenticator, issuerURL string, logger *slog.Logger) map[string]server.MFAProvider {
|
||||
if len(authenticators) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
providers := make(map[string]server.MFAProvider, len(authenticators))
|
||||
for _, auth := range authenticators {
|
||||
switch auth.Type {
|
||||
case "TOTP":
|
||||
provider, err := buildTotpConfig(auth)
|
||||
if err != nil {
|
||||
logger.Error("failed to parse TOTP config", "id", auth.ID, "err", err)
|
||||
continue
|
||||
}
|
||||
providers[auth.ID] = provider
|
||||
logger.Info("MFA authenticator configured", "id", auth.ID, "type", auth.Type)
|
||||
case "WebAuthn":
|
||||
provider, err := buildWebAuthnConfig(auth, issuerURL)
|
||||
if err != nil {
|
||||
logger.Error("failed to parse WebAuthn config", "id", auth.ID, "err", err)
|
||||
continue
|
||||
}
|
||||
providers[auth.ID] = provider
|
||||
logger.Info("MFA authenticator configured", "id", auth.ID, "type", auth.Type)
|
||||
default:
|
||||
logger.Error("unknown MFA authenticator type, skipping", "id", auth.ID, "type", auth.Type)
|
||||
}
|
||||
}
|
||||
return providers
|
||||
}
|
||||
|
||||
func buildSessionsConfig(sessions *Sessions) *server.SessionConfig {
|
||||
if sessions == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if sessions.RememberMeCheckedByDefault == nil {
|
||||
defaultRememberMeCheckedByDefault := false
|
||||
sessions.RememberMeCheckedByDefault = &defaultRememberMeCheckedByDefault
|
||||
}
|
||||
|
||||
absoluteLifetime, _ := parseDuration(sessions.AbsoluteLifetime)
|
||||
validIfNotUsedFor, _ := parseDuration(sessions.ValidIfNotUsedFor)
|
||||
|
||||
return &server.SessionConfig{
|
||||
CookieEncryptionKey: []byte(sessions.CookieEncryptionKey),
|
||||
CookieName: sessions.CookieName,
|
||||
AbsoluteLifetime: absoluteLifetime,
|
||||
ValidIfNotUsedFor: validIfNotUsedFor,
|
||||
RememberMeCheckedByDefault: *sessions.RememberMeCheckedByDefault,
|
||||
SSOSharedWithDefault: sessions.SSOSharedWithDefault,
|
||||
}
|
||||
}
|
||||
|
||||
// ToServerConfig converts YAMLConfig to dex server.Config
|
||||
func (c *YAMLConfig) ToServerConfig(stor storage.Storage, logger *slog.Logger) server.Config {
|
||||
cfg := server.Config{
|
||||
@@ -600,8 +448,6 @@ func (c *YAMLConfig) ToServerConfig(stor storage.Storage, logger *slog.Logger) s
|
||||
Dir: c.Frontend.Dir,
|
||||
Extra: c.Frontend.Extra,
|
||||
},
|
||||
SessionConfig: buildSessionsConfig(c.Sessions),
|
||||
MFAProviders: buildMFAProviders(c.MFA.Authenticators, c.Issuer, logger),
|
||||
}
|
||||
|
||||
// Use embedded NetBird-styled templates if no custom dir specified
|
||||
@@ -614,6 +460,11 @@ func (c *YAMLConfig) ToServerConfig(stor storage.Storage, logger *slog.Logger) s
|
||||
}
|
||||
|
||||
// Apply expiry settings
|
||||
if c.Expiry.SigningKeys != "" {
|
||||
if d, err := parseDuration(c.Expiry.SigningKeys); err == nil {
|
||||
cfg.RotateKeysAfter = d
|
||||
}
|
||||
}
|
||||
if c.Expiry.IDTokens != "" {
|
||||
if d, err := parseDuration(c.Expiry.IDTokens); err == nil {
|
||||
cfg.IDTokensValidFor = d
|
||||
|
||||
@@ -18,7 +18,6 @@ import (
|
||||
|
||||
dexapi "github.com/dexidp/dex/api/v2"
|
||||
"github.com/dexidp/dex/server"
|
||||
"github.com/dexidp/dex/server/signer"
|
||||
"github.com/dexidp/dex/storage"
|
||||
"github.com/dexidp/dex/storage/sql"
|
||||
jose "github.com/go-jose/go-jose/v4"
|
||||
@@ -71,7 +70,7 @@ func NewProvider(ctx context.Context, config *Config) (*Provider, error) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
||||
|
||||
// Ensure data directory exists
|
||||
if err := os.MkdirAll(config.DataDir, 0o700); err != nil {
|
||||
if err := os.MkdirAll(config.DataDir, 0700); err != nil {
|
||||
return nil, fmt.Errorf("failed to create data directory: %w", err)
|
||||
}
|
||||
|
||||
@@ -102,15 +101,6 @@ func NewProvider(ctx context.Context, config *Config) (*Provider, error) {
|
||||
return nil, fmt.Errorf("failed to create refresh token policy: %w", err)
|
||||
}
|
||||
|
||||
localSignerConfig := signer.LocalConfig{
|
||||
KeysRotationPeriod: "6h",
|
||||
}
|
||||
|
||||
localSigner, err := localSignerConfig.Open(ctx, stor, 24*time.Hour, time.Now, logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create local signer: %w", err)
|
||||
}
|
||||
|
||||
// Build Dex server config - use Dex's types directly
|
||||
dexConfig := server.Config{
|
||||
Issuer: issuer,
|
||||
@@ -120,12 +110,12 @@ func NewProvider(ctx context.Context, config *Config) (*Provider, error) {
|
||||
ContinueOnConnectorFailure: true,
|
||||
Logger: logger,
|
||||
PrometheusRegistry: prometheus.NewRegistry(),
|
||||
RotateKeysAfter: 6 * time.Hour,
|
||||
IDTokensValidFor: 24 * time.Hour,
|
||||
RefreshTokenPolicy: refreshPolicy,
|
||||
Web: server.WebConfig{
|
||||
Issuer: "NetBird",
|
||||
},
|
||||
Signer: localSigner,
|
||||
}
|
||||
|
||||
dexSrv, err := server.NewServer(ctx, dexConfig)
|
||||
@@ -177,14 +167,6 @@ func NewProviderFromYAML(ctx context.Context, yamlConfig *YAMLConfig) (*Provider
|
||||
return nil, fmt.Errorf("failed to create refresh token policy: %w", err)
|
||||
}
|
||||
|
||||
localSigner, err := getSigner(ctx, stor, yamlConfig, logger)
|
||||
if err != nil {
|
||||
stor.Close()
|
||||
return nil, fmt.Errorf("failed to create local signer: %w", err)
|
||||
}
|
||||
|
||||
dexConfig.Signer = localSigner
|
||||
|
||||
dexSrv, err := server.NewServer(ctx, dexConfig)
|
||||
if err != nil {
|
||||
stor.Close()
|
||||
@@ -200,32 +182,6 @@ func NewProviderFromYAML(ctx context.Context, yamlConfig *YAMLConfig) (*Provider
|
||||
}, nil
|
||||
}
|
||||
|
||||
func getSigner(ctx context.Context, stor storage.Storage, yamlConfig *YAMLConfig, logger *slog.Logger) (signer.Signer, error) {
|
||||
// Parse expiry durations
|
||||
idTokensValidFor := 24 * time.Hour // default
|
||||
if yamlConfig.Expiry.IDTokens != "" {
|
||||
var err error
|
||||
idTokensValidFor, err = parseDuration(yamlConfig.Expiry.IDTokens)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid config value %q for id token expiry: %v", yamlConfig.Expiry.IDTokens, err)
|
||||
}
|
||||
}
|
||||
|
||||
localSignerConfig := &signer.LocalConfig{
|
||||
KeysRotationPeriod: "720h", // 30 Days
|
||||
}
|
||||
|
||||
if yamlConfig.Expiry.SigningKeys != "" {
|
||||
if _, err := parseDuration(yamlConfig.Expiry.SigningKeys); err != nil {
|
||||
return nil, fmt.Errorf("invalid config value %q for signing key expiry: %v", yamlConfig.Expiry.SigningKeys, err)
|
||||
}
|
||||
|
||||
localSignerConfig.KeysRotationPeriod = yamlConfig.Expiry.SigningKeys
|
||||
}
|
||||
|
||||
return localSignerConfig.Open(ctx, stor, idTokensValidFor, time.Now, logger)
|
||||
}
|
||||
|
||||
// initializeStorage sets up connectors, passwords, and clients in storage
|
||||
func initializeStorage(ctx context.Context, stor storage.Storage, cfg *YAMLConfig) error {
|
||||
if cfg.EnablePasswordDB {
|
||||
@@ -285,8 +241,6 @@ func ensureStaticClients(ctx context.Context, stor storage.Storage, clients []st
|
||||
old.RedirectURIs = client.RedirectURIs
|
||||
old.Name = client.Name
|
||||
old.Public = client.Public
|
||||
old.PostLogoutRedirectURIs = client.PostLogoutRedirectURIs
|
||||
old.MFAChain = client.MFAChain
|
||||
return old, nil
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to update client %s: %w", client.ID, err)
|
||||
@@ -299,6 +253,9 @@ func ensureStaticClients(ctx context.Context, stor storage.Storage, clients []st
|
||||
func buildDexConfig(yamlConfig *YAMLConfig, stor storage.Storage, logger *slog.Logger) server.Config {
|
||||
cfg := yamlConfig.ToServerConfig(stor, logger)
|
||||
cfg.PrometheusRegistry = prometheus.NewRegistry()
|
||||
if cfg.RotateKeysAfter == 0 {
|
||||
cfg.RotateKeysAfter = 24 * 30 * time.Hour
|
||||
}
|
||||
if cfg.IDTokensValidFor == 0 {
|
||||
cfg.IDTokensValidFor = 24 * time.Hour
|
||||
}
|
||||
@@ -493,34 +450,10 @@ func (p *Provider) Storage() storage.Storage {
|
||||
return p.storage
|
||||
}
|
||||
|
||||
// SetClientsMFAChain updates the MFAChain field on the dashboard and CLI OAuth2 clients.
|
||||
// Pass a non-empty slice (e.g. []string{"default-totp"}) to enable MFA, or nil to disable it.
|
||||
func (p *Provider) SetClientsMFAChain(ctx context.Context, clientIDs []string, mfaChain []string) error {
|
||||
for _, clientID := range clientIDs {
|
||||
if err := p.storage.UpdateClient(ctx, clientID, func(old storage.Client) (storage.Client, error) {
|
||||
old.MFAChain = mfaChain
|
||||
return old, nil
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to update MFA chain on client %s: %w", clientID, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Handler returns the Dex server as an http.Handler for embedding in another server.
|
||||
// The handler expects requests with path prefix "/oauth2/".
|
||||
func (p *Provider) Handler() http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Dex's /logout endpoint requires id_token_hint for RP-initiated logout with
|
||||
// post_logout_redirect_uri. If the dashboard calls logout without one, avoid
|
||||
// rendering Dex's non-actionable Bad Request page and send the user home.
|
||||
if strings.HasSuffix(r.URL.Path, "/logout") && r.FormValue("id_token_hint") == "" {
|
||||
http.Redirect(w, r, "/", http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
|
||||
p.dexServer.ServeHTTP(w, r)
|
||||
})
|
||||
return p.dexServer
|
||||
}
|
||||
|
||||
// CreateUser creates a new user with the given email, username, and password.
|
||||
|
||||
@@ -4,8 +4,6 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
@@ -146,30 +144,6 @@ func TestEncodeDexUserID_MatchesDexFormat(t *testing.T) {
|
||||
assert.Equal(t, knownEncodedID, reEncoded)
|
||||
}
|
||||
|
||||
func TestHandlerRedirectsLogoutWithoutIDTokenHint(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "dex-logout-handler-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
provider, err := NewProvider(ctx, &Config{
|
||||
Issuer: "http://localhost:5556/oauth2",
|
||||
Port: 5556,
|
||||
DataDir: tmpDir,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = provider.Stop(ctx) }()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/oauth2/logout?post_logout_redirect_uri=https://example.com", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
provider.Handler().ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusSeeOther, rec.Code)
|
||||
require.Equal(t, "/", rec.Header().Get("Location"))
|
||||
}
|
||||
|
||||
func TestCreateUserInTempDB(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
{{ template "header.html" . }}
|
||||
|
||||
<script>globalThis.location.replace("/");</script>
|
||||
<noscript>
|
||||
<div class="nb-card">
|
||||
<h1 class="nb-heading">Redirecting…</h1>
|
||||
<p class="nb-subheading">You are being redirected to the NetBird dashboard.</p>
|
||||
<a href="/" class="nb-btn" style="display:block;text-align:center;text-decoration:none">Go to Dashboard</a>
|
||||
</div>
|
||||
</noscript>
|
||||
|
||||
{{ template "footer.html" . }}
|
||||
@@ -1,14 +0,0 @@
|
||||
{{ template "header.html" . }}
|
||||
|
||||
<div class="nb-card">
|
||||
<h1 class="nb-heading">Logged Out</h1>
|
||||
<p class="nb-subheading">You have been successfully logged out.</p>
|
||||
|
||||
{{ if .BackURL }}
|
||||
<div class="nb-back-link">
|
||||
<a href="{{ .BackURL }}" class="nb-link">← Back to Application</a>
|
||||
</div>
|
||||
{{ end }}
|
||||
</div>
|
||||
|
||||
{{ template "footer.html" . }}
|
||||
@@ -18,7 +18,6 @@
|
||||
id="login"
|
||||
name="login"
|
||||
class="nb-input"
|
||||
autocomplete="username"
|
||||
placeholder="Enter your {{ .UsernamePrompt | lower }}"
|
||||
{{ if .Username }}value="{{ .Username }}"{{ else }}autofocus{{ end }}
|
||||
required
|
||||
@@ -32,7 +31,6 @@
|
||||
id="password"
|
||||
name="password"
|
||||
class="nb-input"
|
||||
autocomplete="current-password"
|
||||
placeholder="Enter your password"
|
||||
{{ if .Invalid }}autofocus{{ end }}
|
||||
required
|
||||
|
||||
@@ -1,44 +0,0 @@
|
||||
{{ template "header.html" . }}
|
||||
|
||||
<div class="nb-card">
|
||||
<h1 class="nb-heading">Two-factor authentication</h1>
|
||||
{{ if not (eq .QRCode "") }}
|
||||
<p class="nb-subheading">Scan the QR code below using your authenticator app, then enter the code.</p>
|
||||
<div style="text-align: center; margin: 1em 0;">
|
||||
<img src="data:image/png;base64,{{ .QRCode }}" alt="QR code" width="200" height="200"/>
|
||||
</div>
|
||||
{{ else }}
|
||||
<p class="nb-subheading">Enter the code from your authenticator app.</p>
|
||||
{{ end }}
|
||||
|
||||
<form method="post" action="{{ .PostURL }}">
|
||||
{{ if .Invalid }}
|
||||
<div class="nb-error">
|
||||
Invalid code. Please try again.
|
||||
</div>
|
||||
{{ end }}
|
||||
|
||||
<div class="nb-form-group">
|
||||
<label class="nb-label" for="totp">One-time code</label>
|
||||
<input
|
||||
type="text"
|
||||
id="totp"
|
||||
name="totp"
|
||||
class="nb-input"
|
||||
inputmode="numeric"
|
||||
pattern="[0-9]*"
|
||||
maxlength="6"
|
||||
autocomplete="one-time-code"
|
||||
placeholder="000000"
|
||||
autofocus
|
||||
required
|
||||
>
|
||||
</div>
|
||||
|
||||
<button type="submit" id="submit-login" class="nb-btn">
|
||||
Verify
|
||||
</button>
|
||||
</form>
|
||||
</div>
|
||||
|
||||
{{ template "footer.html" . }}
|
||||
@@ -1,12 +0,0 @@
|
||||
{{ template "header.html" . }}
|
||||
|
||||
<script>globalThis.location.replace("/");</script>
|
||||
<noscript>
|
||||
<div class="nb-card">
|
||||
<h1 class="nb-heading">Redirecting…</h1>
|
||||
<p class="nb-subheading">You are being redirected to the NetBird dashboard.</p>
|
||||
<a href="/" class="nb-btn" style="display:block;text-align:center;text-decoration:none">Go to Dashboard</a>
|
||||
</div>
|
||||
</noscript>
|
||||
|
||||
{{ template "footer.html" . }}
|
||||
@@ -221,13 +221,9 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error {
|
||||
func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID string) error {
|
||||
log.WithContext(ctx).Tracef("buffer sending update peers for account %s from %s", accountID, util.GetCallerName())
|
||||
|
||||
if c.accountManagerMetrics != nil {
|
||||
c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation))
|
||||
}
|
||||
|
||||
bufUpd, _ := c.sendAccountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{})
|
||||
b := bufUpd.(*bufferUpdate)
|
||||
|
||||
@@ -574,7 +570,7 @@ func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *t
|
||||
}
|
||||
|
||||
func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string) error {
|
||||
err := c.bufferSendUpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationUpdate})
|
||||
err := c.bufferSendUpdateAccountPeers(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to buffer update account peers for peer update in account %s: %v", accountID, err)
|
||||
}
|
||||
@@ -584,7 +580,7 @@ func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerI
|
||||
|
||||
func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error {
|
||||
log.WithContext(ctx).Debugf("OnPeersAdded call to add peers: %v", peerIDs)
|
||||
return c.bufferSendUpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationCreate})
|
||||
return c.bufferSendUpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error {
|
||||
@@ -620,7 +616,7 @@ func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerI
|
||||
c.peersUpdateManager.CloseChannel(ctx, peerID)
|
||||
}
|
||||
|
||||
return c.bufferSendUpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationDelete})
|
||||
return c.bufferSendUpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
// GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result)
|
||||
|
||||
@@ -31,7 +31,6 @@ type store interface {
|
||||
|
||||
type proxyManager interface {
|
||||
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
|
||||
GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
|
||||
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
||||
@@ -72,8 +71,8 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
|
||||
var ret []*domain.Domain
|
||||
|
||||
// Add connected proxy clusters as free domains.
|
||||
// For BYOP accounts, only their own cluster is returned; otherwise shared clusters.
|
||||
allowList, err := m.getClusterAllowList(ctx, accountID)
|
||||
// The cluster address itself is the free domain base (e.g., "eu.proxy.netbird.io").
|
||||
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err)
|
||||
return nil, err
|
||||
@@ -127,8 +126,8 @@ func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
// Verify the target cluster is in the available clusters for this account
|
||||
allowList, err := m.getClusterAllowList(ctx, accountID)
|
||||
// Verify the target cluster is in the available clusters
|
||||
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get active proxy cluster addresses: %w", err)
|
||||
}
|
||||
@@ -274,7 +273,7 @@ func (m Manager) GetClusterDomains() []string {
|
||||
// For free domains (those ending with a known cluster suffix), the cluster is extracted from the domain.
|
||||
// For custom domains, the cluster is determined by checking the registered custom domain's target cluster.
|
||||
func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error) {
|
||||
allowList, err := m.getClusterAllowList(ctx, accountID)
|
||||
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get active proxy cluster addresses: %w", err)
|
||||
}
|
||||
@@ -299,17 +298,6 @@ func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain
|
||||
return "", fmt.Errorf("domain %s does not match any available proxy cluster", domain)
|
||||
}
|
||||
|
||||
func (m Manager) getClusterAllowList(ctx context.Context, accountID string) ([]string, error) {
|
||||
byopAddresses, err := m.proxyManager.GetActiveClusterAddressesForAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get BYOP cluster addresses: %w", err)
|
||||
}
|
||||
if len(byopAddresses) > 0 {
|
||||
return byopAddresses, nil
|
||||
}
|
||||
return m.proxyManager.GetActiveClusterAddresses(ctx)
|
||||
}
|
||||
|
||||
func extractClusterFromCustomDomains(serviceDomain string, customDomains []*domain.Domain) (string, bool) {
|
||||
bestCluster := ""
|
||||
bestLen := -1
|
||||
|
||||
@@ -1,110 +0,0 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type mockProxyManager struct {
|
||||
getActiveClusterAddressesFunc func(ctx context.Context) ([]string, error)
|
||||
getActiveClusterAddressesForAccountFunc func(ctx context.Context, accountID string) ([]string, error)
|
||||
}
|
||||
|
||||
func (m *mockProxyManager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
|
||||
if m.getActiveClusterAddressesFunc != nil {
|
||||
return m.getActiveClusterAddressesFunc(ctx)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockProxyManager) GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
|
||||
if m.getActiveClusterAddressesForAccountFunc != nil {
|
||||
return m.getActiveClusterAddressesForAccountFunc(ctx, accountID)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockProxyManager) ClusterSupportsCustomPorts(_ context.Context, _ string) *bool {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockProxyManager) ClusterRequireSubdomain(_ context.Context, _ string) *bool {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockProxyManager) ClusterSupportsCrowdSec(_ context.Context, _ string) *bool {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestGetClusterAllowList_BYOPProxy(t *testing.T) {
|
||||
pm := &mockProxyManager{
|
||||
getActiveClusterAddressesForAccountFunc: func(_ context.Context, accID string) ([]string, error) {
|
||||
assert.Equal(t, "acc-123", accID)
|
||||
return []string{"byop.example.com"}, nil
|
||||
},
|
||||
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
|
||||
t.Fatal("should not call GetActiveClusterAddresses when BYOP addresses exist")
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := Manager{proxyManager: pm}
|
||||
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"byop.example.com"}, result)
|
||||
}
|
||||
|
||||
func TestGetClusterAllowList_NoBYOP_FallbackToShared(t *testing.T) {
|
||||
pm := &mockProxyManager{
|
||||
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
|
||||
return nil, nil
|
||||
},
|
||||
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
|
||||
return []string{"eu.proxy.netbird.io", "us.proxy.netbird.io"}, nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := Manager{proxyManager: pm}
|
||||
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"eu.proxy.netbird.io", "us.proxy.netbird.io"}, result)
|
||||
}
|
||||
|
||||
func TestGetClusterAllowList_BYOPError_ReturnsError(t *testing.T) {
|
||||
pm := &mockProxyManager{
|
||||
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
|
||||
return nil, errors.New("db error")
|
||||
},
|
||||
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
|
||||
t.Fatal("should not call GetActiveClusterAddresses when BYOP lookup fails")
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := Manager{proxyManager: pm}
|
||||
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), "BYOP cluster addresses")
|
||||
}
|
||||
|
||||
func TestGetClusterAllowList_BYOPEmptySlice_FallbackToShared(t *testing.T) {
|
||||
pm := &mockProxyManager{
|
||||
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
|
||||
return []string{}, nil
|
||||
},
|
||||
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
|
||||
return []string{"eu.proxy.netbird.io"}, nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := Manager{proxyManager: pm}
|
||||
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"eu.proxy.netbird.io"}, result)
|
||||
}
|
||||
|
||||
@@ -11,19 +11,15 @@ import (
|
||||
|
||||
// Manager defines the interface for proxy operations
|
||||
type Manager interface {
|
||||
Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, accountID *string, capabilities *Capabilities) (*Proxy, error)
|
||||
Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *Capabilities) (*Proxy, error)
|
||||
Disconnect(ctx context.Context, proxyID, sessionID string) error
|
||||
Heartbeat(ctx context.Context, p *Proxy) error
|
||||
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
|
||||
GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
|
||||
GetActiveClusters(ctx context.Context) ([]Cluster, error)
|
||||
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
||||
CleanupStale(ctx context.Context, inactivityDuration time.Duration) error
|
||||
GetAccountProxy(ctx context.Context, accountID string) (*Proxy, error)
|
||||
CountAccountProxies(ctx context.Context, accountID string) (int64, error)
|
||||
IsClusterAddressAvailable(ctx context.Context, clusterAddress, accountID string) (bool, error)
|
||||
DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error
|
||||
}
|
||||
|
||||
// OIDCValidationConfig contains the OIDC configuration needed for token validation.
|
||||
|
||||
@@ -16,16 +16,11 @@ type store interface {
|
||||
DisconnectProxy(ctx context.Context, proxyID, sessionID string) error
|
||||
UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error
|
||||
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
|
||||
GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
|
||||
GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error)
|
||||
GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error)
|
||||
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||
GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||
GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
||||
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
|
||||
GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error)
|
||||
CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error)
|
||||
IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error)
|
||||
DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error
|
||||
}
|
||||
|
||||
// Manager handles all proxy operations
|
||||
@@ -49,7 +44,7 @@ func NewManager(store store, meter metric.Meter) (*Manager, error) {
|
||||
|
||||
// Connect registers a new proxy connection in the database.
|
||||
// capabilities may be nil for old proxies that do not report them.
|
||||
func (m *Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, accountID *string, capabilities *proxy.Capabilities) (*proxy.Proxy, error) {
|
||||
func (m Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *proxy.Capabilities) (*proxy.Proxy, error) {
|
||||
now := time.Now()
|
||||
var caps proxy.Capabilities
|
||||
if capabilities != nil {
|
||||
@@ -60,10 +55,9 @@ func (m *Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddres
|
||||
SessionID: sessionID,
|
||||
ClusterAddress: clusterAddress,
|
||||
IPAddress: ipAddress,
|
||||
AccountID: accountID,
|
||||
LastSeen: now,
|
||||
ConnectedAt: &now,
|
||||
Status: proxy.StatusConnected,
|
||||
Status: "connected",
|
||||
Capabilities: caps,
|
||||
}
|
||||
|
||||
@@ -83,7 +77,7 @@ func (m *Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddres
|
||||
}
|
||||
|
||||
// Disconnect marks a proxy as disconnected in the database.
|
||||
func (m *Manager) Disconnect(ctx context.Context, proxyID, sessionID string) error {
|
||||
func (m Manager) Disconnect(ctx context.Context, proxyID, sessionID string) error {
|
||||
if err := m.store.DisconnectProxy(ctx, proxyID, sessionID); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to disconnect proxy %s session %s: %v", proxyID, sessionID, err)
|
||||
return err
|
||||
@@ -98,7 +92,7 @@ func (m *Manager) Disconnect(ctx context.Context, proxyID, sessionID string) err
|
||||
}
|
||||
|
||||
// Heartbeat updates the proxy's last seen timestamp.
|
||||
func (m *Manager) Heartbeat(ctx context.Context, p *proxy.Proxy) error {
|
||||
func (m Manager) Heartbeat(ctx context.Context, p *proxy.Proxy) error {
|
||||
if err := m.store.UpdateProxyHeartbeat(ctx, p); err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", p.ID, err)
|
||||
return err
|
||||
@@ -110,7 +104,7 @@ func (m *Manager) Heartbeat(ctx context.Context, p *proxy.Proxy) error {
|
||||
}
|
||||
|
||||
// GetActiveClusterAddresses returns all unique cluster addresses for active proxies
|
||||
func (m *Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
|
||||
func (m Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
|
||||
addresses, err := m.store.GetActiveProxyClusterAddresses(ctx)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err)
|
||||
@@ -119,6 +113,16 @@ func (m *Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, erro
|
||||
return addresses, nil
|
||||
}
|
||||
|
||||
// GetActiveClusters returns all active proxy clusters with their connected proxy count.
|
||||
func (m Manager) GetActiveClusters(ctx context.Context) ([]proxy.Cluster, error) {
|
||||
clusters, err := m.store.GetActiveProxyClusters(ctx)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get active proxy clusters: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
return clusters, nil
|
||||
}
|
||||
|
||||
// ClusterSupportsCustomPorts returns whether any active proxy in the cluster
|
||||
// supports custom ports. Returns nil when no proxy has reported capabilities.
|
||||
func (m Manager) ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool {
|
||||
@@ -138,44 +142,10 @@ func (m Manager) ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string
|
||||
}
|
||||
|
||||
// CleanupStale removes proxies that haven't sent heartbeat in the specified duration
|
||||
func (m *Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
|
||||
func (m Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
|
||||
if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to cleanup stale proxies: %v", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
|
||||
addresses, err := m.store.GetActiveProxyClusterAddressesForAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses for account %s: %v", accountID, err)
|
||||
return nil, err
|
||||
}
|
||||
return addresses, nil
|
||||
}
|
||||
|
||||
func (m *Manager) GetAccountProxy(ctx context.Context, accountID string) (*proxy.Proxy, error) {
|
||||
return m.store.GetProxyByAccountID(ctx, accountID)
|
||||
}
|
||||
|
||||
func (m *Manager) CountAccountProxies(ctx context.Context, accountID string) (int64, error) {
|
||||
return m.store.CountProxiesByAccountID(ctx, accountID)
|
||||
}
|
||||
|
||||
func (m *Manager) IsClusterAddressAvailable(ctx context.Context, clusterAddress, accountID string) (bool, error) {
|
||||
conflicting, err := m.store.IsClusterAddressConflicting(ctx, clusterAddress, accountID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return !conflicting, nil
|
||||
}
|
||||
|
||||
func (m *Manager) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error {
|
||||
if err := m.store.DeleteAccountCluster(ctx, clusterAddress, accountID); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete cluster %s for account %s: %v", clusterAddress, accountID, err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,337 +0,0 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel/metric/noop"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
)
|
||||
|
||||
type mockStore struct {
|
||||
saveProxyFunc func(ctx context.Context, p *proxy.Proxy) error
|
||||
disconnectProxyFunc func(ctx context.Context, proxyID, sessionID string) error
|
||||
updateProxyHeartbeatFunc func(ctx context.Context, p *proxy.Proxy) error
|
||||
getActiveProxyClusterAddressesFunc func(ctx context.Context) ([]string, error)
|
||||
getActiveProxyClusterAddressesForAccFunc func(ctx context.Context, accountID string) ([]string, error)
|
||||
cleanupStaleProxiesFunc func(ctx context.Context, d time.Duration) error
|
||||
getProxyByAccountIDFunc func(ctx context.Context, accountID string) (*proxy.Proxy, error)
|
||||
countProxiesByAccountIDFunc func(ctx context.Context, accountID string) (int64, error)
|
||||
isClusterAddressConflictingFunc func(ctx context.Context, clusterAddress, accountID string) (bool, error)
|
||||
deleteAccountClusterFunc func(ctx context.Context, clusterAddress, accountID string) error
|
||||
}
|
||||
|
||||
func (m *mockStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
|
||||
if m.saveProxyFunc != nil {
|
||||
return m.saveProxyFunc(ctx, p)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (m *mockStore) DisconnectProxy(ctx context.Context, proxyID, sessionID string) error {
|
||||
if m.disconnectProxyFunc != nil {
|
||||
return m.disconnectProxyFunc(ctx, proxyID, sessionID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (m *mockStore) UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error {
|
||||
if m.updateProxyHeartbeatFunc != nil {
|
||||
return m.updateProxyHeartbeatFunc(ctx, p)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (m *mockStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) {
|
||||
if m.getActiveProxyClusterAddressesFunc != nil {
|
||||
return m.getActiveProxyClusterAddressesFunc(ctx)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockStore) GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
|
||||
if m.getActiveProxyClusterAddressesForAccFunc != nil {
|
||||
return m.getActiveProxyClusterAddressesForAccFunc(ctx, accountID)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockStore) GetActiveProxyClusters(_ context.Context, _ string) ([]proxy.Cluster, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockStore) CleanupStaleProxies(ctx context.Context, d time.Duration) error {
|
||||
if m.cleanupStaleProxiesFunc != nil {
|
||||
return m.cleanupStaleProxiesFunc(ctx, d)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (m *mockStore) GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error) {
|
||||
if m.getProxyByAccountIDFunc != nil {
|
||||
return m.getProxyByAccountIDFunc(ctx, accountID)
|
||||
}
|
||||
return nil, fmt.Errorf("proxy not found for account %s", accountID)
|
||||
}
|
||||
func (m *mockStore) CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error) {
|
||||
if m.countProxiesByAccountIDFunc != nil {
|
||||
return m.countProxiesByAccountIDFunc(ctx, accountID)
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
func (m *mockStore) IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error) {
|
||||
if m.isClusterAddressConflictingFunc != nil {
|
||||
return m.isClusterAddressConflictingFunc(ctx, clusterAddress, accountID)
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
func (m *mockStore) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error {
|
||||
if m.deleteAccountClusterFunc != nil {
|
||||
return m.deleteAccountClusterFunc(ctx, clusterAddress, accountID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (m *mockStore) GetClusterSupportsCustomPorts(_ context.Context, _ string) *bool {
|
||||
return nil
|
||||
}
|
||||
func (m *mockStore) GetClusterRequireSubdomain(_ context.Context, _ string) *bool {
|
||||
return nil
|
||||
}
|
||||
func (m *mockStore) GetClusterSupportsCrowdSec(_ context.Context, _ string) *bool {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newTestManager(s store) *Manager {
|
||||
meter := noop.NewMeterProvider().Meter("test")
|
||||
m, err := NewManager(s, meter)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func TestConnect_WithAccountID(t *testing.T) {
|
||||
accountID := "acc-123"
|
||||
|
||||
var savedProxy *proxy.Proxy
|
||||
s := &mockStore{
|
||||
saveProxyFunc: func(_ context.Context, p *proxy.Proxy) error {
|
||||
savedProxy = p
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
_, err := mgr.Connect(context.Background(), "proxy-1", "session-1", "cluster.example.com", "10.0.0.1", &accountID, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotNil(t, savedProxy)
|
||||
assert.Equal(t, "proxy-1", savedProxy.ID)
|
||||
assert.Equal(t, "session-1", savedProxy.SessionID)
|
||||
assert.Equal(t, "cluster.example.com", savedProxy.ClusterAddress)
|
||||
assert.Equal(t, "10.0.0.1", savedProxy.IPAddress)
|
||||
assert.Equal(t, &accountID, savedProxy.AccountID)
|
||||
assert.Equal(t, proxy.StatusConnected, savedProxy.Status)
|
||||
assert.NotNil(t, savedProxy.ConnectedAt)
|
||||
}
|
||||
|
||||
func TestConnect_WithoutAccountID(t *testing.T) {
|
||||
var savedProxy *proxy.Proxy
|
||||
s := &mockStore{
|
||||
saveProxyFunc: func(_ context.Context, p *proxy.Proxy) error {
|
||||
savedProxy = p
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
_, err := mgr.Connect(context.Background(), "proxy-1", "session-1", "eu.proxy.netbird.io", "10.0.0.1", nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotNil(t, savedProxy)
|
||||
assert.Nil(t, savedProxy.AccountID)
|
||||
assert.Equal(t, proxy.StatusConnected, savedProxy.Status)
|
||||
}
|
||||
|
||||
func TestConnect_StoreError(t *testing.T) {
|
||||
s := &mockStore{
|
||||
saveProxyFunc: func(_ context.Context, _ *proxy.Proxy) error {
|
||||
return errors.New("db error")
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
_, err := mgr.Connect(context.Background(), "proxy-1", "session-1", "cluster.example.com", "10.0.0.1", nil, nil)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestIsClusterAddressAvailable(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
conflicting bool
|
||||
storeErr error
|
||||
wantResult bool
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "available - no conflict",
|
||||
conflicting: false,
|
||||
wantResult: true,
|
||||
},
|
||||
{
|
||||
name: "not available - conflict exists",
|
||||
conflicting: true,
|
||||
wantResult: false,
|
||||
},
|
||||
{
|
||||
name: "store error",
|
||||
storeErr: errors.New("db error"),
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &mockStore{
|
||||
isClusterAddressConflictingFunc: func(_ context.Context, _, _ string) (bool, error) {
|
||||
return tt.conflicting, tt.storeErr
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
result, err := mgr.IsClusterAddressAvailable(context.Background(), "cluster.example.com", "acc-123")
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.wantResult, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCountAccountProxies(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
count int64
|
||||
storeErr error
|
||||
wantCount int64
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "no proxies",
|
||||
count: 0,
|
||||
wantCount: 0,
|
||||
},
|
||||
{
|
||||
name: "one proxy",
|
||||
count: 1,
|
||||
wantCount: 1,
|
||||
},
|
||||
{
|
||||
name: "store error",
|
||||
storeErr: errors.New("db error"),
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &mockStore{
|
||||
countProxiesByAccountIDFunc: func(_ context.Context, _ string) (int64, error) {
|
||||
return tt.count, tt.storeErr
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
count, err := mgr.CountAccountProxies(context.Background(), "acc-123")
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.wantCount, count)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAccountProxy(t *testing.T) {
|
||||
accountID := "acc-123"
|
||||
|
||||
t.Run("found", func(t *testing.T) {
|
||||
expected := &proxy.Proxy{
|
||||
ID: "proxy-1",
|
||||
ClusterAddress: "byop.example.com",
|
||||
AccountID: &accountID,
|
||||
Status: proxy.StatusConnected,
|
||||
}
|
||||
s := &mockStore{
|
||||
getProxyByAccountIDFunc: func(_ context.Context, accID string) (*proxy.Proxy, error) {
|
||||
assert.Equal(t, accountID, accID)
|
||||
return expected, nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
p, err := mgr.GetAccountProxy(context.Background(), accountID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expected, p)
|
||||
})
|
||||
|
||||
t.Run("not found", func(t *testing.T) {
|
||||
s := &mockStore{
|
||||
getProxyByAccountIDFunc: func(_ context.Context, _ string) (*proxy.Proxy, error) {
|
||||
return nil, errors.New("not found")
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
_, err := mgr.GetAccountProxy(context.Background(), accountID)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeleteAccountCluster(t *testing.T) {
|
||||
t.Run("success", func(t *testing.T) {
|
||||
var deletedCluster, deletedAccount string
|
||||
s := &mockStore{
|
||||
deleteAccountClusterFunc: func(_ context.Context, clusterAddress, accountID string) error {
|
||||
deletedCluster = clusterAddress
|
||||
deletedAccount = accountID
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
err := mgr.DeleteAccountCluster(context.Background(), "cluster.example.com", "acc-123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "cluster.example.com", deletedCluster)
|
||||
assert.Equal(t, "acc-123", deletedAccount)
|
||||
})
|
||||
|
||||
t.Run("store error", func(t *testing.T) {
|
||||
s := &mockStore{
|
||||
deleteAccountClusterFunc: func(_ context.Context, _, _ string) error {
|
||||
return errors.New("db error")
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
err := mgr.DeleteAccountCluster(context.Background(), "cluster.example.com", "acc-123")
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetActiveClusterAddressesForAccount(t *testing.T) {
|
||||
expected := []string{"byop.example.com"}
|
||||
s := &mockStore{
|
||||
getActiveProxyClusterAddressesForAccFunc: func(_ context.Context, accID string) ([]string, error) {
|
||||
assert.Equal(t, "acc-123", accID)
|
||||
return expected, nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
result, err := mgr.GetActiveClusterAddressesForAccount(context.Background(), "acc-123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expected, result)
|
||||
}
|
||||
@@ -93,18 +93,18 @@ func (mr *MockManagerMockRecorder) ClusterSupportsCrowdSec(ctx, clusterAddr inte
|
||||
}
|
||||
|
||||
// Connect mocks base method.
|
||||
func (m *MockManager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, accountID *string, capabilities *Capabilities) (*Proxy, error) {
|
||||
func (m *MockManager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *Capabilities) (*Proxy, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, sessionID, clusterAddress, ipAddress, accountID, capabilities)
|
||||
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities)
|
||||
ret0, _ := ret[0].(*Proxy)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Connect indicates an expected call of Connect.
|
||||
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, sessionID, clusterAddress, ipAddress, accountID, capabilities interface{}) *gomock.Call {
|
||||
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, sessionID, clusterAddress, ipAddress, accountID, capabilities)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities)
|
||||
}
|
||||
|
||||
// Disconnect mocks base method.
|
||||
@@ -136,17 +136,19 @@ func (mr *MockManagerMockRecorder) GetActiveClusterAddresses(ctx interface{}) *g
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusterAddresses", reflect.TypeOf((*MockManager)(nil).GetActiveClusterAddresses), ctx)
|
||||
}
|
||||
|
||||
func (m *MockManager) GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
|
||||
// GetActiveClusters mocks base method.
|
||||
func (m *MockManager) GetActiveClusters(ctx context.Context) ([]Cluster, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetActiveClusterAddressesForAccount", ctx, accountID)
|
||||
ret0, _ := ret[0].([]string)
|
||||
ret := m.ctrl.Call(m, "GetActiveClusters", ctx)
|
||||
ret0, _ := ret[0].([]Cluster)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
func (mr *MockManagerMockRecorder) GetActiveClusterAddressesForAccount(ctx, accountID interface{}) *gomock.Call {
|
||||
// GetActiveClusters indicates an expected call of GetActiveClusters.
|
||||
func (mr *MockManagerMockRecorder) GetActiveClusters(ctx interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusterAddressesForAccount", reflect.TypeOf((*MockManager)(nil).GetActiveClusterAddressesForAccount), ctx, accountID)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusters", reflect.TypeOf((*MockManager)(nil).GetActiveClusters), ctx)
|
||||
}
|
||||
|
||||
// Heartbeat mocks base method.
|
||||
@@ -163,65 +165,6 @@ func (mr *MockManagerMockRecorder) Heartbeat(ctx, p interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, p)
|
||||
}
|
||||
|
||||
// GetAccountProxy mocks base method.
|
||||
func (m *MockManager) GetAccountProxy(ctx context.Context, accountID string) (*Proxy, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAccountProxy", ctx, accountID)
|
||||
ret0, _ := ret[0].(*Proxy)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetAccountProxy indicates an expected call of GetAccountProxy.
|
||||
func (mr *MockManagerMockRecorder) GetAccountProxy(ctx, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountProxy", reflect.TypeOf((*MockManager)(nil).GetAccountProxy), ctx, accountID)
|
||||
}
|
||||
|
||||
// CountAccountProxies mocks base method.
|
||||
func (m *MockManager) CountAccountProxies(ctx context.Context, accountID string) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CountAccountProxies", ctx, accountID)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// CountAccountProxies indicates an expected call of CountAccountProxies.
|
||||
func (mr *MockManagerMockRecorder) CountAccountProxies(ctx, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAccountProxies", reflect.TypeOf((*MockManager)(nil).CountAccountProxies), ctx, accountID)
|
||||
}
|
||||
|
||||
// IsClusterAddressAvailable mocks base method.
|
||||
func (m *MockManager) IsClusterAddressAvailable(ctx context.Context, clusterAddress, accountID string) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "IsClusterAddressAvailable", ctx, clusterAddress, accountID)
|
||||
ret0, _ := ret[0].(bool)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// IsClusterAddressAvailable indicates an expected call of IsClusterAddressAvailable.
|
||||
func (mr *MockManagerMockRecorder) IsClusterAddressAvailable(ctx, clusterAddress, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsClusterAddressAvailable", reflect.TypeOf((*MockManager)(nil).IsClusterAddressAvailable), ctx, clusterAddress, accountID)
|
||||
}
|
||||
|
||||
// DeleteAccountCluster mocks base method.
|
||||
func (m *MockManager) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteAccountCluster", ctx, clusterAddress, accountID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteAccountCluster indicates an expected call of DeleteAccountCluster.
|
||||
func (mr *MockManagerMockRecorder) DeleteAccountCluster(ctx, clusterAddress, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAccountCluster", reflect.TypeOf((*MockManager)(nil).DeleteAccountCluster), ctx, clusterAddress, accountID)
|
||||
}
|
||||
|
||||
// MockController is a mock of Controller interface.
|
||||
type MockController struct {
|
||||
ctrl *gomock.Controller
|
||||
|
||||
@@ -1,13 +1,6 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
StatusConnected = "connected"
|
||||
StatusDisconnected = "disconnected"
|
||||
)
|
||||
import "time"
|
||||
|
||||
// Capabilities describes what a proxy can handle, as reported via gRPC.
|
||||
// Nil fields mean the proxy never reported this capability.
|
||||
@@ -28,7 +21,6 @@ type Proxy struct {
|
||||
SessionID string `gorm:"type:varchar(36)"`
|
||||
ClusterAddress string `gorm:"type:varchar(255);not null;index:idx_proxy_cluster_status"`
|
||||
IPAddress string `gorm:"type:varchar(45)"`
|
||||
AccountID *string `gorm:"type:varchar(255);index:idx_proxy_account_id"`
|
||||
LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"`
|
||||
ConnectedAt *time.Time
|
||||
DisconnectedAt *time.Time
|
||||
@@ -44,8 +36,6 @@ func (Proxy) TableName() string {
|
||||
|
||||
// Cluster represents a group of proxy nodes serving the same address.
|
||||
type Cluster struct {
|
||||
ID string
|
||||
Address string
|
||||
ConnectedProxies int
|
||||
SelfHosted bool
|
||||
}
|
||||
|
||||
@@ -1,195 +0,0 @@
|
||||
package proxytoken
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
type handler struct {
|
||||
store store.Store
|
||||
permissionsManager permissions.Manager
|
||||
}
|
||||
|
||||
func RegisterEndpoints(s store.Store, permissionsManager permissions.Manager, router *mux.Router) {
|
||||
h := &handler{store: s, permissionsManager: permissionsManager}
|
||||
router.HandleFunc("/reverse-proxies/proxy-tokens", h.listTokens).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/reverse-proxies/proxy-tokens", h.createToken).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/reverse-proxies/proxy-tokens/{tokenId}", h.revokeToken).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
func (h *handler) createToken(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Create)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w)
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
util.WriteErrorResponse("permission denied", http.StatusForbidden, w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.ProxyTokenRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" || len(req.Name) > 255 {
|
||||
util.WriteErrorResponse("name is required and must be at most 255 characters", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
var expiresIn time.Duration
|
||||
if req.ExpiresIn != nil {
|
||||
if *req.ExpiresIn < 0 {
|
||||
util.WriteErrorResponse("expires_in must be non-negative", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
if *req.ExpiresIn > 0 {
|
||||
expiresIn = time.Duration(*req.ExpiresIn) * time.Second
|
||||
}
|
||||
}
|
||||
|
||||
accountID := userAuth.AccountId
|
||||
generated, err := types.CreateNewProxyAccessToken(req.Name, expiresIn, &accountID, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("failed to generate token", http.StatusInternalServerError, w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.store.SaveProxyAccessToken(r.Context(), &generated.ProxyAccessToken); err != nil {
|
||||
util.WriteErrorResponse("failed to save token", http.StatusInternalServerError, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := toProxyTokenCreatedResponse(generated)
|
||||
util.WriteJSONObject(r.Context(), w, resp)
|
||||
}
|
||||
|
||||
func (h *handler) listTokens(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Read)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w)
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
util.WriteErrorResponse("permission denied", http.StatusForbidden, w)
|
||||
return
|
||||
}
|
||||
|
||||
tokens, err := h.store.GetProxyAccessTokensByAccountID(r.Context(), store.LockingStrengthNone, userAuth.AccountId)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("failed to list tokens", http.StatusInternalServerError, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := make([]api.ProxyToken, 0, len(tokens))
|
||||
for _, token := range tokens {
|
||||
resp = append(resp, toProxyTokenResponse(token))
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, resp)
|
||||
}
|
||||
|
||||
func (h *handler) revokeToken(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Delete)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w)
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
util.WriteErrorResponse("permission denied", http.StatusForbidden, w)
|
||||
return
|
||||
}
|
||||
|
||||
tokenID := mux.Vars(r)["tokenId"]
|
||||
if tokenID == "" {
|
||||
util.WriteErrorResponse("token ID is required", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
token, err := h.store.GetProxyAccessTokenByID(r.Context(), store.LockingStrengthNone, tokenID)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.ErrorType == status.NotFound {
|
||||
util.WriteErrorResponse("token not found", http.StatusNotFound, w)
|
||||
} else {
|
||||
util.WriteErrorResponse("failed to retrieve token", http.StatusInternalServerError, w)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if token.AccountID == nil || *token.AccountID != userAuth.AccountId {
|
||||
util.WriteErrorResponse("token not found", http.StatusNotFound, w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.store.RevokeProxyAccessToken(r.Context(), tokenID); err != nil {
|
||||
util.WriteErrorResponse("failed to revoke token", http.StatusInternalServerError, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
func toProxyTokenResponse(token *types.ProxyAccessToken) api.ProxyToken {
|
||||
resp := api.ProxyToken{
|
||||
Id: token.ID,
|
||||
Name: token.Name,
|
||||
Revoked: token.Revoked,
|
||||
}
|
||||
if !token.CreatedAt.IsZero() {
|
||||
resp.CreatedAt = token.CreatedAt
|
||||
}
|
||||
if token.ExpiresAt != nil {
|
||||
resp.ExpiresAt = token.ExpiresAt
|
||||
}
|
||||
if token.LastUsed != nil {
|
||||
resp.LastUsed = token.LastUsed
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
func toProxyTokenCreatedResponse(generated *types.ProxyAccessTokenGenerated) api.ProxyTokenCreated {
|
||||
base := toProxyTokenResponse(&generated.ProxyAccessToken)
|
||||
plainToken := string(generated.PlainToken)
|
||||
return api.ProxyTokenCreated{
|
||||
Id: base.Id,
|
||||
Name: base.Name,
|
||||
CreatedAt: base.CreatedAt,
|
||||
ExpiresAt: base.ExpiresAt,
|
||||
LastUsed: base.LastUsed,
|
||||
Revoked: base.Revoked,
|
||||
PlainToken: plainToken,
|
||||
}
|
||||
}
|
||||
@@ -1,275 +0,0 @@
|
||||
package proxytoken
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
func authContext(accountID, userID string) context.Context {
|
||||
return nbcontext.SetUserAuthInContext(context.Background(), auth.UserAuth{
|
||||
AccountId: accountID,
|
||||
UserId: userID,
|
||||
})
|
||||
}
|
||||
|
||||
func TestCreateToken_AccountScoped(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
accountID := "acc-123"
|
||||
var savedToken *types.ProxyAccessToken
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().SaveProxyAccessToken(gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(_ context.Context, token *types.ProxyAccessToken) error {
|
||||
savedToken = token
|
||||
return nil
|
||||
},
|
||||
)
|
||||
|
||||
permsMgr := permissions.NewMockManager(ctrl)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Create).Return(true, nil)
|
||||
|
||||
h := &handler{
|
||||
store: mockStore,
|
||||
permissionsManager: permsMgr,
|
||||
}
|
||||
|
||||
body := `{"name": "my-token"}`
|
||||
req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body))
|
||||
req = req.WithContext(authContext(accountID, "user-1"))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.createToken(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var resp api.ProxyTokenCreated
|
||||
require.NoError(t, json.NewDecoder(w.Body).Decode(&resp))
|
||||
|
||||
assert.NotEmpty(t, resp.PlainToken)
|
||||
assert.Equal(t, "my-token", resp.Name)
|
||||
assert.False(t, resp.Revoked)
|
||||
|
||||
require.NotNil(t, savedToken)
|
||||
require.NotNil(t, savedToken.AccountID)
|
||||
assert.Equal(t, accountID, *savedToken.AccountID)
|
||||
assert.Equal(t, "user-1", savedToken.CreatedBy)
|
||||
}
|
||||
|
||||
func TestCreateToken_WithExpiration(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
var savedToken *types.ProxyAccessToken
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().SaveProxyAccessToken(gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(_ context.Context, token *types.ProxyAccessToken) error {
|
||||
savedToken = token
|
||||
return nil
|
||||
},
|
||||
)
|
||||
|
||||
permsMgr := permissions.NewMockManager(ctrl)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(true, nil)
|
||||
|
||||
h := &handler{
|
||||
store: mockStore,
|
||||
permissionsManager: permsMgr,
|
||||
}
|
||||
|
||||
body := `{"name": "expiring-token", "expires_in": 3600}`
|
||||
req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body))
|
||||
req = req.WithContext(authContext("acc-123", "user-1"))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.createToken(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
require.NotNil(t, savedToken)
|
||||
require.NotNil(t, savedToken.ExpiresAt)
|
||||
assert.True(t, savedToken.ExpiresAt.After(time.Now()))
|
||||
}
|
||||
|
||||
func TestCreateToken_EmptyName(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
permsMgr := permissions.NewMockManager(ctrl)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(true, nil)
|
||||
|
||||
h := &handler{
|
||||
permissionsManager: permsMgr,
|
||||
}
|
||||
|
||||
body := `{"name": ""}`
|
||||
req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body))
|
||||
req = req.WithContext(authContext("acc-123", "user-1"))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.createToken(w, req)
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
}
|
||||
|
||||
func TestCreateToken_PermissionDenied(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
permsMgr := permissions.NewMockManager(ctrl)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(false, nil)
|
||||
|
||||
h := &handler{
|
||||
permissionsManager: permsMgr,
|
||||
}
|
||||
|
||||
body := `{"name": "test"}`
|
||||
req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body))
|
||||
req = req.WithContext(authContext("acc-123", "user-1"))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.createToken(w, req)
|
||||
assert.Equal(t, http.StatusForbidden, w.Code)
|
||||
}
|
||||
|
||||
func TestListTokens(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
accountID := "acc-123"
|
||||
now := time.Now()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().GetProxyAccessTokensByAccountID(gomock.Any(), store.LockingStrengthNone, accountID).Return([]*types.ProxyAccessToken{
|
||||
{ID: "tok-1", Name: "token-1", AccountID: &accountID, CreatedAt: now, Revoked: false},
|
||||
{ID: "tok-2", Name: "token-2", AccountID: &accountID, CreatedAt: now, Revoked: true},
|
||||
}, nil)
|
||||
|
||||
permsMgr := permissions.NewMockManager(ctrl)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Read).Return(true, nil)
|
||||
|
||||
h := &handler{
|
||||
store: mockStore,
|
||||
permissionsManager: permsMgr,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/reverse-proxies/proxy-tokens", nil)
|
||||
req = req.WithContext(authContext(accountID, "user-1"))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.listTokens(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var resp []api.ProxyToken
|
||||
require.NoError(t, json.NewDecoder(w.Body).Decode(&resp))
|
||||
require.Len(t, resp, 2)
|
||||
assert.Equal(t, "tok-1", resp[0].Id)
|
||||
assert.False(t, resp[0].Revoked)
|
||||
assert.Equal(t, "tok-2", resp[1].Id)
|
||||
assert.True(t, resp[1].Revoked)
|
||||
}
|
||||
|
||||
func TestRevokeToken_Success(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
accountID := "acc-123"
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().GetProxyAccessTokenByID(gomock.Any(), store.LockingStrengthNone, "tok-1").Return(&types.ProxyAccessToken{
|
||||
ID: "tok-1",
|
||||
Name: "test-token",
|
||||
AccountID: &accountID,
|
||||
}, nil)
|
||||
mockStore.EXPECT().RevokeProxyAccessToken(gomock.Any(), "tok-1").Return(nil)
|
||||
|
||||
permsMgr := permissions.NewMockManager(ctrl)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Delete).Return(true, nil)
|
||||
|
||||
h := &handler{
|
||||
store: mockStore,
|
||||
permissionsManager: permsMgr,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("DELETE", "/reverse-proxies/proxy-tokens/tok-1", nil)
|
||||
req = req.WithContext(authContext(accountID, "user-1"))
|
||||
req = mux.SetURLVars(req, map[string]string{"tokenId": "tok-1"})
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.revokeToken(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestRevokeToken_WrongAccount(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
otherAccount := "acc-other"
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().GetProxyAccessTokenByID(gomock.Any(), store.LockingStrengthNone, "tok-1").Return(&types.ProxyAccessToken{
|
||||
ID: "tok-1",
|
||||
AccountID: &otherAccount,
|
||||
}, nil)
|
||||
|
||||
permsMgr := permissions.NewMockManager(ctrl)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Delete).Return(true, nil)
|
||||
|
||||
h := &handler{
|
||||
store: mockStore,
|
||||
permissionsManager: permsMgr,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("DELETE", "/reverse-proxies/proxy-tokens/tok-1", nil)
|
||||
req = req.WithContext(authContext("acc-123", "user-1"))
|
||||
req = mux.SetURLVars(req, map[string]string{"tokenId": "tok-1"})
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.revokeToken(w, req)
|
||||
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||
}
|
||||
|
||||
func TestRevokeToken_ManagementWideToken(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().GetProxyAccessTokenByID(gomock.Any(), store.LockingStrengthNone, "tok-1").Return(&types.ProxyAccessToken{
|
||||
ID: "tok-1",
|
||||
AccountID: nil,
|
||||
}, nil)
|
||||
|
||||
permsMgr := permissions.NewMockManager(ctrl)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Delete).Return(true, nil)
|
||||
|
||||
h := &handler{
|
||||
store: mockStore,
|
||||
permissionsManager: permsMgr,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("DELETE", "/reverse-proxies/proxy-tokens/tok-1", nil)
|
||||
req = req.WithContext(authContext("acc-123", "user-1"))
|
||||
req = mux.SetURLVars(req, map[string]string{"tokenId": "tok-1"})
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.revokeToken(w, req)
|
||||
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||
}
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
|
||||
type Manager interface {
|
||||
GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error)
|
||||
DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error
|
||||
GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error)
|
||||
GetService(ctx context.Context, accountID, userID, serviceID string) (*Service, error)
|
||||
CreateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error)
|
||||
@@ -29,5 +28,4 @@ type Manager interface {
|
||||
RenewServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error
|
||||
StopServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error
|
||||
StartExposeReaper(ctx context.Context)
|
||||
GetServiceByDomain(ctx context.Context, domain string) (*Service, error)
|
||||
}
|
||||
|
||||
@@ -79,20 +79,6 @@ func (mr *MockManagerMockRecorder) DeleteAllServices(ctx, accountID, userID inte
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAllServices", reflect.TypeOf((*MockManager)(nil).DeleteAllServices), ctx, accountID, userID)
|
||||
}
|
||||
|
||||
// DeleteAccountCluster mocks base method.
|
||||
func (m *MockManager) DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteAccountCluster", ctx, accountID, userID, clusterAddress)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteAccountCluster indicates an expected call of DeleteAccountCluster.
|
||||
func (mr *MockManagerMockRecorder) DeleteAccountCluster(ctx, accountID, userID, clusterAddress interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAccountCluster", reflect.TypeOf((*MockManager)(nil).DeleteAccountCluster), ctx, accountID, userID, clusterAddress)
|
||||
}
|
||||
|
||||
// DeleteService mocks base method.
|
||||
func (m *MockManager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -152,21 +138,6 @@ func (mr *MockManagerMockRecorder) GetAllServices(ctx, accountID, userID interfa
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllServices", reflect.TypeOf((*MockManager)(nil).GetAllServices), ctx, accountID, userID)
|
||||
}
|
||||
|
||||
// GetServiceByDomain mocks base method.
|
||||
func (m *MockManager) GetServiceByDomain(ctx context.Context, domain string) (*Service, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, domain)
|
||||
ret0, _ := ret[0].(*Service)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetServiceByDomain indicates an expected call of GetServiceByDomain.
|
||||
func (mr *MockManagerMockRecorder) GetServiceByDomain(ctx, domain interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByDomain", reflect.TypeOf((*MockManager)(nil).GetServiceByDomain), ctx, domain)
|
||||
}
|
||||
|
||||
// GetGlobalServices mocks base method.
|
||||
func (m *MockManager) GetGlobalServices(ctx context.Context) ([]*Service, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -35,7 +35,6 @@ func RegisterEndpoints(manager rpservice.Manager, domainManager domainmanager.Ma
|
||||
accesslogsmanager.RegisterEndpoints(router, accessLogsManager)
|
||||
|
||||
router.HandleFunc("/reverse-proxies/clusters", h.getClusters).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/reverse-proxies/clusters/{clusterAddress}", h.deleteCluster).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/reverse-proxies/services", h.getAllServices).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/reverse-proxies/services", h.createService).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/reverse-proxies/services/{serviceId}", h.getService).Methods("GET", "OPTIONS")
|
||||
@@ -196,33 +195,10 @@ func (h *handler) getClusters(w http.ResponseWriter, r *http.Request) {
|
||||
apiClusters := make([]api.ProxyCluster, 0, len(clusters))
|
||||
for _, c := range clusters {
|
||||
apiClusters = append(apiClusters, api.ProxyCluster{
|
||||
Id: c.ID,
|
||||
Address: c.Address,
|
||||
ConnectedProxies: c.ConnectedProxies,
|
||||
SelfHosted: c.SelfHosted,
|
||||
})
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, apiClusters)
|
||||
}
|
||||
|
||||
func (h *handler) deleteCluster(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
clusterAddress := mux.Vars(r)["clusterAddress"]
|
||||
if clusterAddress == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "cluster address is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.manager.DeleteAccountCluster(r.Context(), userAuth.AccountId, userAuth.UserId, clusterAddress); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
@@ -122,21 +122,7 @@ func (m *Manager) GetActiveClusters(ctx context.Context, accountID, userID strin
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return m.store.GetActiveProxyClusters(ctx, accountID)
|
||||
}
|
||||
|
||||
// DeleteAccountCluster removes all proxy registrations for the given cluster address
|
||||
// owned by the account.
|
||||
func (m *Manager) DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return m.store.DeleteAccountCluster(ctx, clusterAddress, accountID)
|
||||
return m.store.GetActiveProxyClusters(ctx)
|
||||
}
|
||||
|
||||
func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) {
|
||||
@@ -1000,10 +986,6 @@ func (m *Manager) GetAccountServices(ctx context.Context, accountID string) ([]*
|
||||
return services, nil
|
||||
}
|
||||
|
||||
func (m *Manager) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) {
|
||||
return m.store.GetServiceByDomain(ctx, domain)
|
||||
}
|
||||
|
||||
func (m *Manager) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) {
|
||||
target, err := m.store.GetServiceTargetByTargetID(ctx, store.LockingStrengthNone, accountID, resourceID)
|
||||
if err != nil {
|
||||
|
||||
@@ -434,7 +434,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
|
||||
t.Helper()
|
||||
tokenStore := nbgrpc.NewOneTimeTokenStore(context.Background(), testCacheStore(t))
|
||||
pkceStore := nbgrpc.NewPKCEVerifierStore(context.Background(), testCacheStore(t))
|
||||
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
|
||||
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
|
||||
return srv
|
||||
}
|
||||
|
||||
@@ -714,7 +714,7 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
|
||||
|
||||
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
|
||||
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t))
|
||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
|
||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
|
||||
|
||||
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
|
||||
require.NoError(t, err)
|
||||
@@ -1138,7 +1138,7 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
|
||||
|
||||
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
|
||||
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t))
|
||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
|
||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
|
||||
|
||||
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -193,7 +193,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
|
||||
|
||||
func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
|
||||
return Create(s, func() *nbgrpc.ProxyServiceServer {
|
||||
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager(), s.Store())
|
||||
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager())
|
||||
s.AfterInit(func(s *BaseServer) {
|
||||
proxyService.SetServiceManager(s.ServiceManager())
|
||||
proxyService.SetProxyController(s.ServiceProxyController())
|
||||
|
||||
@@ -26,7 +26,6 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/networks"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources"
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
@@ -114,47 +113,30 @@ func (s *BaseServer) AccountManager() account.Manager {
|
||||
})
|
||||
}
|
||||
|
||||
func isMFAEnabledForAccount(accounts []*types.Account) bool {
|
||||
if len(accounts) != 1 {
|
||||
return false
|
||||
}
|
||||
|
||||
settings := accounts[0].Settings
|
||||
return settings != nil && settings.LocalMfaEnabled
|
||||
}
|
||||
|
||||
func (s *BaseServer) IdpManager() idp.Manager {
|
||||
return Create(s, func() idp.Manager {
|
||||
var idpManager idp.Manager
|
||||
var err error
|
||||
|
||||
// Use embedded IdP service if embedded Dex is configured and enabled.
|
||||
// Legacy IdpManager won't be used anymore even if configured.
|
||||
embeddedEnabled := s.Config.EmbeddedIdP != nil && s.Config.EmbeddedIdP.Enabled
|
||||
if embeddedEnabled {
|
||||
embeddedMgr, err := idp.NewEmbeddedIdPManager(context.Background(), s.Config.EmbeddedIdP, s.Metrics())
|
||||
idpManager, err = idp.NewEmbeddedIdPManager(context.Background(), s.Config.EmbeddedIdP, s.Metrics())
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create embedded IDP service: %v", err)
|
||||
}
|
||||
|
||||
if val := isMFAEnabledForAccount(s.Store().GetAllAccounts(context.Background())); val {
|
||||
if err := embeddedMgr.SetMFAEnabled(context.Background(), val); err != nil {
|
||||
log.Errorf("failed to set MFA enabled on embedded IDP: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return embeddedMgr
|
||||
return idpManager
|
||||
}
|
||||
|
||||
// Fall back to external IdP service
|
||||
if s.Config.IdpManagerConfig != nil {
|
||||
idpManager, err := idp.NewManager(context.Background(), *s.Config.IdpManagerConfig, s.Metrics())
|
||||
idpManager, err = idp.NewManager(context.Background(), *s.Config.IdpManagerConfig, s.Metrics())
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create IDP service: %v", err)
|
||||
}
|
||||
|
||||
return idpManager
|
||||
}
|
||||
|
||||
|
||||
return nil
|
||||
return idpManager
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -34,8 +34,6 @@ const (
|
||||
ManagementLegacyPort = 33073
|
||||
// DefaultSelfHostedDomain is the default domain used for self-hosted fresh installs.
|
||||
DefaultSelfHostedDomain = "netbird.selfhosted"
|
||||
|
||||
ContainerKeyBaseServer = "baseServer"
|
||||
)
|
||||
|
||||
type Server interface {
|
||||
@@ -93,7 +91,7 @@ type Config struct {
|
||||
|
||||
// NewServer initializes and configures a new Server instance
|
||||
func NewServer(cfg *Config) *BaseServer {
|
||||
s := &BaseServer{
|
||||
return &BaseServer{
|
||||
Config: cfg.NbConfig,
|
||||
container: make(map[string]any),
|
||||
dnsDomain: cfg.DNSDomain,
|
||||
@@ -106,9 +104,6 @@ func NewServer(cfg *Config) *BaseServer {
|
||||
mgmtMetricsPort: cfg.MgmtMetricsPort,
|
||||
autoResolveDomains: cfg.AutoResolveDomains,
|
||||
}
|
||||
s.container[ContainerKeyBaseServer] = s
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *BaseServer) AfterInit(fn func(s *BaseServer)) {
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
@@ -51,11 +50,6 @@ type ProxyOIDCConfig struct {
|
||||
KeysLocation string
|
||||
}
|
||||
|
||||
// ProxyTokenChecker checks whether a proxy access token is still valid.
|
||||
type ProxyTokenChecker interface {
|
||||
IsProxyAccessTokenValid(ctx context.Context, tokenID string) (bool, error)
|
||||
}
|
||||
|
||||
// ProxyServiceServer implements the ProxyService gRPC server
|
||||
type ProxyServiceServer struct {
|
||||
proto.UnimplementedProxyServiceServer
|
||||
@@ -84,9 +78,6 @@ type ProxyServiceServer struct {
|
||||
// Store for one-time authentication tokens
|
||||
tokenStore *OneTimeTokenStore
|
||||
|
||||
// Checker for proxy access token validity
|
||||
tokenChecker ProxyTokenChecker
|
||||
|
||||
// OIDC configuration for proxy authentication
|
||||
oidcConfig ProxyOIDCConfig
|
||||
|
||||
@@ -132,8 +123,6 @@ type proxyConnection struct {
|
||||
proxyID string
|
||||
sessionID string
|
||||
address string
|
||||
accountID *string
|
||||
tokenID string
|
||||
capabilities *proto.ProxyCapabilities
|
||||
stream proto.ProxyService_GetMappingUpdateServer
|
||||
sendChan chan *proto.GetMappingUpdateResponse
|
||||
@@ -141,19 +130,8 @@ type proxyConnection struct {
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func enforceAccountScope(ctx context.Context, requestAccountID string) error {
|
||||
token := GetProxyTokenFromContext(ctx)
|
||||
if token == nil || token.AccountID == nil {
|
||||
return nil
|
||||
}
|
||||
if requestAccountID == "" || *token.AccountID != requestAccountID {
|
||||
return status.Errorf(codes.PermissionDenied, "account-scoped token cannot access account %s", requestAccountID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewProxyServiceServer creates a new proxy service server.
|
||||
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager, tokenChecker ProxyTokenChecker) *ProxyServiceServer {
|
||||
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
s := &ProxyServiceServer{
|
||||
accessLogManager: accessLogMgr,
|
||||
@@ -163,7 +141,6 @@ func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeT
|
||||
peersManager: peersManager,
|
||||
usersManager: usersManager,
|
||||
proxyManager: proxyMgr,
|
||||
tokenChecker: tokenChecker,
|
||||
snapshotBatchSize: snapshotBatchSizeFromEnv(),
|
||||
cancel: cancel,
|
||||
}
|
||||
@@ -223,25 +200,6 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
return status.Errorf(codes.InvalidArgument, "proxy address is invalid")
|
||||
}
|
||||
|
||||
var accountID *string
|
||||
token := GetProxyTokenFromContext(ctx)
|
||||
if token != nil && token.AccountID != nil {
|
||||
accountID = token.AccountID
|
||||
|
||||
available, err := s.proxyManager.IsClusterAddressAvailable(ctx, proxyAddress, *accountID)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Internal, "check cluster address: %v", err)
|
||||
}
|
||||
if !available {
|
||||
return status.Errorf(codes.AlreadyExists, "cluster address %s is already in use", proxyAddress)
|
||||
}
|
||||
}
|
||||
|
||||
var tokenID string
|
||||
if token != nil {
|
||||
tokenID = token.ID
|
||||
}
|
||||
|
||||
sessionID := uuid.NewString()
|
||||
|
||||
if old, loaded := s.connectedProxies.Load(proxyID); loaded {
|
||||
@@ -259,8 +217,6 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
proxyID: proxyID,
|
||||
sessionID: sessionID,
|
||||
address: proxyAddress,
|
||||
accountID: accountID,
|
||||
tokenID: tokenID,
|
||||
capabilities: req.GetCapabilities(),
|
||||
stream: stream,
|
||||
sendChan: make(chan *proto.GetMappingUpdateResponse, 100),
|
||||
@@ -268,6 +224,7 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
// Register proxy in database with capabilities
|
||||
var caps *proxy.Capabilities
|
||||
if c := req.GetCapabilities(); c != nil {
|
||||
caps = &proxy.Capabilities{
|
||||
@@ -276,13 +233,10 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
SupportsCrowdsec: c.SupportsCrowdsec,
|
||||
}
|
||||
}
|
||||
proxyRecord, err := s.proxyManager.Connect(ctx, proxyID, sessionID, proxyAddress, peerInfo, accountID, caps)
|
||||
proxyRecord, err := s.proxyManager.Connect(ctx, proxyID, sessionID, proxyAddress, peerInfo, caps)
|
||||
if err != nil {
|
||||
cancel()
|
||||
if accountID != nil {
|
||||
return status.Errorf(codes.Internal, "failed to register BYOP proxy: %v", err)
|
||||
}
|
||||
log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err)
|
||||
cancel()
|
||||
return status.Errorf(codes.Internal, "register proxy in database: %v", err)
|
||||
}
|
||||
|
||||
@@ -312,7 +266,6 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
"session_id": sessionID,
|
||||
"address": proxyAddress,
|
||||
"cluster_addr": proxyAddress,
|
||||
"account_id": accountID,
|
||||
"total_proxies": len(s.GetConnectedProxies()),
|
||||
}).Info("Proxy registered in cluster")
|
||||
defer func() {
|
||||
@@ -333,7 +286,7 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
log.Infof("Proxy %s session %s disconnected", proxyID, sessionID)
|
||||
}()
|
||||
|
||||
go s.heartbeat(connCtx, conn, proxyRecord)
|
||||
go s.heartbeat(connCtx, proxyRecord)
|
||||
|
||||
select {
|
||||
case err := <-errChan:
|
||||
@@ -345,9 +298,8 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
}
|
||||
}
|
||||
|
||||
// heartbeat updates the proxy's last_seen timestamp every minute and
|
||||
// disconnects the proxy if its access token has been revoked.
|
||||
func (s *ProxyServiceServer) heartbeat(ctx context.Context, conn *proxyConnection, p *proxy.Proxy) {
|
||||
// heartbeat updates the proxy's last_seen timestamp every minute
|
||||
func (s *ProxyServiceServer) heartbeat(ctx context.Context, p *proxy.Proxy) {
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
@@ -357,19 +309,6 @@ func (s *ProxyServiceServer) heartbeat(ctx context.Context, conn *proxyConnectio
|
||||
if err := s.proxyManager.Heartbeat(ctx, p); err != nil {
|
||||
log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", p.ID, err)
|
||||
}
|
||||
|
||||
if conn.tokenID != "" && s.tokenChecker != nil {
|
||||
valid, err := s.tokenChecker.IsProxyAccessTokenValid(ctx, conn.tokenID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to check token validity for proxy %s: %v", conn.proxyID, err)
|
||||
continue
|
||||
}
|
||||
if !valid {
|
||||
log.WithContext(ctx).Warnf("proxy %s token revoked or expired, disconnecting", conn.proxyID)
|
||||
conn.cancel()
|
||||
return
|
||||
}
|
||||
}
|
||||
case <-ctx.Done():
|
||||
log.WithContext(ctx).Infof("proxy %s heartbeat stopped: context canceled", p.ID)
|
||||
return
|
||||
@@ -377,6 +316,8 @@ func (s *ProxyServiceServer) heartbeat(ctx context.Context, conn *proxyConnectio
|
||||
}
|
||||
}
|
||||
|
||||
// sendSnapshot sends the initial snapshot of services to the connecting proxy.
|
||||
// Only entries matching the proxy's cluster address are sent.
|
||||
func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnection) error {
|
||||
if !isProxyAddressValid(conn.address) {
|
||||
return fmt.Errorf("proxy address is invalid")
|
||||
@@ -414,13 +355,7 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *proxyConnection) ([]*proto.ProxyMapping, error) {
|
||||
var services []*rpservice.Service
|
||||
var err error
|
||||
if conn.accountID != nil {
|
||||
services, err = s.serviceManager.GetAccountServices(ctx, *conn.accountID)
|
||||
} else {
|
||||
services, err = s.serviceManager.GetGlobalServices(ctx)
|
||||
}
|
||||
services, err := s.serviceManager.GetGlobalServices(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get services from store: %w", err)
|
||||
}
|
||||
@@ -445,14 +380,8 @@ func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *
|
||||
return mappings, nil
|
||||
}
|
||||
|
||||
// isProxyAddressValid validates a proxy address (domain name or IP address)
|
||||
// isProxyAddressValid validates a proxy address
|
||||
func isProxyAddressValid(addr string) bool {
|
||||
if addr == "" {
|
||||
return false
|
||||
}
|
||||
if net.ParseIP(addr) != nil {
|
||||
return true
|
||||
}
|
||||
_, err := domain.ValidateDomains([]string{addr})
|
||||
return err == nil
|
||||
}
|
||||
@@ -476,10 +405,6 @@ func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error)
|
||||
func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendAccessLogRequest) (*proto.SendAccessLogResponse, error) {
|
||||
accessLog := req.GetLog()
|
||||
|
||||
if err := enforceAccountScope(ctx, accessLog.GetAccountId()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fields := log.Fields{
|
||||
"service_id": accessLog.GetServiceId(),
|
||||
"account_id": accessLog.GetAccountId(),
|
||||
@@ -517,32 +442,11 @@ func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendA
|
||||
// Management should call this when services are created/updated/removed.
|
||||
// For create/update operations a unique one-time auth token is generated per
|
||||
// proxy so that every replica can independently authenticate with management.
|
||||
// BYOP proxies only receive updates for their own account's services.
|
||||
func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateResponse) {
|
||||
log.Debugf("Broadcasting service update to all connected proxy servers")
|
||||
updateAccountIDs := make(map[string]struct{})
|
||||
for _, m := range update.Mapping {
|
||||
if m.AccountId != "" {
|
||||
updateAccountIDs[m.AccountId] = struct{}{}
|
||||
}
|
||||
}
|
||||
s.connectedProxies.Range(func(key, value interface{}) bool {
|
||||
conn := value.(*proxyConnection)
|
||||
connUpdate := update
|
||||
if conn.accountID != nil && len(updateAccountIDs) > 0 {
|
||||
if _, ok := updateAccountIDs[*conn.accountID]; !ok {
|
||||
return true
|
||||
}
|
||||
filtered := filterMappingsForAccount(update.Mapping, *conn.accountID)
|
||||
if len(filtered) == 0 {
|
||||
return true
|
||||
}
|
||||
connUpdate = &proto.GetMappingUpdateResponse{
|
||||
Mapping: filtered,
|
||||
InitialSyncComplete: update.InitialSyncComplete,
|
||||
}
|
||||
}
|
||||
resp := s.perProxyMessage(connUpdate, conn.proxyID)
|
||||
resp := s.perProxyMessage(update, conn.proxyID)
|
||||
if resp == nil {
|
||||
log.Warnf("Token generation failed for proxy %s, disconnecting to force resync", conn.proxyID)
|
||||
conn.cancel()
|
||||
@@ -559,26 +463,6 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateRes
|
||||
})
|
||||
}
|
||||
|
||||
// ForceDisconnect cancels the gRPC stream for a connected proxy, causing it to disconnect.
|
||||
func (s *ProxyServiceServer) ForceDisconnect(proxyID string) {
|
||||
if connVal, ok := s.connectedProxies.Load(proxyID); ok {
|
||||
conn := connVal.(*proxyConnection)
|
||||
conn.cancel()
|
||||
s.connectedProxies.Delete(proxyID)
|
||||
log.WithFields(log.Fields{"proxyID": proxyID}).Info("force disconnected proxy")
|
||||
}
|
||||
}
|
||||
|
||||
func filterMappingsForAccount(mappings []*proto.ProxyMapping, accountID string) []*proto.ProxyMapping {
|
||||
var filtered []*proto.ProxyMapping
|
||||
for _, m := range mappings {
|
||||
if m.AccountId == accountID {
|
||||
filtered = append(filtered, m)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
// GetConnectedProxies returns a list of connected proxy IDs
|
||||
func (s *ProxyServiceServer) GetConnectedProxies() []string {
|
||||
var proxies []string
|
||||
@@ -647,9 +531,6 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd
|
||||
continue
|
||||
}
|
||||
conn := connVal.(*proxyConnection)
|
||||
if conn.accountID != nil && update.AccountId != "" && *conn.accountID != update.AccountId {
|
||||
continue
|
||||
}
|
||||
if !proxyAcceptsMapping(conn, update) {
|
||||
log.WithContext(ctx).Debugf("Skipping proxy %s: does not support custom ports for mapping %s", proxyID, update.Id)
|
||||
continue
|
||||
@@ -737,10 +618,6 @@ func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping {
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
|
||||
if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
service, err := s.serviceManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId())
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to get service from store: %v", err)
|
||||
@@ -860,10 +737,6 @@ func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authentic
|
||||
|
||||
// SendStatusUpdate handles status updates from proxy clients.
|
||||
func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.SendStatusUpdateRequest) (*proto.SendStatusUpdateResponse, error) {
|
||||
if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
accountID := req.GetAccountId()
|
||||
serviceID := req.GetServiceId()
|
||||
protoStatus := req.GetStatus()
|
||||
@@ -934,10 +807,6 @@ func protoStatusToInternal(protoStatus proto.ProxyStatus) rpservice.Status {
|
||||
|
||||
// CreateProxyPeer handles proxy peer creation with one-time token authentication
|
||||
func (s *ProxyServiceServer) CreateProxyPeer(ctx context.Context, req *proto.CreateProxyPeerRequest) (*proto.CreateProxyPeerResponse, error) {
|
||||
if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
serviceID := req.GetServiceId()
|
||||
accountID := req.GetAccountId()
|
||||
token := req.GetToken()
|
||||
@@ -992,10 +861,6 @@ func strPtr(s string) *string {
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCURLRequest) (*proto.GetOIDCURLResponse, error) {
|
||||
if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
redirectURL, err := url.Parse(req.GetRedirectUrl())
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "parse redirect url: %v", err)
|
||||
@@ -1124,9 +989,21 @@ func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL
|
||||
|
||||
// GenerateSessionToken creates a signed session JWT for the given domain and user.
|
||||
func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, userID string, method proxyauth.Method) (string, error) {
|
||||
service, err := s.getServiceByDomain(ctx, domain)
|
||||
// Find the service by domain to get its signing key
|
||||
services, err := s.serviceManager.GetGlobalServices(ctx)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("service not found for domain %s: %w", domain, err)
|
||||
return "", fmt.Errorf("get services: %w", err)
|
||||
}
|
||||
|
||||
var service *rpservice.Service
|
||||
for _, svc := range services {
|
||||
if svc.Domain == domain {
|
||||
service = svc
|
||||
break
|
||||
}
|
||||
}
|
||||
if service == nil {
|
||||
return "", fmt.Errorf("service not found for domain: %s", domain)
|
||||
}
|
||||
|
||||
if service.SessionPrivateKey == "" {
|
||||
@@ -1224,10 +1101,6 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
|
||||
}, nil
|
||||
}
|
||||
|
||||
if err := enforceAccountScope(ctx, service.AccountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pubKeyBytes, err := base64.StdEncoding.DecodeString(service.SessionPublicKey)
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{
|
||||
@@ -1311,7 +1184,18 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain string) (*rpservice.Service, error) {
|
||||
return s.serviceManager.GetServiceByDomain(ctx, domain)
|
||||
services, err := s.serviceManager.GetGlobalServices(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get services: %w", err)
|
||||
}
|
||||
|
||||
for _, service := range services {
|
||||
if service.Domain == domain {
|
||||
return service, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("service not found for domain: %s", domain)
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) checkGroupAccess(service *rpservice.Service, user *types.User) error {
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestIsProxyAddressValid(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
addr string
|
||||
valid bool
|
||||
}{
|
||||
{name: "valid domain", addr: "eu.proxy.netbird.io", valid: true},
|
||||
{name: "valid subdomain", addr: "byop.proxy.example.com", valid: true},
|
||||
{name: "valid IPv4", addr: "10.0.0.1", valid: true},
|
||||
{name: "valid IPv4 public", addr: "203.0.113.10", valid: true},
|
||||
{name: "valid IPv6", addr: "::1", valid: true},
|
||||
{name: "valid IPv6 full", addr: "2001:db8::1", valid: true},
|
||||
{name: "empty string", addr: "", valid: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.valid, isProxyAddressValid(tt.addr))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -153,6 +153,9 @@ func (i *proxyAuthInterceptor) doValidateProxyToken(ctx context.Context) (*types
|
||||
return nil, status.Errorf(codes.Unauthenticated, "invalid token")
|
||||
}
|
||||
|
||||
// TODO: Enforce AccountID scope for "bring your own proxy" feature.
|
||||
// Currently tokens are management-wide; AccountID field is reserved for future use.
|
||||
|
||||
if !token.IsValid() {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "token expired or revoked")
|
||||
}
|
||||
|
||||
@@ -53,10 +53,6 @@ func (m *mockReverseProxyManager) DeleteService(ctx context.Context, accountID,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) DeleteAccountCluster(_ context.Context, _, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) SetCertificateIssuedAt(ctx context.Context, accountID, reverseProxyID string) error {
|
||||
return nil
|
||||
}
|
||||
@@ -95,20 +91,6 @@ func (m *mockReverseProxyManager) StopServiceFromPeer(_ context.Context, _, _, _
|
||||
|
||||
func (m *mockReverseProxyManager) StartExposeReaper(_ context.Context) {}
|
||||
|
||||
func (m *mockReverseProxyManager) GetServiceByDomain(_ context.Context, domain string) (*service.Service, error) {
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
for _, services := range m.proxiesByAccount {
|
||||
for _, svc := range services {
|
||||
if svc.Domain == domain {
|
||||
return svc, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, errors.New("service not found for domain: " + domain)
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) GetActiveClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -12,12 +12,9 @@ import (
|
||||
cachestore "github.com/eko/gocache/lib/v4/store"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc/codes"
|
||||
grpcstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
@@ -319,58 +316,6 @@ func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) {
|
||||
assert.Contains(t, err.Error(), "invalid state format")
|
||||
}
|
||||
|
||||
func scopedCtx(accountID string) context.Context {
|
||||
token := &types.ProxyAccessToken{
|
||||
ID: "token-1",
|
||||
AccountID: &accountID,
|
||||
}
|
||||
return context.WithValue(context.Background(), ProxyTokenContextKey, token)
|
||||
}
|
||||
|
||||
func globalCtx() context.Context {
|
||||
token := &types.ProxyAccessToken{
|
||||
ID: "token-global",
|
||||
}
|
||||
return context.WithValue(context.Background(), ProxyTokenContextKey, token)
|
||||
}
|
||||
|
||||
func TestEnforceAccountScope_AllowsMatchingAccount(t *testing.T) {
|
||||
err := enforceAccountScope(scopedCtx("acc-1"), "acc-1")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestEnforceAccountScope_BlocksMismatchedAccount(t *testing.T) {
|
||||
err := enforceAccountScope(scopedCtx("acc-1"), "acc-2")
|
||||
require.Error(t, err)
|
||||
st, ok := grpcstatus.FromError(err)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, codes.PermissionDenied, st.Code())
|
||||
}
|
||||
|
||||
func TestEnforceAccountScope_BlocksEmptyRequestAccountID(t *testing.T) {
|
||||
err := enforceAccountScope(scopedCtx("acc-1"), "")
|
||||
require.Error(t, err)
|
||||
st, ok := grpcstatus.FromError(err)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, codes.PermissionDenied, st.Code())
|
||||
}
|
||||
|
||||
func TestEnforceAccountScope_AllowsGlobalToken(t *testing.T) {
|
||||
err := enforceAccountScope(globalCtx(), "acc-1")
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = enforceAccountScope(globalCtx(), "acc-2")
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = enforceAccountScope(globalCtx(), "")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestEnforceAccountScope_AllowsNoTokenInContext(t *testing.T) {
|
||||
err := enforceAccountScope(context.Background(), "acc-1")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestValidateState_RejectsInvalidHMAC(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t))
|
||||
|
||||
@@ -42,7 +42,7 @@ func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup {
|
||||
tokenStore := NewOneTimeTokenStore(ctx, testCacheStore(t))
|
||||
pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t))
|
||||
|
||||
proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager, nil)
|
||||
proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager)
|
||||
proxyService.SetServiceManager(serviceManager)
|
||||
|
||||
createTestProxies(t, ctx, testStore)
|
||||
@@ -318,17 +318,13 @@ func (m *testValidateSessionServiceManager) StopServiceFromPeer(_ context.Contex
|
||||
|
||||
func (m *testValidateSessionServiceManager) StartExposeReaper(_ context.Context) {}
|
||||
|
||||
func (m *testValidateSessionServiceManager) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) {
|
||||
return m.store.GetServiceByDomain(ctx, domain)
|
||||
}
|
||||
|
||||
func (m *testValidateSessionServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type testValidateSessionProxyManager struct{}
|
||||
|
||||
func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string, _ *string, _ *proxy.Capabilities) error {
|
||||
func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string, _ *proxy.Capabilities) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -344,10 +340,6 @@ func (m *testValidateSessionProxyManager) GetActiveClusterAddresses(_ context.Co
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) GetActiveClusterAddressesForAccount(_ context.Context, _ string) ([]string, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) GetActiveClusters(_ context.Context) ([]proxy.Cluster, error) {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -356,22 +348,6 @@ func (m *testValidateSessionProxyManager) CleanupStale(_ context.Context, _ time
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) GetAccountProxy(_ context.Context, _ string) (*proxy.Proxy, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) CountAccountProxies(_ context.Context, _ string) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) IsClusterAddressAvailable(_ context.Context, _, _ string) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) DeleteProxy(_ context.Context, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) ClusterSupportsCustomPorts(_ context.Context, _ string) *bool {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -386,9 +386,6 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
if err = am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = am.handleLocalMfaSettings(ctx, oldSettings, newSettings, userID, accountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if oldSettings.DNSDomain != newSettings.DNSDomain {
|
||||
eventMeta := map[string]any{
|
||||
"old_dns_domain": oldSettings.DNSDomain,
|
||||
@@ -605,29 +602,6 @@ func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handleLocalMfaSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) error {
|
||||
if oldSettings.LocalMfaEnabled == newSettings.LocalMfaEnabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
embeddedIdp, ok := am.idpManager.(*idp.EmbeddedIdPManager)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := embeddedIdp.SetMFAEnabled(ctx, newSettings.LocalMfaEnabled); err != nil {
|
||||
return fmt.Errorf("failed to toggle MFA: %w", err)
|
||||
}
|
||||
|
||||
if newSettings.LocalMfaEnabled {
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountLocalMfaEnabled, nil)
|
||||
} else {
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountLocalMfaDisabled, nil)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) {
|
||||
return func() (time.Duration, bool) {
|
||||
//nolint
|
||||
|
||||
@@ -3113,7 +3113,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager, nil)
|
||||
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager)
|
||||
proxyController, err := proxymanager.NewGRPCController(proxyGrpcServer, noop.Meter{})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
|
||||
@@ -236,11 +236,6 @@ const (
|
||||
// AccountIPv6Disabled indicates that a user disabled IPv6 overlay for the account
|
||||
AccountIPv6Disabled Activity = 122
|
||||
|
||||
// AccountLocalMfaEnabled indicates that a user enabled TOTP MFA for local users
|
||||
AccountLocalMfaEnabled Activity = 123
|
||||
// AccountLocalMfaDisabled indicates that a user disabled TOTP MFA for local users
|
||||
AccountLocalMfaDisabled Activity = 124
|
||||
|
||||
AccountDeleted Activity = 99999
|
||||
)
|
||||
|
||||
@@ -391,9 +386,6 @@ var activityMap = map[Activity]Code{
|
||||
AccountPeerExposeEnabled: {"Account peer expose enabled", "account.setting.peer.expose.enable"},
|
||||
AccountPeerExposeDisabled: {"Account peer expose disabled", "account.setting.peer.expose.disable"},
|
||||
|
||||
AccountLocalMfaEnabled: {"Account local MFA enabled", "account.setting.local.mfa.enable"},
|
||||
AccountLocalMfaDisabled: {"Account local MFA disabled", "account.setting.local.mfa.disable"},
|
||||
|
||||
DomainAdded: {"Domain added", "domain.add"},
|
||||
DomainDeleted: {"Domain deleted", "domain.delete"},
|
||||
DomainValidated: {"Domain validated", "domain.validate"},
|
||||
|
||||
@@ -16,7 +16,6 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxytoken"
|
||||
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
|
||||
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
@@ -145,9 +144,6 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
|
||||
if serviceManager != nil && reverseProxyDomainManager != nil {
|
||||
reverseproxymanager.RegisterEndpoints(serviceManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, permissionsManager, router)
|
||||
}
|
||||
|
||||
proxytoken.RegisterEndpoints(accountManager.GetStore(), permissionsManager, router)
|
||||
|
||||
// Register OAuth callback handler for proxy authentication
|
||||
if proxyGRPCServer != nil {
|
||||
oauthHandler := proxy.NewAuthCallbackHandler(proxyGRPCServer, trustedHTTPProxies)
|
||||
|
||||
@@ -277,9 +277,6 @@ func (h *handler) updateAccountRequestSettings(req api.PutApiAccountsAccountIdJS
|
||||
if req.Settings.AutoUpdateAlways != nil {
|
||||
returnSettings.AutoUpdateAlways = *req.Settings.AutoUpdateAlways
|
||||
}
|
||||
if req.Settings.LocalMfaEnabled != nil {
|
||||
returnSettings.LocalMfaEnabled = *req.Settings.LocalMfaEnabled
|
||||
}
|
||||
if req.Settings.Ipv6EnabledGroups != nil {
|
||||
returnSettings.IPv6EnabledGroups = *req.Settings.Ipv6EnabledGroups
|
||||
}
|
||||
@@ -415,7 +412,6 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A
|
||||
Ipv6EnabledGroups: &settings.IPv6EnabledGroups,
|
||||
EmbeddedIdpEnabled: &settings.EmbeddedIdpEnabled,
|
||||
LocalAuthDisabled: &settings.LocalAuthDisabled,
|
||||
LocalMfaEnabled: &settings.LocalMfaEnabled,
|
||||
}
|
||||
|
||||
if settings.NetworkRange.IsValid() {
|
||||
|
||||
@@ -131,7 +131,6 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
AutoUpdateVersion: sr(""),
|
||||
EmbeddedIdpEnabled: br(false),
|
||||
LocalAuthDisabled: br(false),
|
||||
LocalMfaEnabled: br(false),
|
||||
},
|
||||
expectedArray: true,
|
||||
expectedID: accountID,
|
||||
@@ -158,7 +157,6 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
AutoUpdateVersion: sr(""),
|
||||
EmbeddedIdpEnabled: br(false),
|
||||
LocalAuthDisabled: br(false),
|
||||
LocalMfaEnabled: br(false),
|
||||
},
|
||||
expectedArray: false,
|
||||
expectedID: accountID,
|
||||
@@ -185,7 +183,6 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
AutoUpdateVersion: sr("latest"),
|
||||
EmbeddedIdpEnabled: br(false),
|
||||
LocalAuthDisabled: br(false),
|
||||
LocalMfaEnabled: br(false),
|
||||
},
|
||||
expectedArray: false,
|
||||
expectedID: accountID,
|
||||
@@ -212,7 +209,6 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
AutoUpdateVersion: sr(""),
|
||||
EmbeddedIdpEnabled: br(false),
|
||||
LocalAuthDisabled: br(false),
|
||||
LocalMfaEnabled: br(false),
|
||||
},
|
||||
expectedArray: false,
|
||||
expectedID: accountID,
|
||||
@@ -239,7 +235,6 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
AutoUpdateVersion: sr(""),
|
||||
EmbeddedIdpEnabled: br(false),
|
||||
LocalAuthDisabled: br(false),
|
||||
LocalMfaEnabled: br(false),
|
||||
},
|
||||
expectedArray: false,
|
||||
expectedID: accountID,
|
||||
@@ -266,7 +261,6 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
AutoUpdateVersion: sr(""),
|
||||
EmbeddedIdpEnabled: br(false),
|
||||
LocalAuthDisabled: br(false),
|
||||
LocalMfaEnabled: br(false),
|
||||
},
|
||||
expectedArray: false,
|
||||
expectedID: accountID,
|
||||
|
||||
@@ -216,7 +216,6 @@ func setupAuthCallbackTest(t *testing.T) *testSetup {
|
||||
nil,
|
||||
usersManager,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
|
||||
proxyService.SetServiceManager(&testServiceManager{store: testStore})
|
||||
@@ -390,10 +389,6 @@ func (m *testServiceManager) DeleteService(_ context.Context, _, _, _ string) er
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testServiceManager) DeleteAccountCluster(_ context.Context, _, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testServiceManager) SetCertificateIssuedAt(_ context.Context, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
@@ -440,10 +435,6 @@ func (m *testServiceManager) StopServiceFromPeer(_ context.Context, _, _, _ stri
|
||||
|
||||
func (m *testServiceManager) StartExposeReaper(_ context.Context) {}
|
||||
|
||||
func (m *testServiceManager) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) {
|
||||
return m.store.GetServiceByDomain(ctx, domain)
|
||||
}
|
||||
|
||||
func (m *testServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]nbproxy.Cluster, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -109,7 +109,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create proxy manager: %v", err)
|
||||
}
|
||||
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr, nil)
|
||||
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr)
|
||||
domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am)
|
||||
serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter)
|
||||
if err != nil {
|
||||
@@ -238,7 +238,7 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create proxy manager: %v", err)
|
||||
}
|
||||
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr, nil)
|
||||
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr)
|
||||
domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am)
|
||||
serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter)
|
||||
if err != nil {
|
||||
|
||||
@@ -2,11 +2,9 @@ package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
@@ -19,13 +17,12 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
staticClientDashboard = "netbird-dashboard"
|
||||
staticClientCLI = "netbird-cli"
|
||||
defaultCLIRedirectURL1 = "http://localhost:53000/"
|
||||
defaultCLIRedirectURL2 = "http://localhost:54000/"
|
||||
defaultScopes = "openid profile email groups"
|
||||
defaultUserIDClaim = "sub"
|
||||
sessionCookieEncryptionKeyEnv = "NB_IDP_SESSION_COOKIE_ENCRYPTION_KEY"
|
||||
staticClientDashboard = "netbird-dashboard"
|
||||
staticClientCLI = "netbird-cli"
|
||||
defaultCLIRedirectURL1 = "http://localhost:53000/"
|
||||
defaultCLIRedirectURL2 = "http://localhost:54000/"
|
||||
defaultScopes = "openid profile email groups"
|
||||
defaultUserIDClaim = "sub"
|
||||
)
|
||||
|
||||
// EmbeddedIdPConfig contains configuration for the embedded Dex OIDC identity provider
|
||||
@@ -52,26 +49,6 @@ type EmbeddedIdPConfig struct {
|
||||
// Existing local users are preserved and will be able to login again if re-enabled.
|
||||
// Cannot be enabled if no external identity provider connectors are configured.
|
||||
LocalAuthDisabled bool
|
||||
// MfaSessionMaxLifetime is the maximum MFA session duration from creation (e.g. "24h").
|
||||
// Defaults to "24h" if empty.
|
||||
MfaSessionMaxLifetime string
|
||||
// MfaSessionIdleTimeout is the idle timeout after which the MFA session expires (e.g. "1h").
|
||||
// Defaults to "1h" if empty.
|
||||
MfaSessionIdleTimeout string
|
||||
// MfaSessionRememberMe controls the default state of the "remember me" checkbox on the
|
||||
// login screen. When true, the session cookie persists across browser tabs/restarts so
|
||||
// MFA is not re-prompted until the session expires. Defaults to false.
|
||||
MfaSessionRememberMe bool
|
||||
// SessionCookieEncryptionKey is the optional AES key used to encrypt embedded IdP session cookies.
|
||||
// It can also be set with NB_IDP_SESSION_COOKIE_ENCRYPTION_KEY. The value must be 16, 24, or 32
|
||||
// bytes when provided as a raw string, or base64-encoded to one of those lengths.
|
||||
SessionCookieEncryptionKey string
|
||||
// Dashboard Post logout redirect URIs, these are required to tell
|
||||
// Dex what to allow when an RP-Initiated logout is started by the frontend
|
||||
// at least one of these must match the dashboard base URL or the dashboard
|
||||
// DASHBOARD_POST_LOGOUT_URL environment variable
|
||||
// WARNING: Dex only uses exact match, not wildcards
|
||||
DashboardPostLogoutRedirectURIs []string
|
||||
// StaticConnectors are additional connectors to seed during initialization
|
||||
StaticConnectors []dex.Connector
|
||||
}
|
||||
@@ -149,11 +126,6 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
|
||||
// todo: resolve import cycle
|
||||
dashboardRedirectURIs = append(dashboardRedirectURIs, baseURL+"/api/reverse-proxy/callback")
|
||||
|
||||
dashboardPostLogoutRedirectURIs := c.DashboardPostLogoutRedirectURIs
|
||||
// It is safe to assume that most installations will share the location of the
|
||||
// MGMT api and the dashboard, adding baseURL means less configuration for the instance admin
|
||||
dashboardPostLogoutRedirectURIs = append(dashboardPostLogoutRedirectURIs, baseURL)
|
||||
|
||||
cfg := &dex.YAMLConfig{
|
||||
Issuer: c.Issuer,
|
||||
Storage: dex.Storage{
|
||||
@@ -176,11 +148,10 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
|
||||
EnablePasswordDB: true,
|
||||
StaticClients: []storage.Client{
|
||||
{
|
||||
ID: staticClientDashboard,
|
||||
Name: "NetBird Dashboard",
|
||||
Public: true,
|
||||
RedirectURIs: dashboardRedirectURIs,
|
||||
PostLogoutRedirectURIs: sanitizePostLogoutRedirectURIs(dashboardPostLogoutRedirectURIs),
|
||||
ID: staticClientDashboard,
|
||||
Name: "NetBird Dashboard",
|
||||
Public: true,
|
||||
RedirectURIs: dashboardRedirectURIs,
|
||||
},
|
||||
{
|
||||
ID: staticClientCLI,
|
||||
@@ -192,12 +163,6 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
|
||||
StaticConnectors: c.StaticConnectors,
|
||||
}
|
||||
|
||||
// Always initialize MFA providers and sessions so TOTP can be toggled at runtime.
|
||||
// MFAChain on clients is NOT set here — it's synced from the DB setting on startup.
|
||||
if err := configureMFA(cfg, c.MfaSessionMaxLifetime, c.MfaSessionIdleTimeout, c.MfaSessionRememberMe, c.SessionCookieEncryptionKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Add owner user if provided
|
||||
if c.Owner != nil && c.Owner.Email != "" && c.Owner.Hash != "" {
|
||||
username := c.Owner.Username
|
||||
@@ -217,100 +182,6 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// Due to how the frontend generates the logout, sometimes it appends a trailing slash
|
||||
// and because Dex only allows exact matches, we need to make sure we always have both
|
||||
// versions of each provided uri
|
||||
func sanitizePostLogoutRedirectURIs(uris []string) []string {
|
||||
result := make([]string, 0)
|
||||
for _, uri := range uris {
|
||||
if strings.HasSuffix(uri, "/") {
|
||||
result = append(result, uri)
|
||||
result = append(result, strings.TrimSuffix(uri, "/"))
|
||||
} else {
|
||||
result = append(result, uri)
|
||||
result = append(result, uri+"/")
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func configureMFA(cfg *dex.YAMLConfig, sessionMaxLifetime, sessionIdleTimeout string, rememberMe bool, sessionCookieEncryptionKey string) error {
|
||||
cfg.MFA.Authenticators = []dex.MFAAuthenticator{{
|
||||
ID: "default-totp",
|
||||
// Has to be caps otherwise it will fail
|
||||
Type: "TOTP",
|
||||
Config: map[string]interface{}{
|
||||
"issuer": "NetBird",
|
||||
},
|
||||
ConnectorTypes: []string{"local"},
|
||||
}}
|
||||
|
||||
if sessionMaxLifetime == "" {
|
||||
sessionMaxLifetime = "24h"
|
||||
}
|
||||
if sessionIdleTimeout == "" {
|
||||
sessionIdleTimeout = "1h"
|
||||
}
|
||||
|
||||
cookieEncryptionKey, err := resolveSessionCookieEncryptionKey(sessionCookieEncryptionKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cfg.Sessions = &dex.Sessions{
|
||||
CookieName: "netbird-session",
|
||||
AbsoluteLifetime: sessionMaxLifetime,
|
||||
ValidIfNotUsedFor: sessionIdleTimeout,
|
||||
RememberMeCheckedByDefault: &rememberMe,
|
||||
SSOSharedWithDefault: "all",
|
||||
CookieEncryptionKey: cookieEncryptionKey,
|
||||
}
|
||||
// Absolutely required, otherwise the dex server will omit the MFA configuration entirely
|
||||
os.Setenv("DEX_SESSIONS_ENABLED", "true")
|
||||
|
||||
// Note: MFAChain on clients is NOT set here.
|
||||
// It is toggled at runtime via SetMFAEnabled() based on the account settings DB value.
|
||||
return nil
|
||||
}
|
||||
|
||||
func resolveSessionCookieEncryptionKey(configuredKey string) (string, error) {
|
||||
key := strings.TrimSpace(configuredKey)
|
||||
if key == "" {
|
||||
key = strings.TrimSpace(os.Getenv(sessionCookieEncryptionKeyEnv))
|
||||
}
|
||||
if key == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
if validSessionCookieEncryptionKeyLength(len([]byte(key))) {
|
||||
return key, nil
|
||||
}
|
||||
|
||||
for _, encoding := range []*base64.Encoding{
|
||||
base64.StdEncoding,
|
||||
base64.RawStdEncoding,
|
||||
base64.URLEncoding,
|
||||
base64.RawURLEncoding,
|
||||
} {
|
||||
decoded, err := encoding.DecodeString(key)
|
||||
if err == nil && validSessionCookieEncryptionKeyLength(len(decoded)) {
|
||||
return string(decoded), nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("invalid embedded IdP session cookie encryption key: %s (or sessionCookieEncryptionKey) must be 16, 24, or 32 bytes as a raw string or base64-encoded to one of those lengths; got %d raw bytes", sessionCookieEncryptionKeyEnv, len([]byte(key)))
|
||||
}
|
||||
|
||||
func validSessionCookieEncryptionKeyLength(length int) bool {
|
||||
switch length {
|
||||
case 16, 24, 32:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Compile-time check that EmbeddedIdPManager implements Manager interface
|
||||
var _ Manager = (*EmbeddedIdPManager)(nil)
|
||||
|
||||
@@ -344,7 +215,6 @@ type EmbeddedIdPManager struct {
|
||||
provider *dex.Provider
|
||||
appMetrics telemetry.AppMetrics
|
||||
config EmbeddedIdPConfig
|
||||
mfaEnabled bool
|
||||
}
|
||||
|
||||
// NewEmbeddedIdPManager creates a new instance of EmbeddedIdPManager from a configuration.
|
||||
@@ -771,27 +641,6 @@ func (m *EmbeddedIdPManager) IsLocalAuthDisabled() bool {
|
||||
return m.config.LocalAuthDisabled
|
||||
}
|
||||
|
||||
// SetMFAEnabled enables or disables TOTP MFA for local users by updating the MFAChain on OAuth2 clients.
|
||||
func (m *EmbeddedIdPManager) SetMFAEnabled(ctx context.Context, enabled bool) error {
|
||||
var mfaChain []string
|
||||
if enabled {
|
||||
mfaChain = []string{"default-totp"}
|
||||
}
|
||||
if err := m.provider.SetClientsMFAChain(ctx, []string{
|
||||
staticClientCLI,
|
||||
staticClientDashboard,
|
||||
}, mfaChain); err != nil {
|
||||
return fmt.Errorf("failed to set MFA enabled=%v: %w", enabled, err)
|
||||
}
|
||||
m.mfaEnabled = enabled
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsMFAEnabled returns whether TOTP MFA is currently enabled for local users.
|
||||
func (m *EmbeddedIdPManager) IsMFAEnabled() bool {
|
||||
return m.mfaEnabled
|
||||
}
|
||||
|
||||
// HasNonLocalConnectors checks if there are any identity provider connectors other than local.
|
||||
func (m *EmbeddedIdPManager) HasNonLocalConnectors(ctx context.Context) (bool, error) {
|
||||
return m.provider.HasNonLocalConnectors(ctx)
|
||||
|
||||
@@ -2,7 +2,6 @@ package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
@@ -314,72 +313,6 @@ func TestEmbeddedIdPManager_UpdateUserPassword(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestEmbeddedIdPConfig_ToYAMLConfig_SessionCookieEncryptionKey(t *testing.T) {
|
||||
t.Setenv(sessionCookieEncryptionKeyEnv, "")
|
||||
|
||||
rawKey := "0123456789abcdef0123456789abcdef"
|
||||
config := &EmbeddedIdPConfig{
|
||||
Enabled: true,
|
||||
Issuer: "http://localhost:5556/dex",
|
||||
SessionCookieEncryptionKey: base64.StdEncoding.EncodeToString([]byte(rawKey)),
|
||||
Storage: EmbeddedStorageConfig{
|
||||
Type: "sqlite3",
|
||||
Config: EmbeddedStorageTypeConfig{
|
||||
File: filepath.Join(t.TempDir(), "dex.db"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
yamlConfig, err := config.ToYAMLConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, yamlConfig.Sessions)
|
||||
assert.Equal(t, rawKey, yamlConfig.Sessions.CookieEncryptionKey)
|
||||
}
|
||||
|
||||
func TestResolveSessionCookieEncryptionKey(t *testing.T) {
|
||||
rawKey := "0123456789abcdef0123456789abcdef"
|
||||
|
||||
t.Run("uses raw configured key", func(t *testing.T) {
|
||||
t.Setenv(sessionCookieEncryptionKeyEnv, "")
|
||||
|
||||
key, err := resolveSessionCookieEncryptionKey(rawKey)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, rawKey, key)
|
||||
})
|
||||
|
||||
t.Run("uses base64 configured key", func(t *testing.T) {
|
||||
t.Setenv(sessionCookieEncryptionKeyEnv, "")
|
||||
|
||||
key, err := resolveSessionCookieEncryptionKey(base64.StdEncoding.EncodeToString([]byte(rawKey)))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, rawKey, key)
|
||||
})
|
||||
|
||||
t.Run("falls back to env var", func(t *testing.T) {
|
||||
t.Setenv(sessionCookieEncryptionKeyEnv, rawKey)
|
||||
|
||||
key, err := resolveSessionCookieEncryptionKey("")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, rawKey, key)
|
||||
})
|
||||
|
||||
t.Run("empty key disables encryption", func(t *testing.T) {
|
||||
t.Setenv(sessionCookieEncryptionKeyEnv, "")
|
||||
|
||||
key, err := resolveSessionCookieEncryptionKey("")
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, key)
|
||||
})
|
||||
|
||||
t.Run("rejects invalid key length", func(t *testing.T) {
|
||||
t.Setenv(sessionCookieEncryptionKeyEnv, "")
|
||||
|
||||
_, err := resolveSessionCookieEncryptionKey("32")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), sessionCookieEncryptionKeyEnv)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEmbeddedIdPManager_GetLocalKeysLocation(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
|
||||
@@ -236,6 +236,7 @@ func (s *SqlStore) GetPeerJobs(ctx context.Context, accountID, peerID string) ([
|
||||
Where(accountAndPeerIDQueryCondition, accountID, peerID).
|
||||
Order("created_at DESC").
|
||||
Find(&jobs).Error
|
||||
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to fetch jobs from store: %s", err)
|
||||
return nil, err
|
||||
@@ -462,6 +463,7 @@ func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1512,7 +1514,6 @@ func (s *SqlStore) getAccount(ctx context.Context, accountID string) (*types.Acc
|
||||
settings_jwt_groups_enabled, settings_jwt_groups_claim_name, settings_jwt_allow_groups,
|
||||
settings_routing_peer_dns_resolution_enabled, settings_dns_domain, settings_network_range,
|
||||
settings_network_range_v6, settings_ipv6_enabled_groups, settings_lazy_connection_enabled,
|
||||
settings_local_mfa_enabled,
|
||||
-- Embedded ExtraSettings
|
||||
settings_extra_peer_approval_enabled, settings_extra_user_approval_required,
|
||||
settings_extra_integrated_validator, settings_extra_integrated_validator_groups
|
||||
@@ -1534,7 +1535,6 @@ func (s *SqlStore) getAccount(ctx context.Context, accountID string) (*types.Acc
|
||||
sNetworkRangeV6 sql.NullString
|
||||
sIPv6EnabledGroups sql.NullString
|
||||
sLazyConnectionEnabled sql.NullBool
|
||||
sLocalMFAEnabled sql.NullBool
|
||||
sExtraPeerApprovalEnabled sql.NullBool
|
||||
sExtraUserApprovalRequired sql.NullBool
|
||||
sExtraIntegratedValidator sql.NullString
|
||||
@@ -1557,7 +1557,6 @@ func (s *SqlStore) getAccount(ctx context.Context, accountID string) (*types.Acc
|
||||
&sJWTGroupsEnabled, &sJWTGroupsClaimName, &sJWTAllowGroups,
|
||||
&sRoutingPeerDNSResolutionEnabled, &sDNSDomain, &sNetworkRange,
|
||||
&sNetworkRangeV6, &sIPv6EnabledGroups, &sLazyConnectionEnabled,
|
||||
&sLocalMFAEnabled,
|
||||
&sExtraPeerApprovalEnabled, &sExtraUserApprovalRequired,
|
||||
&sExtraIntegratedValidator, &sExtraIntegratedValidatorGroups,
|
||||
)
|
||||
@@ -1620,9 +1619,6 @@ func (s *SqlStore) getAccount(ctx context.Context, accountID string) (*types.Acc
|
||||
if sLazyConnectionEnabled.Valid {
|
||||
account.Settings.LazyConnectionEnabled = sLazyConnectionEnabled.Bool
|
||||
}
|
||||
if sLocalMFAEnabled.Valid {
|
||||
account.Settings.LocalMfaEnabled = sLocalMFAEnabled.Bool
|
||||
}
|
||||
if sJWTAllowGroups.Valid {
|
||||
_ = json.Unmarshal([]byte(sJWTAllowGroups.String), &account.Settings.JWTAllowGroups)
|
||||
}
|
||||
@@ -3065,6 +3061,7 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer
|
||||
GroupID: groupID,
|
||||
PeerID: peerID,
|
||||
}).Error
|
||||
|
||||
if err != nil {
|
||||
return status.Errorf(status.Internal, "error adding peer to group 'All': %v", err)
|
||||
}
|
||||
@@ -3084,6 +3081,7 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountID, peerID, groupI
|
||||
Columns: []clause.Column{{Name: "group_id"}, {Name: "peer_id"}},
|
||||
DoNothing: true,
|
||||
}).Create(peer).Error
|
||||
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to add peer %s to group %s for account %s: %v", peerID, groupID, accountID, err)
|
||||
return status.Errorf(status.Internal, "failed to add peer to group")
|
||||
@@ -3096,6 +3094,7 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountID, peerID, groupI
|
||||
func (s *SqlStore) RemovePeerFromGroup(ctx context.Context, peerID string, groupID string) error {
|
||||
err := s.db.
|
||||
Delete(&types.GroupPeer{}, "group_id = ? AND peer_id = ?", groupID, peerID).Error
|
||||
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to remove peer %s from group %s: %v", peerID, groupID, err)
|
||||
return status.Errorf(status.Internal, "failed to remove peer from group")
|
||||
@@ -3108,6 +3107,7 @@ func (s *SqlStore) RemovePeerFromGroup(ctx context.Context, peerID string, group
|
||||
func (s *SqlStore) RemovePeerFromAllGroups(ctx context.Context, peerID string) error {
|
||||
err := s.db.
|
||||
Delete(&types.GroupPeer{}, "peer_id = ?", peerID).Error
|
||||
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to remove peer %s from all groups: %v", peerID, err)
|
||||
return status.Errorf(status.Internal, "failed to remove peer from all groups")
|
||||
@@ -4513,47 +4513,6 @@ func (s *SqlStore) RevokeProxyAccessToken(ctx context.Context, tokenID string) e
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetProxyAccessTokensByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.ProxyAccessToken, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var tokens []*types.ProxyAccessToken
|
||||
result := tx.Where("account_id = ?", accountID).Find(&tokens)
|
||||
if result.Error != nil {
|
||||
return nil, status.Errorf(status.Internal, "get proxy access tokens by account: %v", result.Error)
|
||||
}
|
||||
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) IsProxyAccessTokenValid(ctx context.Context, tokenID string) (bool, error) {
|
||||
token, err := s.GetProxyAccessTokenByID(ctx, LockingStrengthNone, tokenID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return token.IsValid(), nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetProxyAccessTokenByID(ctx context.Context, lockStrength LockingStrength, tokenID string) (*types.ProxyAccessToken, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var token types.ProxyAccessToken
|
||||
result := tx.Take(&token, idQueryCondition, tokenID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "proxy access token not found")
|
||||
}
|
||||
return nil, status.Errorf(status.Internal, "get proxy access token by ID: %v", result.Error)
|
||||
}
|
||||
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// MarkProxyAccessTokenUsed updates the last used timestamp for a proxy access token.
|
||||
func (s *SqlStore) MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error {
|
||||
result := s.db.Model(&types.ProxyAccessToken{}).
|
||||
@@ -5005,6 +4964,7 @@ func (s *SqlStore) UpdateService(ctx context.Context, service *rpservice.Service
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to update service to store: %v", err)
|
||||
return status.Errorf(status.Internal, "failed to update service to store")
|
||||
@@ -5528,7 +5488,7 @@ func (s *SqlStore) DisconnectProxy(ctx context.Context, proxyID, sessionID strin
|
||||
Model(&proxy.Proxy{}).
|
||||
Where("id = ? AND session_id = ?", proxyID, sessionID).
|
||||
Updates(map[string]any{
|
||||
"status": proxy.StatusDisconnected,
|
||||
"status": "disconnected",
|
||||
"disconnected_at": now,
|
||||
"last_seen": now,
|
||||
})
|
||||
@@ -5559,7 +5519,7 @@ func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) err
|
||||
if result.RowsAffected == 0 {
|
||||
p.LastSeen = now
|
||||
p.ConnectedAt = &now
|
||||
p.Status = proxy.StatusConnected
|
||||
p.Status = "connected"
|
||||
if err := s.db.Create(p).Error; err != nil {
|
||||
log.WithContext(ctx).Debugf("proxy %s session %s: heartbeat fallback insert skipped: %v", p.ID, p.SessionID, err)
|
||||
}
|
||||
@@ -5568,15 +5528,13 @@ func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) err
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetActiveProxyClusterAddresses returns the unique cluster addresses of active
|
||||
// shared proxies (those without an account scope). BYOP cluster addresses are
|
||||
// excluded; use GetActiveProxyClusterAddressesForAccount to retrieve them.
|
||||
// GetActiveProxyClusterAddresses returns all unique cluster addresses for active proxies
|
||||
func (s *SqlStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) {
|
||||
var addresses []string
|
||||
|
||||
result := s.db.
|
||||
Model(&proxy.Proxy{}).
|
||||
Where("account_id IS NULL AND status = ? AND last_seen > ?", proxy.StatusConnected, time.Now().Add(-proxyActiveThreshold)).
|
||||
Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-proxyActiveThreshold)).
|
||||
Distinct("cluster_address").
|
||||
Pluck("cluster_address", &addresses)
|
||||
|
||||
@@ -5588,75 +5546,13 @@ func (s *SqlStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string
|
||||
return addresses, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
|
||||
var addresses []string
|
||||
|
||||
result := s.db.
|
||||
Model(&proxy.Proxy{}).
|
||||
Where("account_id = ? AND status = ? AND last_seen > ?", accountID, proxy.StatusConnected, time.Now().Add(-proxyActiveThreshold)).
|
||||
Distinct("cluster_address").
|
||||
Pluck("cluster_address", &addresses)
|
||||
|
||||
if result.Error != nil {
|
||||
return nil, status.Errorf(status.Internal, "failed to get active proxy cluster addresses for account")
|
||||
}
|
||||
|
||||
return addresses, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error) {
|
||||
var p proxy.Proxy
|
||||
result := s.db.Where("account_id = ?", accountID).Take(&p)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "proxy not found for account")
|
||||
}
|
||||
return nil, status.Errorf(status.Internal, "get proxy by account ID: %v", result.Error)
|
||||
}
|
||||
return &p, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error) {
|
||||
var count int64
|
||||
result := s.db.Model(&proxy.Proxy{}).Where("account_id = ?", accountID).Count(&count)
|
||||
if result.Error != nil {
|
||||
return 0, status.Errorf(status.Internal, "count proxies by account ID: %v", result.Error)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error) {
|
||||
var count int64
|
||||
result := s.db.
|
||||
Model(&proxy.Proxy{}).
|
||||
Where("cluster_address = ? AND (account_id IS NULL OR account_id != ?)", clusterAddress, accountID).
|
||||
Count(&count)
|
||||
if result.Error != nil {
|
||||
return false, status.Errorf(status.Internal, "check cluster address conflict: %v", result.Error)
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error {
|
||||
result := s.db.
|
||||
Where("cluster_address = ? AND account_id = ?", clusterAddress, accountID).
|
||||
Delete(&proxy.Proxy{})
|
||||
if result.Error != nil {
|
||||
return status.Errorf(status.Internal, "delete account cluster: %v", result.Error)
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return status.Errorf(status.NotFound, "cluster not found")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error) {
|
||||
// GetActiveProxyClusters returns all active proxy clusters with their connected proxy count.
|
||||
func (s *SqlStore) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) {
|
||||
var clusters []proxy.Cluster
|
||||
|
||||
result := s.db.Model(&proxy.Proxy{}).
|
||||
Select("MIN(id) as id, cluster_address as address, COUNT(*) as connected_proxies, COUNT(account_id) > 0 as self_hosted").
|
||||
Where("status = ? AND last_seen > ? AND (account_id IS NULL OR account_id = ?)",
|
||||
proxy.StatusConnected, time.Now().Add(-proxyActiveThreshold), accountID).
|
||||
Select("cluster_address as address, COUNT(*) as connected_proxies").
|
||||
Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-proxyActiveThreshold)).
|
||||
Group("cluster_address").
|
||||
Scan(&clusters)
|
||||
|
||||
@@ -5724,6 +5620,7 @@ func (s *SqlStore) getClusterUnanimousCapability(ctx context.Context, clusterAdd
|
||||
Where("cluster_address = ? AND status = ? AND last_seen > ?",
|
||||
clusterAddr, "connected", time.Now().Add(-proxyActiveThreshold)).
|
||||
Scan(&result).Error
|
||||
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("query cluster capability %s for %s: %v", column, clusterAddr, err)
|
||||
return nil
|
||||
@@ -5765,6 +5662,7 @@ func (s *SqlStore) getClusterCapability(ctx context.Context, clusterAddr, column
|
||||
Where("cluster_address = ? AND status = ? AND last_seen > ?",
|
||||
clusterAddr, "connected", time.Now().Add(-proxyActiveThreshold)).
|
||||
Scan(&result).Error
|
||||
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("query cluster capability %s for %s: %v", column, clusterAddr, err)
|
||||
return nil
|
||||
|
||||
@@ -114,9 +114,6 @@ type Store interface {
|
||||
|
||||
GetProxyAccessTokenByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken types.HashedProxyToken) (*types.ProxyAccessToken, error)
|
||||
GetAllProxyAccessTokens(ctx context.Context, lockStrength LockingStrength) ([]*types.ProxyAccessToken, error)
|
||||
GetProxyAccessTokensByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.ProxyAccessToken, error)
|
||||
GetProxyAccessTokenByID(ctx context.Context, lockStrength LockingStrength, tokenID string) (*types.ProxyAccessToken, error)
|
||||
IsProxyAccessTokenValid(ctx context.Context, tokenID string) (bool, error)
|
||||
SaveProxyAccessToken(ctx context.Context, token *types.ProxyAccessToken) error
|
||||
RevokeProxyAccessToken(ctx context.Context, tokenID string) error
|
||||
MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error
|
||||
@@ -291,16 +288,11 @@ type Store interface {
|
||||
DisconnectProxy(ctx context.Context, proxyID, sessionID string) error
|
||||
UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error
|
||||
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
|
||||
GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
|
||||
GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error)
|
||||
GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error)
|
||||
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||
GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||
GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
||||
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
|
||||
GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error)
|
||||
CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error)
|
||||
IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error)
|
||||
DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error
|
||||
|
||||
GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error)
|
||||
|
||||
@@ -504,9 +496,6 @@ func getMigrationsPostAuto(ctx context.Context) []migrationFunc {
|
||||
func(db *gorm.DB) error {
|
||||
return migration.CreateIndexIfNotExists[nbpeer.Peer](ctx, db, "idx_peers_key_unique", "key")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.DropIndex[proxy.Proxy](ctx, db, "idx_proxy_account_id_unique")
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -166,6 +166,20 @@ func (mr *MockStoreMockRecorder) CleanupStaleProxies(ctx, inactivityDuration int
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupStaleProxies", reflect.TypeOf((*MockStore)(nil).CleanupStaleProxies), ctx, inactivityDuration)
|
||||
}
|
||||
|
||||
// GetClusterSupportsCrowdSec mocks base method.
|
||||
func (m *MockStore) GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetClusterSupportsCrowdSec", ctx, clusterAddr)
|
||||
ret0, _ := ret[0].(*bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetClusterSupportsCrowdSec indicates an expected call of GetClusterSupportsCrowdSec.
|
||||
func (mr *MockStoreMockRecorder) GetClusterSupportsCrowdSec(ctx, clusterAddr interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCrowdSec", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCrowdSec), ctx, clusterAddr)
|
||||
}
|
||||
|
||||
// Close mocks base method.
|
||||
func (m *MockStore) Close(ctx context.Context) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -224,21 +238,6 @@ func (mr *MockStoreMockRecorder) CountEphemeralServicesByPeer(ctx, lockStrength,
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountEphemeralServicesByPeer", reflect.TypeOf((*MockStore)(nil).CountEphemeralServicesByPeer), ctx, lockStrength, accountID, peerID)
|
||||
}
|
||||
|
||||
// CountProxiesByAccountID mocks base method.
|
||||
func (m *MockStore) CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CountProxiesByAccountID", ctx, accountID)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// CountProxiesByAccountID indicates an expected call of CountProxiesByAccountID.
|
||||
func (mr *MockStoreMockRecorder) CountProxiesByAccountID(ctx, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountProxiesByAccountID", reflect.TypeOf((*MockStore)(nil).CountProxiesByAccountID), ctx, accountID)
|
||||
}
|
||||
|
||||
// CreateAccessLog mocks base method.
|
||||
func (m *MockStore) CreateAccessLog(ctx context.Context, log *accesslogs.AccessLogEntry) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -577,20 +576,6 @@ func (mr *MockStoreMockRecorder) DeletePostureChecks(ctx, accountID, postureChec
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePostureChecks", reflect.TypeOf((*MockStore)(nil).DeletePostureChecks), ctx, accountID, postureChecksID)
|
||||
}
|
||||
|
||||
// DeleteAccountCluster mocks base method.
|
||||
func (m *MockStore) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteAccountCluster", ctx, clusterAddress, accountID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteAccountCluster indicates an expected call of DeleteAccountCluster.
|
||||
func (mr *MockStoreMockRecorder) DeleteAccountCluster(ctx, clusterAddress, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAccountCluster", reflect.TypeOf((*MockStore)(nil).DeleteAccountCluster), ctx, clusterAddress, accountID)
|
||||
}
|
||||
|
||||
// DeleteRoute mocks base method.
|
||||
func (m *MockStore) DeleteRoute(ctx context.Context, accountID, routeID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1317,34 +1302,19 @@ func (mr *MockStoreMockRecorder) GetActiveProxyClusterAddresses(ctx interface{})
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusterAddresses", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusterAddresses), ctx)
|
||||
}
|
||||
|
||||
// GetActiveProxyClusterAddressesForAccount mocks base method.
|
||||
func (m *MockStore) GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetActiveProxyClusterAddressesForAccount", ctx, accountID)
|
||||
ret0, _ := ret[0].([]string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetActiveProxyClusterAddressesForAccount indicates an expected call of GetActiveProxyClusterAddressesForAccount.
|
||||
func (mr *MockStoreMockRecorder) GetActiveProxyClusterAddressesForAccount(ctx, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusterAddressesForAccount", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusterAddressesForAccount), ctx, accountID)
|
||||
}
|
||||
|
||||
// GetActiveProxyClusters mocks base method.
|
||||
func (m *MockStore) GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error) {
|
||||
func (m *MockStore) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetActiveProxyClusters", ctx, accountID)
|
||||
ret := m.ctrl.Call(m, "GetActiveProxyClusters", ctx)
|
||||
ret0, _ := ret[0].([]proxy.Cluster)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetActiveProxyClusters indicates an expected call of GetActiveProxyClusters.
|
||||
func (mr *MockStoreMockRecorder) GetActiveProxyClusters(ctx, accountID interface{}) *gomock.Call {
|
||||
func (mr *MockStoreMockRecorder) GetActiveProxyClusters(ctx interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusters", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusters), ctx, accountID)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusters", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusters), ctx)
|
||||
}
|
||||
|
||||
// GetAllAccounts mocks base method.
|
||||
@@ -1420,20 +1390,6 @@ func (mr *MockStoreMockRecorder) GetClusterRequireSubdomain(ctx, clusterAddr int
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterRequireSubdomain", reflect.TypeOf((*MockStore)(nil).GetClusterRequireSubdomain), ctx, clusterAddr)
|
||||
}
|
||||
|
||||
// GetClusterSupportsCrowdSec mocks base method.
|
||||
func (m *MockStore) GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetClusterSupportsCrowdSec", ctx, clusterAddr)
|
||||
ret0, _ := ret[0].(*bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetClusterSupportsCrowdSec indicates an expected call of GetClusterSupportsCrowdSec.
|
||||
func (mr *MockStoreMockRecorder) GetClusterSupportsCrowdSec(ctx, clusterAddr interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCrowdSec", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCrowdSec), ctx, clusterAddr)
|
||||
}
|
||||
|
||||
// GetClusterSupportsCustomPorts mocks base method.
|
||||
func (m *MockStore) GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2003,51 +1959,6 @@ func (mr *MockStoreMockRecorder) GetProxyAccessTokenByHashedToken(ctx, lockStren
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyAccessTokenByHashedToken", reflect.TypeOf((*MockStore)(nil).GetProxyAccessTokenByHashedToken), ctx, lockStrength, hashedToken)
|
||||
}
|
||||
|
||||
// GetProxyAccessTokenByID mocks base method.
|
||||
func (m *MockStore) GetProxyAccessTokenByID(ctx context.Context, lockStrength LockingStrength, tokenID string) (*types2.ProxyAccessToken, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetProxyAccessTokenByID", ctx, lockStrength, tokenID)
|
||||
ret0, _ := ret[0].(*types2.ProxyAccessToken)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetProxyAccessTokenByID indicates an expected call of GetProxyAccessTokenByID.
|
||||
func (mr *MockStoreMockRecorder) GetProxyAccessTokenByID(ctx, lockStrength, tokenID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyAccessTokenByID", reflect.TypeOf((*MockStore)(nil).GetProxyAccessTokenByID), ctx, lockStrength, tokenID)
|
||||
}
|
||||
|
||||
// GetProxyAccessTokensByAccountID mocks base method.
|
||||
func (m *MockStore) GetProxyAccessTokensByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types2.ProxyAccessToken, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetProxyAccessTokensByAccountID", ctx, lockStrength, accountID)
|
||||
ret0, _ := ret[0].([]*types2.ProxyAccessToken)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetProxyAccessTokensByAccountID indicates an expected call of GetProxyAccessTokensByAccountID.
|
||||
func (mr *MockStoreMockRecorder) GetProxyAccessTokensByAccountID(ctx, lockStrength, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyAccessTokensByAccountID", reflect.TypeOf((*MockStore)(nil).GetProxyAccessTokensByAccountID), ctx, lockStrength, accountID)
|
||||
}
|
||||
|
||||
// GetProxyByAccountID mocks base method.
|
||||
func (m *MockStore) GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetProxyByAccountID", ctx, accountID)
|
||||
ret0, _ := ret[0].(*proxy.Proxy)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetProxyByAccountID indicates an expected call of GetProxyByAccountID.
|
||||
func (mr *MockStoreMockRecorder) GetProxyByAccountID(ctx, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyByAccountID", reflect.TypeOf((*MockStore)(nil).GetProxyByAccountID), ctx, accountID)
|
||||
}
|
||||
|
||||
// GetResourceGroups mocks base method.
|
||||
func (m *MockStore) GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types2.Group, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2480,21 +2391,6 @@ func (mr *MockStoreMockRecorder) IncrementSetupKeyUsage(ctx, setupKeyID interfac
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IncrementSetupKeyUsage", reflect.TypeOf((*MockStore)(nil).IncrementSetupKeyUsage), ctx, setupKeyID)
|
||||
}
|
||||
|
||||
// IsClusterAddressConflicting mocks base method.
|
||||
func (m *MockStore) IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "IsClusterAddressConflicting", ctx, clusterAddress, accountID)
|
||||
ret0, _ := ret[0].(bool)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// IsClusterAddressConflicting indicates an expected call of IsClusterAddressConflicting.
|
||||
func (mr *MockStoreMockRecorder) IsClusterAddressConflicting(ctx, clusterAddress, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsClusterAddressConflicting", reflect.TypeOf((*MockStore)(nil).IsClusterAddressConflicting), ctx, clusterAddress, accountID)
|
||||
}
|
||||
|
||||
// IsPrimaryAccount mocks base method.
|
||||
func (m *MockStore) IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2511,21 +2407,6 @@ func (mr *MockStoreMockRecorder) IsPrimaryAccount(ctx, accountID interface{}) *g
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsPrimaryAccount", reflect.TypeOf((*MockStore)(nil).IsPrimaryAccount), ctx, accountID)
|
||||
}
|
||||
|
||||
// IsProxyAccessTokenValid mocks base method.
|
||||
func (m *MockStore) IsProxyAccessTokenValid(ctx context.Context, tokenID string) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "IsProxyAccessTokenValid", ctx, tokenID)
|
||||
ret0, _ := ret[0].(bool)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// IsProxyAccessTokenValid indicates an expected call of IsProxyAccessTokenValid.
|
||||
func (mr *MockStoreMockRecorder) IsProxyAccessTokenValid(ctx, tokenID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsProxyAccessTokenValid", reflect.TypeOf((*MockStore)(nil).IsProxyAccessTokenValid), ctx, tokenID)
|
||||
}
|
||||
|
||||
// ListCustomDomains mocks base method.
|
||||
func (m *MockStore) ListCustomDomains(ctx context.Context, accountID string) ([]*domain.Domain, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -80,10 +80,6 @@ type Settings struct {
|
||||
// LocalAuthDisabled indicates if local (email/password) authentication is disabled.
|
||||
// This is a runtime-only field, not stored in the database.
|
||||
LocalAuthDisabled bool `gorm:"-"`
|
||||
|
||||
// LocalMfaEnabled indicates if TOTP MFA is enabled for local users.
|
||||
// Only applicable when the embedded IDP is enabled.
|
||||
LocalMfaEnabled bool
|
||||
}
|
||||
|
||||
// Copy copies the Settings struct
|
||||
@@ -112,7 +108,6 @@ func (s *Settings) Copy() *Settings {
|
||||
IPv6EnabledGroups: slices.Clone(s.IPv6EnabledGroups),
|
||||
EmbeddedIdpEnabled: s.EmbeddedIdpEnabled,
|
||||
LocalAuthDisabled: s.LocalAuthDisabled,
|
||||
LocalMfaEnabled: s.LocalMfaEnabled,
|
||||
}
|
||||
if s.Extra != nil {
|
||||
settings.Extra = s.Extra.Copy()
|
||||
|
||||
@@ -109,6 +109,22 @@ var debugStopCmd = &cobra.Command{
|
||||
SilenceUsage: true,
|
||||
}
|
||||
|
||||
var debugPerfCmd = &cobra.Command{
|
||||
Use: "perf <pool-cap>",
|
||||
Short: "Live-retune the tunnel buffer pool cap on all running clients",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: runDebugPerfSet,
|
||||
SilenceUsage: true,
|
||||
}
|
||||
|
||||
var debugRuntimeCmd = &cobra.Command{
|
||||
Use: "runtime",
|
||||
Short: "Show runtime stats (heap, goroutines, RSS)",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: runDebugRuntime,
|
||||
SilenceUsage: true,
|
||||
}
|
||||
|
||||
var debugCaptureCmd = &cobra.Command{
|
||||
Use: "capture <account-id> [filter expression]",
|
||||
Short: "Capture packets on a client's WireGuard interface",
|
||||
@@ -159,6 +175,8 @@ func init() {
|
||||
debugCmd.AddCommand(debugLogCmd)
|
||||
debugCmd.AddCommand(debugStartCmd)
|
||||
debugCmd.AddCommand(debugStopCmd)
|
||||
debugCmd.AddCommand(debugPerfCmd)
|
||||
debugCmd.AddCommand(debugRuntimeCmd)
|
||||
debugCmd.AddCommand(debugCaptureCmd)
|
||||
|
||||
rootCmd.AddCommand(debugCmd)
|
||||
@@ -220,6 +238,18 @@ func runDebugStop(cmd *cobra.Command, args []string) error {
|
||||
return getDebugClient(cmd).StopClient(cmd.Context(), args[0])
|
||||
}
|
||||
|
||||
func runDebugPerfSet(cmd *cobra.Command, args []string) error {
|
||||
n, err := strconv.ParseUint(args[0], 10, 32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid value %q: %w", args[0], err)
|
||||
}
|
||||
return getDebugClient(cmd).PerfSet(cmd.Context(), uint32(n))
|
||||
}
|
||||
|
||||
func runDebugRuntime(cmd *cobra.Command, _ []string) error {
|
||||
return getDebugClient(cmd).Runtime(cmd.Context())
|
||||
}
|
||||
|
||||
func runDebugCapture(cmd *cobra.Command, args []string) error {
|
||||
duration, _ := cmd.Flags().GetDuration("duration")
|
||||
forcePcap, _ := cmd.Flags().GetBool("pcap")
|
||||
|
||||
@@ -15,11 +15,22 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
|
||||
"github.com/netbirdio/netbird/client/embed"
|
||||
"github.com/netbirdio/netbird/proxy"
|
||||
nbacme "github.com/netbirdio/netbird/proxy/internal/acme"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
const (
|
||||
// envPreallocatedBuffers caps the per-tunnel buffer pool. Zero (unset)
|
||||
// keeps the upstream uncapped default.
|
||||
envPreallocatedBuffers = "NB_PROXY_PREALLOCATED_BUFFERS"
|
||||
// envMaxBatchSize overrides the per-tunnel batch size, which controls
|
||||
// how many buffers each receive/TUN worker eagerly allocates. Zero
|
||||
// (unset) keeps the platform default.
|
||||
envMaxBatchSize = "NB_PROXY_MAX_BATCH_SIZE"
|
||||
)
|
||||
|
||||
const DefaultManagementURL = "https://api.netbird.io:443"
|
||||
|
||||
// envProxyToken is the environment variable name for the proxy access token.
|
||||
@@ -145,6 +156,45 @@ func runServer(cmd *cobra.Command, args []string) error {
|
||||
|
||||
logger.Infof("configured log level: %s", level)
|
||||
|
||||
var wgPool, wgBatch uint64
|
||||
var perf embed.Performance
|
||||
if raw := os.Getenv(envPreallocatedBuffers); raw != "" {
|
||||
n, err := strconv.ParseUint(raw, 10, 32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid %s %q: %w", envPreallocatedBuffers, raw, err)
|
||||
}
|
||||
wgPool = n
|
||||
v := uint32(n)
|
||||
perf.PreallocatedBuffersPerPool = &v
|
||||
logger.Infof("tunnel preallocated buffers per pool: %d", n)
|
||||
}
|
||||
if raw := os.Getenv(envMaxBatchSize); raw != "" {
|
||||
n, err := strconv.ParseUint(raw, 10, 32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid %s %q: %w", envMaxBatchSize, raw, err)
|
||||
}
|
||||
wgBatch = n
|
||||
v := uint32(n)
|
||||
perf.MaxBatchSize = &v
|
||||
logger.Infof("tunnel max batch size override: %d", n)
|
||||
}
|
||||
if wgPool > 0 {
|
||||
// Each bind recv goroutine (IPv4 + IPv6 + ICE relay) plus
|
||||
// RoutineReadFromTUN eagerly reserves `batch` message buffers for
|
||||
// the lifetime of the Device. A pool cap below that floor blocks
|
||||
// the receive pipeline at startup.
|
||||
batch := wgBatch
|
||||
if batch == 0 {
|
||||
batch = 128
|
||||
}
|
||||
const recvGoroutines = 4
|
||||
floor := batch * recvGoroutines
|
||||
if wgPool < floor {
|
||||
logger.Warnf("%s=%d is below the eager-allocation floor (~%d for batch=%d); startup may deadlock",
|
||||
envPreallocatedBuffers, wgPool, floor, batch)
|
||||
}
|
||||
}
|
||||
|
||||
switch forwardedProto {
|
||||
case "auto", "http", "https":
|
||||
default:
|
||||
@@ -184,6 +234,7 @@ func runServer(cmd *cobra.Command, args []string) error {
|
||||
CertLockMethod: nbacme.CertLockMethod(certLockMethod),
|
||||
WildcardCertDir: wildcardCertDir,
|
||||
WireguardPort: wgPort,
|
||||
Performance: perf,
|
||||
ProxyProtocol: proxyProtocol,
|
||||
PreSharedKey: preSharedKey,
|
||||
SupportsCustomPorts: supportsCustomPorts,
|
||||
|
||||
@@ -284,6 +284,63 @@ func (c *Client) printLogLevelResult(data map[string]any) {
|
||||
}
|
||||
}
|
||||
|
||||
// PerfSet live-retunes the tunnel buffer pool cap on all running embedded
|
||||
// clients. Batch size is not live-tunable; configure it at proxy startup.
|
||||
func (c *Client) PerfSet(ctx context.Context, value uint32) error {
|
||||
path := fmt.Sprintf("/debug/perf?value=%d", value)
|
||||
return c.fetchAndPrint(ctx, path, c.printPerfSet)
|
||||
}
|
||||
|
||||
func (c *Client) printPerfSet(data map[string]any) {
|
||||
if errMsg, ok := data["error"].(string); ok && errMsg != "" {
|
||||
c.printError(data)
|
||||
return
|
||||
}
|
||||
val, _ := data["value"].(float64)
|
||||
applied, _ := data["applied"].(float64)
|
||||
_, _ = fmt.Fprintf(c.out, "Pool cap set to: %d\n", uint32(val))
|
||||
_, _ = fmt.Fprintf(c.out, "Applied to %d live clients\n", int(applied))
|
||||
if failed, ok := data["failed"].(map[string]any); ok && len(failed) > 0 {
|
||||
_, _ = fmt.Fprintln(c.out, "Failed:")
|
||||
for k, v := range failed {
|
||||
_, _ = fmt.Fprintf(c.out, " %s: %v\n", k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Runtime fetches runtime stats (heap, goroutines, RSS).
|
||||
func (c *Client) Runtime(ctx context.Context) error {
|
||||
return c.fetchAndPrint(ctx, "/debug/runtime", c.printRuntime)
|
||||
}
|
||||
|
||||
func (c *Client) printRuntime(data map[string]any) {
|
||||
i := func(k string) uint64 {
|
||||
v, _ := data[k].(float64)
|
||||
return uint64(v)
|
||||
}
|
||||
mb := func(n uint64) string { return fmt.Sprintf("%.1f MB", float64(n)/(1<<20)) }
|
||||
|
||||
_, _ = fmt.Fprintf(c.out, "Uptime: %v\n", data["uptime"])
|
||||
_, _ = fmt.Fprintf(c.out, "Go: %v on %d CPU (GOMAXPROCS=%d)\n", data["go_version"], uint32(i("num_cpu")), uint32(i("gomaxprocs")))
|
||||
_, _ = fmt.Fprintf(c.out, "Goroutines: %d\n", i("goroutines"))
|
||||
_, _ = fmt.Fprintf(c.out, "Live objects: %d\n", i("live_objects"))
|
||||
_, _ = fmt.Fprintf(c.out, "GC: %d cycles, %v pause total\n", i("num_gc"), time.Duration(i("pause_total_ns")))
|
||||
_, _ = fmt.Fprintln(c.out, "Heap:")
|
||||
_, _ = fmt.Fprintf(c.out, " alloc: %s\n", mb(i("heap_alloc")))
|
||||
_, _ = fmt.Fprintf(c.out, " in-use: %s\n", mb(i("heap_inuse")))
|
||||
_, _ = fmt.Fprintf(c.out, " idle: %s\n", mb(i("heap_idle")))
|
||||
_, _ = fmt.Fprintf(c.out, " released: %s\n", mb(i("heap_released")))
|
||||
_, _ = fmt.Fprintf(c.out, " sys: %s\n", mb(i("heap_sys")))
|
||||
_, _ = fmt.Fprintf(c.out, "Total sys: %s\n", mb(i("sys")))
|
||||
if _, ok := data["vm_rss"]; ok {
|
||||
_, _ = fmt.Fprintln(c.out, "Process:")
|
||||
_, _ = fmt.Fprintf(c.out, " VmRSS: %s\n", mb(i("vm_rss")))
|
||||
_, _ = fmt.Fprintf(c.out, " VmSize: %s\n", mb(i("vm_size")))
|
||||
_, _ = fmt.Fprintf(c.out, " VmData: %s\n", mb(i("vm_data")))
|
||||
}
|
||||
_, _ = fmt.Fprintf(c.out, "Clients: %d (%d started)\n", i("clients"), i("started"))
|
||||
}
|
||||
|
||||
// StartClient starts a specific client.
|
||||
func (c *Client) StartClient(ctx context.Context, accountID string) error {
|
||||
path := "/debug/clients/" + url.PathEscape(accountID) + "/start"
|
||||
|
||||
@@ -11,6 +11,8 @@ import (
|
||||
"maps"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -59,6 +61,7 @@ func sortedAccountIDs(m map[types.AccountID]roundtrip.ClientDebugInfo) []types.A
|
||||
type clientProvider interface {
|
||||
GetClient(accountID types.AccountID) (*nbembed.Client, bool)
|
||||
ListClientsForDebug() map[types.AccountID]roundtrip.ClientDebugInfo
|
||||
ListClientsForStartup() map[types.AccountID]*nbembed.Client
|
||||
}
|
||||
|
||||
// healthChecker provides health probe state.
|
||||
@@ -140,6 +143,10 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
h.handleListClients(w, r, wantJSON)
|
||||
case "/debug/health":
|
||||
h.handleHealth(w, r, wantJSON)
|
||||
case "/debug/perf":
|
||||
h.handlePerf(w, r)
|
||||
case "/debug/runtime":
|
||||
h.handleRuntime(w, r)
|
||||
default:
|
||||
if h.handleClientRoutes(w, r, path, wantJSON) {
|
||||
return
|
||||
@@ -233,10 +240,10 @@ func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON b
|
||||
}
|
||||
|
||||
if wantJSON {
|
||||
clientsJSON := make([]map[string]interface{}, 0, len(clients))
|
||||
clientsJSON := make([]map[string]any, 0, len(clients))
|
||||
for _, id := range sortedIDs {
|
||||
info := clients[id]
|
||||
clientsJSON = append(clientsJSON, map[string]interface{}{
|
||||
clientsJSON = append(clientsJSON, map[string]any{
|
||||
"account_id": info.AccountID,
|
||||
"service_count": info.ServiceCount,
|
||||
"service_keys": info.ServiceKeys,
|
||||
@@ -245,7 +252,7 @@ func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON b
|
||||
"age": time.Since(info.CreatedAt).Round(time.Second).String(),
|
||||
})
|
||||
}
|
||||
resp := map[string]interface{}{
|
||||
resp := map[string]any{
|
||||
"version": version.NetbirdVersion(),
|
||||
"uptime": time.Since(h.startTime).Round(time.Second).String(),
|
||||
"client_count": len(clients),
|
||||
@@ -323,10 +330,10 @@ func (h *Handler) handleListClients(w http.ResponseWriter, _ *http.Request, want
|
||||
sortedIDs := sortedAccountIDs(clients)
|
||||
|
||||
if wantJSON {
|
||||
clientsJSON := make([]map[string]interface{}, 0, len(clients))
|
||||
clientsJSON := make([]map[string]any, 0, len(clients))
|
||||
for _, id := range sortedIDs {
|
||||
info := clients[id]
|
||||
clientsJSON = append(clientsJSON, map[string]interface{}{
|
||||
clientsJSON = append(clientsJSON, map[string]any{
|
||||
"account_id": info.AccountID,
|
||||
"service_count": info.ServiceCount,
|
||||
"service_keys": info.ServiceKeys,
|
||||
@@ -335,7 +342,7 @@ func (h *Handler) handleListClients(w http.ResponseWriter, _ *http.Request, want
|
||||
"age": time.Since(info.CreatedAt).Round(time.Second).String(),
|
||||
})
|
||||
}
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
h.writeJSON(w, map[string]any{
|
||||
"uptime": time.Since(h.startTime).Round(time.Second).String(),
|
||||
"client_count": len(clients),
|
||||
"clients": clientsJSON,
|
||||
@@ -421,7 +428,7 @@ func (h *Handler) handleClientStatus(w http.ResponseWriter, r *http.Request, acc
|
||||
})
|
||||
|
||||
if wantJSON {
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
h.writeJSON(w, map[string]any{
|
||||
"account_id": accountID,
|
||||
"status": overview.FullDetailSummary(),
|
||||
})
|
||||
@@ -504,20 +511,20 @@ func (h *Handler) handleClientTools(w http.ResponseWriter, _ *http.Request, acco
|
||||
func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
|
||||
client, ok := h.provider.GetClient(accountID)
|
||||
if !ok {
|
||||
h.writeJSON(w, map[string]interface{}{"error": "client not found"})
|
||||
h.writeJSON(w, map[string]any{"error": "client not found"})
|
||||
return
|
||||
}
|
||||
|
||||
host := r.URL.Query().Get("host")
|
||||
portStr := r.URL.Query().Get("port")
|
||||
if host == "" || portStr == "" {
|
||||
h.writeJSON(w, map[string]interface{}{"error": "host and port parameters required"})
|
||||
h.writeJSON(w, map[string]any{"error": "host and port parameters required"})
|
||||
return
|
||||
}
|
||||
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil || port < 1 || port > 65535 {
|
||||
h.writeJSON(w, map[string]interface{}{"error": "invalid port"})
|
||||
h.writeJSON(w, map[string]any{"error": "invalid port"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -541,7 +548,7 @@ func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountI
|
||||
|
||||
conn, err := client.Dial(ctx, network, address)
|
||||
if err != nil {
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
h.writeJSON(w, map[string]any{
|
||||
"success": false,
|
||||
"host": host,
|
||||
"port": port,
|
||||
@@ -556,39 +563,38 @@ func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountI
|
||||
}
|
||||
|
||||
latency := time.Since(start)
|
||||
resp := map[string]interface{}{
|
||||
h.writeJSON(w, map[string]any{
|
||||
"success": true,
|
||||
"host": host,
|
||||
"port": port,
|
||||
"remote": remote,
|
||||
"latency_ms": latency.Milliseconds(),
|
||||
"latency": formatDuration(latency),
|
||||
}
|
||||
h.writeJSON(w, resp)
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) handleLogLevel(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
|
||||
client, ok := h.provider.GetClient(accountID)
|
||||
if !ok {
|
||||
h.writeJSON(w, map[string]interface{}{"error": "client not found"})
|
||||
h.writeJSON(w, map[string]any{"error": "client not found"})
|
||||
return
|
||||
}
|
||||
|
||||
level := r.URL.Query().Get("level")
|
||||
if level == "" {
|
||||
h.writeJSON(w, map[string]interface{}{"error": "level parameter required (trace, debug, info, warn, error)"})
|
||||
h.writeJSON(w, map[string]any{"error": "level parameter required (trace, debug, info, warn, error)"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := client.SetLogLevel(level); err != nil {
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
h.writeJSON(w, map[string]any{
|
||||
"success": false,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
h.writeJSON(w, map[string]any{
|
||||
"success": true,
|
||||
"level": level,
|
||||
})
|
||||
@@ -599,7 +605,7 @@ const clientActionTimeout = 30 * time.Second
|
||||
func (h *Handler) handleClientStart(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
|
||||
client, ok := h.provider.GetClient(accountID)
|
||||
if !ok {
|
||||
h.writeJSON(w, map[string]interface{}{"error": "client not found"})
|
||||
h.writeJSON(w, map[string]any{"error": "client not found"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -607,14 +613,14 @@ func (h *Handler) handleClientStart(w http.ResponseWriter, r *http.Request, acco
|
||||
defer cancel()
|
||||
|
||||
if err := client.Start(ctx); err != nil {
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
h.writeJSON(w, map[string]any{
|
||||
"success": false,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
h.writeJSON(w, map[string]any{
|
||||
"success": true,
|
||||
"message": "client started",
|
||||
})
|
||||
@@ -623,7 +629,7 @@ func (h *Handler) handleClientStart(w http.ResponseWriter, r *http.Request, acco
|
||||
func (h *Handler) handleClientStop(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
|
||||
client, ok := h.provider.GetClient(accountID)
|
||||
if !ok {
|
||||
h.writeJSON(w, map[string]interface{}{"error": "client not found"})
|
||||
h.writeJSON(w, map[string]any{"error": "client not found"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -631,19 +637,125 @@ func (h *Handler) handleClientStop(w http.ResponseWriter, r *http.Request, accou
|
||||
defer cancel()
|
||||
|
||||
if err := client.Stop(ctx); err != nil {
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
h.writeJSON(w, map[string]any{
|
||||
"success": false,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
h.writeJSON(w, map[string]any{
|
||||
"success": true,
|
||||
"message": "client stopped",
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) handlePerf(w http.ResponseWriter, r *http.Request) {
|
||||
raw := r.URL.Query().Get("value")
|
||||
if raw == "" {
|
||||
http.Error(w, "value parameter is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
n, err := strconv.ParseUint(raw, 10, 32)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("invalid value %q: %v", raw, err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
capN := uint32(n)
|
||||
applied := 0
|
||||
failed := map[string]string{}
|
||||
for accountID, client := range h.provider.ListClientsForStartup() {
|
||||
if err := client.SetPerformance(nbembed.Performance{PreallocatedBuffersPerPool: &capN}); err != nil {
|
||||
failed[string(accountID)] = err.Error()
|
||||
continue
|
||||
}
|
||||
applied++
|
||||
}
|
||||
|
||||
resp := map[string]any{
|
||||
"success": true,
|
||||
"value": capN,
|
||||
"applied": applied,
|
||||
}
|
||||
if len(failed) > 0 {
|
||||
resp["failed"] = failed
|
||||
}
|
||||
h.writeJSON(w, resp)
|
||||
}
|
||||
|
||||
// handleRuntime returns cheap runtime and process stats. Safe to hit on a
|
||||
// running proxy; does not read pprof profiles.
|
||||
func (h *Handler) handleRuntime(w http.ResponseWriter, _ *http.Request) {
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
|
||||
clients := h.provider.ListClientsForDebug()
|
||||
started := 0
|
||||
for _, c := range clients {
|
||||
if c.HasClient {
|
||||
started++
|
||||
}
|
||||
}
|
||||
|
||||
resp := map[string]any{
|
||||
"uptime": time.Since(h.startTime).Round(time.Second).String(),
|
||||
"goroutines": runtime.NumGoroutine(),
|
||||
"num_cpu": runtime.NumCPU(),
|
||||
"gomaxprocs": runtime.GOMAXPROCS(0),
|
||||
"go_version": runtime.Version(),
|
||||
"heap_alloc": m.HeapAlloc,
|
||||
"heap_inuse": m.HeapInuse,
|
||||
"heap_idle": m.HeapIdle,
|
||||
"heap_released": m.HeapReleased,
|
||||
"heap_sys": m.HeapSys,
|
||||
"sys": m.Sys,
|
||||
"live_objects": m.Mallocs - m.Frees,
|
||||
"num_gc": m.NumGC,
|
||||
"pause_total_ns": m.PauseTotalNs,
|
||||
"clients": len(clients),
|
||||
"started": started,
|
||||
}
|
||||
|
||||
if proc := readProcStatus(); proc != nil {
|
||||
resp["vm_rss"] = proc["VmRSS"]
|
||||
resp["vm_size"] = proc["VmSize"]
|
||||
resp["vm_data"] = proc["VmData"]
|
||||
}
|
||||
|
||||
h.writeJSON(w, resp)
|
||||
}
|
||||
|
||||
// readProcStatus parses /proc/self/status on Linux and returns size fields
|
||||
// in bytes. Returns nil on non-Linux or read failure.
|
||||
func readProcStatus() map[string]uint64 {
|
||||
raw, err := os.ReadFile("/proc/self/status")
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
out := map[string]uint64{}
|
||||
for _, line := range strings.Split(string(raw), "\n") {
|
||||
k, v, ok := strings.Cut(line, ":")
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if k != "VmRSS" && k != "VmSize" && k != "VmData" {
|
||||
continue
|
||||
}
|
||||
fields := strings.Fields(v)
|
||||
if len(fields) < 1 {
|
||||
continue
|
||||
}
|
||||
n, err := strconv.ParseUint(fields[0], 10, 64)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
// Values are reported in kB.
|
||||
out[k] = n * 1024
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
const maxCaptureDuration = 30 * time.Minute
|
||||
|
||||
// handleCapture streams a pcap or text packet capture for the given client.
|
||||
@@ -772,7 +884,7 @@ func (h *Handler) handleHealth(w http.ResponseWriter, r *http.Request, wantJSON
|
||||
h.writeJSON(w, resp)
|
||||
}
|
||||
|
||||
func (h *Handler) renderTemplate(w http.ResponseWriter, name string, data interface{}) {
|
||||
func (h *Handler) renderTemplate(w http.ResponseWriter, name string, data any) {
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
tmpl := h.getTemplates()
|
||||
if tmpl == nil {
|
||||
@@ -785,7 +897,7 @@ func (h *Handler) renderTemplate(w http.ResponseWriter, name string, data interf
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) writeJSON(w http.ResponseWriter, v interface{}) {
|
||||
func (h *Handler) writeJSON(w http.ResponseWriter, v any) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
enc := json.NewEncoder(w)
|
||||
enc.SetIndent("", " ")
|
||||
|
||||
@@ -112,6 +112,7 @@ type ClientConfig struct {
|
||||
MgmtAddr string
|
||||
WGPort uint16
|
||||
PreSharedKey string
|
||||
Performance embed.Performance
|
||||
}
|
||||
|
||||
type statusNotifier interface {
|
||||
@@ -267,6 +268,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
|
||||
BlockInbound: true,
|
||||
WireguardPort: &wgPort,
|
||||
PreSharedKey: n.clientCfg.PreSharedKey,
|
||||
Performance: n.clientCfg.Performance,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create netbird client: %w", err)
|
||||
|
||||
@@ -25,12 +25,6 @@ func (c *peekedConn) Read(b []byte) (int, error) {
|
||||
return c.reader.Read(b)
|
||||
}
|
||||
|
||||
// halfCloser matches connections that support shutting down the write
|
||||
// side while keeping the read side open (e.g. *net.TCPConn).
|
||||
type halfCloser interface {
|
||||
CloseWrite() error
|
||||
}
|
||||
|
||||
// CloseWrite delegates to the underlying connection if it supports
|
||||
// half-close (e.g. *net.TCPConn). Without this, embedding net.Conn
|
||||
// as an interface hides the concrete type's CloseWrite method, making
|
||||
|
||||
156
proxy/internal/tcp/relay.go
Normal file
156
proxy/internal/tcp/relay.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/netutil"
|
||||
)
|
||||
|
||||
// errIdleTimeout is returned when a relay connection is closed due to inactivity.
|
||||
var errIdleTimeout = errors.New("idle timeout")
|
||||
|
||||
// DefaultIdleTimeout is the default idle timeout for TCP relay connections.
|
||||
// A zero value disables idle timeout checking.
|
||||
const DefaultIdleTimeout = 5 * time.Minute
|
||||
|
||||
// halfCloser is implemented by connections that support half-close
|
||||
// (e.g. *net.TCPConn). When one copy direction finishes, we signal
|
||||
// EOF to the remote by closing the write side while keeping the read
|
||||
// side open so the other direction can drain.
|
||||
type halfCloser interface {
|
||||
CloseWrite() error
|
||||
}
|
||||
|
||||
// copyBufPool avoids allocating a new 32KB buffer per io.Copy call.
|
||||
var copyBufPool = sync.Pool{
|
||||
New: func() any {
|
||||
buf := make([]byte, 32*1024)
|
||||
return &buf
|
||||
},
|
||||
}
|
||||
|
||||
// Relay copies data bidirectionally between src and dst until both
|
||||
// sides are done or the context is canceled. When idleTimeout is
|
||||
// non-zero, each direction's read is deadline-guarded; if no data
|
||||
// flows within the timeout the connection is torn down. When one
|
||||
// direction finishes, it half-closes the write side of the
|
||||
// destination (if supported) to signal EOF, allowing the other
|
||||
// direction to drain gracefully before the full connection teardown.
|
||||
func Relay(ctx context.Context, logger *log.Entry, src, dst net.Conn, idleTimeout time.Duration) (srcToDst, dstToSrc int64) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
_ = src.Close()
|
||||
_ = dst.Close()
|
||||
}()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
var errSrcToDst, errDstToSrc error
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
srcToDst, errSrcToDst = copyWithIdleTimeout(dst, src, idleTimeout)
|
||||
halfClose(dst)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
dstToSrc, errDstToSrc = copyWithIdleTimeout(src, dst, idleTimeout)
|
||||
halfClose(src)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if errors.Is(errSrcToDst, errIdleTimeout) || errors.Is(errDstToSrc, errIdleTimeout) {
|
||||
logger.Debug("relay closed due to idle timeout")
|
||||
}
|
||||
if errSrcToDst != nil && !isExpectedCopyError(errSrcToDst) {
|
||||
logger.Debugf("relay copy error (src→dst): %v", errSrcToDst)
|
||||
}
|
||||
if errDstToSrc != nil && !isExpectedCopyError(errDstToSrc) {
|
||||
logger.Debugf("relay copy error (dst→src): %v", errDstToSrc)
|
||||
}
|
||||
|
||||
return srcToDst, dstToSrc
|
||||
}
|
||||
|
||||
// copyWithIdleTimeout copies from src to dst using a pooled buffer.
|
||||
// When idleTimeout > 0 it sets a read deadline on src before each
|
||||
// read and treats a timeout as an idle-triggered close.
|
||||
func copyWithIdleTimeout(dst io.Writer, src io.Reader, idleTimeout time.Duration) (int64, error) {
|
||||
bufp := copyBufPool.Get().(*[]byte)
|
||||
defer copyBufPool.Put(bufp)
|
||||
|
||||
if idleTimeout <= 0 {
|
||||
return io.CopyBuffer(dst, src, *bufp)
|
||||
}
|
||||
|
||||
conn, ok := src.(net.Conn)
|
||||
if !ok {
|
||||
return io.CopyBuffer(dst, src, *bufp)
|
||||
}
|
||||
|
||||
buf := *bufp
|
||||
var total int64
|
||||
for {
|
||||
if err := conn.SetReadDeadline(time.Now().Add(idleTimeout)); err != nil {
|
||||
return total, err
|
||||
}
|
||||
nr, readErr := src.Read(buf)
|
||||
if nr > 0 {
|
||||
n, err := checkedWrite(dst, buf[:nr])
|
||||
total += n
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
}
|
||||
if readErr != nil {
|
||||
if netutil.IsTimeout(readErr) {
|
||||
return total, errIdleTimeout
|
||||
}
|
||||
return total, readErr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// checkedWrite writes buf to dst and returns the number of bytes written.
|
||||
// It guards against short writes and negative counts per io.Copy convention.
|
||||
func checkedWrite(dst io.Writer, buf []byte) (int64, error) {
|
||||
nw, err := dst.Write(buf)
|
||||
if nw < 0 || nw > len(buf) {
|
||||
nw = 0
|
||||
}
|
||||
if err != nil {
|
||||
return int64(nw), err
|
||||
}
|
||||
if nw != len(buf) {
|
||||
return int64(nw), io.ErrShortWrite
|
||||
}
|
||||
return int64(nw), nil
|
||||
}
|
||||
|
||||
func isExpectedCopyError(err error) bool {
|
||||
return errors.Is(err, errIdleTimeout) || netutil.IsExpectedError(err)
|
||||
}
|
||||
|
||||
// halfClose attempts to half-close the write side of the connection.
|
||||
// If the connection does not support half-close, this is a no-op.
|
||||
func halfClose(conn net.Conn) {
|
||||
if hc, ok := conn.(halfCloser); ok {
|
||||
// Best-effort; the full close will follow shortly.
|
||||
_ = hc.CloseWrite()
|
||||
}
|
||||
}
|
||||
@@ -13,13 +13,8 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/netutil"
|
||||
"github.com/netbirdio/netbird/util/netrelay"
|
||||
)
|
||||
|
||||
func testRelay(ctx context.Context, logger *log.Entry, src, dst net.Conn, idleTimeout time.Duration) (int64, int64) {
|
||||
return netrelay.Relay(ctx, src, dst, netrelay.Options{IdleTimeout: idleTimeout, Logger: logger})
|
||||
}
|
||||
|
||||
func TestRelay_BidirectionalCopy(t *testing.T) {
|
||||
srcClient, srcServer := net.Pipe()
|
||||
dstClient, dstServer := net.Pipe()
|
||||
@@ -46,7 +41,7 @@ func TestRelay_BidirectionalCopy(t *testing.T) {
|
||||
srcClient.Close()
|
||||
}()
|
||||
|
||||
s2d, d2s := testRelay(ctx, logger, srcServer, dstServer, 0)
|
||||
s2d, d2s := Relay(ctx, logger, srcServer, dstServer, 0)
|
||||
|
||||
assert.Equal(t, int64(len(srcData)), s2d, "bytes src→dst")
|
||||
assert.Equal(t, int64(len(dstData)), d2s, "bytes dst→src")
|
||||
@@ -63,7 +58,7 @@ func TestRelay_ContextCancellation(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
testRelay(ctx, logger, srcServer, dstServer, 0)
|
||||
Relay(ctx, logger, srcServer, dstServer, 0)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -90,7 +85,7 @@ func TestRelay_OneSideClosed(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
testRelay(ctx, logger, srcServer, dstServer, 0)
|
||||
Relay(ctx, logger, srcServer, dstServer, 0)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -134,7 +129,7 @@ func TestRelay_LargeTransfer(t *testing.T) {
|
||||
dstClient.Close()
|
||||
}()
|
||||
|
||||
s2d, _ := testRelay(ctx, logger, srcServer, dstServer, 0)
|
||||
s2d, _ := Relay(ctx, logger, srcServer, dstServer, 0)
|
||||
assert.Equal(t, int64(len(data)), s2d, "should transfer all bytes")
|
||||
require.NoError(t, <-errCh)
|
||||
}
|
||||
@@ -187,7 +182,7 @@ func TestRelay_IdleTimeout(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
var s2d, d2s int64
|
||||
go func() {
|
||||
s2d, d2s = testRelay(ctx, logger, srcServer, dstServer, 200*time.Millisecond)
|
||||
s2d, d2s = Relay(ctx, logger, srcServer, dstServer, 200*time.Millisecond)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
|
||||
@@ -16,7 +16,6 @@ import (
|
||||
"github.com/netbirdio/netbird/proxy/internal/accesslog"
|
||||
"github.com/netbirdio/netbird/proxy/internal/restrict"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
"github.com/netbirdio/netbird/util/netrelay"
|
||||
)
|
||||
|
||||
// defaultDialTimeout is the fallback dial timeout when no per-route
|
||||
@@ -529,14 +528,11 @@ func (r *Router) relayTCP(ctx context.Context, conn net.Conn, sni SNIHost, route
|
||||
|
||||
idleTimeout := route.SessionIdleTimeout
|
||||
if idleTimeout <= 0 {
|
||||
idleTimeout = netrelay.DefaultIdleTimeout
|
||||
idleTimeout = DefaultIdleTimeout
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
s2d, d2s := netrelay.Relay(svcCtx, conn, backend, netrelay.Options{
|
||||
IdleTimeout: idleTimeout,
|
||||
Logger: entry,
|
||||
})
|
||||
s2d, d2s := Relay(svcCtx, entry, conn, backend, idleTimeout)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if obs != nil {
|
||||
|
||||
@@ -1,409 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel/metric/noop"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/metadata"
|
||||
grpcstatus "google.golang.org/grpc/status"
|
||||
|
||||
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
type byopTestSetup struct {
|
||||
store store.Store
|
||||
proxyService *nbgrpc.ProxyServiceServer
|
||||
grpcServer *grpc.Server
|
||||
grpcAddr string
|
||||
cleanup func()
|
||||
|
||||
accountA string
|
||||
accountB string
|
||||
accountAToken types.PlainProxyToken
|
||||
accountBToken types.PlainProxyToken
|
||||
accountACluster string
|
||||
accountBCluster string
|
||||
}
|
||||
|
||||
func setupBYOPIntegrationTest(t *testing.T) *byopTestSetup {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
testStore, storeCleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||
require.NoError(t, err)
|
||||
|
||||
accountAID := "byop-account-a"
|
||||
accountBID := "byop-account-b"
|
||||
|
||||
for _, acc := range []*types.Account{
|
||||
{Id: accountAID, Domain: "a.test.com", DomainCategory: "private", IsDomainPrimaryAccount: true, CreatedAt: time.Now()},
|
||||
{Id: accountBID, Domain: "b.test.com", DomainCategory: "private", IsDomainPrimaryAccount: true, CreatedAt: time.Now()},
|
||||
} {
|
||||
require.NoError(t, testStore.SaveAccount(ctx, acc))
|
||||
}
|
||||
|
||||
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
pubKey := base64.StdEncoding.EncodeToString(pub)
|
||||
privKey := base64.StdEncoding.EncodeToString(priv)
|
||||
|
||||
clusterA := "byop-a.proxy.test"
|
||||
clusterB := "byop-b.proxy.test"
|
||||
|
||||
services := []*service.Service{
|
||||
{
|
||||
ID: "svc-a1", AccountID: accountAID, Name: "App A1",
|
||||
Domain: "app1." + clusterA, ProxyCluster: clusterA, Enabled: true,
|
||||
SessionPrivateKey: privKey, SessionPublicKey: pubKey,
|
||||
Targets: []*service.Target{{Path: strPtr("/"), Host: "10.0.0.1", Port: 8080, Protocol: "http", TargetId: "peer-a1", TargetType: "peer", Enabled: true}},
|
||||
},
|
||||
{
|
||||
ID: "svc-a2", AccountID: accountAID, Name: "App A2",
|
||||
Domain: "app2." + clusterA, ProxyCluster: clusterA, Enabled: true,
|
||||
SessionPrivateKey: privKey, SessionPublicKey: pubKey,
|
||||
Targets: []*service.Target{{Path: strPtr("/"), Host: "10.0.0.2", Port: 8080, Protocol: "http", TargetId: "peer-a2", TargetType: "peer", Enabled: true}},
|
||||
},
|
||||
{
|
||||
ID: "svc-b1", AccountID: accountBID, Name: "App B1",
|
||||
Domain: "app1." + clusterB, ProxyCluster: clusterB, Enabled: true,
|
||||
SessionPrivateKey: privKey, SessionPublicKey: pubKey,
|
||||
Targets: []*service.Target{{Path: strPtr("/"), Host: "10.0.0.3", Port: 8080, Protocol: "http", TargetId: "peer-b1", TargetType: "peer", Enabled: true}},
|
||||
},
|
||||
}
|
||||
for _, svc := range services {
|
||||
require.NoError(t, testStore.CreateService(ctx, svc))
|
||||
}
|
||||
|
||||
tokenA, err := types.CreateNewProxyAccessToken("byop-token-a", 0, &accountAID, "admin-a")
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, testStore.SaveProxyAccessToken(ctx, &tokenA.ProxyAccessToken))
|
||||
|
||||
tokenB, err := types.CreateNewProxyAccessToken("byop-token-b", 0, &accountBID, "admin-b")
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, testStore.SaveProxyAccessToken(ctx, &tokenB.ProxyAccessToken))
|
||||
|
||||
cacheStore, err := nbcache.NewStore(ctx, 30*time.Minute, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, cacheStore)
|
||||
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, cacheStore)
|
||||
|
||||
meter := noop.NewMeterProvider().Meter("test")
|
||||
realProxyManager, err := proxymanager.NewManager(testStore, meter)
|
||||
require.NoError(t, err)
|
||||
|
||||
oidcConfig := nbgrpc.ProxyOIDCConfig{
|
||||
Issuer: "https://fake-issuer.example.com",
|
||||
ClientID: "test-client",
|
||||
HMACKey: []byte("test-hmac-key"),
|
||||
}
|
||||
|
||||
usersManager := users.NewManager(testStore)
|
||||
|
||||
proxyService := nbgrpc.NewProxyServiceServer(
|
||||
&testAccessLogManager{},
|
||||
tokenStore,
|
||||
pkceStore,
|
||||
oidcConfig,
|
||||
nil,
|
||||
usersManager,
|
||||
realProxyManager,
|
||||
nil,
|
||||
)
|
||||
|
||||
svcMgr := &storeBackedServiceManager{store: testStore, tokenStore: tokenStore}
|
||||
proxyService.SetServiceManager(svcMgr)
|
||||
|
||||
proxyController := &testProxyController{}
|
||||
proxyService.SetProxyController(proxyController)
|
||||
|
||||
_, streamInterceptor, authClose := nbgrpc.NewProxyAuthInterceptors(testStore)
|
||||
|
||||
lis, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
grpcServer := grpc.NewServer(grpc.StreamInterceptor(streamInterceptor))
|
||||
proto.RegisterProxyServiceServer(grpcServer, proxyService)
|
||||
|
||||
go func() {
|
||||
if err := grpcServer.Serve(lis); err != nil {
|
||||
t.Logf("gRPC server error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return &byopTestSetup{
|
||||
store: testStore,
|
||||
proxyService: proxyService,
|
||||
grpcServer: grpcServer,
|
||||
grpcAddr: lis.Addr().String(),
|
||||
cleanup: func() {
|
||||
grpcServer.GracefulStop()
|
||||
authClose()
|
||||
storeCleanup()
|
||||
},
|
||||
accountA: accountAID,
|
||||
accountB: accountBID,
|
||||
accountAToken: tokenA.PlainToken,
|
||||
accountBToken: tokenB.PlainToken,
|
||||
accountACluster: clusterA,
|
||||
accountBCluster: clusterB,
|
||||
}
|
||||
}
|
||||
|
||||
func byopContext(ctx context.Context, token types.PlainProxyToken) context.Context {
|
||||
md := metadata.Pairs("authorization", "Bearer "+string(token))
|
||||
return metadata.NewOutgoingContext(ctx, md)
|
||||
}
|
||||
|
||||
func receiveBYOPMappings(t *testing.T, stream proto.ProxyService_GetMappingUpdateClient) []*proto.ProxyMapping {
|
||||
t.Helper()
|
||||
var mappings []*proto.ProxyMapping
|
||||
for {
|
||||
msg, err := stream.Recv()
|
||||
require.NoError(t, err)
|
||||
mappings = append(mappings, msg.GetMapping()...)
|
||||
if msg.GetInitialSyncComplete() {
|
||||
break
|
||||
}
|
||||
}
|
||||
return mappings
|
||||
}
|
||||
|
||||
func TestIntegration_BYOPProxy_ReceivesOnlyAccountServices(t *testing.T) {
|
||||
setup := setupBYOPIntegrationTest(t)
|
||||
defer setup.cleanup()
|
||||
|
||||
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewProxyServiceClient(conn)
|
||||
|
||||
ctx, cancel := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
stream, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
|
||||
ProxyId: "byop-proxy-a",
|
||||
Version: "test-v1",
|
||||
Address: setup.accountACluster,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
mappings := receiveBYOPMappings(t, stream)
|
||||
|
||||
assert.Len(t, mappings, 2, "BYOP proxy should receive only account A's 2 services")
|
||||
for _, m := range mappings {
|
||||
assert.Equal(t, setup.accountA, m.GetAccountId(), "all mappings should belong to account A")
|
||||
t.Logf("received mapping: id=%s domain=%s account=%s", m.GetId(), m.GetDomain(), m.GetAccountId())
|
||||
}
|
||||
|
||||
ids := map[string]bool{}
|
||||
for _, m := range mappings {
|
||||
ids[m.GetId()] = true
|
||||
}
|
||||
assert.True(t, ids["svc-a1"], "should contain svc-a1")
|
||||
assert.True(t, ids["svc-a2"], "should contain svc-a2")
|
||||
assert.False(t, ids["svc-b1"], "should NOT contain account B's svc-b1")
|
||||
}
|
||||
|
||||
func TestIntegration_BYOPProxy_AccountBReceivesOnlyItsServices(t *testing.T) {
|
||||
setup := setupBYOPIntegrationTest(t)
|
||||
defer setup.cleanup()
|
||||
|
||||
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewProxyServiceClient(conn)
|
||||
|
||||
ctx, cancel := context.WithTimeout(byopContext(context.Background(), setup.accountBToken), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
stream, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
|
||||
ProxyId: "byop-proxy-b",
|
||||
Version: "test-v1",
|
||||
Address: setup.accountBCluster,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
mappings := receiveBYOPMappings(t, stream)
|
||||
|
||||
assert.Len(t, mappings, 1, "BYOP proxy B should receive only 1 service")
|
||||
assert.Equal(t, "svc-b1", mappings[0].GetId())
|
||||
assert.Equal(t, setup.accountB, mappings[0].GetAccountId())
|
||||
}
|
||||
|
||||
func TestIntegration_BYOPProxy_MultiplePerAccount(t *testing.T) {
|
||||
setup := setupBYOPIntegrationTest(t)
|
||||
defer setup.cleanup()
|
||||
|
||||
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewProxyServiceClient(conn)
|
||||
|
||||
ctx1, cancel1 := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second)
|
||||
defer cancel1()
|
||||
|
||||
stream1, err := client.GetMappingUpdate(ctx1, &proto.GetMappingUpdateRequest{
|
||||
ProxyId: "byop-proxy-a-first",
|
||||
Version: "test-v1",
|
||||
Address: setup.accountACluster,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
mappings1 := receiveBYOPMappings(t, stream1)
|
||||
assert.Len(t, mappings1, 2, "first BYOP proxy should receive account A's 2 services")
|
||||
|
||||
ctx2, cancel2 := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second)
|
||||
defer cancel2()
|
||||
|
||||
stream2, err := client.GetMappingUpdate(ctx2, &proto.GetMappingUpdateRequest{
|
||||
ProxyId: "byop-proxy-a-second",
|
||||
Version: "test-v1",
|
||||
Address: setup.accountACluster,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
mappings2 := receiveBYOPMappings(t, stream2)
|
||||
assert.Len(t, mappings2, 2, "second BYOP proxy from same account should also receive the 2 services")
|
||||
for _, m := range mappings2 {
|
||||
assert.Equal(t, setup.accountA, m.GetAccountId())
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_BYOPProxy_ClusterAddressConflict(t *testing.T) {
|
||||
setup := setupBYOPIntegrationTest(t)
|
||||
defer setup.cleanup()
|
||||
|
||||
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewProxyServiceClient(conn)
|
||||
|
||||
ctx1, cancel1 := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second)
|
||||
defer cancel1()
|
||||
|
||||
stream1, err := client.GetMappingUpdate(ctx1, &proto.GetMappingUpdateRequest{
|
||||
ProxyId: "byop-proxy-a-cluster",
|
||||
Version: "test-v1",
|
||||
Address: setup.accountACluster,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_ = receiveBYOPMappings(t, stream1)
|
||||
|
||||
ctx2, cancel2 := context.WithTimeout(byopContext(context.Background(), setup.accountBToken), 5*time.Second)
|
||||
defer cancel2()
|
||||
|
||||
stream2, err := client.GetMappingUpdate(ctx2, &proto.GetMappingUpdateRequest{
|
||||
ProxyId: "byop-proxy-b-conflict",
|
||||
Version: "test-v1",
|
||||
Address: setup.accountACluster,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = stream2.Recv()
|
||||
require.Error(t, err)
|
||||
|
||||
st, ok := grpcstatus.FromError(err)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, codes.AlreadyExists, st.Code(), "cluster address conflict should return AlreadyExists")
|
||||
t.Logf("expected rejection: %s", st.Message())
|
||||
}
|
||||
|
||||
func TestIntegration_BYOPProxy_SameProxyReconnects(t *testing.T) {
|
||||
setup := setupBYOPIntegrationTest(t)
|
||||
defer setup.cleanup()
|
||||
|
||||
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewProxyServiceClient(conn)
|
||||
|
||||
proxyID := "byop-proxy-reconnect"
|
||||
|
||||
ctx1, cancel1 := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second)
|
||||
stream1, err := client.GetMappingUpdate(ctx1, &proto.GetMappingUpdateRequest{
|
||||
ProxyId: proxyID,
|
||||
Version: "test-v1",
|
||||
Address: setup.accountACluster,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
firstMappings := receiveBYOPMappings(t, stream1)
|
||||
cancel1()
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
ctx2, cancel2 := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second)
|
||||
defer cancel2()
|
||||
|
||||
stream2, err := client.GetMappingUpdate(ctx2, &proto.GetMappingUpdateRequest{
|
||||
ProxyId: proxyID,
|
||||
Version: "test-v1",
|
||||
Address: setup.accountACluster,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
secondMappings := receiveBYOPMappings(t, stream2)
|
||||
|
||||
assert.Equal(t, len(firstMappings), len(secondMappings), "reconnect should receive same mappings")
|
||||
|
||||
firstIDs := map[string]bool{}
|
||||
for _, m := range firstMappings {
|
||||
firstIDs[m.GetId()] = true
|
||||
}
|
||||
for _, m := range secondMappings {
|
||||
assert.True(t, firstIDs[m.GetId()], "mapping %s should be present on reconnect", m.GetId())
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_BYOPProxy_UnauthenticatedRejected(t *testing.T) {
|
||||
setup := setupBYOPIntegrationTest(t)
|
||||
defer setup.cleanup()
|
||||
|
||||
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewProxyServiceClient(conn)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
stream, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
|
||||
ProxyId: "no-auth-proxy",
|
||||
Version: "test-v1",
|
||||
Address: "some.cluster.io",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = stream.Recv()
|
||||
require.Error(t, err)
|
||||
|
||||
st, ok := grpcstatus.FromError(err)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, codes.Unauthenticated, st.Code())
|
||||
}
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@@ -141,7 +140,6 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup {
|
||||
nil,
|
||||
usersManager,
|
||||
proxyManager,
|
||||
nil,
|
||||
)
|
||||
|
||||
// Use store-backed service manager
|
||||
@@ -203,8 +201,8 @@ func (m *testAccessLogManager) GetAllAccessLogs(_ context.Context, _, _ string,
|
||||
// testProxyManager is a mock implementation of proxy.Manager for testing.
|
||||
type testProxyManager struct{}
|
||||
|
||||
func (m *testProxyManager) Connect(_ context.Context, proxyID, sessionID, _, _ string, _ *string, _ *nbproxy.Capabilities) (*nbproxy.Proxy, error) {
|
||||
return &nbproxy.Proxy{ID: proxyID, SessionID: sessionID, Status: nbproxy.StatusConnected}, nil
|
||||
func (m *testProxyManager) Connect(_ context.Context, proxyID, sessionID, _, _ string, _ *nbproxy.Capabilities) (*nbproxy.Proxy, error) {
|
||||
return &nbproxy.Proxy{ID: proxyID, SessionID: sessionID, Status: "connected"}, nil
|
||||
}
|
||||
|
||||
func (m *testProxyManager) Disconnect(_ context.Context, _, _ string) error {
|
||||
@@ -219,10 +217,6 @@ func (m *testProxyManager) GetActiveClusterAddresses(_ context.Context) ([]strin
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *testProxyManager) GetActiveClusterAddressesForAccount(_ context.Context, _ string) ([]string, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *testProxyManager) GetActiveClusters(_ context.Context) ([]nbproxy.Cluster, error) {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -243,22 +237,6 @@ func (m *testProxyManager) CleanupStale(_ context.Context, _ time.Duration) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testProxyManager) GetAccountProxy(_ context.Context, accountID string) (*nbproxy.Proxy, error) {
|
||||
return nil, fmt.Errorf("proxy not found for account %s", accountID)
|
||||
}
|
||||
|
||||
func (m *testProxyManager) CountAccountProxies(_ context.Context, _ string) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *testProxyManager) IsClusterAddressAvailable(_ context.Context, _, _ string) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (m *testProxyManager) DeleteAccountCluster(_ context.Context, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// testProxyController is a mock implementation of rpservice.ProxyController for testing.
|
||||
type testProxyController struct{}
|
||||
|
||||
@@ -312,10 +290,6 @@ func (m *storeBackedServiceManager) DeleteService(ctx context.Context, accountID
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *storeBackedServiceManager) DeleteAccountCluster(_ context.Context, _, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *storeBackedServiceManager) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
|
||||
return nil
|
||||
}
|
||||
@@ -362,10 +336,6 @@ func (m *storeBackedServiceManager) StopServiceFromPeer(_ context.Context, _, _,
|
||||
|
||||
func (m *storeBackedServiceManager) StartExposeReaper(_ context.Context) {}
|
||||
|
||||
func (m *storeBackedServiceManager) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) {
|
||||
return m.store.GetServiceByDomain(ctx, domain)
|
||||
}
|
||||
|
||||
func (m *storeBackedServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]nbproxy.Cluster, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -37,6 +37,7 @@ import (
|
||||
"google.golang.org/grpc/keepalive"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/netbirdio/netbird/client/embed"
|
||||
"github.com/netbirdio/netbird/proxy/internal/accesslog"
|
||||
"github.com/netbirdio/netbird/proxy/internal/acme"
|
||||
"github.com/netbirdio/netbird/proxy/internal/auth"
|
||||
@@ -162,6 +163,9 @@ type Server struct {
|
||||
// single-account deployments; multiple accounts will fail to bind
|
||||
// the same port.
|
||||
WireguardPort uint16
|
||||
// Performance configures the tunnel pool/batch sizes for every
|
||||
// embedded client this proxy spawns.
|
||||
Performance embed.Performance
|
||||
// ProxyProtocol enables PROXY protocol (v1/v2) on TCP listeners.
|
||||
// When enabled, the real client IP is extracted from the PROXY header
|
||||
// sent by upstream L4 proxies that support PROXY protocol.
|
||||
@@ -281,6 +285,7 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
||||
MgmtAddr: s.ManagementAddress,
|
||||
WGPort: s.WireguardPort,
|
||||
PreSharedKey: s.PreSharedKey,
|
||||
Performance: s.Performance,
|
||||
}, s.Logger, s, s.mgmtClient)
|
||||
|
||||
// Create health checker before the mapping worker so it can track
|
||||
|
||||
@@ -381,10 +381,6 @@ components:
|
||||
type: boolean
|
||||
readOnly: true
|
||||
example: false
|
||||
local_mfa_enabled:
|
||||
description: Enables or disables TOTP multi-factor authentication for local users. Only applicable when the embedded identity provider is enabled.
|
||||
type: boolean
|
||||
example: false
|
||||
ipv6_enabled_groups:
|
||||
description: List of group IDs whose peers receive IPv6 overlay addresses. Peers not in any of these groups will not be allocated an IPv6 address. New accounts default to the All group.
|
||||
type: array
|
||||
@@ -3355,64 +3351,10 @@ components:
|
||||
example: false
|
||||
required:
|
||||
- enabled
|
||||
ProxyTokenRequest:
|
||||
type: object
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
description: Human-readable token name
|
||||
example: "my-proxy-token"
|
||||
expires_in:
|
||||
type: integer
|
||||
minimum: 0
|
||||
description: Token expiration in seconds (0 = never expires)
|
||||
example: 0
|
||||
required:
|
||||
- name
|
||||
ProxyToken:
|
||||
type: object
|
||||
properties:
|
||||
id:
|
||||
type: string
|
||||
name:
|
||||
type: string
|
||||
expires_at:
|
||||
type: string
|
||||
format: date-time
|
||||
created_at:
|
||||
type: string
|
||||
format: date-time
|
||||
last_used:
|
||||
type: string
|
||||
format: date-time
|
||||
revoked:
|
||||
type: boolean
|
||||
required:
|
||||
- id
|
||||
- name
|
||||
- created_at
|
||||
- revoked
|
||||
ProxyTokenCreated:
|
||||
type: object
|
||||
description: Returned on creation — plain_token is shown only once
|
||||
allOf:
|
||||
- $ref: '#/components/schemas/ProxyToken'
|
||||
- type: object
|
||||
properties:
|
||||
plain_token:
|
||||
type: string
|
||||
description: The plain text token (shown only once)
|
||||
example: "nbx_abc123..."
|
||||
required:
|
||||
- plain_token
|
||||
ProxyCluster:
|
||||
type: object
|
||||
description: A proxy cluster represents a group of proxy nodes serving the same address
|
||||
properties:
|
||||
id:
|
||||
type: string
|
||||
description: Unique identifier of a proxy in this cluster
|
||||
example: "chlfq4q5r8kc73b0qjpg"
|
||||
address:
|
||||
type: string
|
||||
description: Cluster address used for CNAME targets
|
||||
@@ -3421,15 +3363,9 @@ components:
|
||||
type: integer
|
||||
description: Number of proxy nodes connected in this cluster
|
||||
example: 3
|
||||
self_hosted:
|
||||
type: boolean
|
||||
description: Whether this cluster is a self-hosted (BYOP) proxy managed by the account owner
|
||||
example: false
|
||||
required:
|
||||
- id
|
||||
- address
|
||||
- connected_proxies
|
||||
- self_hosted
|
||||
ReverseProxyDomainType:
|
||||
type: string
|
||||
description: Type of Reverse Proxy Domain
|
||||
@@ -11435,111 +11371,6 @@ paths:
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/reverse-proxies/clusters/{clusterAddress}:
|
||||
delete:
|
||||
summary: Delete a self-hosted proxy cluster
|
||||
description: Removes all self-hosted (BYOP) proxy registrations for the given cluster address owned by the account.
|
||||
tags: [ Services ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
- TokenAuth: [ ]
|
||||
parameters:
|
||||
- in: path
|
||||
name: clusterAddress
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The address of the proxy cluster
|
||||
responses:
|
||||
'200':
|
||||
description: Proxy cluster deleted successfully
|
||||
content: { }
|
||||
'400':
|
||||
"$ref": "#/components/responses/bad_request"
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'404':
|
||||
"$ref": "#/components/responses/not_found"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/reverse-proxies/proxy-tokens:
|
||||
get:
|
||||
summary: List Proxy Tokens
|
||||
description: Returns all proxy access tokens for the account
|
||||
tags: [ Self-Hosted Proxies ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
- TokenAuth: [ ]
|
||||
responses:
|
||||
'200':
|
||||
description: A JSON Array of proxy tokens
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/ProxyToken'
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
post:
|
||||
summary: Create a Proxy Token
|
||||
description: Generate an account-scoped proxy access token for self-hosted proxy registration
|
||||
tags: [ Self-Hosted Proxies ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
- TokenAuth: [ ]
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ProxyTokenRequest'
|
||||
responses:
|
||||
'200':
|
||||
description: Proxy token created (plain token shown once)
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ProxyTokenCreated'
|
||||
'400':
|
||||
"$ref": "#/components/responses/bad_request"
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/reverse-proxies/proxy-tokens/{tokenId}:
|
||||
delete:
|
||||
summary: Revoke a Proxy Token
|
||||
description: Revoke an account-scoped proxy access token
|
||||
tags: [ Self-Hosted Proxies ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
- TokenAuth: [ ]
|
||||
parameters:
|
||||
- in: path
|
||||
name: tokenId
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The unique identifier of the proxy token
|
||||
responses:
|
||||
'200':
|
||||
description: Token revoked
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'404':
|
||||
"$ref": "#/components/responses/not_found"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/reverse-proxies/services:
|
||||
get:
|
||||
summary: List all Services
|
||||
|
||||
@@ -1486,9 +1486,6 @@ type AccountSettings struct {
|
||||
// LocalAuthDisabled Indicates whether local (email/password) authentication is disabled. When true, users can only authenticate via external identity providers. This is a read-only field.
|
||||
LocalAuthDisabled *bool `json:"local_auth_disabled,omitempty"`
|
||||
|
||||
// LocalMfaEnabled Enables or disables TOTP multi-factor authentication for local users. Only applicable when the embedded identity provider is enabled.
|
||||
LocalMfaEnabled *bool `json:"local_mfa_enabled,omitempty"`
|
||||
|
||||
// NetworkRange Allows to define a custom network range for the account in CIDR format
|
||||
NetworkRange *string `json:"network_range,omitempty"`
|
||||
|
||||
@@ -3785,49 +3782,11 @@ type ProxyAccessLogsResponse struct {
|
||||
|
||||
// ProxyCluster A proxy cluster represents a group of proxy nodes serving the same address
|
||||
type ProxyCluster struct {
|
||||
// Id Unique identifier of a proxy in this cluster
|
||||
Id string `json:"id"`
|
||||
|
||||
// Address Cluster address used for CNAME targets
|
||||
Address string `json:"address"`
|
||||
|
||||
// ConnectedProxies Number of proxy nodes connected in this cluster
|
||||
ConnectedProxies int `json:"connected_proxies"`
|
||||
|
||||
// SelfHosted Whether this cluster is a self-hosted (BYOP) proxy managed by the account owner
|
||||
SelfHosted bool `json:"self_hosted"`
|
||||
}
|
||||
|
||||
// ProxyToken defines model for ProxyToken.
|
||||
type ProxyToken struct {
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
||||
Id string `json:"id"`
|
||||
LastUsed *time.Time `json:"last_used,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Revoked bool `json:"revoked"`
|
||||
}
|
||||
|
||||
// ProxyTokenCreated defines model for ProxyTokenCreated.
|
||||
type ProxyTokenCreated struct {
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
||||
Id string `json:"id"`
|
||||
LastUsed *time.Time `json:"last_used,omitempty"`
|
||||
Name string `json:"name"`
|
||||
|
||||
// PlainToken The plain text token (shown only once)
|
||||
PlainToken string `json:"plain_token"`
|
||||
Revoked bool `json:"revoked"`
|
||||
}
|
||||
|
||||
// ProxyTokenRequest defines model for ProxyTokenRequest.
|
||||
type ProxyTokenRequest struct {
|
||||
// ExpiresIn Token expiration in seconds (0 = never expires)
|
||||
ExpiresIn *int `json:"expires_in,omitempty"`
|
||||
|
||||
// Name Human-readable token name
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// Resource defines model for Resource.
|
||||
@@ -5198,9 +5157,6 @@ type PutApiPostureChecksPostureCheckIdJSONRequestBody = PostureCheckUpdate
|
||||
// PostApiReverseProxiesDomainsJSONRequestBody defines body for PostApiReverseProxiesDomains for application/json ContentType.
|
||||
type PostApiReverseProxiesDomainsJSONRequestBody = ReverseProxyDomainRequest
|
||||
|
||||
// PostApiReverseProxiesProxyTokensJSONRequestBody defines body for PostApiReverseProxiesProxyTokens for application/json ContentType.
|
||||
type PostApiReverseProxiesProxyTokensJSONRequestBody = ProxyTokenRequest
|
||||
|
||||
// PostApiReverseProxiesServicesJSONRequestBody defines body for PostApiReverseProxiesServices for application/json ContentType.
|
||||
type PostApiReverseProxiesServicesJSONRequestBody = ServiceRequest
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/coder/websocket"
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -34,7 +35,13 @@ func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn,
|
||||
var underlying net.Conn
|
||||
opts := createDialOptions(serverName, &underlying)
|
||||
|
||||
wsConn, resp, err := websocket.Dial(ctx, wsURL, opts)
|
||||
parsedURL, err := url.Parse(wsURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
parsedURL.Path = relay.WebSocketURLPath
|
||||
|
||||
wsConn, resp, err := websocket.Dial(ctx, parsedURL.String(), opts)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil, err
|
||||
@@ -50,27 +57,12 @@ func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn,
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// prepareURL rewrites a rel://host[:port] or rels://host[:port] address into a
|
||||
// ws://host[:port]/relay or wss://host[:port]/relay URL, preserving any
|
||||
// non-standard port from the input.
|
||||
func prepareURL(address string) (string, error) {
|
||||
parsed, err := url.Parse(address)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parse relay address %q: %w", address, err)
|
||||
if !strings.HasPrefix(address, "rel:") && !strings.HasPrefix(address, "rels:") {
|
||||
return "", fmt.Errorf("unsupported scheme: %s", address)
|
||||
}
|
||||
switch parsed.Scheme {
|
||||
case "rel":
|
||||
parsed.Scheme = "ws"
|
||||
case "rels":
|
||||
parsed.Scheme = "wss"
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported scheme: %s", parsed.Scheme)
|
||||
}
|
||||
if parsed.Host == "" {
|
||||
return "", fmt.Errorf("missing host in relay address %q", address)
|
||||
}
|
||||
parsed.Path = relay.WebSocketURLPath
|
||||
return parsed.String(), nil
|
||||
|
||||
return strings.Replace(address, "rel", "ws", 1), nil
|
||||
}
|
||||
|
||||
// httpClientNbDialer builds the http client used by the websocket library.
|
||||
|
||||
@@ -1,76 +0,0 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPrepareURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "rel scheme with non-standard port",
|
||||
input: "rel://test-domain-2:45678",
|
||||
want: "ws://test-domain-2:45678/relay",
|
||||
},
|
||||
{
|
||||
name: "rels scheme with non-standard port",
|
||||
input: "rels://test-domain-2:45678",
|
||||
want: "wss://test-domain-2:45678/relay",
|
||||
},
|
||||
{
|
||||
name: "rel scheme without port",
|
||||
input: "rel://test-domain-2",
|
||||
want: "ws://test-domain-2/relay",
|
||||
},
|
||||
{
|
||||
name: "rels scheme without port",
|
||||
input: "rels://test-domain-2",
|
||||
want: "wss://test-domain-2/relay",
|
||||
},
|
||||
{
|
||||
name: "rel scheme with IP and port",
|
||||
input: "rel://1.2.3.4:45678",
|
||||
want: "ws://1.2.3.4:45678/relay",
|
||||
},
|
||||
{
|
||||
name: "rel scheme with hostname starting with rel",
|
||||
input: "rel://relay.example.com:45678",
|
||||
want: "ws://relay.example.com:45678/relay",
|
||||
},
|
||||
{
|
||||
name: "rel scheme with IPv6 and port",
|
||||
input: "rel://[2001:db8::1]:45678",
|
||||
want: "ws://[2001:db8::1]:45678/relay",
|
||||
},
|
||||
{
|
||||
name: "rels scheme with IPv6 loopback and port",
|
||||
input: "rels://[::1]:45678",
|
||||
want: "wss://[::1]:45678/relay",
|
||||
},
|
||||
{
|
||||
name: "unsupported scheme",
|
||||
input: "http://test-domain-2:45678",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "no scheme",
|
||||
input: "test-domain-2:45678",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := prepareURL(tt.input)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Fatalf("prepareURL(%q) err = %v, wantErr %v", tt.input, err, tt.wantErr)
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("prepareURL(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// TextWriter writes human-readable one-line-per-packet summaries.
|
||||
@@ -595,45 +594,19 @@ func formatDNSResponse(d *layers.DNS, rd string, plen int) string {
|
||||
anCount := d.ANCount
|
||||
nsCount := d.NSCount
|
||||
arCount := d.ARCount
|
||||
ede := formatEDE(d)
|
||||
|
||||
if d.ResponseCode != layers.DNSResponseCodeNoErr {
|
||||
return fmt.Sprintf("%04x %d/%d/%d %s%s (%d)", d.ID, anCount, nsCount, arCount, d.ResponseCode, ede, plen)
|
||||
return fmt.Sprintf("%04x %d/%d/%d %s (%d)", d.ID, anCount, nsCount, arCount, d.ResponseCode, plen)
|
||||
}
|
||||
|
||||
if anCount > 0 && len(d.Answers) > 0 {
|
||||
rr := d.Answers[0]
|
||||
if rdata := shortRData(&rr); rdata != "" {
|
||||
return fmt.Sprintf("%04x %d/%d/%d %s %s%s (%d)", d.ID, anCount, nsCount, arCount, rr.Type, rdata, ede, plen)
|
||||
return fmt.Sprintf("%04x %d/%d/%d %s %s (%d)", d.ID, anCount, nsCount, arCount, rr.Type, rdata, plen)
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%04x %d/%d/%d%s (%d)", d.ID, anCount, nsCount, arCount, ede, plen)
|
||||
}
|
||||
|
||||
// dnsOPTCodeEDE is the EDNS0 option code for Extended DNS Errors (RFC 8914).
|
||||
const dnsOPTCodeEDE layers.DNSOptionCode = layers.DNSOptionCode(dns.EDNS0EDE)
|
||||
|
||||
// formatEDE returns " EDE=Name" for the first Extended DNS Error option
|
||||
// found in the response, or empty string if none is present.
|
||||
func formatEDE(d *layers.DNS) string {
|
||||
for _, rr := range d.Additionals {
|
||||
if rr.Type != layers.DNSTypeOPT {
|
||||
continue
|
||||
}
|
||||
for _, opt := range rr.OPT {
|
||||
if opt.Code != dnsOPTCodeEDE || len(opt.Data) < 2 {
|
||||
continue
|
||||
}
|
||||
info := binary.BigEndian.Uint16(opt.Data[:2])
|
||||
name, ok := dns.ExtendedErrorCodeToString[info]
|
||||
if !ok {
|
||||
name = fmt.Sprintf("%d", info)
|
||||
}
|
||||
return " EDE=" + name
|
||||
}
|
||||
}
|
||||
return ""
|
||||
return fmt.Sprintf("%04x %d/%d/%d (%d)", d.ID, anCount, nsCount, arCount, plen)
|
||||
}
|
||||
|
||||
func shortRData(rr *layers.DNSResourceRecord) string {
|
||||
|
||||
@@ -1,238 +0,0 @@
|
||||
// Package netrelay provides a bidirectional byte-copy helper for TCP-like
|
||||
// connections with correct half-close propagation.
|
||||
//
|
||||
// When one direction reads EOF, the write side of the opposite connection is
|
||||
// half-closed (CloseWrite) so the peer sees FIN, then the second direction is
|
||||
// allowed to drain to its own EOF before both connections are fully closed.
|
||||
// This preserves TCP half-close semantics (e.g. shutdown(SHUT_WR)) that the
|
||||
// naive "cancel-both-on-first-EOF" pattern breaks.
|
||||
package netrelay
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DebugLogger is the minimal interface netrelay uses to surface teardown
|
||||
// errors. Both *logrus.Entry and *nblog.Logger (via its Debugf method)
|
||||
// satisfy it, so callers can pass whichever they already use without an
|
||||
// adapter. Debugf is the only required method; callers with richer
|
||||
// loggers just expose this one shape here.
|
||||
type DebugLogger interface {
|
||||
Debugf(format string, args ...any)
|
||||
}
|
||||
|
||||
// DefaultIdleTimeout is a reasonable default for Options.IdleTimeout. Callers
|
||||
// that want an idle timeout but have no specific preference can use this.
|
||||
const DefaultIdleTimeout = 5 * time.Minute
|
||||
|
||||
// halfCloser is implemented by connections that support half-close
|
||||
// (e.g. *net.TCPConn, *gonet.TCPConn).
|
||||
type halfCloser interface {
|
||||
CloseWrite() error
|
||||
}
|
||||
|
||||
var copyBufPool = sync.Pool{
|
||||
New: func() any {
|
||||
buf := make([]byte, 32*1024)
|
||||
return &buf
|
||||
},
|
||||
}
|
||||
|
||||
// Options configures Relay behavior. The zero value is valid: no idle timeout,
|
||||
// no logging.
|
||||
type Options struct {
|
||||
// IdleTimeout tears down the session if no bytes flow in either
|
||||
// direction within this window. It is a connection-wide watchdog, so a
|
||||
// long unidirectional transfer on one side keeps the other side alive.
|
||||
// Zero disables idle tracking.
|
||||
IdleTimeout time.Duration
|
||||
// Logger receives debug-level copy/idle errors. Nil suppresses logging.
|
||||
// Any logger with Debug/Debugf methods is accepted (logrus.Entry,
|
||||
// uspfilter's nblog.Logger, etc.).
|
||||
Logger DebugLogger
|
||||
}
|
||||
|
||||
// Relay copies bytes in both directions between a and b until both directions
|
||||
// EOF or ctx is canceled. On each direction's EOF it half-closes the
|
||||
// opposite conn's write side (best effort) so the peer sees FIN while the
|
||||
// other direction drains. Both conns are fully closed when Relay returns.
|
||||
//
|
||||
// a and b only need to implement io.ReadWriteCloser; connections that also
|
||||
// implement CloseWrite (e.g. *net.TCPConn, ssh.Channel) get proper half-close
|
||||
// propagation. Options.IdleTimeout, when set, is enforced by a connection-wide
|
||||
// watchdog that tracks reads in either direction.
|
||||
//
|
||||
// Return values are byte counts: aToB (a.Read → b.Write) and bToA (b.Read →
|
||||
// a.Write). Errors are logged via Options.Logger when set; they are not
|
||||
// returned because a relay always terminates on some kind of EOF/cancel.
|
||||
func Relay(ctx context.Context, a, b io.ReadWriteCloser, opts Options) (aToB, bToA int64) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
closeDone := make(chan struct{})
|
||||
defer func() {
|
||||
cancel()
|
||||
<-closeDone
|
||||
}()
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
_ = a.Close()
|
||||
_ = b.Close()
|
||||
close(closeDone)
|
||||
}()
|
||||
|
||||
// Both sides must support CloseWrite to propagate half-close. If either
|
||||
// doesn't, a direction's EOF can't be signaled to the peer and the other
|
||||
// direction would block forever waiting for data; in that case we fall
|
||||
// back to the cancel-both-on-first-EOF behavior.
|
||||
_, aHC := a.(halfCloser)
|
||||
_, bHC := b.(halfCloser)
|
||||
halfCloseSupported := aHC && bHC
|
||||
|
||||
var (
|
||||
lastActivity atomic.Int64
|
||||
idleHit atomic.Bool
|
||||
)
|
||||
lastActivity.Store(time.Now().UnixNano())
|
||||
|
||||
if opts.IdleTimeout > 0 {
|
||||
go watchdog(ctx, cancel, &lastActivity, &idleHit, opts.IdleTimeout)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
var errAToB, errBToA error
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
aToB, errAToB = copyTracked(b, a, &lastActivity)
|
||||
if halfCloseSupported && isCleanEOF(errAToB) {
|
||||
halfClose(b)
|
||||
} else {
|
||||
cancel()
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
bToA, errBToA = copyTracked(a, b, &lastActivity)
|
||||
if halfCloseSupported && isCleanEOF(errBToA) {
|
||||
halfClose(a)
|
||||
} else {
|
||||
cancel()
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if opts.Logger != nil {
|
||||
if idleHit.Load() {
|
||||
opts.Logger.Debugf("relay closed due to idle timeout")
|
||||
}
|
||||
if errAToB != nil && !isExpectedCopyError(errAToB) {
|
||||
opts.Logger.Debugf("relay copy error (a→b): %v", errAToB)
|
||||
}
|
||||
if errBToA != nil && !isExpectedCopyError(errBToA) {
|
||||
opts.Logger.Debugf("relay copy error (b→a): %v", errBToA)
|
||||
}
|
||||
}
|
||||
|
||||
return aToB, bToA
|
||||
}
|
||||
|
||||
// watchdog enforces a connection-wide idle timeout. It cancels ctx when no
|
||||
// activity has been seen on either direction for idle. It exits as soon as
|
||||
// ctx is canceled so it doesn't outlive the relay.
|
||||
func watchdog(ctx context.Context, cancel context.CancelFunc, lastActivity *atomic.Int64, idleHit *atomic.Bool, idle time.Duration) {
|
||||
// Cap the tick at 50ms so detection latency stays bounded regardless of
|
||||
// how large idle is, and fall back to idle/2 when that is smaller so
|
||||
// very short timeouts (mainly in tests) are still caught promptly.
|
||||
tick := min(idle/2, 50*time.Millisecond)
|
||||
if tick <= 0 {
|
||||
tick = time.Millisecond
|
||||
}
|
||||
t := time.NewTicker(tick)
|
||||
defer t.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-t.C:
|
||||
last := time.Unix(0, lastActivity.Load())
|
||||
if time.Since(last) >= idle {
|
||||
idleHit.Store(true)
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// copyTracked copies from src to dst using a pooled buffer, updating
|
||||
// lastActivity on every successful read so a shared watchdog can enforce a
|
||||
// connection-wide idle timeout.
|
||||
func copyTracked(dst io.Writer, src io.Reader, lastActivity *atomic.Int64) (int64, error) {
|
||||
bufp := copyBufPool.Get().(*[]byte)
|
||||
defer copyBufPool.Put(bufp)
|
||||
|
||||
buf := *bufp
|
||||
var total int64
|
||||
for {
|
||||
nr, readErr := src.Read(buf)
|
||||
if nr > 0 {
|
||||
lastActivity.Store(time.Now().UnixNano())
|
||||
n, werr := checkedWrite(dst, buf[:nr])
|
||||
total += n
|
||||
if werr != nil {
|
||||
return total, werr
|
||||
}
|
||||
}
|
||||
if readErr != nil {
|
||||
return total, readErr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func checkedWrite(dst io.Writer, buf []byte) (int64, error) {
|
||||
nw, err := dst.Write(buf)
|
||||
if nw < 0 || nw > len(buf) {
|
||||
nw = 0
|
||||
}
|
||||
if err != nil {
|
||||
return int64(nw), err
|
||||
}
|
||||
if nw != len(buf) {
|
||||
return int64(nw), io.ErrShortWrite
|
||||
}
|
||||
return int64(nw), nil
|
||||
}
|
||||
|
||||
func halfClose(conn io.ReadWriteCloser) {
|
||||
if hc, ok := conn.(halfCloser); ok {
|
||||
_ = hc.CloseWrite()
|
||||
}
|
||||
}
|
||||
|
||||
// isCleanEOF reports whether a copy terminated on a graceful end-of-stream.
|
||||
// Only in that case is it correct to propagate the EOF via CloseWrite on the
|
||||
// peer; any other error means the flow is broken and both directions should
|
||||
// tear down.
|
||||
func isCleanEOF(err error) bool {
|
||||
return err == nil || errors.Is(err, io.EOF)
|
||||
}
|
||||
|
||||
func isExpectedCopyError(err error) bool {
|
||||
return errors.Is(err, net.ErrClosed) ||
|
||||
errors.Is(err, context.Canceled) ||
|
||||
errors.Is(err, io.EOF) ||
|
||||
errors.Is(err, syscall.ECONNRESET) ||
|
||||
errors.Is(err, syscall.EPIPE) ||
|
||||
errors.Is(err, syscall.ECONNABORTED)
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user