Merge branch 'main' into feature/disable-legacy-port

This commit is contained in:
pascal
2026-02-17 16:12:49 +01:00
20 changed files with 588 additions and 101 deletions

View File

@@ -358,9 +358,9 @@ func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
// Fast path for IPv4 addresses (4 bytes) - most common case // Fast path for IPv4 addresses (4 bytes) - most common case
if len(oldBytes) == 4 && len(newBytes) == 4 { if len(oldBytes) == 4 && len(newBytes) == 4 {
sum += uint32(^binary.BigEndian.Uint16(oldBytes[0:2])) sum += uint32(^binary.BigEndian.Uint16(oldBytes[0:2]))
sum += uint32(^binary.BigEndian.Uint16(oldBytes[2:4])) sum += uint32(^binary.BigEndian.Uint16(oldBytes[2:4])) //nolint:gosec // length checked above
sum += uint32(binary.BigEndian.Uint16(newBytes[0:2])) sum += uint32(binary.BigEndian.Uint16(newBytes[0:2]))
sum += uint32(binary.BigEndian.Uint16(newBytes[2:4])) sum += uint32(binary.BigEndian.Uint16(newBytes[2:4])) //nolint:gosec // length checked above
} else { } else {
// Fallback for other lengths // Fallback for other lengths
for i := 0; i < len(oldBytes)-1; i += 2 { for i := 0; i < len(oldBytes)-1; i += 2 {

View File

@@ -410,7 +410,7 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr) conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
} }
func (conn *Conn) onICEStateDisconnected() { func (conn *Conn) onICEStateDisconnected(sessionChanged bool) {
conn.mu.Lock() conn.mu.Lock()
defer conn.mu.Unlock() defer conn.mu.Unlock()
@@ -430,6 +430,10 @@ func (conn *Conn) onICEStateDisconnected() {
if conn.isReadyToUpgrade() { if conn.isReadyToUpgrade() {
conn.Log.Infof("ICE disconnected, set Relay to active connection") conn.Log.Infof("ICE disconnected, set Relay to active connection")
conn.dumpState.SwitchToRelay() conn.dumpState.SwitchToRelay()
if sessionChanged {
conn.resetEndpoint()
}
conn.wgProxyRelay.Work() conn.wgProxyRelay.Work()
presharedKey := conn.presharedKey(conn.rosenpassRemoteKey) presharedKey := conn.presharedKey(conn.rosenpassRemoteKey)
@@ -757,6 +761,17 @@ func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
return wgProxy, nil return wgProxy, nil
} }
func (conn *Conn) resetEndpoint() {
if !isController(conn.config) {
return
}
conn.Log.Infof("reset wg endpoint")
conn.wgWatcher.Reset()
if err := conn.endpointUpdater.RemoveEndpointAddress(); err != nil {
conn.Log.Warnf("failed to remove endpoint address before update: %v", err)
}
}
func (conn *Conn) isReadyToUpgrade() bool { func (conn *Conn) isReadyToUpgrade() bool {
return conn.wgProxyRelay != nil && conn.currentConnPriority != conntype.Relay return conn.wgProxyRelay != nil && conn.currentConnPriority != conntype.Relay
} }

View File

@@ -66,6 +66,10 @@ func (e *EndpointUpdater) RemoveWgPeer() error {
return e.wgConfig.WgInterface.RemovePeer(e.wgConfig.RemoteKey) return e.wgConfig.WgInterface.RemovePeer(e.wgConfig.RemoteKey)
} }
func (e *EndpointUpdater) RemoveEndpointAddress() error {
return e.wgConfig.WgInterface.RemoveEndpointAddress(e.wgConfig.RemoteKey)
}
func (e *EndpointUpdater) waitForCloseTheDelayedUpdate() { func (e *EndpointUpdater) waitForCloseTheDelayedUpdate() {
if e.cancelFunc == nil { if e.cancelFunc == nil {
return return

View File

@@ -32,6 +32,8 @@ type WGWatcher struct {
enabled bool enabled bool
muEnabled sync.RWMutex muEnabled sync.RWMutex
resetCh chan struct{}
} }
func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher { func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher {
@@ -40,6 +42,7 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin
wgIfaceStater: wgIfaceStater, wgIfaceStater: wgIfaceStater,
peerKey: peerKey, peerKey: peerKey,
stateDump: stateDump, stateDump: stateDump,
resetCh: make(chan struct{}, 1),
} }
} }
@@ -76,6 +79,15 @@ func (w *WGWatcher) IsEnabled() bool {
return w.enabled return w.enabled
} }
// Reset signals the watcher that the WireGuard peer has been reset and a new
// handshake is expected. This restarts the handshake timeout from scratch.
func (w *WGWatcher) Reset() {
select {
case w.resetCh <- struct{}{}:
default:
}
}
// wgStateCheck help to check the state of the WireGuard handshake and relay connection // wgStateCheck help to check the state of the WireGuard handshake and relay connection
func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn func(), enabledTime time.Time, initialHandshake time.Time) { func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn func(), enabledTime time.Time, initialHandshake time.Time) {
w.log.Infof("WireGuard watcher started") w.log.Infof("WireGuard watcher started")
@@ -105,6 +117,12 @@ func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn
w.stateDump.WGcheckSuccess() w.stateDump.WGcheckSuccess()
w.log.Debugf("WireGuard watcher reset timer: %v", resetTime) w.log.Debugf("WireGuard watcher reset timer: %v", resetTime)
case <-w.resetCh:
w.log.Infof("WireGuard watcher received peer reset, restarting handshake timeout")
lastHandshake = time.Time{}
enabledTime = time.Now()
timer.Stop()
timer.Reset(wgHandshakeOvertime)
case <-ctx.Done(): case <-ctx.Done():
w.log.Infof("WireGuard watcher stopped") w.log.Infof("WireGuard watcher stopped")
return return

View File

@@ -52,8 +52,9 @@ type WorkerICE struct {
// increase by one when disconnecting the agent // increase by one when disconnecting the agent
// with it the remote peer can discard the already deprecated offer/answer // with it the remote peer can discard the already deprecated offer/answer
// Without it the remote peer may recreate a workable ICE connection // Without it the remote peer may recreate a workable ICE connection
sessionID ICESessionID sessionID ICESessionID
muxAgent sync.Mutex remoteSessionChanged bool
muxAgent sync.Mutex
localUfrag string localUfrag string
localPwd string localPwd string
@@ -106,6 +107,7 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
return return
} }
w.log.Debugf("agent already exists, recreate the connection") w.log.Debugf("agent already exists, recreate the connection")
w.remoteSessionChanged = true
w.agentDialerCancel() w.agentDialerCancel()
if w.agent != nil { if w.agent != nil {
if err := w.agent.Close(); err != nil { if err := w.agent.Close(); err != nil {
@@ -306,13 +308,17 @@ func (w *WorkerICE) connect(ctx context.Context, agent *icemaker.ThreadSafeAgent
w.conn.onICEConnectionIsReady(selectedPriority(pair), ci) w.conn.onICEConnectionIsReady(selectedPriority(pair), ci)
} }
func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.CancelFunc) { func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.CancelFunc) bool {
cancel() cancel()
if err := agent.Close(); err != nil { if err := agent.Close(); err != nil {
w.log.Warnf("failed to close ICE agent: %s", err) w.log.Warnf("failed to close ICE agent: %s", err)
} }
w.muxAgent.Lock() w.muxAgent.Lock()
defer w.muxAgent.Unlock()
sessionChanged := w.remoteSessionChanged
w.remoteSessionChanged = false
if w.agent == agent { if w.agent == agent {
// consider to remove from here and move to the OnNewOffer // consider to remove from here and move to the OnNewOffer
@@ -325,7 +331,7 @@ func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.C
w.agentConnecting = false w.agentConnecting = false
w.remoteSessionID = "" w.remoteSessionID = ""
} }
w.muxAgent.Unlock() return sessionChanged
} }
func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) { func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
@@ -426,11 +432,11 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia
// ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to // ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to
// notify the conn.onICEStateDisconnected changes to update the current used priority // notify the conn.onICEStateDisconnected changes to update the current used priority
w.closeAgent(agent, dialerCancel) sessionChanged := w.closeAgent(agent, dialerCancel)
if w.lastKnownState == ice.ConnectionStateConnected { if w.lastKnownState == ice.ConnectionStateConnected {
w.lastKnownState = ice.ConnectionStateDisconnected w.lastKnownState = ice.ConnectionStateDisconnected
w.conn.onICEStateDisconnected() w.conn.onICEStateDisconnected(sessionChanged)
} }
default: default:
return return

1
go.mod
View File

@@ -83,6 +83,7 @@ require (
github.com/pion/stun/v3 v3.1.0 github.com/pion/stun/v3 v3.1.0
github.com/pion/transport/v3 v3.1.1 github.com/pion/transport/v3 v3.1.1
github.com/pion/turn/v3 v3.0.1 github.com/pion/turn/v3 v3.0.1
github.com/pires/go-proxyproto v0.11.0
github.com/pkg/sftp v1.13.9 github.com/pkg/sftp v1.13.9
github.com/prometheus/client_golang v1.23.2 github.com/prometheus/client_golang v1.23.2
github.com/quic-go/quic-go v0.55.0 github.com/quic-go/quic-go v0.55.0

2
go.sum
View File

@@ -474,6 +474,8 @@ github.com/pion/turn/v3 v3.0.1 h1:wLi7BTQr6/Q20R0vt/lHbjv6y4GChFtC33nkYbasoT8=
github.com/pion/turn/v3 v3.0.1/go.mod h1:MrJDKgqryDyWy1/4NT9TWfXWGMC7UHT6pJIv1+gMeNE= github.com/pion/turn/v3 v3.0.1/go.mod h1:MrJDKgqryDyWy1/4NT9TWfXWGMC7UHT6pJIv1+gMeNE=
github.com/pion/turn/v4 v4.1.1 h1:9UnY2HB99tpDyz3cVVZguSxcqkJ1DsTSZ+8TGruh4fc= github.com/pion/turn/v4 v4.1.1 h1:9UnY2HB99tpDyz3cVVZguSxcqkJ1DsTSZ+8TGruh4fc=
github.com/pion/turn/v4 v4.1.1/go.mod h1:2123tHk1O++vmjI5VSD0awT50NywDAq5A2NNNU4Jjs8= github.com/pion/turn/v4 v4.1.1/go.mod h1:2123tHk1O++vmjI5VSD0awT50NywDAq5A2NNNU4Jjs8=
github.com/pires/go-proxyproto v0.11.0 h1:gUQpS85X/VJMdUsYyEgyn59uLJvGqPhJV5YvG68wXH4=
github.com/pires/go-proxyproto v0.11.0/go.mod h1:ZKAAyp3cgy5Y5Mo4n9AlScrkCZwUy0g3Jf+slqQVcuU=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=

View File

@@ -99,15 +99,16 @@ func NewProvider(ctx context.Context, config *Config) (*Provider, error) {
// Build Dex server config - use Dex's types directly // Build Dex server config - use Dex's types directly
dexConfig := server.Config{ dexConfig := server.Config{
Issuer: issuer, Issuer: issuer,
Storage: stor, Storage: stor,
SkipApprovalScreen: true, SkipApprovalScreen: true,
SupportedResponseTypes: []string{"code"}, SupportedResponseTypes: []string{"code"},
Logger: logger, ContinueOnConnectorFailure: true,
PrometheusRegistry: prometheus.NewRegistry(), Logger: logger,
RotateKeysAfter: 6 * time.Hour, PrometheusRegistry: prometheus.NewRegistry(),
IDTokensValidFor: 24 * time.Hour, RotateKeysAfter: 6 * time.Hour,
RefreshTokenPolicy: refreshPolicy, IDTokensValidFor: 24 * time.Hour,
RefreshTokenPolicy: refreshPolicy,
Web: server.WebConfig{ Web: server.WebConfig{
Issuer: "NetBird", Issuer: "NetBird",
}, },
@@ -260,6 +261,7 @@ func buildDexConfig(yamlConfig *YAMLConfig, stor storage.Storage, logger *slog.L
if len(cfg.SupportedResponseTypes) == 0 { if len(cfg.SupportedResponseTypes) == 0 {
cfg.SupportedResponseTypes = []string{"code"} cfg.SupportedResponseTypes = []string{"code"}
} }
cfg.ContinueOnConnectorFailure = true
return cfg return cfg
} }

View File

@@ -2,6 +2,7 @@ package dex
import ( import (
"context" "context"
"log/slog"
"os" "os"
"path/filepath" "path/filepath"
"testing" "testing"
@@ -195,3 +196,64 @@ enablePasswordDB: true
t.Logf("User lookup successful: rawID=%s, connectorID=%s", rawID, connID) t.Logf("User lookup successful: rawID=%s, connectorID=%s", rawID, connID)
} }
func TestNewProvider_ContinueOnConnectorFailure(t *testing.T) {
ctx := context.Background()
tmpDir, err := os.MkdirTemp("", "dex-connector-failure-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
config := &Config{
Issuer: "http://localhost:5556/dex",
Port: 5556,
DataDir: tmpDir,
}
provider, err := NewProvider(ctx, config)
require.NoError(t, err)
defer func() { _ = provider.Stop(ctx) }()
// The provider should have started successfully even though
// ContinueOnConnectorFailure is an internal Dex config field.
// We verify the provider is functional by performing a basic operation.
assert.NotNil(t, provider.dexServer)
assert.NotNil(t, provider.storage)
}
func TestBuildDexConfig_ContinueOnConnectorFailure(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "dex-build-config-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
yamlContent := `
issuer: http://localhost:5556/dex
storage:
type: sqlite3
config:
file: ` + filepath.Join(tmpDir, "dex.db") + `
web:
http: 127.0.0.1:5556
enablePasswordDB: true
`
configPath := filepath.Join(tmpDir, "config.yaml")
err = os.WriteFile(configPath, []byte(yamlContent), 0644)
require.NoError(t, err)
yamlConfig, err := LoadConfig(configPath)
require.NoError(t, err)
ctx := context.Background()
stor, err := yamlConfig.Storage.OpenStorage(slog.New(slog.NewTextHandler(os.Stderr, nil)))
require.NoError(t, err)
defer stor.Close()
err = initializeStorage(ctx, stor, yamlConfig)
require.NoError(t, err)
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
cfg := buildDexConfig(yamlConfig, stor, logger)
assert.True(t, cfg.ContinueOnConnectorFailure,
"buildDexConfig must set ContinueOnConnectorFailure to true so management starts even if an external IdP is down")
}

View File

@@ -329,6 +329,9 @@ initialize_default_values() {
BIND_LOCALHOST_ONLY="true" BIND_LOCALHOST_ONLY="true"
EXTERNAL_PROXY_NETWORK="" EXTERNAL_PROXY_NETWORK=""
# Traefik static IP within the internal bridge network
TRAEFIK_IP="172.30.0.10"
# NetBird Proxy configuration # NetBird Proxy configuration
ENABLE_PROXY="false" ENABLE_PROXY="false"
PROXY_DOMAIN="" PROXY_DOMAIN=""
@@ -393,7 +396,7 @@ check_existing_installation() {
echo "Generated files already exist, if you want to reinitialize the environment, please remove them first." echo "Generated files already exist, if you want to reinitialize the environment, please remove them first."
echo "You can use the following commands:" echo "You can use the following commands:"
echo " $DOCKER_COMPOSE_COMMAND down --volumes # to remove all containers and volumes" echo " $DOCKER_COMPOSE_COMMAND down --volumes # to remove all containers and volumes"
echo " rm -f docker-compose.yml dashboard.env config.yaml proxy.env nginx-netbird.conf caddyfile-netbird.txt npm-advanced-config.txt" echo " rm -f docker-compose.yml dashboard.env config.yaml proxy.env traefik-dynamic.yaml nginx-netbird.conf caddyfile-netbird.txt npm-advanced-config.txt"
echo "Be aware that this will remove all data from the database, and you will have to reconfigure the dashboard." echo "Be aware that this will remove all data from the database, and you will have to reconfigure the dashboard."
exit 1 exit 1
fi fi
@@ -412,6 +415,8 @@ generate_configuration_files() {
# This will be overwritten with the actual token after netbird-server starts # This will be overwritten with the actual token after netbird-server starts
echo "# Placeholder - will be updated with token after netbird-server starts" > proxy.env echo "# Placeholder - will be updated with token after netbird-server starts" > proxy.env
echo "NB_PROXY_TOKEN=placeholder" >> proxy.env echo "NB_PROXY_TOKEN=placeholder" >> proxy.env
# TCP ServersTransport for PROXY protocol v2 to the proxy backend
render_traefik_dynamic > traefik-dynamic.yaml
fi fi
;; ;;
1) 1)
@@ -559,10 +564,14 @@ init_environment() {
############################################ ############################################
render_docker_compose_traefik_builtin() { render_docker_compose_traefik_builtin() {
# Generate proxy service section if enabled # Generate proxy service section and Traefik dynamic config if enabled
local proxy_service="" local proxy_service=""
local proxy_volumes="" local proxy_volumes=""
local traefik_file_provider=""
local traefik_dynamic_volume=""
if [[ "$ENABLE_PROXY" == "true" ]]; then if [[ "$ENABLE_PROXY" == "true" ]]; then
traefik_file_provider=' - "--providers.file.filename=/etc/traefik/dynamic.yaml"'
traefik_dynamic_volume=" - ./traefik-dynamic.yaml:/etc/traefik/dynamic.yaml:ro"
proxy_service=" proxy_service="
# NetBird Proxy - exposes internal resources to the internet # NetBird Proxy - exposes internal resources to the internet
proxy: proxy:
@@ -570,7 +579,7 @@ render_docker_compose_traefik_builtin() {
container_name: netbird-proxy container_name: netbird-proxy
# Hairpin NAT fix: route domain back to traefik's static IP within Docker # Hairpin NAT fix: route domain back to traefik's static IP within Docker
extra_hosts: extra_hosts:
- \"$NETBIRD_DOMAIN:172.30.0.10\" - \"$NETBIRD_DOMAIN:$TRAEFIK_IP\"
ports: ports:
- 51820:51820/udp - 51820:51820/udp
restart: unless-stopped restart: unless-stopped
@@ -590,6 +599,7 @@ render_docker_compose_traefik_builtin() {
- traefik.tcp.routers.proxy-passthrough.service=proxy-tls - traefik.tcp.routers.proxy-passthrough.service=proxy-tls
- traefik.tcp.routers.proxy-passthrough.priority=1 - traefik.tcp.routers.proxy-passthrough.priority=1
- traefik.tcp.services.proxy-tls.loadbalancer.server.port=8443 - traefik.tcp.services.proxy-tls.loadbalancer.server.port=8443
- traefik.tcp.services.proxy-tls.loadbalancer.serverstransport=pp-v2@file
logging: logging:
driver: \"json-file\" driver: \"json-file\"
options: options:
@@ -609,7 +619,7 @@ services:
restart: unless-stopped restart: unless-stopped
networks: networks:
netbird: netbird:
ipv4_address: 172.30.0.10 ipv4_address: $TRAEFIK_IP
command: command:
# Logging # Logging
- "--log.level=INFO" - "--log.level=INFO"
@@ -636,12 +646,14 @@ services:
# gRPC transport settings # gRPC transport settings
- "--serverstransport.forwardingtimeouts.responseheadertimeout=0s" - "--serverstransport.forwardingtimeouts.responseheadertimeout=0s"
- "--serverstransport.forwardingtimeouts.idleconntimeout=0s" - "--serverstransport.forwardingtimeouts.idleconntimeout=0s"
$traefik_file_provider
ports: ports:
- '443:443' - '443:443'
- '80:80' - '80:80'
volumes: volumes:
- /var/run/docker.sock:/var/run/docker.sock:ro - /var/run/docker.sock:/var/run/docker.sock:ro
- netbird_traefik_letsencrypt:/letsencrypt - netbird_traefik_letsencrypt:/letsencrypt
$traefik_dynamic_volume
logging: logging:
driver: "json-file" driver: "json-file"
options: options:
@@ -751,6 +763,10 @@ server:
cliRedirectURIs: cliRedirectURIs:
- "http://localhost:53000/" - "http://localhost:53000/"
reverseProxy:
trustedHTTPProxies:
- "$TRAEFIK_IP/32"
store: store:
engine: "sqlite" engine: "sqlite"
encryptionKey: "$DATASTORE_ENCRYPTION_KEY" encryptionKey: "$DATASTORE_ENCRYPTION_KEY"
@@ -780,6 +796,17 @@ EOF
return 0 return 0
} }
render_traefik_dynamic() {
cat <<'EOF'
tcp:
serversTransports:
pp-v2:
proxyProtocol:
version: 2
EOF
return 0
}
render_proxy_env() { render_proxy_env() {
cat <<EOF cat <<EOF
# NetBird Proxy Configuration # NetBird Proxy Configuration
@@ -799,6 +826,10 @@ NB_PROXY_OIDC_CLIENT_ID=netbird-proxy
NB_PROXY_OIDC_ENDPOINT=$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/oauth2 NB_PROXY_OIDC_ENDPOINT=$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/oauth2
NB_PROXY_OIDC_SCOPES=openid,profile,email NB_PROXY_OIDC_SCOPES=openid,profile,email
NB_PROXY_FORWARDED_PROTO=https NB_PROXY_FORWARDED_PROTO=https
# Enable PROXY protocol to preserve client IPs through L4 proxies (Traefik TCP passthrough)
NB_PROXY_PROXY_PROTOCOL=true
# Trust Traefik's IP for PROXY protocol headers
NB_PROXY_TRUSTED_PROXIES=$TRAEFIK_IP
EOF EOF
return 0 return 0
} }
@@ -1176,8 +1207,9 @@ print_builtin_traefik_instructions() {
echo " The proxy service is enabled and running." echo " The proxy service is enabled and running."
echo " Any domain NOT matching $NETBIRD_DOMAIN will be passed through to the proxy." echo " Any domain NOT matching $NETBIRD_DOMAIN will be passed through to the proxy."
echo " The proxy handles its own TLS certificates via ACME TLS-ALPN-01 challenge." echo " The proxy handles its own TLS certificates via ACME TLS-ALPN-01 challenge."
echo " Point your proxy domain to this server's domain address like in the example below:" echo " Point your proxy domain to this server's domain address like in the examples below:"
echo "" echo ""
echo " $PROXY_DOMAIN CNAME $NETBIRD_DOMAIN"
echo " *.$PROXY_DOMAIN CNAME $NETBIRD_DOMAIN" echo " *.$PROXY_DOMAIN CNAME $NETBIRD_DOMAIN"
echo "" echo ""
fi fi

View File

@@ -224,6 +224,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
s.syncSem.Add(1) s.syncSem.Add(1)
reqStart := time.Now() reqStart := time.Now()
syncStart := reqStart.UTC()
ctx := srv.Context() ctx := srv.Context()
@@ -300,7 +301,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
metahash := metaHash(peerMeta, realIP.String()) metahash := metaHash(peerMeta, realIP.String())
s.loginFilter.addLogin(peerKey.String(), metahash) s.loginFilter.addLogin(peerKey.String(), metahash)
peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP, reqStart) peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP, syncStart)
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err) log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err)
s.syncSem.Add(-1) s.syncSem.Add(-1)
@@ -311,7 +312,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err) log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
s.syncSem.Add(-1) s.syncSem.Add(-1)
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, reqStart) s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, syncStart)
return err return err
} }
@@ -319,7 +320,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("error while notify peer connected for %s: %v", peerKey.String(), err) log.WithContext(ctx).Debugf("error while notify peer connected for %s: %v", peerKey.String(), err)
s.syncSem.Add(-1) s.syncSem.Add(-1)
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, reqStart) s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, syncStart)
return err return err
} }
@@ -336,7 +337,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
s.syncSem.Add(-1) s.syncSem.Add(-1)
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv, reqStart) return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv, syncStart)
} }
func (s *Server) handleHandshake(ctx context.Context, srv proto.ManagementService_JobServer) (wgtypes.Key, error) { func (s *Server) handleHandshake(ctx context.Context, srv proto.ManagementService_JobServer) (wgtypes.Key, error) {

View File

@@ -56,6 +56,7 @@ var (
certKeyFile string certKeyFile string
certLockMethod string certLockMethod string
wgPort int wgPort int
proxyProtocol bool
) )
var rootCmd = &cobra.Command{ var rootCmd = &cobra.Command{
@@ -90,6 +91,7 @@ func init() {
rootCmd.Flags().StringVar(&certKeyFile, "cert-key-file", envStringOrDefault("NB_PROXY_CERTIFICATE_KEY_FILE", "tls.key"), "TLS certificate key filename within the certificate directory") rootCmd.Flags().StringVar(&certKeyFile, "cert-key-file", envStringOrDefault("NB_PROXY_CERTIFICATE_KEY_FILE", "tls.key"), "TLS certificate key filename within the certificate directory")
rootCmd.Flags().StringVar(&certLockMethod, "cert-lock-method", envStringOrDefault("NB_PROXY_CERT_LOCK_METHOD", "auto"), "Certificate lock method for cross-replica coordination: auto, flock, or k8s-lease") rootCmd.Flags().StringVar(&certLockMethod, "cert-lock-method", envStringOrDefault("NB_PROXY_CERT_LOCK_METHOD", "auto"), "Certificate lock method for cross-replica coordination: auto, flock, or k8s-lease")
rootCmd.Flags().IntVar(&wgPort, "wg-port", envIntOrDefault("NB_PROXY_WG_PORT", 0), "WireGuard listen port (0 = random). Fixed port only works with single-account deployments") rootCmd.Flags().IntVar(&wgPort, "wg-port", envIntOrDefault("NB_PROXY_WG_PORT", 0), "WireGuard listen port (0 = random). Fixed port only works with single-account deployments")
rootCmd.Flags().BoolVar(&proxyProtocol, "proxy-protocol", envBoolOrDefault("NB_PROXY_PROXY_PROTOCOL", false), "Enable PROXY protocol on TCP listeners to preserve client IPs behind L4 proxies")
} }
// Execute runs the root command. // Execute runs the root command.
@@ -165,6 +167,7 @@ func runServer(cmd *cobra.Command, args []string) error {
TrustedProxies: parsedTrustedProxies, TrustedProxies: parsedTrustedProxies,
CertLockMethod: nbacme.CertLockMethod(certLockMethod), CertLockMethod: nbacme.CertLockMethod(certLockMethod),
WireguardPort: wgPort, WireguardPort: wgPort,
ProxyProtocol: proxyProtocol,
} }
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT) ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT)

