mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-02 22:19:54 +00:00
Compare commits
13 Commits
test/perft
...
test/conne
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2ea7fb7b21 | ||
|
|
93f530637d | ||
|
|
3a84475d14 | ||
|
|
3d7368e51f | ||
|
|
318cf59d66 | ||
|
|
e9b2a6e808 | ||
|
|
2dbdb5c1a7 | ||
|
|
2cdab6d7b7 | ||
|
|
e49c0e8862 | ||
|
|
e7c84d0ead | ||
|
|
1c934cca64 | ||
|
|
4aff4a6424 | ||
|
|
1bd7190954 |
@@ -358,9 +358,9 @@ func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
|
||||
// Fast path for IPv4 addresses (4 bytes) - most common case
|
||||
if len(oldBytes) == 4 && len(newBytes) == 4 {
|
||||
sum += uint32(^binary.BigEndian.Uint16(oldBytes[0:2]))
|
||||
sum += uint32(^binary.BigEndian.Uint16(oldBytes[2:4]))
|
||||
sum += uint32(^binary.BigEndian.Uint16(oldBytes[2:4])) //nolint:gosec // length checked above
|
||||
sum += uint32(binary.BigEndian.Uint16(newBytes[0:2]))
|
||||
sum += uint32(binary.BigEndian.Uint16(newBytes[2:4]))
|
||||
sum += uint32(binary.BigEndian.Uint16(newBytes[2:4])) //nolint:gosec // length checked above
|
||||
} else {
|
||||
// Fallback for other lengths
|
||||
for i := 0; i < len(oldBytes)-1; i += 2 {
|
||||
|
||||
@@ -589,6 +589,101 @@ func Test_ConnectPeers(t *testing.T) {
|
||||
|
||||
}
|
||||
|
||||
func Test_UserSpaceAddAllowedIPs(t *testing.T) {
|
||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+5)
|
||||
wgIP := "10.99.99.21/30"
|
||||
wgPort := 33105
|
||||
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
opts := WGIFaceOpts{
|
||||
IFaceName: ifaceName,
|
||||
Address: wgIP,
|
||||
WGPort: wgPort,
|
||||
WGPrivKey: key,
|
||||
MTU: DefaultMTU,
|
||||
TransportNet: newNet,
|
||||
}
|
||||
|
||||
iface, err := NewWGIFace(opts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = iface.Create()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := iface.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
_, err = iface.Up()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
keepAlive := 15 * time.Second
|
||||
initialAllowedIP := netip.MustParsePrefix("10.99.99.22/32")
|
||||
endpoint, err := net.ResolveUDPAddr("udp", "127.0.0.1:9905")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Add peer with initial endpoint and first allowed IP
|
||||
err = iface.UpdatePeer(peerPubKey, []netip.Prefix{initialAllowedIP}, keepAlive, endpoint, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Phase 1: generate 500 allowed IPs into a list
|
||||
const extraIPs = 500
|
||||
addedPrefixes := make([]netip.Prefix, 0, extraIPs)
|
||||
for i := 0; i < extraIPs; i++ {
|
||||
// Use 172.16.x.y/32 range: i encoded as two octets
|
||||
prefix := netip.MustParsePrefix(fmt.Sprintf("172.16.%d.%d/32", i/256, i%256))
|
||||
addedPrefixes = append(addedPrefixes, prefix)
|
||||
}
|
||||
|
||||
// Phase 2: iterate over the list and add each allowed IP to the peer
|
||||
phase2Start := time.Now()
|
||||
for _, prefix := range addedPrefixes {
|
||||
if addErr := iface.AddAllowedIP(peerPubKey, prefix); addErr != nil {
|
||||
t.Fatalf("failed to add allowed IP %s: %v", prefix, addErr)
|
||||
}
|
||||
}
|
||||
t.Logf("Phase 2 (add %d IPs to peer): %s", extraIPs, time.Since(phase2Start))
|
||||
|
||||
// Verify the peer has all 101 allowed IPs (1 initial + 100 added)
|
||||
peer, err := getPeer(ifaceName, peerPubKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if peer.Endpoint.String() != endpoint.String() {
|
||||
t.Fatalf("expected endpoint %s, got %s", endpoint, peer.Endpoint)
|
||||
}
|
||||
|
||||
allExpected := append([]netip.Prefix{initialAllowedIP}, addedPrefixes...)
|
||||
if len(peer.AllowedIPs) != len(allExpected) {
|
||||
t.Fatalf("expected %d allowed IPs, got %d", len(allExpected), len(peer.AllowedIPs))
|
||||
}
|
||||
|
||||
allowedIPSet := make(map[string]struct{}, len(peer.AllowedIPs))
|
||||
for _, aip := range peer.AllowedIPs {
|
||||
allowedIPSet[aip.String()] = struct{}{}
|
||||
}
|
||||
for _, expected := range allExpected {
|
||||
if _, found := allowedIPSet[expected.String()]; !found {
|
||||
t.Errorf("expected allowed IP %s not found in peer config", expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getPeer(ifaceName, peerPubKey string) (wgtypes.Peer, error) {
|
||||
wg, err := wgctrl.New()
|
||||
if err != nil {
|
||||
|
||||
@@ -290,6 +290,10 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
return wrapErr(err)
|
||||
}
|
||||
|
||||
if relayClient.IsDisableRelay() {
|
||||
relayURLs = []string{}
|
||||
}
|
||||
|
||||
relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String(), engineConfig.MTU)
|
||||
c.statusRecorder.SetRelayMgr(relayManager)
|
||||
if len(relayURLs) > 0 {
|
||||
@@ -310,6 +314,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
c.engineMutex.Lock()
|
||||
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks, stateManager)
|
||||
engine.SetSyncResponsePersistence(c.persistSyncResponse)
|
||||
engine.SetReadyChan(runningChan)
|
||||
runningChan = nil
|
||||
c.engine = engine
|
||||
c.engineMutex.Unlock()
|
||||
|
||||
@@ -330,11 +336,6 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
|
||||
state.Set(StatusConnected)
|
||||
|
||||
if runningChan != nil {
|
||||
close(runningChan)
|
||||
runningChan = nil
|
||||
}
|
||||
|
||||
<-engineCtx.Done()
|
||||
|
||||
c.engineMutex.Lock()
|
||||
|
||||
@@ -28,8 +28,8 @@ import (
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/internal/acl"
|
||||
"github.com/netbirdio/netbird/client/internal/debug"
|
||||
@@ -217,6 +217,10 @@ type Engine struct {
|
||||
// WireGuard interface monitor
|
||||
wgIfaceMonitor *WGIfaceMonitor
|
||||
|
||||
// readyChan is closed when the first sync message is received from management
|
||||
readyChan chan struct{}
|
||||
readyChanOnce sync.Once
|
||||
|
||||
// shutdownWg tracks all long-running goroutines to ensure clean shutdown
|
||||
shutdownWg sync.WaitGroup
|
||||
|
||||
@@ -275,6 +279,10 @@ func NewEngine(
|
||||
return engine
|
||||
}
|
||||
|
||||
func (e *Engine) SetReadyChan(ch chan struct{}) {
|
||||
e.readyChan = ch
|
||||
}
|
||||
|
||||
func (e *Engine) Stop() error {
|
||||
if e == nil {
|
||||
// this seems to be a very odd case but there was the possibility if the netbird down command comes before the engine is fully started
|
||||
@@ -834,6 +842,13 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
defer func() {
|
||||
log.Infof("sync finished in %s", time.Since(started))
|
||||
}()
|
||||
|
||||
e.readyChanOnce.Do(func() {
|
||||
if e.readyChan != nil {
|
||||
close(e.readyChan)
|
||||
}
|
||||
})
|
||||
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
@@ -880,9 +895,11 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
// todo update signal
|
||||
}
|
||||
|
||||
uCheckTime := time.Now()
|
||||
if err := e.updateChecksIfNew(update.Checks); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Infof("update check finished in %s", time.Since(uCheckTime))
|
||||
|
||||
nm := update.GetNetworkMap()
|
||||
if nm == nil {
|
||||
@@ -925,7 +942,9 @@ func (e *Engine) handleRelayUpdate(update *mgmProto.RelayConfig) error {
|
||||
return fmt.Errorf("update relay token: %w", err)
|
||||
}
|
||||
|
||||
e.relayManager.UpdateServerURLs(update.Urls)
|
||||
if !relayClient.IsDisableRelay() {
|
||||
e.relayManager.UpdateServerURLs(update.Urls)
|
||||
}
|
||||
|
||||
// Just in case the agent started with an MGM server where the relay was disabled but was later enabled.
|
||||
// We can ignore all errors because the guard will manage the reconnection retries.
|
||||
|
||||
@@ -434,14 +434,14 @@ func (conn *Conn) onICEStateDisconnected(sessionChanged bool) {
|
||||
conn.resetEndpoint()
|
||||
}
|
||||
|
||||
// todo consider to move after the ConfigureWGEndpoint
|
||||
conn.wgProxyRelay.Work()
|
||||
|
||||
presharedKey := conn.presharedKey(conn.rosenpassRemoteKey)
|
||||
if err := conn.endpointUpdater.ConfigureWGEndpoint(conn.wgProxyRelay.EndpointAddr(), presharedKey); err != nil {
|
||||
if err := conn.endpointUpdater.SwitchWGEndpoint(conn.wgProxyRelay.EndpointAddr(), presharedKey); err != nil {
|
||||
conn.Log.Errorf("failed to switch to relay conn: %v", err)
|
||||
}
|
||||
|
||||
conn.wgProxyRelay.Work()
|
||||
conn.currentConnPriority = conntype.Relay
|
||||
} else {
|
||||
conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String())
|
||||
@@ -503,20 +503,22 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
||||
return
|
||||
}
|
||||
|
||||
wgProxy.Work()
|
||||
presharedKey := conn.presharedKey(rci.rosenpassPubKey)
|
||||
controller := isController(conn.config)
|
||||
|
||||
if controller {
|
||||
wgProxy.Work()
|
||||
}
|
||||
conn.enableWgWatcherIfNeeded()
|
||||
|
||||
if err := conn.endpointUpdater.ConfigureWGEndpoint(wgProxy.EndpointAddr(), presharedKey); err != nil {
|
||||
if err := conn.endpointUpdater.ConfigureWGEndpoint(wgProxy.EndpointAddr(), conn.presharedKey(rci.rosenpassPubKey)); err != nil {
|
||||
if err := wgProxy.CloseConn(); err != nil {
|
||||
conn.Log.Warnf("Failed to close relay connection: %v", err)
|
||||
}
|
||||
conn.Log.Errorf("Failed to update WireGuard peer configuration: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
wgConfigWorkaround()
|
||||
if !controller {
|
||||
wgProxy.Work()
|
||||
}
|
||||
conn.rosenpassRemoteKey = rci.rosenpassPubKey
|
||||
conn.currentConnPriority = conntype.Relay
|
||||
conn.statusRelay.SetConnected()
|
||||
@@ -877,9 +879,3 @@ func isController(config ConnConfig) bool {
|
||||
func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool {
|
||||
return remoteRosenpassPubKey != nil
|
||||
}
|
||||
|
||||
// wgConfigWorkaround is a workaround for the issue with WireGuard configuration update
|
||||
// When update a peer configuration in near to each other time, the second update can be ignored by WireGuard
|
||||
func wgConfigWorkaround() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
@@ -34,28 +34,27 @@ func NewEndpointUpdater(log *logrus.Entry, wgConfig WgConfig, initiator bool) *E
|
||||
}
|
||||
}
|
||||
|
||||
// ConfigureWGEndpoint sets up the WireGuard endpoint configuration.
|
||||
// The initiator immediately configures the endpoint, while the non-initiator
|
||||
// waits for a fallback period before configuring to avoid handshake congestion.
|
||||
func (e *EndpointUpdater) ConfigureWGEndpoint(addr *net.UDPAddr, presharedKey *wgtypes.Key) error {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
|
||||
if e.initiator {
|
||||
e.log.Debugf("configure up WireGuard as initiatr")
|
||||
return e.updateWireGuardPeer(addr, presharedKey)
|
||||
e.log.Debugf("configure up WireGuard as initiator")
|
||||
return e.configureAsInitiator(addr, presharedKey)
|
||||
}
|
||||
|
||||
e.log.Debugf("configure up WireGuard as responder")
|
||||
return e.configureAsResponder(addr, presharedKey)
|
||||
}
|
||||
|
||||
func (e *EndpointUpdater) SwitchWGEndpoint(addr *net.UDPAddr, presharedKey *wgtypes.Key) error {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
|
||||
// prevent to run new update while cancel the previous update
|
||||
e.waitForCloseTheDelayedUpdate()
|
||||
|
||||
var ctx context.Context
|
||||
ctx, e.cancelFunc = context.WithCancel(context.Background())
|
||||
e.updateWg.Add(1)
|
||||
go e.scheduleDelayedUpdate(ctx, addr, presharedKey)
|
||||
|
||||
e.log.Debugf("configure up WireGuard and wait for handshake")
|
||||
return e.updateWireGuardPeer(nil, presharedKey)
|
||||
return e.updateWireGuardPeer(addr, presharedKey)
|
||||
}
|
||||
|
||||
func (e *EndpointUpdater) RemoveWgPeer() error {
|
||||
@@ -67,9 +66,37 @@ func (e *EndpointUpdater) RemoveWgPeer() error {
|
||||
}
|
||||
|
||||
func (e *EndpointUpdater) RemoveEndpointAddress() error {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
|
||||
e.waitForCloseTheDelayedUpdate()
|
||||
return e.wgConfig.WgInterface.RemoveEndpointAddress(e.wgConfig.RemoteKey)
|
||||
}
|
||||
|
||||
func (e *EndpointUpdater) configureAsInitiator(addr *net.UDPAddr, presharedKey *wgtypes.Key) error {
|
||||
if err := e.updateWireGuardPeer(addr, presharedKey); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *EndpointUpdater) configureAsResponder(addr *net.UDPAddr, presharedKey *wgtypes.Key) error {
|
||||
// prevent to run new update while cancel the previous update
|
||||
e.waitForCloseTheDelayedUpdate()
|
||||
|
||||
e.log.Debugf("configure up WireGuard and wait for handshake")
|
||||
var ctx context.Context
|
||||
ctx, e.cancelFunc = context.WithCancel(context.Background())
|
||||
e.updateWg.Add(1)
|
||||
go e.scheduleDelayedUpdate(ctx, addr, presharedKey)
|
||||
|
||||
if err := e.updateWireGuardPeer(nil, presharedKey); err != nil {
|
||||
e.waitForCloseTheDelayedUpdate()
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *EndpointUpdater) waitForCloseTheDelayedUpdate() {
|
||||
if e.cancelFunc == nil {
|
||||
return
|
||||
@@ -105,3 +132,9 @@ func (e *EndpointUpdater) updateWireGuardPeer(endpoint *net.UDPAddr, presharedKe
|
||||
presharedKey,
|
||||
)
|
||||
}
|
||||
|
||||
// wgConfigWorkaround is a workaround for the issue with WireGuard configuration update
|
||||
// When update a peer configuration in near to each other time, the second update can be ignored by WireGuard
|
||||
func wgConfigWorkaround() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -125,6 +126,10 @@ func GenerateICECredentials() (string, string, error) {
|
||||
}
|
||||
|
||||
func CandidateTypes() []ice.CandidateType {
|
||||
if relayClient.IsDisableRelay() {
|
||||
return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay}
|
||||
}
|
||||
|
||||
if hasICEForceRelayConn() {
|
||||
return []ice.CandidateType{ice.CandidateTypeRelay}
|
||||
}
|
||||
|
||||
@@ -488,15 +488,17 @@ func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*
|
||||
mgmtPort, _ := strconv.Atoi(portStr)
|
||||
|
||||
mgmtSrv := mgmtServer.NewServer(
|
||||
mgmtConfig,
|
||||
dnsDomain,
|
||||
singleAccModeDomain,
|
||||
mgmtPort,
|
||||
cfg.Server.MetricsPort,
|
||||
mgmt.DisableAnonymousMetrics,
|
||||
mgmt.DisableGeoliteUpdate,
|
||||
// Always enable user deletion from IDP in combined server (embedded IdP is always enabled)
|
||||
true,
|
||||
&mgmtServer.Config{
|
||||
NbConfig: mgmtConfig,
|
||||
DNSDomain: dnsDomain,
|
||||
MgmtSingleAccModeDomain: singleAccModeDomain,
|
||||
MgmtPort: mgmtPort,
|
||||
MgmtMetricsPort: cfg.Server.MetricsPort,
|
||||
DisableMetrics: mgmt.DisableAnonymousMetrics,
|
||||
DisableGeoliteUpdate: mgmt.DisableGeoliteUpdate,
|
||||
// Always enable user deletion from IDP in combined server (embedded IdP is always enabled)
|
||||
UserDeleteFromIDPEnabled: true,
|
||||
},
|
||||
)
|
||||
|
||||
return mgmtSrv, nil
|
||||
|
||||
@@ -1,463 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/embed"
|
||||
)
|
||||
|
||||
const (
|
||||
echoPort = 9000
|
||||
connectTimeout = 120 * time.Second
|
||||
startTimeout = 60 * time.Second
|
||||
stopTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
type peerInfo struct {
|
||||
client *embed.Client
|
||||
tunnelIP string
|
||||
name string
|
||||
}
|
||||
|
||||
type pairStats struct {
|
||||
from string
|
||||
to string
|
||||
sent int64
|
||||
received int64
|
||||
lost int64
|
||||
rtts []time.Duration
|
||||
}
|
||||
|
||||
func (s *pairStats) summary() (avgRTT, minRTT, maxRTT time.Duration, lossPercent float64) {
|
||||
if len(s.rtts) == 0 {
|
||||
return 0, 0, 0, 100
|
||||
}
|
||||
minRTT = s.rtts[0]
|
||||
maxRTT = s.rtts[0]
|
||||
var total time.Duration
|
||||
for _, rtt := range s.rtts {
|
||||
total += rtt
|
||||
if rtt < minRTT {
|
||||
minRTT = rtt
|
||||
}
|
||||
if rtt > maxRTT {
|
||||
maxRTT = rtt
|
||||
}
|
||||
}
|
||||
avgRTT = total / time.Duration(len(s.rtts))
|
||||
if s.sent > 0 {
|
||||
lossPercent = float64(s.lost) / float64(s.sent) * 100
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func main() {
|
||||
managementURL := flag.String("management-url", "", "Management server URL (required)")
|
||||
setupKey := flag.String("setup-key", "", "Reusable setup key (required)")
|
||||
numPeers := flag.Int("peers", 5, "Number of peers to spawn")
|
||||
forceRelay := flag.Bool("force-relay", false, "Force relay connections (NB_FORCE_RELAY=true)")
|
||||
duration := flag.Duration("duration", 30*time.Second, "Traffic test duration")
|
||||
packetSize := flag.Int("packet-size", 512, "UDP packet size in bytes")
|
||||
logLevel := flag.String("log-level", "panic", "Client log level (trace, debug, info, warn, error, panic)")
|
||||
flag.Parse()
|
||||
|
||||
if *managementURL == "" || *setupKey == "" {
|
||||
fmt.Fprintln(os.Stderr, "Error: --management-url and --setup-key are required")
|
||||
flag.Usage()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if *numPeers < 2 {
|
||||
fmt.Fprintln(os.Stderr, "Error: --peers must be at least 2")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Minimum packet size: 8 bytes for timestamp + 8 bytes for sequence number
|
||||
if *packetSize < 16 {
|
||||
fmt.Fprintln(os.Stderr, "Error: --packet-size must be at least 16")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if *forceRelay {
|
||||
os.Setenv("NB_FORCE_RELAY", "true")
|
||||
}
|
||||
os.Setenv("NB_USE_NETSTACK_MODE", "true")
|
||||
|
||||
fmt.Println("=== NetBird Performance Test ===")
|
||||
fmt.Printf("Management URL: %s\n", *managementURL)
|
||||
fmt.Printf("Peers: %d\n", *numPeers)
|
||||
fmt.Printf("Force relay: %v\n", *forceRelay)
|
||||
fmt.Printf("Duration: %s\n", *duration)
|
||||
fmt.Printf("Packet size: %d bytes\n", *packetSize)
|
||||
fmt.Println()
|
||||
|
||||
// Phase 1: Create peers
|
||||
fmt.Println("--- Phase 1: Creating peers ---")
|
||||
peers := make([]peerInfo, *numPeers)
|
||||
for i := 0; i < *numPeers; i++ {
|
||||
name := fmt.Sprintf("perf-peer-%d", i)
|
||||
port := 0
|
||||
c, err := embed.New(embed.Options{
|
||||
DeviceName: name,
|
||||
SetupKey: *setupKey,
|
||||
ManagementURL: *managementURL,
|
||||
WireguardPort: &port,
|
||||
LogLevel: *logLevel,
|
||||
LogOutput: io.Discard,
|
||||
})
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error creating peer %s: %v\n", name, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
peers[i] = peerInfo{client: c, name: name}
|
||||
fmt.Printf(" Created %s\n", name)
|
||||
}
|
||||
|
||||
// Phase 2: Start peers in parallel
|
||||
fmt.Println("\n--- Phase 2: Starting peers ---")
|
||||
startTime := time.Now()
|
||||
var wg sync.WaitGroup
|
||||
startErrors := make([]error, *numPeers)
|
||||
|
||||
for i := range peers {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), startTimeout)
|
||||
defer cancel()
|
||||
t := time.Now()
|
||||
if err := peers[idx].client.Start(ctx); err != nil {
|
||||
startErrors[idx] = err
|
||||
return
|
||||
}
|
||||
fmt.Printf(" %s started in %s\n", peers[idx].name, time.Since(t).Round(time.Millisecond))
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// Check for start errors
|
||||
var failed bool
|
||||
for i, err := range startErrors {
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, " Error starting %s: %v\n", peers[i].name, err)
|
||||
failed = true
|
||||
}
|
||||
}
|
||||
if failed {
|
||||
cleanup(peers)
|
||||
os.Exit(1)
|
||||
}
|
||||
fmt.Printf(" All peers started in %s\n", time.Since(startTime).Round(time.Millisecond))
|
||||
|
||||
// Get tunnel IPs
|
||||
for i := range peers {
|
||||
status, err := peers[i].client.Status()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, " Error getting status for %s: %v\n", peers[i].name, err)
|
||||
cleanup(peers)
|
||||
os.Exit(1)
|
||||
}
|
||||
ip := status.LocalPeerState.IP
|
||||
// Strip CIDR suffix if present (e.g. "100.64.0.1/16" -> "100.64.0.1")
|
||||
if idx := strings.Index(ip, "/"); idx != -1 {
|
||||
ip = ip[:idx]
|
||||
}
|
||||
peers[i].tunnelIP = ip
|
||||
fmt.Printf(" %s -> %s\n", peers[i].name, peers[i].tunnelIP)
|
||||
}
|
||||
|
||||
// Phase 3: Wait for connections
|
||||
fmt.Println("\n--- Phase 3: Waiting for peer connections ---")
|
||||
connStart := time.Now()
|
||||
expectedPeers := *numPeers - 1
|
||||
deadline := time.After(connectTimeout)
|
||||
ticker := time.NewTicker(2 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
allConnected := false
|
||||
waitLoop:
|
||||
for {
|
||||
select {
|
||||
case <-deadline:
|
||||
fmt.Fprintf(os.Stderr, " Timeout waiting for connections after %s\n", connectTimeout)
|
||||
printConnectionStatus(peers)
|
||||
cleanup(peers)
|
||||
os.Exit(1)
|
||||
case <-ticker.C:
|
||||
allConnected = true
|
||||
for i := range peers {
|
||||
connected := countConnectedPeers(peers[i].client)
|
||||
if connected < expectedPeers {
|
||||
allConnected = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if allConnected {
|
||||
break waitLoop
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Printf(" All peers connected in %s\n", time.Since(connStart).Round(time.Millisecond))
|
||||
printConnectionStatus(peers)
|
||||
|
||||
// Phase 4: Traffic test
|
||||
fmt.Printf("\n--- Phase 4: Traffic test (%s) ---\n", *duration)
|
||||
|
||||
// Start echo listeners on all peers
|
||||
listeners := make([]net.PacketConn, *numPeers)
|
||||
for i := range peers {
|
||||
conn, err := peers[i].client.ListenUDP(fmt.Sprintf(":%d", echoPort))
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, " Error creating listener on %s: %v\n", peers[i].name, err)
|
||||
cleanup(peers)
|
||||
os.Exit(1)
|
||||
}
|
||||
listeners[i] = conn
|
||||
go echoServer(conn, *packetSize)
|
||||
fmt.Printf(" Echo listener started on %s:%d\n", peers[i].tunnelIP, echoPort)
|
||||
}
|
||||
|
||||
// Run traffic between all pairs (i < j)
|
||||
var statsMu sync.Mutex
|
||||
var allStats []pairStats
|
||||
|
||||
var trafficWg sync.WaitGroup
|
||||
for i := 0; i < *numPeers; i++ {
|
||||
for j := i + 1; j < *numPeers; j++ {
|
||||
trafficWg.Add(1)
|
||||
go func(from, to int) {
|
||||
defer trafficWg.Done()
|
||||
stats := runTraffic(peers[from].client, peers[from].name, peers[to].tunnelIP, peers[to].name, *duration, *packetSize)
|
||||
statsMu.Lock()
|
||||
allStats = append(allStats, stats)
|
||||
statsMu.Unlock()
|
||||
}(i, j)
|
||||
}
|
||||
}
|
||||
trafficWg.Wait()
|
||||
|
||||
// Close listeners
|
||||
for _, l := range listeners {
|
||||
if l != nil {
|
||||
l.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 5: Report
|
||||
fmt.Println("\n--- Phase 5: Results ---")
|
||||
printReport(allStats)
|
||||
|
||||
// Cleanup
|
||||
fmt.Println("\n--- Cleanup ---")
|
||||
cleanup(peers)
|
||||
fmt.Println("Done.")
|
||||
}
|
||||
|
||||
func countConnectedPeers(c *embed.Client) int {
|
||||
status, err := c.Status()
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
count := 0
|
||||
for _, p := range status.Peers {
|
||||
if p.ConnStatus == embed.PeerStatusConnected {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func printConnectionStatus(peers []peerInfo) {
|
||||
for i := range peers {
|
||||
status, err := peers[i].client.Status()
|
||||
if err != nil {
|
||||
fmt.Printf(" %s: error getting status: %v\n", peers[i].name, err)
|
||||
continue
|
||||
}
|
||||
connected := 0
|
||||
relayed := 0
|
||||
for _, p := range status.Peers {
|
||||
if p.ConnStatus == embed.PeerStatusConnected {
|
||||
connected++
|
||||
if p.Relayed {
|
||||
relayed++
|
||||
}
|
||||
}
|
||||
}
|
||||
connType := "direct"
|
||||
if relayed > 0 {
|
||||
connType = fmt.Sprintf("%d direct, %d relayed", connected-relayed, relayed)
|
||||
}
|
||||
fmt.Printf(" %s: %d/%d connected (%s)\n", peers[i].name, connected, len(status.Peers), connType)
|
||||
}
|
||||
}
|
||||
|
||||
func echoServer(conn net.PacketConn, maxSize int) {
|
||||
buf := make([]byte, maxSize+100)
|
||||
for {
|
||||
n, addr, err := conn.ReadFrom(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, _ = conn.WriteTo(buf[:n], addr)
|
||||
}
|
||||
}
|
||||
|
||||
func runTraffic(client *embed.Client, fromName, toIP, toName string, duration time.Duration, packetSize int) pairStats {
|
||||
stats := pairStats{
|
||||
from: fromName,
|
||||
to: toName,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), duration+10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
conn, err := client.Dial(ctx, "udp", fmt.Sprintf("%s:%d", toIP, echoPort))
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, " Error dialing %s -> %s: %v\n", fromName, toName, err)
|
||||
return stats
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
deadline := time.Now().Add(duration)
|
||||
buf := make([]byte, packetSize)
|
||||
recvBuf := make([]byte, packetSize+100)
|
||||
var seq uint64
|
||||
|
||||
for time.Now().Before(deadline) {
|
||||
seq++
|
||||
// Encode timestamp and sequence number
|
||||
binary.BigEndian.PutUint64(buf[0:8], uint64(time.Now().UnixNano()))
|
||||
binary.BigEndian.PutUint64(buf[8:16], seq)
|
||||
|
||||
stats.sent++
|
||||
conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
|
||||
_, err := conn.Write(buf)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, " Error sending packet to %s: %v\n", toName, err)
|
||||
stats.lost++
|
||||
continue
|
||||
}
|
||||
|
||||
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||
n, err := conn.Read(recvBuf)
|
||||
if err != nil {
|
||||
stats.lost++
|
||||
continue
|
||||
}
|
||||
|
||||
if n >= 8 {
|
||||
sentNano := binary.BigEndian.Uint64(recvBuf[0:8])
|
||||
rtt := time.Since(time.Unix(0, int64(sentNano)))
|
||||
stats.received++
|
||||
stats.rtts = append(stats.rtts, rtt)
|
||||
} else {
|
||||
stats.received++
|
||||
}
|
||||
|
||||
// Small sleep to avoid flooding
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
func printReport(allStats []pairStats) {
|
||||
if len(allStats) == 0 {
|
||||
fmt.Println(" No traffic data collected.")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf(" %-30s %8s %8s %8s %8s %10s %10s %10s\n",
|
||||
"Pair", "Sent", "Recv", "Lost", "Loss%", "Avg RTT", "Min RTT", "Max RTT")
|
||||
fmt.Println(" " + strings.Repeat("-", 108))
|
||||
|
||||
var totalSent, totalRecv, totalLost int64
|
||||
var totalRTTs []time.Duration
|
||||
|
||||
for _, s := range allStats {
|
||||
avg, min, max, loss := s.summary()
|
||||
pair := fmt.Sprintf("%s -> %s", s.from, s.to)
|
||||
fmt.Printf(" %-30s %8d %8d %8d %7.1f%% %10s %10s %10s\n",
|
||||
pair, s.sent, s.received, s.lost, loss,
|
||||
avg.Round(time.Microsecond), min.Round(time.Microsecond), max.Round(time.Microsecond))
|
||||
totalSent += s.sent
|
||||
totalRecv += s.received
|
||||
totalLost += s.lost
|
||||
totalRTTs = append(totalRTTs, s.rtts...)
|
||||
}
|
||||
|
||||
fmt.Println(" " + strings.Repeat("-", 108))
|
||||
|
||||
// Overall summary
|
||||
var overallLoss float64
|
||||
if totalSent > 0 {
|
||||
overallLoss = float64(totalLost) / float64(totalSent) * 100
|
||||
}
|
||||
|
||||
var avgRTT, minRTT, maxRTT time.Duration
|
||||
if len(totalRTTs) > 0 {
|
||||
minRTT = totalRTTs[0]
|
||||
maxRTT = totalRTTs[0]
|
||||
var total time.Duration
|
||||
for _, rtt := range totalRTTs {
|
||||
total += rtt
|
||||
if rtt < minRTT {
|
||||
minRTT = rtt
|
||||
}
|
||||
if rtt > maxRTT {
|
||||
maxRTT = rtt
|
||||
}
|
||||
}
|
||||
avgRTT = total / time.Duration(len(totalRTTs))
|
||||
}
|
||||
|
||||
fmt.Printf(" %-30s %8d %8d %8d %7.1f%% %10s %10s %10s\n",
|
||||
"TOTAL", totalSent, totalRecv, totalLost, overallLoss,
|
||||
avgRTT.Round(time.Microsecond), minRTT.Round(time.Microsecond), maxRTT.Round(time.Microsecond))
|
||||
|
||||
// Extra stats
|
||||
if len(totalRTTs) > 0 {
|
||||
fmt.Println()
|
||||
var sumSq float64
|
||||
avgNs := float64(avgRTT.Nanoseconds())
|
||||
for _, rtt := range totalRTTs {
|
||||
diff := float64(rtt.Nanoseconds()) - avgNs
|
||||
sumSq += diff * diff
|
||||
}
|
||||
stddev := time.Duration(math.Sqrt(sumSq / float64(len(totalRTTs))))
|
||||
|
||||
fmt.Printf(" RTT stddev: %s\n", stddev.Round(time.Microsecond))
|
||||
fmt.Printf(" Pairs tested: %d\n", len(allStats))
|
||||
}
|
||||
}
|
||||
|
||||
func cleanup(peers []peerInfo) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), stopTimeout)
|
||||
defer cancel()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := range peers {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
if err := peers[idx].client.Stop(ctx); err != nil {
|
||||
fmt.Fprintf(os.Stderr, " Error stopping %s: %v\n", peers[idx].name, err)
|
||||
} else {
|
||||
fmt.Printf(" Stopped %s\n", peers[idx].name)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
@@ -1,61 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
BINARY="$SCRIPT_DIR/perftest"
|
||||
|
||||
# Defaults
|
||||
MANAGEMENT_URL="${MANAGEMENT_URL:-}"
|
||||
SETUP_KEY="${SETUP_KEY:-}"
|
||||
PEERS="${PEERS:-5}"
|
||||
DURATION="${DURATION:-30s}"
|
||||
PACKET_SIZE="${PACKET_SIZE:-512}"
|
||||
FORCE_RELAY="${FORCE_RELAY:-false}"
|
||||
LOG_LEVEL="${LOG_LEVEL:-panic}"
|
||||
|
||||
usage() {
|
||||
cat <<EOF
|
||||
Usage: MANAGEMENT_URL=... SETUP_KEY=... $0 [options]
|
||||
|
||||
Environment variables (or flags):
|
||||
MANAGEMENT_URL Management server URL (required)
|
||||
SETUP_KEY Reusable setup key (required). Use ephemeral.
|
||||
PEERS Number of peers (default: 5)
|
||||
DURATION Traffic test duration (default: 30s)
|
||||
PACKET_SIZE UDP packet size in bytes (default: 512)
|
||||
FORCE_RELAY Force relay mode (default: false)
|
||||
LOG_LEVEL Client log level (default: panic)
|
||||
|
||||
All extra arguments are passed directly to the binary.
|
||||
EOF
|
||||
exit 1
|
||||
}
|
||||
|
||||
if [[ -z "$MANAGEMENT_URL" || -z "$SETUP_KEY" ]]; then
|
||||
echo "Error: MANAGEMENT_URL and SETUP_KEY must be set"
|
||||
echo
|
||||
usage
|
||||
fi
|
||||
|
||||
# Build
|
||||
echo "Building perftest..."
|
||||
cd "$SCRIPT_DIR"
|
||||
go build -o "$BINARY" .
|
||||
echo "Build OK: $BINARY"
|
||||
echo
|
||||
|
||||
# Run
|
||||
ARGS=(
|
||||
--management-url "$MANAGEMENT_URL"
|
||||
--setup-key "$SETUP_KEY"
|
||||
--peers "$PEERS"
|
||||
--duration "$DURATION"
|
||||
--packet-size "$PACKET_SIZE"
|
||||
--log-level "$LOG_LEVEL"
|
||||
)
|
||||
|
||||
if [[ "$FORCE_RELAY" == "true" ]]; then
|
||||
ARGS+=(--force-relay)
|
||||
fi
|
||||
|
||||
exec "$BINARY" "${ARGS[@]}" "$@"
|
||||
2
go.mod
2
go.mod
@@ -40,7 +40,7 @@ require (
|
||||
github.com/c-robinson/iplib v1.0.3
|
||||
github.com/caddyserver/certmagic v0.21.3
|
||||
github.com/cilium/ebpf v0.15.0
|
||||
github.com/coder/websocket v1.8.13
|
||||
github.com/coder/websocket v1.8.14
|
||||
github.com/coreos/go-iptables v0.7.0
|
||||
github.com/coreos/go-oidc/v3 v3.14.1
|
||||
github.com/creack/pty v1.1.24
|
||||
|
||||
4
go.sum
4
go.sum
@@ -107,8 +107,8 @@ github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UF
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/cilium/ebpf v0.15.0 h1:7NxJhNiBT3NG8pZJ3c+yfrVdHY8ScgKD27sScgjLMMk=
|
||||
github.com/cilium/ebpf v0.15.0/go.mod h1:DHp1WyrLeiBh19Cf/tfiSMhqheEiK8fXFZ4No0P1Hso=
|
||||
github.com/coder/websocket v1.8.13 h1:f3QZdXy7uGVz+4uCJy2nTZyM0yTBj8yANEHhqlXZ9FE=
|
||||
github.com/coder/websocket v1.8.13/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs=
|
||||
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
|
||||
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
|
||||
github.com/containerd/containerd v1.7.29 h1:90fWABQsaN9mJhGkoVnuzEY+o1XDPbg9BTC9QTAHnuE=
|
||||
github.com/containerd/containerd v1.7.29/go.mod h1:azUkWcOvHrWvaiUjSQH0fjzuHIwSPg1WL5PshGP4Szs=
|
||||
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
|
||||
|
||||
@@ -99,15 +99,16 @@ func NewProvider(ctx context.Context, config *Config) (*Provider, error) {
|
||||
|
||||
// Build Dex server config - use Dex's types directly
|
||||
dexConfig := server.Config{
|
||||
Issuer: issuer,
|
||||
Storage: stor,
|
||||
SkipApprovalScreen: true,
|
||||
SupportedResponseTypes: []string{"code"},
|
||||
Logger: logger,
|
||||
PrometheusRegistry: prometheus.NewRegistry(),
|
||||
RotateKeysAfter: 6 * time.Hour,
|
||||
IDTokensValidFor: 24 * time.Hour,
|
||||
RefreshTokenPolicy: refreshPolicy,
|
||||
Issuer: issuer,
|
||||
Storage: stor,
|
||||
SkipApprovalScreen: true,
|
||||
SupportedResponseTypes: []string{"code"},
|
||||
ContinueOnConnectorFailure: true,
|
||||
Logger: logger,
|
||||
PrometheusRegistry: prometheus.NewRegistry(),
|
||||
RotateKeysAfter: 6 * time.Hour,
|
||||
IDTokensValidFor: 24 * time.Hour,
|
||||
RefreshTokenPolicy: refreshPolicy,
|
||||
Web: server.WebConfig{
|
||||
Issuer: "NetBird",
|
||||
},
|
||||
@@ -260,6 +261,7 @@ func buildDexConfig(yamlConfig *YAMLConfig, stor storage.Storage, logger *slog.L
|
||||
if len(cfg.SupportedResponseTypes) == 0 {
|
||||
cfg.SupportedResponseTypes = []string{"code"}
|
||||
}
|
||||
cfg.ContinueOnConnectorFailure = true
|
||||
return cfg
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package dex
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
@@ -195,3 +196,64 @@ enablePasswordDB: true
|
||||
|
||||
t.Logf("User lookup successful: rawID=%s, connectorID=%s", rawID, connID)
|
||||
}
|
||||
|
||||
func TestNewProvider_ContinueOnConnectorFailure(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "dex-connector-failure-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
config := &Config{
|
||||
Issuer: "http://localhost:5556/dex",
|
||||
Port: 5556,
|
||||
DataDir: tmpDir,
|
||||
}
|
||||
|
||||
provider, err := NewProvider(ctx, config)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = provider.Stop(ctx) }()
|
||||
|
||||
// The provider should have started successfully even though
|
||||
// ContinueOnConnectorFailure is an internal Dex config field.
|
||||
// We verify the provider is functional by performing a basic operation.
|
||||
assert.NotNil(t, provider.dexServer)
|
||||
assert.NotNil(t, provider.storage)
|
||||
}
|
||||
|
||||
func TestBuildDexConfig_ContinueOnConnectorFailure(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "dex-build-config-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
yamlContent := `
|
||||
issuer: http://localhost:5556/dex
|
||||
storage:
|
||||
type: sqlite3
|
||||
config:
|
||||
file: ` + filepath.Join(tmpDir, "dex.db") + `
|
||||
web:
|
||||
http: 127.0.0.1:5556
|
||||
enablePasswordDB: true
|
||||
`
|
||||
configPath := filepath.Join(tmpDir, "config.yaml")
|
||||
err = os.WriteFile(configPath, []byte(yamlContent), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
yamlConfig, err := LoadConfig(configPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
stor, err := yamlConfig.Storage.OpenStorage(slog.New(slog.NewTextHandler(os.Stderr, nil)))
|
||||
require.NoError(t, err)
|
||||
defer stor.Close()
|
||||
|
||||
err = initializeStorage(ctx, stor, yamlConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
||||
cfg := buildDexConfig(yamlConfig, stor, logger)
|
||||
|
||||
assert.True(t, cfg.ContinueOnConnectorFailure,
|
||||
"buildDexConfig must set ContinueOnConnectorFailure to true so management starts even if an external IdP is down")
|
||||
}
|
||||
|
||||
@@ -577,9 +577,6 @@ render_docker_compose_traefik_builtin() {
|
||||
proxy:
|
||||
image: $NETBIRD_PROXY_IMAGE
|
||||
container_name: netbird-proxy
|
||||
# Hairpin NAT fix: route domain back to traefik's static IP within Docker
|
||||
extra_hosts:
|
||||
- \"$NETBIRD_DOMAIN:$TRAEFIK_IP\"
|
||||
ports:
|
||||
- 51820:51820/udp
|
||||
restart: unless-stopped
|
||||
@@ -822,9 +819,6 @@ NB_PROXY_TOKEN=$PROXY_TOKEN
|
||||
NB_PROXY_CERTIFICATE_DIRECTORY=/certs
|
||||
NB_PROXY_ACME_CERTIFICATES=true
|
||||
NB_PROXY_ACME_CHALLENGE_TYPE=tls-alpn-01
|
||||
NB_PROXY_OIDC_CLIENT_ID=netbird-proxy
|
||||
NB_PROXY_OIDC_ENDPOINT=$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/oauth2
|
||||
NB_PROXY_OIDC_SCOPES=openid,profile,email
|
||||
NB_PROXY_FORWARDED_PROTO=https
|
||||
# Enable PROXY protocol to preserve client IPs through L4 proxies (Traefik TCP passthrough)
|
||||
NB_PROXY_PROXY_PROTOCOL=true
|
||||
|
||||
@@ -29,11 +29,11 @@ import (
|
||||
"github.com/netbirdio/netbird/util/crypt"
|
||||
)
|
||||
|
||||
var newServer = func(config *nbconfig.Config, dnsDomain, mgmtSingleAccModeDomain string, mgmtPort int, mgmtMetricsPort int, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled bool) server.Server {
|
||||
return server.NewServer(config, dnsDomain, mgmtSingleAccModeDomain, mgmtPort, mgmtMetricsPort, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled)
|
||||
var newServer = func(cfg *server.Config) server.Server {
|
||||
return server.NewServer(cfg)
|
||||
}
|
||||
|
||||
func SetNewServer(fn func(config *nbconfig.Config, dnsDomain, mgmtSingleAccModeDomain string, mgmtPort int, mgmtMetricsPort int, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled bool) server.Server) {
|
||||
func SetNewServer(fn func(*server.Config) server.Server) {
|
||||
newServer = fn
|
||||
}
|
||||
|
||||
@@ -110,7 +110,17 @@ var (
|
||||
mgmtSingleAccModeDomain = ""
|
||||
}
|
||||
|
||||
srv := newServer(config, dnsDomain, mgmtSingleAccModeDomain, mgmtPort, mgmtMetricsPort, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled)
|
||||
srv := newServer(&server.Config{
|
||||
NbConfig: config,
|
||||
DNSDomain: dnsDomain,
|
||||
MgmtSingleAccModeDomain: mgmtSingleAccModeDomain,
|
||||
MgmtPort: mgmtPort,
|
||||
MgmtMetricsPort: mgmtMetricsPort,
|
||||
DisableLegacyManagementPort: disableLegacyManagementPort,
|
||||
DisableMetrics: disableMetrics,
|
||||
DisableGeoliteUpdate: disableGeoliteUpdate,
|
||||
UserDeleteFromIDPEnabled: userDeleteFromIDPEnabled,
|
||||
})
|
||||
go func() {
|
||||
if err := srv.Start(cmd.Context()); err != nil {
|
||||
log.Fatalf("Server error: %v", err)
|
||||
|
||||
@@ -16,21 +16,22 @@ const (
|
||||
)
|
||||
|
||||
var (
|
||||
dnsDomain string
|
||||
mgmtDataDir string
|
||||
logLevel string
|
||||
logFile string
|
||||
disableMetrics bool
|
||||
disableSingleAccMode bool
|
||||
disableGeoliteUpdate bool
|
||||
idpSignKeyRefreshEnabled bool
|
||||
userDeleteFromIDPEnabled bool
|
||||
mgmtPort int
|
||||
mgmtMetricsPort int
|
||||
mgmtLetsencryptDomain string
|
||||
mgmtSingleAccModeDomain string
|
||||
certFile string
|
||||
certKey string
|
||||
dnsDomain string
|
||||
mgmtDataDir string
|
||||
logLevel string
|
||||
logFile string
|
||||
disableMetrics bool
|
||||
disableSingleAccMode bool
|
||||
disableGeoliteUpdate bool
|
||||
idpSignKeyRefreshEnabled bool
|
||||
userDeleteFromIDPEnabled bool
|
||||
mgmtPort int
|
||||
mgmtMetricsPort int
|
||||
disableLegacyManagementPort bool
|
||||
mgmtLetsencryptDomain string
|
||||
mgmtSingleAccModeDomain string
|
||||
certFile string
|
||||
certKey string
|
||||
|
||||
rootCmd = &cobra.Command{
|
||||
Use: "netbird-mgmt",
|
||||
@@ -55,6 +56,7 @@ func Execute() error {
|
||||
|
||||
func init() {
|
||||
mgmtCmd.Flags().IntVar(&mgmtPort, "port", 80, "server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise")
|
||||
mgmtCmd.Flags().BoolVar(&disableLegacyManagementPort, "disable-legacy-port", false, "disabling the old legacy port (33073)")
|
||||
mgmtCmd.Flags().IntVar(&mgmtMetricsPort, "metrics-port", 9090, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics")
|
||||
mgmtCmd.Flags().StringVar(&mgmtDataDir, "datadir", defaultMgmtDataDir, "server data directory location")
|
||||
mgmtCmd.Flags().StringVar(&nbconfig.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location. Config params specified via command line (e.g. datadir) have a precedence over configuration from this file")
|
||||
|
||||
@@ -50,13 +50,14 @@ type BaseServer struct {
|
||||
// AfterInit is a function that will be called after the server is initialized
|
||||
afterInit []func(s *BaseServer)
|
||||
|
||||
disableMetrics bool
|
||||
dnsDomain string
|
||||
disableGeoliteUpdate bool
|
||||
userDeleteFromIDPEnabled bool
|
||||
mgmtSingleAccModeDomain string
|
||||
mgmtMetricsPort int
|
||||
mgmtPort int
|
||||
disableMetrics bool
|
||||
dnsDomain string
|
||||
disableGeoliteUpdate bool
|
||||
userDeleteFromIDPEnabled bool
|
||||
mgmtSingleAccModeDomain string
|
||||
mgmtMetricsPort int
|
||||
mgmtPort int
|
||||
disableLegacyManagementPort bool
|
||||
|
||||
proxyAuthClose func()
|
||||
|
||||
@@ -69,18 +70,32 @@ type BaseServer struct {
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// Config holds the configuration parameters for creating a new server
|
||||
type Config struct {
|
||||
NbConfig *nbconfig.Config
|
||||
DNSDomain string
|
||||
MgmtSingleAccModeDomain string
|
||||
MgmtPort int
|
||||
MgmtMetricsPort int
|
||||
DisableLegacyManagementPort bool
|
||||
DisableMetrics bool
|
||||
DisableGeoliteUpdate bool
|
||||
UserDeleteFromIDPEnabled bool
|
||||
}
|
||||
|
||||
// NewServer initializes and configures a new Server instance
|
||||
func NewServer(config *nbconfig.Config, dnsDomain, mgmtSingleAccModeDomain string, mgmtPort, mgmtMetricsPort int, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled bool) *BaseServer {
|
||||
func NewServer(cfg *Config) *BaseServer {
|
||||
return &BaseServer{
|
||||
Config: config,
|
||||
container: make(map[string]any),
|
||||
dnsDomain: dnsDomain,
|
||||
mgmtSingleAccModeDomain: mgmtSingleAccModeDomain,
|
||||
disableMetrics: disableMetrics,
|
||||
disableGeoliteUpdate: disableGeoliteUpdate,
|
||||
userDeleteFromIDPEnabled: userDeleteFromIDPEnabled,
|
||||
mgmtPort: mgmtPort,
|
||||
mgmtMetricsPort: mgmtMetricsPort,
|
||||
Config: cfg.NbConfig,
|
||||
container: make(map[string]any),
|
||||
dnsDomain: cfg.DNSDomain,
|
||||
mgmtSingleAccModeDomain: cfg.MgmtSingleAccModeDomain,
|
||||
disableMetrics: cfg.DisableMetrics,
|
||||
disableGeoliteUpdate: cfg.DisableGeoliteUpdate,
|
||||
userDeleteFromIDPEnabled: cfg.UserDeleteFromIDPEnabled,
|
||||
mgmtPort: cfg.MgmtPort,
|
||||
disableLegacyManagementPort: cfg.DisableLegacyManagementPort,
|
||||
mgmtMetricsPort: cfg.MgmtMetricsPort,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -152,7 +167,7 @@ func (s *BaseServer) Start(ctx context.Context) error {
|
||||
}
|
||||
|
||||
var compatListener net.Listener
|
||||
if s.mgmtPort != ManagementLegacyPort {
|
||||
if s.mgmtPort != ManagementLegacyPort && !s.disableLegacyManagementPort {
|
||||
// The Management gRPC server was running on port 33073 previously. Old agents that are already connected to it
|
||||
// are using port 33073. For compatibility purposes we keep running a 2nd gRPC server on port 33073.
|
||||
compatListener, err = s.serveGRPC(srvCtx, s.GRPCServer(), ManagementLegacyPort)
|
||||
|
||||
@@ -224,6 +224,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
||||
s.syncSem.Add(1)
|
||||
|
||||
reqStart := time.Now()
|
||||
syncStart := reqStart.UTC()
|
||||
|
||||
ctx := srv.Context()
|
||||
|
||||
@@ -300,7 +301,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
||||
metahash := metaHash(peerMeta, realIP.String())
|
||||
s.loginFilter.addLogin(peerKey.String(), metahash)
|
||||
|
||||
peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP, reqStart)
|
||||
peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP, syncStart)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err)
|
||||
s.syncSem.Add(-1)
|
||||
@@ -311,7 +312,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
|
||||
s.syncSem.Add(-1)
|
||||
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, reqStart)
|
||||
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, syncStart)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -319,7 +320,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("error while notify peer connected for %s: %v", peerKey.String(), err)
|
||||
s.syncSem.Add(-1)
|
||||
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, reqStart)
|
||||
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, syncStart)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -336,7 +337,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
||||
|
||||
s.syncSem.Add(-1)
|
||||
|
||||
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv, reqStart)
|
||||
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv, syncStart)
|
||||
}
|
||||
|
||||
func (s *Server) handleHandshake(ctx context.Context, srv proto.ManagementService_JobServer) (wgtypes.Key, error) {
|
||||
|
||||
@@ -6,14 +6,14 @@ import (
|
||||
"os"
|
||||
"os/signal"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/crypto/acme"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy"
|
||||
nbacme "github.com/netbirdio/netbird/proxy/internal/acme"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
@@ -46,10 +46,6 @@ var (
|
||||
debugEndpoint bool
|
||||
debugEndpointAddr string
|
||||
healthAddr string
|
||||
oidcClientID string
|
||||
oidcClientSecret string
|
||||
oidcEndpoint string
|
||||
oidcScopes string
|
||||
forwardedProto string
|
||||
trustedProxies string
|
||||
certFile string
|
||||
@@ -81,10 +77,6 @@ func init() {
|
||||
rootCmd.Flags().BoolVar(&debugEndpoint, "debug-endpoint", envBoolOrDefault("NB_PROXY_DEBUG_ENDPOINT", false), "Enable debug HTTP endpoint")
|
||||
rootCmd.Flags().StringVar(&debugEndpointAddr, "debug-endpoint-addr", envStringOrDefault("NB_PROXY_DEBUG_ENDPOINT_ADDRESS", "localhost:8444"), "Address for the debug HTTP endpoint")
|
||||
rootCmd.Flags().StringVar(&healthAddr, "health-addr", envStringOrDefault("NB_PROXY_HEALTH_ADDRESS", "localhost:8080"), "Address for the health probe endpoint (liveness/readiness/startup)")
|
||||
rootCmd.Flags().StringVar(&oidcClientID, "oidc-id", envStringOrDefault("NB_PROXY_OIDC_CLIENT_ID", "netbird-proxy"), "The OAuth2 Client ID for OIDC User Authentication")
|
||||
rootCmd.Flags().StringVar(&oidcClientSecret, "oidc-secret", envStringOrDefault("NB_PROXY_OIDC_CLIENT_SECRET", ""), "The OAuth2 Client Secret for OIDC User Authentication")
|
||||
rootCmd.Flags().StringVar(&oidcEndpoint, "oidc-endpoint", envStringOrDefault("NB_PROXY_OIDC_ENDPOINT", ""), "The OIDC Endpoint for OIDC User Authentication")
|
||||
rootCmd.Flags().StringVar(&oidcScopes, "oidc-scopes", envStringOrDefault("NB_PROXY_OIDC_SCOPES", "openid,profile,email"), "The OAuth2 scopes for OIDC User Authentication, comma separated")
|
||||
rootCmd.Flags().StringVar(&forwardedProto, "forwarded-proto", envStringOrDefault("NB_PROXY_FORWARDED_PROTO", "auto"), "X-Forwarded-Proto value for backends: auto, http, or https")
|
||||
rootCmd.Flags().StringVar(&trustedProxies, "trusted-proxies", envStringOrDefault("NB_PROXY_TRUSTED_PROXIES", ""), "Comma-separated list of trusted upstream proxy CIDR ranges (e.g. '10.0.0.0/8,192.168.1.1')")
|
||||
rootCmd.Flags().StringVar(&certFile, "cert-file", envStringOrDefault("NB_PROXY_CERTIFICATE_FILE", "tls.crt"), "TLS certificate filename within the certificate directory")
|
||||
@@ -159,10 +151,6 @@ func runServer(cmd *cobra.Command, args []string) error {
|
||||
DebugEndpointEnabled: debugEndpoint,
|
||||
DebugEndpointAddress: debugEndpointAddr,
|
||||
HealthAddress: healthAddr,
|
||||
OIDCClientId: oidcClientID,
|
||||
OIDCClientSecret: oidcClientSecret,
|
||||
OIDCEndpoint: oidcEndpoint,
|
||||
OIDCScopes: strings.Split(oidcScopes, ","),
|
||||
ForwardedProto: forwardedProto,
|
||||
TrustedProxies: parsedTrustedProxies,
|
||||
CertLockMethod: nbacme.CertLockMethod(certLockMethod),
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
||||
"github.com/netbirdio/netbird/proxy/internal/responsewriter"
|
||||
"github.com/netbirdio/netbird/proxy/web"
|
||||
)
|
||||
|
||||
@@ -27,8 +28,8 @@ func (l *Logger) Middleware(next http.Handler) http.Handler {
|
||||
|
||||
// Use a response writer wrapper so we can access the status code later.
|
||||
sw := &statusWriter{
|
||||
w: w,
|
||||
status: http.StatusOK,
|
||||
PassthroughWriter: responsewriter.New(w),
|
||||
status: http.StatusOK,
|
||||
}
|
||||
|
||||
// Resolve the source IP using trusted proxy configuration before passing
|
||||
|
||||
@@ -1,26 +1,18 @@
|
||||
package accesslog
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"github.com/netbirdio/netbird/proxy/internal/responsewriter"
|
||||
)
|
||||
|
||||
// statusWriter is a simple wrapper around an http.ResponseWriter
|
||||
// that captures the setting of the status code via the WriteHeader
|
||||
// function and stores it so that it can be retrieved later.
|
||||
// statusWriter captures the HTTP status code from WriteHeader calls.
|
||||
// It embeds responsewriter.PassthroughWriter which handles all the optional
|
||||
// interfaces (Hijacker, Flusher, Pusher) automatically.
|
||||
type statusWriter struct {
|
||||
w http.ResponseWriter
|
||||
*responsewriter.PassthroughWriter
|
||||
status int
|
||||
}
|
||||
|
||||
func (w *statusWriter) Header() http.Header {
|
||||
return w.w.Header()
|
||||
}
|
||||
|
||||
func (w *statusWriter) Write(data []byte) (int, error) {
|
||||
return w.w.Write(data)
|
||||
}
|
||||
|
||||
func (w *statusWriter) WriteHeader(status int) {
|
||||
w.status = status
|
||||
w.w.WriteHeader(status)
|
||||
w.PassthroughWriter.WriteHeader(status)
|
||||
}
|
||||
|
||||
49
proxy/internal/conntrack/conn.go
Normal file
49
proxy/internal/conntrack/conn.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"net"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// trackedConn wraps a net.Conn and removes itself from the tracker on Close.
|
||||
type trackedConn struct {
|
||||
net.Conn
|
||||
tracker *HijackTracker
|
||||
}
|
||||
|
||||
func (c *trackedConn) Close() error {
|
||||
c.tracker.conns.Delete(c)
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
// trackingWriter wraps an http.ResponseWriter and intercepts Hijack calls
|
||||
// to replace the raw connection with a trackedConn that auto-deregisters.
|
||||
type trackingWriter struct {
|
||||
http.ResponseWriter
|
||||
tracker *HijackTracker
|
||||
}
|
||||
|
||||
func (w *trackingWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
hijacker, ok := w.ResponseWriter.(http.Hijacker)
|
||||
if !ok {
|
||||
return nil, nil, http.ErrNotSupported
|
||||
}
|
||||
conn, buf, err := hijacker.Hijack()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
tc := &trackedConn{Conn: conn, tracker: w.tracker}
|
||||
w.tracker.conns.Store(tc, struct{}{})
|
||||
return tc, buf, nil
|
||||
}
|
||||
|
||||
func (w *trackingWriter) Flush() {
|
||||
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (w *trackingWriter) Unwrap() http.ResponseWriter {
|
||||
return w.ResponseWriter
|
||||
}
|
||||
41
proxy/internal/conntrack/hijacked.go
Normal file
41
proxy/internal/conntrack/hijacked.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// HijackTracker tracks connections that have been hijacked (e.g. WebSocket
|
||||
// upgrades). http.Server.Shutdown does not close hijacked connections, so
|
||||
// they must be tracked and closed explicitly during graceful shutdown.
|
||||
//
|
||||
// Use Middleware as the outermost HTTP middleware to ensure hijacked
|
||||
// connections are tracked and automatically deregistered when closed.
|
||||
type HijackTracker struct {
|
||||
conns sync.Map // net.Conn → struct{}
|
||||
}
|
||||
|
||||
// Middleware returns an HTTP middleware that wraps the ResponseWriter so that
|
||||
// hijacked connections are tracked and automatically deregistered from the
|
||||
// tracker when closed. This should be the outermost middleware in the chain.
|
||||
func (t *HijackTracker) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
next.ServeHTTP(&trackingWriter{ResponseWriter: w, tracker: t}, r)
|
||||
})
|
||||
}
|
||||
|
||||
// CloseAll closes all tracked hijacked connections and returns the number
|
||||
// of connections that were closed.
|
||||
func (t *HijackTracker) CloseAll() int {
|
||||
var count int
|
||||
t.conns.Range(func(key, _ any) bool {
|
||||
if conn, ok := key.(net.Conn); ok {
|
||||
_ = conn.Close()
|
||||
count++
|
||||
}
|
||||
t.conns.Delete(key)
|
||||
return true
|
||||
})
|
||||
return count
|
||||
}
|
||||
@@ -5,9 +5,11 @@ import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
||||
"github.com/netbirdio/netbird/proxy/internal/responsewriter"
|
||||
)
|
||||
|
||||
type Metrics struct {
|
||||
@@ -60,18 +62,18 @@ func New(reg prometheus.Registerer) *Metrics {
|
||||
}
|
||||
|
||||
type responseInterceptor struct {
|
||||
http.ResponseWriter
|
||||
*responsewriter.PassthroughWriter
|
||||
status int
|
||||
size int
|
||||
}
|
||||
|
||||
func (w *responseInterceptor) WriteHeader(status int) {
|
||||
w.status = status
|
||||
w.ResponseWriter.WriteHeader(status)
|
||||
w.PassthroughWriter.WriteHeader(status)
|
||||
}
|
||||
|
||||
func (w *responseInterceptor) Write(b []byte) (int, error) {
|
||||
size, err := w.ResponseWriter.Write(b)
|
||||
size, err := w.PassthroughWriter.Write(b)
|
||||
w.size += size
|
||||
return size, err
|
||||
}
|
||||
@@ -81,7 +83,7 @@ func (m *Metrics) Middleware(next http.Handler) http.Handler {
|
||||
m.requestsTotal.Inc()
|
||||
m.activeRequests.Inc()
|
||||
|
||||
interceptor := &responseInterceptor{ResponseWriter: w}
|
||||
interceptor := &responseInterceptor{PassthroughWriter: responsewriter.New(w)}
|
||||
|
||||
start := time.Now()
|
||||
next.ServeHTTP(interceptor, r)
|
||||
|
||||
53
proxy/internal/responsewriter/responsewriter.go
Normal file
53
proxy/internal/responsewriter/responsewriter.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package responsewriter
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"net"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// PassthroughWriter wraps an http.ResponseWriter and preserves optional
|
||||
// interfaces like Hijacker, Flusher, and Pusher by delegating to the underlying
|
||||
// ResponseWriter if it supports them.
|
||||
//
|
||||
// This is the standard pattern for Go middleware that needs to wrap ResponseWriter
|
||||
// while maintaining support for protocol upgrades (WebSocket), streaming (Flusher),
|
||||
// and HTTP/2 server push.
|
||||
type PassthroughWriter struct {
|
||||
http.ResponseWriter
|
||||
}
|
||||
|
||||
// New creates a new wrapper around the given ResponseWriter.
|
||||
func New(w http.ResponseWriter) *PassthroughWriter {
|
||||
return &PassthroughWriter{ResponseWriter: w}
|
||||
}
|
||||
|
||||
// Hijack implements http.Hijacker interface if the underlying ResponseWriter supports it.
|
||||
// This is required for WebSocket connections and other protocol upgrades.
|
||||
func (w *PassthroughWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
if hijacker, ok := w.ResponseWriter.(http.Hijacker); ok {
|
||||
return hijacker.Hijack()
|
||||
}
|
||||
return nil, nil, http.ErrNotSupported
|
||||
}
|
||||
|
||||
// Flush implements http.Flusher interface if the underlying ResponseWriter supports it.
|
||||
func (w *PassthroughWriter) Flush() {
|
||||
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// Push implements http.Pusher interface if the underlying ResponseWriter supports it.
|
||||
func (w *PassthroughWriter) Push(target string, opts *http.PushOptions) error {
|
||||
if pusher, ok := w.ResponseWriter.(http.Pusher); ok {
|
||||
return pusher.Push(target, opts)
|
||||
}
|
||||
return http.ErrNotSupported
|
||||
}
|
||||
|
||||
// Unwrap returns the underlying ResponseWriter.
|
||||
// This is required for http.ResponseController (Go 1.20+) to work correctly.
|
||||
func (w *PassthroughWriter) Unwrap() http.ResponseWriter {
|
||||
return w.ResponseWriter
|
||||
}
|
||||
@@ -23,7 +23,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
proxyproto "github.com/pires/go-proxyproto"
|
||||
"github.com/pires/go-proxyproto"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -37,6 +37,7 @@ import (
|
||||
"github.com/netbirdio/netbird/proxy/internal/acme"
|
||||
"github.com/netbirdio/netbird/proxy/internal/auth"
|
||||
"github.com/netbirdio/netbird/proxy/internal/certwatch"
|
||||
"github.com/netbirdio/netbird/proxy/internal/conntrack"
|
||||
"github.com/netbirdio/netbird/proxy/internal/debug"
|
||||
proxygrpc "github.com/netbirdio/netbird/proxy/internal/grpc"
|
||||
"github.com/netbirdio/netbird/proxy/internal/health"
|
||||
@@ -64,6 +65,11 @@ type Server struct {
|
||||
healthChecker *health.Checker
|
||||
meter *metrics.Metrics
|
||||
|
||||
// hijackTracker tracks hijacked connections (e.g. WebSocket upgrades)
|
||||
// so they can be closed during graceful shutdown, since http.Server.Shutdown
|
||||
// does not handle them.
|
||||
hijackTracker conntrack.HijackTracker
|
||||
|
||||
// Mostly used for debugging on management.
|
||||
startTime time.Time
|
||||
|
||||
@@ -83,11 +89,7 @@ type Server struct {
|
||||
ACMEChallengeType string
|
||||
// CertLockMethod controls how ACME certificate locks are coordinated
|
||||
// across replicas. Default: CertLockAuto (detect environment).
|
||||
CertLockMethod acme.CertLockMethod
|
||||
OIDCClientId string
|
||||
OIDCClientSecret string
|
||||
OIDCEndpoint string
|
||||
OIDCScopes []string
|
||||
CertLockMethod acme.CertLockMethod
|
||||
|
||||
// DebugEndpointEnabled enables the debug HTTP endpoint.
|
||||
DebugEndpointEnabled bool
|
||||
@@ -185,10 +187,18 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
// Build the handler chain from inside out.
|
||||
handler := http.Handler(s.proxy)
|
||||
handler = s.auth.Protect(handler)
|
||||
handler = web.AssetHandler(handler)
|
||||
handler = accessLog.Middleware(handler)
|
||||
handler = s.meter.Middleware(handler)
|
||||
handler = s.hijackTracker.Middleware(handler)
|
||||
|
||||
// Start the reverse proxy HTTPS server.
|
||||
s.https = &http.Server{
|
||||
Addr: addr,
|
||||
Handler: s.meter.Middleware(accessLog.Middleware(web.AssetHandler(s.auth.Protect(s.proxy)))),
|
||||
Handler: handler,
|
||||
TLSConfig: tlsConfig,
|
||||
ErrorLog: newHTTPServerLogger(s.Logger, logtagValueHTTPS),
|
||||
}
|
||||
@@ -457,7 +467,12 @@ func (s *Server) gracefulShutdown() {
|
||||
s.Logger.Warnf("https server drain: %v", err)
|
||||
}
|
||||
|
||||
// Step 4: Stop all remaining background services.
|
||||
// Step 4: Close hijacked connections (WebSocket) that Shutdown does not handle.
|
||||
if n := s.hijackTracker.CloseAll(); n > 0 {
|
||||
s.Logger.Infof("closed %d hijacked connection(s)", n)
|
||||
}
|
||||
|
||||
// Step 5: Stop all remaining background services.
|
||||
s.shutdownServices()
|
||||
s.Logger.Info("graceful shutdown complete")
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/protocol"
|
||||
nbRelay "github.com/netbirdio/netbird/shared/relay"
|
||||
)
|
||||
|
||||
const Proto protocol.Protocol = "quic"
|
||||
@@ -27,7 +28,7 @@ type Listener struct {
|
||||
func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
|
||||
quicCfg := &quic.Config{
|
||||
EnableDatagrams: true,
|
||||
InitialPacketSize: 1452,
|
||||
InitialPacketSize: nbRelay.QUICInitialPacketSize,
|
||||
}
|
||||
listener, err := quic.ListenAddr(l.Address, l.TLSConfig, quicCfg)
|
||||
if err != nil {
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
nbRelay "github.com/netbirdio/netbird/shared/relay"
|
||||
quictls "github.com/netbirdio/netbird/shared/relay/tls"
|
||||
)
|
||||
|
||||
@@ -42,7 +43,7 @@ func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) {
|
||||
KeepAlivePeriod: 30 * time.Second,
|
||||
MaxIdleTimeout: 4 * time.Minute,
|
||||
EnableDatagrams: true,
|
||||
InitialPacketSize: 1452,
|
||||
InitialPacketSize: nbRelay.QUICInitialPacketSize,
|
||||
}
|
||||
|
||||
udpConn, err := nbnet.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
|
||||
15
shared/relay/client/env.go
Normal file
15
shared/relay/client/env.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
const (
|
||||
envKeyNBDebugDisableRelay = "NB_DEBUG_DISABLE_RELAY"
|
||||
)
|
||||
|
||||
func IsDisableRelay() bool {
|
||||
v, _ := strconv.ParseBool(os.Getenv(envKeyNBDebugDisableRelay))
|
||||
return v
|
||||
}
|
||||
@@ -3,4 +3,9 @@ package relay
|
||||
const (
|
||||
// WebSocketURLPath is the path for the websocket relay connection
|
||||
WebSocketURLPath = "/relay"
|
||||
|
||||
// QUICInitialPacketSize is the conservative initial QUIC packet size (bytes)
|
||||
// for unknown-path PMTU, per RFC 9000 §14: 1280 (IPv6 min MTU) − 40 (IPv6
|
||||
// header) − 8 (UDP header) = 1232. DPLPMTUD may probe larger sizes later.
|
||||
QUICInitialPacketSize = 1232
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user