View File

@@ -9,6 +9,7 @@ import (
"github.com/rs/xid" "github.com/rs/xid"
"github.com/netbirdio/netbird/proxy/internal/proxy" "github.com/netbirdio/netbird/proxy/internal/proxy"
"github.com/netbirdio/netbird/proxy/internal/responsewriter"
"github.com/netbirdio/netbird/proxy/web" "github.com/netbirdio/netbird/proxy/web"
) )
@@ -27,8 +28,8 @@ func (l *Logger) Middleware(next http.Handler) http.Handler {
// Use a response writer wrapper so we can access the status code later. // Use a response writer wrapper so we can access the status code later.
sw := &statusWriter{ sw := &statusWriter{
w: w, PassthroughWriter: responsewriter.New(w),
status: http.StatusOK, status: http.StatusOK,
} }
// Resolve the source IP using trusted proxy configuration before passing // Resolve the source IP using trusted proxy configuration before passing

View File

@@ -1,26 +1,18 @@
package accesslog package accesslog
import ( import (
"net/http" "github.com/netbirdio/netbird/proxy/internal/responsewriter"
) )
// statusWriter is a simple wrapper around an http.ResponseWriter // statusWriter captures the HTTP status code from WriteHeader calls.
// that captures the setting of the status code via the WriteHeader // It embeds responsewriter.PassthroughWriter which handles all the optional
// function and stores it so that it can be retrieved later. // interfaces (Hijacker, Flusher, Pusher) automatically.
type statusWriter struct { type statusWriter struct {
w http.ResponseWriter *responsewriter.PassthroughWriter
status int status int
} }
func (w *statusWriter) Header() http.Header {
return w.w.Header()
}
func (w *statusWriter) Write(data []byte) (int, error) {
return w.w.Write(data)
}
func (w *statusWriter) WriteHeader(status int) { func (w *statusWriter) WriteHeader(status int) {
w.status = status w.status = status
w.w.WriteHeader(status) w.PassthroughWriter.WriteHeader(status)
} }

View File

@@ -0,0 +1,49 @@
package conntrack
import (
"bufio"
"net"
"net/http"
)
// trackedConn wraps a net.Conn and removes itself from the tracker on Close.
type trackedConn struct {
net.Conn
tracker *HijackTracker
}
func (c *trackedConn) Close() error {
c.tracker.conns.Delete(c)
return c.Conn.Close()
}
// trackingWriter wraps an http.ResponseWriter and intercepts Hijack calls
// to replace the raw connection with a trackedConn that auto-deregisters.
type trackingWriter struct {
http.ResponseWriter
tracker *HijackTracker
}
func (w *trackingWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hijacker, ok := w.ResponseWriter.(http.Hijacker)
if !ok {
return nil, nil, http.ErrNotSupported
}
conn, buf, err := hijacker.Hijack()
if err != nil {
return nil, nil, err
}
tc := &trackedConn{Conn: conn, tracker: w.tracker}
w.tracker.conns.Store(tc, struct{}{})
return tc, buf, nil
}
func (w *trackingWriter) Flush() {
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
}
func (w *trackingWriter) Unwrap() http.ResponseWriter {
return w.ResponseWriter
}

View File

@@ -0,0 +1,41 @@
package conntrack
import (
"net"
"net/http"
"sync"
)
// HijackTracker tracks connections that have been hijacked (e.g. WebSocket
// upgrades). http.Server.Shutdown does not close hijacked connections, so
// they must be tracked and closed explicitly during graceful shutdown.
//
// Use Middleware as the outermost HTTP middleware to ensure hijacked
// connections are tracked and automatically deregistered when closed.
type HijackTracker struct {
conns sync.Map // net.Conn → struct{}
}
// Middleware returns an HTTP middleware that wraps the ResponseWriter so that
// hijacked connections are tracked and automatically deregistered from the
// tracker when closed. This should be the outermost middleware in the chain.
func (t *HijackTracker) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(&trackingWriter{ResponseWriter: w, tracker: t}, r)
})
}
// CloseAll closes all tracked hijacked connections and returns the number
// of connections that were closed.
func (t *HijackTracker) CloseAll() int {
var count int
t.conns.Range(func(key, _ any) bool {
if conn, ok := key.(net.Conn); ok {
_ = conn.Close()
count++
}
t.conns.Delete(key)
return true
})
return count
}

View File

@@ -5,9 +5,11 @@ import (
"strconv" "strconv"
"time" "time"
"github.com/netbirdio/netbird/proxy/internal/proxy"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto" "github.com/prometheus/client_golang/prometheus/promauto"
"github.com/netbirdio/netbird/proxy/internal/proxy"
"github.com/netbirdio/netbird/proxy/internal/responsewriter"
) )
type Metrics struct { type Metrics struct {
@@ -60,18 +62,18 @@ func New(reg prometheus.Registerer) *Metrics {
} }
type responseInterceptor struct { type responseInterceptor struct {
http.ResponseWriter *responsewriter.PassthroughWriter
status int status int
size int size int
} }
func (w *responseInterceptor) WriteHeader(status int) { func (w *responseInterceptor) WriteHeader(status int) {
w.status = status w.status = status
w.ResponseWriter.WriteHeader(status) w.PassthroughWriter.WriteHeader(status)
} }
func (w *responseInterceptor) Write(b []byte) (int, error) { func (w *responseInterceptor) Write(b []byte) (int, error) {
size, err := w.ResponseWriter.Write(b) size, err := w.PassthroughWriter.Write(b)
w.size += size w.size += size
return size, err return size, err
} }
@@ -81,7 +83,7 @@ func (m *Metrics) Middleware(next http.Handler) http.Handler {
m.requestsTotal.Inc() m.requestsTotal.Inc()
m.activeRequests.Inc() m.activeRequests.Inc()
interceptor := &responseInterceptor{ResponseWriter: w} interceptor := &responseInterceptor{PassthroughWriter: responsewriter.New(w)}
start := time.Now() start := time.Now()
next.ServeHTTP(interceptor, r) next.ServeHTTP(interceptor, r)

View File

@@ -0,0 +1,53 @@
package responsewriter
import (
"bufio"
"net"
"net/http"
)
// PassthroughWriter wraps an http.ResponseWriter and preserves optional
// interfaces like Hijacker, Flusher, and Pusher by delegating to the underlying
// ResponseWriter if it supports them.
//
// This is the standard pattern for Go middleware that needs to wrap ResponseWriter
// while maintaining support for protocol upgrades (WebSocket), streaming (Flusher),
// and HTTP/2 server push.
type PassthroughWriter struct {
http.ResponseWriter
}
// New creates a new wrapper around the given ResponseWriter.
func New(w http.ResponseWriter) *PassthroughWriter {
return &PassthroughWriter{ResponseWriter: w}
}
// Hijack implements http.Hijacker interface if the underlying ResponseWriter supports it.
// This is required for WebSocket connections and other protocol upgrades.
func (w *PassthroughWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hijacker, ok := w.ResponseWriter.(http.Hijacker); ok {
return hijacker.Hijack()
}
return nil, nil, http.ErrNotSupported
}
// Flush implements http.Flusher interface if the underlying ResponseWriter supports it.
func (w *PassthroughWriter) Flush() {
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
}
// Push implements http.Pusher interface if the underlying ResponseWriter supports it.
func (w *PassthroughWriter) Push(target string, opts *http.PushOptions) error {
if pusher, ok := w.ResponseWriter.(http.Pusher); ok {
return pusher.Push(target, opts)
}
return http.ErrNotSupported
}
// Unwrap returns the underlying ResponseWriter.
// This is required for http.ResponseController (Go 1.20+) to work correctly.
func (w *PassthroughWriter) Unwrap() http.ResponseWriter {
return w.ResponseWriter
}

106
proxy/proxyprotocol_test.go Normal file
View File

@@ -0,0 +1,106 @@
package proxy
import (
"net"
"net/netip"
"testing"
"time"
proxyproto "github.com/pires/go-proxyproto"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestWrapProxyProtocol_OverridesRemoteAddr(t *testing.T) {
srv := &Server{
Logger: log.StandardLogger(),
TrustedProxies: []netip.Prefix{netip.MustParsePrefix("127.0.0.1/32")},
ProxyProtocol: true,
}
raw, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer raw.Close()
ln := srv.wrapProxyProtocol(raw)
realClientIP := "203.0.113.50"
realClientPort := uint16(54321)
accepted := make(chan net.Conn, 1)
go func() {
conn, err := ln.Accept()
if err != nil {
return
}
accepted <- conn
}()
// Connect and send a PROXY v2 header.
conn, err := net.Dial("tcp", ln.Addr().String())
require.NoError(t, err)
defer conn.Close()
header := &proxyproto.Header{
Version: 2,
Command: proxyproto.PROXY,
TransportProtocol: proxyproto.TCPv4,
SourceAddr: &net.TCPAddr{IP: net.ParseIP(realClientIP), Port: int(realClientPort)},
DestinationAddr: &net.TCPAddr{IP: net.ParseIP("10.0.0.1"), Port: 443},
}
_, err = header.WriteTo(conn)
require.NoError(t, err)
select {
case accepted := <-accepted:
defer accepted.Close()
host, _, err := net.SplitHostPort(accepted.RemoteAddr().String())
require.NoError(t, err)
assert.Equal(t, realClientIP, host, "RemoteAddr should reflect the PROXY header source IP")
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for connection")
}
}
func TestProxyProtocolPolicy_TrustedRequires(t *testing.T) {
srv := &Server{
Logger: log.StandardLogger(),
TrustedProxies: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
}
opts := proxyproto.ConnPolicyOptions{
Upstream: &net.TCPAddr{IP: net.ParseIP("10.0.0.1"), Port: 1234},
}
policy, err := srv.proxyProtocolPolicy(opts)
require.NoError(t, err)
assert.Equal(t, proxyproto.REQUIRE, policy, "trusted source should require PROXY header")
}
func TestProxyProtocolPolicy_UntrustedIgnores(t *testing.T) {
srv := &Server{
Logger: log.StandardLogger(),
TrustedProxies: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
}
opts := proxyproto.ConnPolicyOptions{
Upstream: &net.TCPAddr{IP: net.ParseIP("203.0.113.50"), Port: 1234},
}
policy, err := srv.proxyProtocolPolicy(opts)
require.NoError(t, err)
assert.Equal(t, proxyproto.IGNORE, policy, "untrusted source should have PROXY header ignored")
}
func TestProxyProtocolPolicy_InvalidIPRejects(t *testing.T) {
srv := &Server{
Logger: log.StandardLogger(),
TrustedProxies: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
}
opts := proxyproto.ConnPolicyOptions{
Upstream: &net.UnixAddr{Name: "/tmp/test.sock", Net: "unix"},
}
policy, err := srv.proxyProtocolPolicy(opts)
require.NoError(t, err)
assert.Equal(t, proxyproto.REJECT, policy, "unparsable address should be rejected")
}

View File

@@ -23,6 +23,7 @@ import (
"time" "time"
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
proxyproto "github.com/pires/go-proxyproto"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp" "github.com/prometheus/client_golang/prometheus/promhttp"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -36,6 +37,7 @@ import (
"github.com/netbirdio/netbird/proxy/internal/acme" "github.com/netbirdio/netbird/proxy/internal/acme"
"github.com/netbirdio/netbird/proxy/internal/auth" "github.com/netbirdio/netbird/proxy/internal/auth"
"github.com/netbirdio/netbird/proxy/internal/certwatch" "github.com/netbirdio/netbird/proxy/internal/certwatch"
"github.com/netbirdio/netbird/proxy/internal/conntrack"
"github.com/netbirdio/netbird/proxy/internal/debug" "github.com/netbirdio/netbird/proxy/internal/debug"
proxygrpc "github.com/netbirdio/netbird/proxy/internal/grpc" proxygrpc "github.com/netbirdio/netbird/proxy/internal/grpc"
"github.com/netbirdio/netbird/proxy/internal/health" "github.com/netbirdio/netbird/proxy/internal/health"
@@ -63,6 +65,11 @@ type Server struct {
healthChecker *health.Checker healthChecker *health.Checker
meter *metrics.Metrics meter *metrics.Metrics
// hijackTracker tracks hijacked connections (e.g. WebSocket upgrades)
// so they can be closed during graceful shutdown, since http.Server.Shutdown
// does not handle them.
hijackTracker conntrack.HijackTracker
// Mostly used for debugging on management. // Mostly used for debugging on management.
startTime time.Time startTime time.Time
@@ -92,7 +99,7 @@ type Server struct {
DebugEndpointEnabled bool DebugEndpointEnabled bool
// DebugEndpointAddress is the address for the debug HTTP endpoint (default: ":8444"). // DebugEndpointAddress is the address for the debug HTTP endpoint (default: ":8444").
DebugEndpointAddress string DebugEndpointAddress string
// HealthAddress is the address for the health probe endpoint (default: "localhost:8080"). // HealthAddress is the address for the health probe endpoint.
HealthAddress string HealthAddress string
// ProxyToken is the access token for authenticating with the management server. // ProxyToken is the access token for authenticating with the management server.
ProxyToken string ProxyToken string
@@ -107,6 +114,10 @@ type Server struct {
// random OS-assigned port. A fixed port only works with single-account // random OS-assigned port. A fixed port only works with single-account
// deployments; multiple accounts will fail to bind the same port. // deployments; multiple accounts will fail to bind the same port.
WireguardPort int WireguardPort int
// ProxyProtocol enables PROXY protocol (v1/v2) on TCP listeners.
// When enabled, the real client IP is extracted from the PROXY header
// sent by upstream L4 proxies that support PROXY protocol.
ProxyProtocol bool
} }
// NotifyStatus sends a status update to management about tunnel connectivity // NotifyStatus sends a status update to management about tunnel connectivity
@@ -137,23 +148,8 @@ func (s *Server) NotifyCertificateIssued(ctx context.Context, accountID, service
} }
func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
s.startTime = time.Now() s.initDefaults()
// If no ID is set then one can be generated.
if s.ID == "" {
s.ID = "netbird-proxy-" + s.startTime.Format("20060102150405")
}
// Fallback version option in case it is not set.
if s.Version == "" {
s.Version = "dev"
}
// If no logger is specified fallback to the standard logger.
if s.Logger == nil {
s.Logger = log.StandardLogger()
}
// Start up metrics gathering
reg := prometheus.NewRegistry() reg := prometheus.NewRegistry()
s.meter = metrics.New(reg) s.meter = metrics.New(reg)
@@ -189,53 +185,41 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
s.healthChecker = health.NewChecker(s.Logger, s.netbird) s.healthChecker = health.NewChecker(s.Logger, s.netbird)
if s.DebugEndpointEnabled { s.startDebugEndpoint()
debugAddr := debugEndpointAddr(s.DebugEndpointAddress)
debugHandler := debug.NewHandler(s.netbird, s.healthChecker, s.Logger) if err := s.startHealthServer(reg); err != nil {
if s.acme != nil { return err
debugHandler.SetCertStatus(s.acme)
}
s.debug = &http.Server{
Addr: debugAddr,
Handler: debugHandler,
ErrorLog: newHTTPServerLogger(s.Logger, logtagValueDebug),
}
go func() {
s.Logger.Infof("starting debug endpoint on %s", debugAddr)
if err := s.debug.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
s.Logger.Errorf("debug endpoint error: %v", err)
}
}()
} }
// Start health probe server. // Build the handler chain from inside out.
healthAddr := s.HealthAddress handler := http.Handler(s.proxy)
if healthAddr == "" { handler = s.auth.Protect(handler)
healthAddr = "localhost:8080" handler = web.AssetHandler(handler)
} handler = accessLog.Middleware(handler)
s.healthServer = health.NewServer(healthAddr, s.healthChecker, s.Logger, promhttp.HandlerFor(reg, promhttp.HandlerOpts{})) handler = s.meter.Middleware(handler)
healthListener, err := net.Listen("tcp", healthAddr) handler = s.hijackTracker.Middleware(handler)
if err != nil {
return fmt.Errorf("health probe server listen on %s: %w", healthAddr, err)
}
go func() {
if err := s.healthServer.Serve(healthListener); err != nil && !errors.Is(err, http.ErrServerClosed) {
s.Logger.Errorf("health probe server: %v", err)
}
}()
// Start the reverse proxy HTTPS server. // Start the reverse proxy HTTPS server.
s.https = &http.Server{ s.https = &http.Server{
Addr: addr, Addr: addr,
Handler: s.meter.Middleware(accessLog.Middleware(web.AssetHandler(s.auth.Protect(s.proxy)))), Handler: handler,
TLSConfig: tlsConfig, TLSConfig: tlsConfig,
ErrorLog: newHTTPServerLogger(s.Logger, logtagValueHTTPS), ErrorLog: newHTTPServerLogger(s.Logger, logtagValueHTTPS),
} }
lc := net.ListenConfig{}
ln, err := lc.Listen(ctx, "tcp", addr)
if err != nil {
return fmt.Errorf("listen on %s: %w", addr, err)
}
if s.ProxyProtocol {
ln = s.wrapProxyProtocol(ln)
}
httpsErr := make(chan error, 1) httpsErr := make(chan error, 1)
go func() { go func() {
s.Logger.Debugf("starting reverse proxy server on %s", addr) s.Logger.Debugf("starting reverse proxy server on %s", addr)
httpsErr <- s.https.ListenAndServeTLS("", "") httpsErr <- s.https.ServeTLS(ln, "", "")
}() }()
select { select {
@@ -251,7 +235,115 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
} }
} }
// initDefaults sets fallback values for optional Server fields.
func (s *Server) initDefaults() {
s.startTime = time.Now()
// If no ID is set then one can be generated.
if s.ID == "" {
s.ID = "netbird-proxy-" + s.startTime.Format("20060102150405")
}
// Fallback version option in case it is not set.
if s.Version == "" {
s.Version = "dev"
}
// If no logger is specified fallback to the standard logger.
if s.Logger == nil {
s.Logger = log.StandardLogger()
}
}
// startDebugEndpoint launches the debug HTTP server if enabled.
func (s *Server) startDebugEndpoint() {
if !s.DebugEndpointEnabled {
return
}
debugAddr := debugEndpointAddr(s.DebugEndpointAddress)
debugHandler := debug.NewHandler(s.netbird, s.healthChecker, s.Logger)
if s.acme != nil {
debugHandler.SetCertStatus(s.acme)
}
s.debug = &http.Server{
Addr: debugAddr,
Handler: debugHandler,
ErrorLog: newHTTPServerLogger(s.Logger, logtagValueDebug),
}
go func() {
s.Logger.Infof("starting debug endpoint on %s", debugAddr)
if err := s.debug.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
s.Logger.Errorf("debug endpoint error: %v", err)
}
}()
}
// startHealthServer launches the health probe and metrics server.
func (s *Server) startHealthServer(reg *prometheus.Registry) error {
healthAddr := s.HealthAddress
if healthAddr == "" {
healthAddr = defaultHealthAddr
}
s.healthServer = health.NewServer(healthAddr, s.healthChecker, s.Logger, promhttp.HandlerFor(reg, promhttp.HandlerOpts{}))
healthListener, err := net.Listen("tcp", healthAddr)
if err != nil {
return fmt.Errorf("health probe server listen on %s: %w", healthAddr, err)
}
go func() {
if err := s.healthServer.Serve(healthListener); err != nil && !errors.Is(err, http.ErrServerClosed) {
s.Logger.Errorf("health probe server: %v", err)
}
}()
return nil
}
// wrapProxyProtocol wraps a listener with PROXY protocol support.
// When TrustedProxies is configured, only those sources may send PROXY headers;
// connections from untrusted sources have any PROXY header ignored.
func (s *Server) wrapProxyProtocol(ln net.Listener) net.Listener {
ppListener := &proxyproto.Listener{
Listener: ln,
ReadHeaderTimeout: proxyProtoHeaderTimeout,
}
if len(s.TrustedProxies) > 0 {
ppListener.ConnPolicy = s.proxyProtocolPolicy
} else {
s.Logger.Warn("PROXY protocol enabled without trusted proxies; any source may send PROXY headers")
}
s.Logger.Info("PROXY protocol enabled on listener")
return ppListener
}
// proxyProtocolPolicy returns whether to require, skip, or reject the PROXY
// header based on whether the connection source is in TrustedProxies.
func (s *Server) proxyProtocolPolicy(opts proxyproto.ConnPolicyOptions) (proxyproto.Policy, error) {
// No logging on reject to prevent abuse
tcpAddr, ok := opts.Upstream.(*net.TCPAddr)
if !ok {
return proxyproto.REJECT, nil
}
addr, ok := netip.AddrFromSlice(tcpAddr.IP)
if !ok {
return proxyproto.REJECT, nil
}
addr = addr.Unmap()
// called per accept
for _, prefix := range s.TrustedProxies {
if prefix.Contains(addr) {
return proxyproto.REQUIRE, nil
}
}
return proxyproto.IGNORE, nil
}
const ( const (
defaultHealthAddr = "localhost:8080"
defaultDebugAddr = "localhost:8444"
// proxyProtoHeaderTimeout is the deadline for reading the PROXY protocol
// header after accepting a connection.
proxyProtoHeaderTimeout = 5 * time.Second
// shutdownPreStopDelay is the time to wait after receiving a shutdown signal // shutdownPreStopDelay is the time to wait after receiving a shutdown signal
// before draining connections. This allows the load balancer to propagate // before draining connections. This allows the load balancer to propagate
// the endpoint removal. // the endpoint removal.
@@ -379,7 +471,12 @@ func (s *Server) gracefulShutdown() {
s.Logger.Warnf("https server drain: %v", err) s.Logger.Warnf("https server drain: %v", err)
} }
// Step 4: Stop all remaining background services. // Step 4: Close hijacked connections (WebSocket) that Shutdown does not handle.
if n := s.hijackTracker.CloseAll(); n > 0 {
s.Logger.Infof("closed %d hijacked connection(s)", n)
}
// Step 5: Stop all remaining background services.
s.shutdownServices() s.shutdownServices()
s.Logger.Info("graceful shutdown complete") s.Logger.Info("graceful shutdown complete")
} }
@@ -647,7 +744,7 @@ func (s *Server) protoToMapping(mapping *proto.ProxyMapping) proxy.Mapping {
// If addr is empty, it defaults to localhost:8444 for security. // If addr is empty, it defaults to localhost:8444 for security.
func debugEndpointAddr(addr string) string { func debugEndpointAddr(addr string) string {
if addr == "" { if addr == "" {
return "localhost:8444" return defaultDebugAddr
} }
return addr return addr
} }