mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
Merge branch 'main' into feature/disable-legacy-port
This commit is contained in:
@@ -358,9 +358,9 @@ func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
|
||||
// Fast path for IPv4 addresses (4 bytes) - most common case
|
||||
if len(oldBytes) == 4 && len(newBytes) == 4 {
|
||||
sum += uint32(^binary.BigEndian.Uint16(oldBytes[0:2]))
|
||||
sum += uint32(^binary.BigEndian.Uint16(oldBytes[2:4]))
|
||||
sum += uint32(^binary.BigEndian.Uint16(oldBytes[2:4])) //nolint:gosec // length checked above
|
||||
sum += uint32(binary.BigEndian.Uint16(newBytes[0:2]))
|
||||
sum += uint32(binary.BigEndian.Uint16(newBytes[2:4]))
|
||||
sum += uint32(binary.BigEndian.Uint16(newBytes[2:4])) //nolint:gosec // length checked above
|
||||
} else {
|
||||
// Fallback for other lengths
|
||||
for i := 0; i < len(oldBytes)-1; i += 2 {
|
||||
|
||||
@@ -410,7 +410,7 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
|
||||
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
|
||||
}
|
||||
|
||||
func (conn *Conn) onICEStateDisconnected() {
|
||||
func (conn *Conn) onICEStateDisconnected(sessionChanged bool) {
|
||||
conn.mu.Lock()
|
||||
defer conn.mu.Unlock()
|
||||
|
||||
@@ -430,6 +430,10 @@ func (conn *Conn) onICEStateDisconnected() {
|
||||
if conn.isReadyToUpgrade() {
|
||||
conn.Log.Infof("ICE disconnected, set Relay to active connection")
|
||||
conn.dumpState.SwitchToRelay()
|
||||
if sessionChanged {
|
||||
conn.resetEndpoint()
|
||||
}
|
||||
|
||||
conn.wgProxyRelay.Work()
|
||||
|
||||
presharedKey := conn.presharedKey(conn.rosenpassRemoteKey)
|
||||
@@ -757,6 +761,17 @@ func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
|
||||
return wgProxy, nil
|
||||
}
|
||||
|
||||
func (conn *Conn) resetEndpoint() {
|
||||
if !isController(conn.config) {
|
||||
return
|
||||
}
|
||||
conn.Log.Infof("reset wg endpoint")
|
||||
conn.wgWatcher.Reset()
|
||||
if err := conn.endpointUpdater.RemoveEndpointAddress(); err != nil {
|
||||
conn.Log.Warnf("failed to remove endpoint address before update: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) isReadyToUpgrade() bool {
|
||||
return conn.wgProxyRelay != nil && conn.currentConnPriority != conntype.Relay
|
||||
}
|
||||
|
||||
@@ -66,6 +66,10 @@ func (e *EndpointUpdater) RemoveWgPeer() error {
|
||||
return e.wgConfig.WgInterface.RemovePeer(e.wgConfig.RemoteKey)
|
||||
}
|
||||
|
||||
func (e *EndpointUpdater) RemoveEndpointAddress() error {
|
||||
return e.wgConfig.WgInterface.RemoveEndpointAddress(e.wgConfig.RemoteKey)
|
||||
}
|
||||
|
||||
func (e *EndpointUpdater) waitForCloseTheDelayedUpdate() {
|
||||
if e.cancelFunc == nil {
|
||||
return
|
||||
|
||||
@@ -32,6 +32,8 @@ type WGWatcher struct {
|
||||
|
||||
enabled bool
|
||||
muEnabled sync.RWMutex
|
||||
|
||||
resetCh chan struct{}
|
||||
}
|
||||
|
||||
func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher {
|
||||
@@ -40,6 +42,7 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin
|
||||
wgIfaceStater: wgIfaceStater,
|
||||
peerKey: peerKey,
|
||||
stateDump: stateDump,
|
||||
resetCh: make(chan struct{}, 1),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -76,6 +79,15 @@ func (w *WGWatcher) IsEnabled() bool {
|
||||
return w.enabled
|
||||
}
|
||||
|
||||
// Reset signals the watcher that the WireGuard peer has been reset and a new
|
||||
// handshake is expected. This restarts the handshake timeout from scratch.
|
||||
func (w *WGWatcher) Reset() {
|
||||
select {
|
||||
case w.resetCh <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// wgStateCheck help to check the state of the WireGuard handshake and relay connection
|
||||
func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn func(), enabledTime time.Time, initialHandshake time.Time) {
|
||||
w.log.Infof("WireGuard watcher started")
|
||||
@@ -105,6 +117,12 @@ func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn
|
||||
w.stateDump.WGcheckSuccess()
|
||||
|
||||
w.log.Debugf("WireGuard watcher reset timer: %v", resetTime)
|
||||
case <-w.resetCh:
|
||||
w.log.Infof("WireGuard watcher received peer reset, restarting handshake timeout")
|
||||
lastHandshake = time.Time{}
|
||||
enabledTime = time.Now()
|
||||
timer.Stop()
|
||||
timer.Reset(wgHandshakeOvertime)
|
||||
case <-ctx.Done():
|
||||
w.log.Infof("WireGuard watcher stopped")
|
||||
return
|
||||
|
||||
@@ -52,8 +52,9 @@ type WorkerICE struct {
|
||||
// increase by one when disconnecting the agent
|
||||
// with it the remote peer can discard the already deprecated offer/answer
|
||||
// Without it the remote peer may recreate a workable ICE connection
|
||||
sessionID ICESessionID
|
||||
muxAgent sync.Mutex
|
||||
sessionID ICESessionID
|
||||
remoteSessionChanged bool
|
||||
muxAgent sync.Mutex
|
||||
|
||||
localUfrag string
|
||||
localPwd string
|
||||
@@ -106,6 +107,7 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
||||
return
|
||||
}
|
||||
w.log.Debugf("agent already exists, recreate the connection")
|
||||
w.remoteSessionChanged = true
|
||||
w.agentDialerCancel()
|
||||
if w.agent != nil {
|
||||
if err := w.agent.Close(); err != nil {
|
||||
@@ -306,13 +308,17 @@ func (w *WorkerICE) connect(ctx context.Context, agent *icemaker.ThreadSafeAgent
|
||||
w.conn.onICEConnectionIsReady(selectedPriority(pair), ci)
|
||||
}
|
||||
|
||||
func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.CancelFunc) {
|
||||
func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.CancelFunc) bool {
|
||||
cancel()
|
||||
if err := agent.Close(); err != nil {
|
||||
w.log.Warnf("failed to close ICE agent: %s", err)
|
||||
}
|
||||
|
||||
w.muxAgent.Lock()
|
||||
defer w.muxAgent.Unlock()
|
||||
|
||||
sessionChanged := w.remoteSessionChanged
|
||||
w.remoteSessionChanged = false
|
||||
|
||||
if w.agent == agent {
|
||||
// consider to remove from here and move to the OnNewOffer
|
||||
@@ -325,7 +331,7 @@ func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.C
|
||||
w.agentConnecting = false
|
||||
w.remoteSessionID = ""
|
||||
}
|
||||
w.muxAgent.Unlock()
|
||||
return sessionChanged
|
||||
}
|
||||
|
||||
func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
|
||||
@@ -426,11 +432,11 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia
|
||||
// ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to
|
||||
// notify the conn.onICEStateDisconnected changes to update the current used priority
|
||||
|
||||
w.closeAgent(agent, dialerCancel)
|
||||
sessionChanged := w.closeAgent(agent, dialerCancel)
|
||||
|
||||
if w.lastKnownState == ice.ConnectionStateConnected {
|
||||
w.lastKnownState = ice.ConnectionStateDisconnected
|
||||
w.conn.onICEStateDisconnected()
|
||||
w.conn.onICEStateDisconnected(sessionChanged)
|
||||
}
|
||||
default:
|
||||
return
|
||||
|
||||
1
go.mod
1
go.mod
@@ -83,6 +83,7 @@ require (
|
||||
github.com/pion/stun/v3 v3.1.0
|
||||
github.com/pion/transport/v3 v3.1.1
|
||||
github.com/pion/turn/v3 v3.0.1
|
||||
github.com/pires/go-proxyproto v0.11.0
|
||||
github.com/pkg/sftp v1.13.9
|
||||
github.com/prometheus/client_golang v1.23.2
|
||||
github.com/quic-go/quic-go v0.55.0
|
||||
|
||||
2
go.sum
2
go.sum
@@ -474,6 +474,8 @@ github.com/pion/turn/v3 v3.0.1 h1:wLi7BTQr6/Q20R0vt/lHbjv6y4GChFtC33nkYbasoT8=
|
||||
github.com/pion/turn/v3 v3.0.1/go.mod h1:MrJDKgqryDyWy1/4NT9TWfXWGMC7UHT6pJIv1+gMeNE=
|
||||
github.com/pion/turn/v4 v4.1.1 h1:9UnY2HB99tpDyz3cVVZguSxcqkJ1DsTSZ+8TGruh4fc=
|
||||
github.com/pion/turn/v4 v4.1.1/go.mod h1:2123tHk1O++vmjI5VSD0awT50NywDAq5A2NNNU4Jjs8=
|
||||
github.com/pires/go-proxyproto v0.11.0 h1:gUQpS85X/VJMdUsYyEgyn59uLJvGqPhJV5YvG68wXH4=
|
||||
github.com/pires/go-proxyproto v0.11.0/go.mod h1:ZKAAyp3cgy5Y5Mo4n9AlScrkCZwUy0g3Jf+slqQVcuU=
|
||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
|
||||
@@ -99,15 +99,16 @@ func NewProvider(ctx context.Context, config *Config) (*Provider, error) {
|
||||
|
||||
// Build Dex server config - use Dex's types directly
|
||||
dexConfig := server.Config{
|
||||
Issuer: issuer,
|
||||
Storage: stor,
|
||||
SkipApprovalScreen: true,
|
||||
SupportedResponseTypes: []string{"code"},
|
||||
Logger: logger,
|
||||
PrometheusRegistry: prometheus.NewRegistry(),
|
||||
RotateKeysAfter: 6 * time.Hour,
|
||||
IDTokensValidFor: 24 * time.Hour,
|
||||
RefreshTokenPolicy: refreshPolicy,
|
||||
Issuer: issuer,
|
||||
Storage: stor,
|
||||
SkipApprovalScreen: true,
|
||||
SupportedResponseTypes: []string{"code"},
|
||||
ContinueOnConnectorFailure: true,
|
||||
Logger: logger,
|
||||
PrometheusRegistry: prometheus.NewRegistry(),
|
||||
RotateKeysAfter: 6 * time.Hour,
|
||||
IDTokensValidFor: 24 * time.Hour,
|
||||
RefreshTokenPolicy: refreshPolicy,
|
||||
Web: server.WebConfig{
|
||||
Issuer: "NetBird",
|
||||
},
|
||||
@@ -260,6 +261,7 @@ func buildDexConfig(yamlConfig *YAMLConfig, stor storage.Storage, logger *slog.L
|
||||
if len(cfg.SupportedResponseTypes) == 0 {
|
||||
cfg.SupportedResponseTypes = []string{"code"}
|
||||
}
|
||||
cfg.ContinueOnConnectorFailure = true
|
||||
return cfg
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package dex
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
@@ -195,3 +196,64 @@ enablePasswordDB: true
|
||||
|
||||
t.Logf("User lookup successful: rawID=%s, connectorID=%s", rawID, connID)
|
||||
}
|
||||
|
||||
func TestNewProvider_ContinueOnConnectorFailure(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "dex-connector-failure-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
config := &Config{
|
||||
Issuer: "http://localhost:5556/dex",
|
||||
Port: 5556,
|
||||
DataDir: tmpDir,
|
||||
}
|
||||
|
||||
provider, err := NewProvider(ctx, config)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = provider.Stop(ctx) }()
|
||||
|
||||
// The provider should have started successfully even though
|
||||
// ContinueOnConnectorFailure is an internal Dex config field.
|
||||
// We verify the provider is functional by performing a basic operation.
|
||||
assert.NotNil(t, provider.dexServer)
|
||||
assert.NotNil(t, provider.storage)
|
||||
}
|
||||
|
||||
func TestBuildDexConfig_ContinueOnConnectorFailure(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "dex-build-config-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
yamlContent := `
|
||||
issuer: http://localhost:5556/dex
|
||||
storage:
|
||||
type: sqlite3
|
||||
config:
|
||||
file: ` + filepath.Join(tmpDir, "dex.db") + `
|
||||
web:
|
||||
http: 127.0.0.1:5556
|
||||
enablePasswordDB: true
|
||||
`
|
||||
configPath := filepath.Join(tmpDir, "config.yaml")
|
||||
err = os.WriteFile(configPath, []byte(yamlContent), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
yamlConfig, err := LoadConfig(configPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
stor, err := yamlConfig.Storage.OpenStorage(slog.New(slog.NewTextHandler(os.Stderr, nil)))
|
||||
require.NoError(t, err)
|
||||
defer stor.Close()
|
||||
|
||||
err = initializeStorage(ctx, stor, yamlConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
||||
cfg := buildDexConfig(yamlConfig, stor, logger)
|
||||
|
||||
assert.True(t, cfg.ContinueOnConnectorFailure,
|
||||
"buildDexConfig must set ContinueOnConnectorFailure to true so management starts even if an external IdP is down")
|
||||
}
|
||||
|
||||
@@ -329,6 +329,9 @@ initialize_default_values() {
|
||||
BIND_LOCALHOST_ONLY="true"
|
||||
EXTERNAL_PROXY_NETWORK=""
|
||||
|
||||
# Traefik static IP within the internal bridge network
|
||||
TRAEFIK_IP="172.30.0.10"
|
||||
|
||||
# NetBird Proxy configuration
|
||||
ENABLE_PROXY="false"
|
||||
PROXY_DOMAIN=""
|
||||
@@ -393,7 +396,7 @@ check_existing_installation() {
|
||||
echo "Generated files already exist, if you want to reinitialize the environment, please remove them first."
|
||||
echo "You can use the following commands:"
|
||||
echo " $DOCKER_COMPOSE_COMMAND down --volumes # to remove all containers and volumes"
|
||||
echo " rm -f docker-compose.yml dashboard.env config.yaml proxy.env nginx-netbird.conf caddyfile-netbird.txt npm-advanced-config.txt"
|
||||
echo " rm -f docker-compose.yml dashboard.env config.yaml proxy.env traefik-dynamic.yaml nginx-netbird.conf caddyfile-netbird.txt npm-advanced-config.txt"
|
||||
echo "Be aware that this will remove all data from the database, and you will have to reconfigure the dashboard."
|
||||
exit 1
|
||||
fi
|
||||
@@ -412,6 +415,8 @@ generate_configuration_files() {
|
||||
# This will be overwritten with the actual token after netbird-server starts
|
||||
echo "# Placeholder - will be updated with token after netbird-server starts" > proxy.env
|
||||
echo "NB_PROXY_TOKEN=placeholder" >> proxy.env
|
||||
# TCP ServersTransport for PROXY protocol v2 to the proxy backend
|
||||
render_traefik_dynamic > traefik-dynamic.yaml
|
||||
fi
|
||||
;;
|
||||
1)
|
||||
@@ -559,10 +564,14 @@ init_environment() {
|
||||
############################################
|
||||
|
||||
render_docker_compose_traefik_builtin() {
|
||||
# Generate proxy service section if enabled
|
||||
# Generate proxy service section and Traefik dynamic config if enabled
|
||||
local proxy_service=""
|
||||
local proxy_volumes=""
|
||||
local traefik_file_provider=""
|
||||
local traefik_dynamic_volume=""
|
||||
if [[ "$ENABLE_PROXY" == "true" ]]; then
|
||||
traefik_file_provider=' - "--providers.file.filename=/etc/traefik/dynamic.yaml"'
|
||||
traefik_dynamic_volume=" - ./traefik-dynamic.yaml:/etc/traefik/dynamic.yaml:ro"
|
||||
proxy_service="
|
||||
# NetBird Proxy - exposes internal resources to the internet
|
||||
proxy:
|
||||
@@ -570,7 +579,7 @@ render_docker_compose_traefik_builtin() {
|
||||
container_name: netbird-proxy
|
||||
# Hairpin NAT fix: route domain back to traefik's static IP within Docker
|
||||
extra_hosts:
|
||||
- \"$NETBIRD_DOMAIN:172.30.0.10\"
|
||||
- \"$NETBIRD_DOMAIN:$TRAEFIK_IP\"
|
||||
ports:
|
||||
- 51820:51820/udp
|
||||
restart: unless-stopped
|
||||
@@ -590,6 +599,7 @@ render_docker_compose_traefik_builtin() {
|
||||
- traefik.tcp.routers.proxy-passthrough.service=proxy-tls
|
||||
- traefik.tcp.routers.proxy-passthrough.priority=1
|
||||
- traefik.tcp.services.proxy-tls.loadbalancer.server.port=8443
|
||||
- traefik.tcp.services.proxy-tls.loadbalancer.serverstransport=pp-v2@file
|
||||
logging:
|
||||
driver: \"json-file\"
|
||||
options:
|
||||
@@ -609,7 +619,7 @@ services:
|
||||
restart: unless-stopped
|
||||
networks:
|
||||
netbird:
|
||||
ipv4_address: 172.30.0.10
|
||||
ipv4_address: $TRAEFIK_IP
|
||||
command:
|
||||
# Logging
|
||||
- "--log.level=INFO"
|
||||
@@ -636,12 +646,14 @@ services:
|
||||
# gRPC transport settings
|
||||
- "--serverstransport.forwardingtimeouts.responseheadertimeout=0s"
|
||||
- "--serverstransport.forwardingtimeouts.idleconntimeout=0s"
|
||||
$traefik_file_provider
|
||||
ports:
|
||||
- '443:443'
|
||||
- '80:80'
|
||||
volumes:
|
||||
- /var/run/docker.sock:/var/run/docker.sock:ro
|
||||
- netbird_traefik_letsencrypt:/letsencrypt
|
||||
$traefik_dynamic_volume
|
||||
logging:
|
||||
driver: "json-file"
|
||||
options:
|
||||
@@ -751,6 +763,10 @@ server:
|
||||
cliRedirectURIs:
|
||||
- "http://localhost:53000/"
|
||||
|
||||
reverseProxy:
|
||||
trustedHTTPProxies:
|
||||
- "$TRAEFIK_IP/32"
|
||||
|
||||
store:
|
||||
engine: "sqlite"
|
||||
encryptionKey: "$DATASTORE_ENCRYPTION_KEY"
|
||||
@@ -780,6 +796,17 @@ EOF
|
||||
return 0
|
||||
}
|
||||
|
||||
render_traefik_dynamic() {
|
||||
cat <<'EOF'
|
||||
tcp:
|
||||
serversTransports:
|
||||
pp-v2:
|
||||
proxyProtocol:
|
||||
version: 2
|
||||
EOF
|
||||
return 0
|
||||
}
|
||||
|
||||
render_proxy_env() {
|
||||
cat <<EOF
|
||||
# NetBird Proxy Configuration
|
||||
@@ -799,6 +826,10 @@ NB_PROXY_OIDC_CLIENT_ID=netbird-proxy
|
||||
NB_PROXY_OIDC_ENDPOINT=$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/oauth2
|
||||
NB_PROXY_OIDC_SCOPES=openid,profile,email
|
||||
NB_PROXY_FORWARDED_PROTO=https
|
||||
# Enable PROXY protocol to preserve client IPs through L4 proxies (Traefik TCP passthrough)
|
||||
NB_PROXY_PROXY_PROTOCOL=true
|
||||
# Trust Traefik's IP for PROXY protocol headers
|
||||
NB_PROXY_TRUSTED_PROXIES=$TRAEFIK_IP
|
||||
EOF
|
||||
return 0
|
||||
}
|
||||
@@ -1176,8 +1207,9 @@ print_builtin_traefik_instructions() {
|
||||
echo " The proxy service is enabled and running."
|
||||
echo " Any domain NOT matching $NETBIRD_DOMAIN will be passed through to the proxy."
|
||||
echo " The proxy handles its own TLS certificates via ACME TLS-ALPN-01 challenge."
|
||||
echo " Point your proxy domain to this server's domain address like in the example below:"
|
||||
echo " Point your proxy domain to this server's domain address like in the examples below:"
|
||||
echo ""
|
||||
echo " $PROXY_DOMAIN CNAME $NETBIRD_DOMAIN"
|
||||
echo " *.$PROXY_DOMAIN CNAME $NETBIRD_DOMAIN"
|
||||
echo ""
|
||||
fi
|
||||
|
||||
@@ -224,6 +224,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
||||
s.syncSem.Add(1)
|
||||
|
||||
reqStart := time.Now()
|
||||
syncStart := reqStart.UTC()
|
||||
|
||||
ctx := srv.Context()
|
||||
|
||||
@@ -300,7 +301,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
||||
metahash := metaHash(peerMeta, realIP.String())
|
||||
s.loginFilter.addLogin(peerKey.String(), metahash)
|
||||
|
||||
peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP, reqStart)
|
||||
peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP, syncStart)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err)
|
||||
s.syncSem.Add(-1)
|
||||
@@ -311,7 +312,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
|
||||
s.syncSem.Add(-1)
|
||||
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, reqStart)
|
||||
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, syncStart)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -319,7 +320,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("error while notify peer connected for %s: %v", peerKey.String(), err)
|
||||
s.syncSem.Add(-1)
|
||||
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, reqStart)
|
||||
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, syncStart)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -336,7 +337,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
||||
|
||||
s.syncSem.Add(-1)
|
||||
|
||||
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv, reqStart)
|
||||
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv, syncStart)
|
||||
}
|
||||
|
||||
func (s *Server) handleHandshake(ctx context.Context, srv proto.ManagementService_JobServer) (wgtypes.Key, error) {
|
||||
|
||||
@@ -56,6 +56,7 @@ var (
|
||||
certKeyFile string
|
||||
certLockMethod string
|
||||
wgPort int
|
||||
proxyProtocol bool
|
||||
)
|
||||
|
||||
var rootCmd = &cobra.Command{
|
||||
@@ -90,6 +91,7 @@ func init() {
|
||||
rootCmd.Flags().StringVar(&certKeyFile, "cert-key-file", envStringOrDefault("NB_PROXY_CERTIFICATE_KEY_FILE", "tls.key"), "TLS certificate key filename within the certificate directory")
|
||||
rootCmd.Flags().StringVar(&certLockMethod, "cert-lock-method", envStringOrDefault("NB_PROXY_CERT_LOCK_METHOD", "auto"), "Certificate lock method for cross-replica coordination: auto, flock, or k8s-lease")
|
||||
rootCmd.Flags().IntVar(&wgPort, "wg-port", envIntOrDefault("NB_PROXY_WG_PORT", 0), "WireGuard listen port (0 = random). Fixed port only works with single-account deployments")
|
||||
rootCmd.Flags().BoolVar(&proxyProtocol, "proxy-protocol", envBoolOrDefault("NB_PROXY_PROXY_PROTOCOL", false), "Enable PROXY protocol on TCP listeners to preserve client IPs behind L4 proxies")
|
||||
}
|
||||
|
||||
// Execute runs the root command.
|
||||
@@ -165,6 +167,7 @@ func runServer(cmd *cobra.Command, args []string) error {
|
||||
TrustedProxies: parsedTrustedProxies,
|
||||
CertLockMethod: nbacme.CertLockMethod(certLockMethod),
|
||||
WireguardPort: wgPort,
|
||||
ProxyProtocol: proxyProtocol,
|
||||
}
|
||||
|
||||
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT)
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
||||
"github.com/netbirdio/netbird/proxy/internal/responsewriter"
|
||||
"github.com/netbirdio/netbird/proxy/web"
|
||||
)
|
||||
|
||||
@@ -27,8 +28,8 @@ func (l *Logger) Middleware(next http.Handler) http.Handler {
|
||||
|
||||
// Use a response writer wrapper so we can access the status code later.
|
||||
sw := &statusWriter{
|
||||
w: w,
|
||||
status: http.StatusOK,
|
||||
PassthroughWriter: responsewriter.New(w),
|
||||
status: http.StatusOK,
|
||||
}
|
||||
|
||||
// Resolve the source IP using trusted proxy configuration before passing
|
||||
|
||||
@@ -1,26 +1,18 @@
|
||||
package accesslog
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"github.com/netbirdio/netbird/proxy/internal/responsewriter"
|
||||
)
|
||||
|
||||
// statusWriter is a simple wrapper around an http.ResponseWriter
|
||||
// that captures the setting of the status code via the WriteHeader
|
||||
// function and stores it so that it can be retrieved later.
|
||||
// statusWriter captures the HTTP status code from WriteHeader calls.
|
||||
// It embeds responsewriter.PassthroughWriter which handles all the optional
|
||||
// interfaces (Hijacker, Flusher, Pusher) automatically.
|
||||
type statusWriter struct {
|
||||
w http.ResponseWriter
|
||||
*responsewriter.PassthroughWriter
|
||||
status int
|
||||
}
|
||||
|
||||
func (w *statusWriter) Header() http.Header {
|
||||
return w.w.Header()
|
||||
}
|
||||
|
||||
func (w *statusWriter) Write(data []byte) (int, error) {
|
||||
return w.w.Write(data)
|
||||
}
|
||||
|
||||
func (w *statusWriter) WriteHeader(status int) {
|
||||
w.status = status
|
||||
w.w.WriteHeader(status)
|
||||
w.PassthroughWriter.WriteHeader(status)
|
||||
}
|
||||
|
||||
49
proxy/internal/conntrack/conn.go
Normal file
49
proxy/internal/conntrack/conn.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"net"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// trackedConn wraps a net.Conn and removes itself from the tracker on Close.
|
||||
type trackedConn struct {
|
||||
net.Conn
|
||||
tracker *HijackTracker
|
||||
}
|
||||
|
||||
func (c *trackedConn) Close() error {
|
||||
c.tracker.conns.Delete(c)
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
// trackingWriter wraps an http.ResponseWriter and intercepts Hijack calls
|
||||
// to replace the raw connection with a trackedConn that auto-deregisters.
|
||||
type trackingWriter struct {
|
||||
http.ResponseWriter
|
||||
tracker *HijackTracker
|
||||
}
|
||||
|
||||
func (w *trackingWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
hijacker, ok := w.ResponseWriter.(http.Hijacker)
|
||||
if !ok {
|
||||
return nil, nil, http.ErrNotSupported
|
||||
}
|
||||
conn, buf, err := hijacker.Hijack()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
tc := &trackedConn{Conn: conn, tracker: w.tracker}
|
||||
w.tracker.conns.Store(tc, struct{}{})
|
||||
return tc, buf, nil
|
||||
}
|
||||
|
||||
func (w *trackingWriter) Flush() {
|
||||
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (w *trackingWriter) Unwrap() http.ResponseWriter {
|
||||
return w.ResponseWriter
|
||||
}
|
||||
41
proxy/internal/conntrack/hijacked.go
Normal file
41
proxy/internal/conntrack/hijacked.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// HijackTracker tracks connections that have been hijacked (e.g. WebSocket
|
||||
// upgrades). http.Server.Shutdown does not close hijacked connections, so
|
||||
// they must be tracked and closed explicitly during graceful shutdown.
|
||||
//
|
||||
// Use Middleware as the outermost HTTP middleware to ensure hijacked
|
||||
// connections are tracked and automatically deregistered when closed.
|
||||
type HijackTracker struct {
|
||||
conns sync.Map // net.Conn → struct{}
|
||||
}
|
||||
|
||||
// Middleware returns an HTTP middleware that wraps the ResponseWriter so that
|
||||
// hijacked connections are tracked and automatically deregistered from the
|
||||
// tracker when closed. This should be the outermost middleware in the chain.
|
||||
func (t *HijackTracker) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
next.ServeHTTP(&trackingWriter{ResponseWriter: w, tracker: t}, r)
|
||||
})
|
||||
}
|
||||
|
||||
// CloseAll closes all tracked hijacked connections and returns the number
|
||||
// of connections that were closed.
|
||||
func (t *HijackTracker) CloseAll() int {
|
||||
var count int
|
||||
t.conns.Range(func(key, _ any) bool {
|
||||
if conn, ok := key.(net.Conn); ok {
|
||||
_ = conn.Close()
|
||||
count++
|
||||
}
|
||||
t.conns.Delete(key)
|
||||
return true
|
||||
})
|
||||
return count
|
||||
}
|
||||
@@ -5,9 +5,11 @@ import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
||||
"github.com/netbirdio/netbird/proxy/internal/responsewriter"
|
||||
)
|
||||
|
||||
type Metrics struct {
|
||||
@@ -60,18 +62,18 @@ func New(reg prometheus.Registerer) *Metrics {
|
||||
}
|
||||
|
||||
type responseInterceptor struct {
|
||||
http.ResponseWriter
|
||||
*responsewriter.PassthroughWriter
|
||||
status int
|
||||
size int
|
||||
}
|
||||
|
||||
func (w *responseInterceptor) WriteHeader(status int) {
|
||||
w.status = status
|
||||
w.ResponseWriter.WriteHeader(status)
|
||||
w.PassthroughWriter.WriteHeader(status)
|
||||
}
|
||||
|
||||
func (w *responseInterceptor) Write(b []byte) (int, error) {
|
||||
size, err := w.ResponseWriter.Write(b)
|
||||
size, err := w.PassthroughWriter.Write(b)
|
||||
w.size += size
|
||||
return size, err
|
||||
}
|
||||
@@ -81,7 +83,7 @@ func (m *Metrics) Middleware(next http.Handler) http.Handler {
|
||||
m.requestsTotal.Inc()
|
||||
m.activeRequests.Inc()
|
||||
|
||||
interceptor := &responseInterceptor{ResponseWriter: w}
|
||||
interceptor := &responseInterceptor{PassthroughWriter: responsewriter.New(w)}
|
||||
|
||||
start := time.Now()
|
||||
next.ServeHTTP(interceptor, r)
|
||||
|
||||
53
proxy/internal/responsewriter/responsewriter.go
Normal file
53
proxy/internal/responsewriter/responsewriter.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package responsewriter
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"net"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// PassthroughWriter wraps an http.ResponseWriter and preserves optional
|
||||
// interfaces like Hijacker, Flusher, and Pusher by delegating to the underlying
|
||||
// ResponseWriter if it supports them.
|
||||
//
|
||||
// This is the standard pattern for Go middleware that needs to wrap ResponseWriter
|
||||
// while maintaining support for protocol upgrades (WebSocket), streaming (Flusher),
|
||||
// and HTTP/2 server push.
|
||||
type PassthroughWriter struct {
|
||||
http.ResponseWriter
|
||||
}
|
||||
|
||||
// New creates a new wrapper around the given ResponseWriter.
|
||||
func New(w http.ResponseWriter) *PassthroughWriter {
|
||||
return &PassthroughWriter{ResponseWriter: w}
|
||||
}
|
||||
|
||||
// Hijack implements http.Hijacker interface if the underlying ResponseWriter supports it.
|
||||
// This is required for WebSocket connections and other protocol upgrades.
|
||||
func (w *PassthroughWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
if hijacker, ok := w.ResponseWriter.(http.Hijacker); ok {
|
||||
return hijacker.Hijack()
|
||||
}
|
||||
return nil, nil, http.ErrNotSupported
|
||||
}
|
||||
|
||||
// Flush implements http.Flusher interface if the underlying ResponseWriter supports it.
|
||||
func (w *PassthroughWriter) Flush() {
|
||||
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// Push implements http.Pusher interface if the underlying ResponseWriter supports it.
|
||||
func (w *PassthroughWriter) Push(target string, opts *http.PushOptions) error {
|
||||
if pusher, ok := w.ResponseWriter.(http.Pusher); ok {
|
||||
return pusher.Push(target, opts)
|
||||
}
|
||||
return http.ErrNotSupported
|
||||
}
|
||||
|
||||
// Unwrap returns the underlying ResponseWriter.
|
||||
// This is required for http.ResponseController (Go 1.20+) to work correctly.
|
||||
func (w *PassthroughWriter) Unwrap() http.ResponseWriter {
|
||||
return w.ResponseWriter
|
||||
}
|
||||
106
proxy/proxyprotocol_test.go
Normal file
106
proxy/proxyprotocol_test.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
proxyproto "github.com/pires/go-proxyproto"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestWrapProxyProtocol_OverridesRemoteAddr(t *testing.T) {
|
||||
srv := &Server{
|
||||
Logger: log.StandardLogger(),
|
||||
TrustedProxies: []netip.Prefix{netip.MustParsePrefix("127.0.0.1/32")},
|
||||
ProxyProtocol: true,
|
||||
}
|
||||
|
||||
raw, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer raw.Close()
|
||||
|
||||
ln := srv.wrapProxyProtocol(raw)
|
||||
|
||||
realClientIP := "203.0.113.50"
|
||||
realClientPort := uint16(54321)
|
||||
|
||||
accepted := make(chan net.Conn, 1)
|
||||
go func() {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
accepted <- conn
|
||||
}()
|
||||
|
||||
// Connect and send a PROXY v2 header.
|
||||
conn, err := net.Dial("tcp", ln.Addr().String())
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
header := &proxyproto.Header{
|
||||
Version: 2,
|
||||
Command: proxyproto.PROXY,
|
||||
TransportProtocol: proxyproto.TCPv4,
|
||||
SourceAddr: &net.TCPAddr{IP: net.ParseIP(realClientIP), Port: int(realClientPort)},
|
||||
DestinationAddr: &net.TCPAddr{IP: net.ParseIP("10.0.0.1"), Port: 443},
|
||||
}
|
||||
_, err = header.WriteTo(conn)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case accepted := <-accepted:
|
||||
defer accepted.Close()
|
||||
host, _, err := net.SplitHostPort(accepted.RemoteAddr().String())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, realClientIP, host, "RemoteAddr should reflect the PROXY header source IP")
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for connection")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyProtocolPolicy_TrustedRequires(t *testing.T) {
|
||||
srv := &Server{
|
||||
Logger: log.StandardLogger(),
|
||||
TrustedProxies: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
|
||||
}
|
||||
|
||||
opts := proxyproto.ConnPolicyOptions{
|
||||
Upstream: &net.TCPAddr{IP: net.ParseIP("10.0.0.1"), Port: 1234},
|
||||
}
|
||||
policy, err := srv.proxyProtocolPolicy(opts)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, proxyproto.REQUIRE, policy, "trusted source should require PROXY header")
|
||||
}
|
||||
|
||||
func TestProxyProtocolPolicy_UntrustedIgnores(t *testing.T) {
|
||||
srv := &Server{
|
||||
Logger: log.StandardLogger(),
|
||||
TrustedProxies: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
|
||||
}
|
||||
|
||||
opts := proxyproto.ConnPolicyOptions{
|
||||
Upstream: &net.TCPAddr{IP: net.ParseIP("203.0.113.50"), Port: 1234},
|
||||
}
|
||||
policy, err := srv.proxyProtocolPolicy(opts)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, proxyproto.IGNORE, policy, "untrusted source should have PROXY header ignored")
|
||||
}
|
||||
|
||||
func TestProxyProtocolPolicy_InvalidIPRejects(t *testing.T) {
|
||||
srv := &Server{
|
||||
Logger: log.StandardLogger(),
|
||||
TrustedProxies: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
|
||||
}
|
||||
|
||||
opts := proxyproto.ConnPolicyOptions{
|
||||
Upstream: &net.UnixAddr{Name: "/tmp/test.sock", Net: "unix"},
|
||||
}
|
||||
policy, err := srv.proxyProtocolPolicy(opts)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, proxyproto.REJECT, policy, "unparsable address should be rejected")
|
||||
}
|
||||
203
proxy/server.go
203
proxy/server.go
@@ -23,6 +23,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
proxyproto "github.com/pires/go-proxyproto"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -36,6 +37,7 @@ import (
|
||||
"github.com/netbirdio/netbird/proxy/internal/acme"
|
||||
"github.com/netbirdio/netbird/proxy/internal/auth"
|
||||
"github.com/netbirdio/netbird/proxy/internal/certwatch"
|
||||
"github.com/netbirdio/netbird/proxy/internal/conntrack"
|
||||
"github.com/netbirdio/netbird/proxy/internal/debug"
|
||||
proxygrpc "github.com/netbirdio/netbird/proxy/internal/grpc"
|
||||
"github.com/netbirdio/netbird/proxy/internal/health"
|
||||
@@ -63,6 +65,11 @@ type Server struct {
|
||||
healthChecker *health.Checker
|
||||
meter *metrics.Metrics
|
||||
|
||||
// hijackTracker tracks hijacked connections (e.g. WebSocket upgrades)
|
||||
// so they can be closed during graceful shutdown, since http.Server.Shutdown
|
||||
// does not handle them.
|
||||
hijackTracker conntrack.HijackTracker
|
||||
|
||||
// Mostly used for debugging on management.
|
||||
startTime time.Time
|
||||
|
||||
@@ -92,7 +99,7 @@ type Server struct {
|
||||
DebugEndpointEnabled bool
|
||||
// DebugEndpointAddress is the address for the debug HTTP endpoint (default: ":8444").
|
||||
DebugEndpointAddress string
|
||||
// HealthAddress is the address for the health probe endpoint (default: "localhost:8080").
|
||||
// HealthAddress is the address for the health probe endpoint.
|
||||
HealthAddress string
|
||||
// ProxyToken is the access token for authenticating with the management server.
|
||||
ProxyToken string
|
||||
@@ -107,6 +114,10 @@ type Server struct {
|
||||
// random OS-assigned port. A fixed port only works with single-account
|
||||
// deployments; multiple accounts will fail to bind the same port.
|
||||
WireguardPort int
|
||||
// ProxyProtocol enables PROXY protocol (v1/v2) on TCP listeners.
|
||||
// When enabled, the real client IP is extracted from the PROXY header
|
||||
// sent by upstream L4 proxies that support PROXY protocol.
|
||||
ProxyProtocol bool
|
||||
}
|
||||
|
||||
// NotifyStatus sends a status update to management about tunnel connectivity
|
||||
@@ -137,23 +148,8 @@ func (s *Server) NotifyCertificateIssued(ctx context.Context, accountID, service
|
||||
}
|
||||
|
||||
func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
||||
s.startTime = time.Now()
|
||||
s.initDefaults()
|
||||
|
||||
// If no ID is set then one can be generated.
|
||||
if s.ID == "" {
|
||||
s.ID = "netbird-proxy-" + s.startTime.Format("20060102150405")
|
||||
}
|
||||
// Fallback version option in case it is not set.
|
||||
if s.Version == "" {
|
||||
s.Version = "dev"
|
||||
}
|
||||
|
||||
// If no logger is specified fallback to the standard logger.
|
||||
if s.Logger == nil {
|
||||
s.Logger = log.StandardLogger()
|
||||
}
|
||||
|
||||
// Start up metrics gathering
|
||||
reg := prometheus.NewRegistry()
|
||||
s.meter = metrics.New(reg)
|
||||
|
||||
@@ -189,53 +185,41 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
||||
|
||||
s.healthChecker = health.NewChecker(s.Logger, s.netbird)
|
||||
|
||||
if s.DebugEndpointEnabled {
|
||||
debugAddr := debugEndpointAddr(s.DebugEndpointAddress)
|
||||
debugHandler := debug.NewHandler(s.netbird, s.healthChecker, s.Logger)
|
||||
if s.acme != nil {
|
||||
debugHandler.SetCertStatus(s.acme)
|
||||
}
|
||||
s.debug = &http.Server{
|
||||
Addr: debugAddr,
|
||||
Handler: debugHandler,
|
||||
ErrorLog: newHTTPServerLogger(s.Logger, logtagValueDebug),
|
||||
}
|
||||
go func() {
|
||||
s.Logger.Infof("starting debug endpoint on %s", debugAddr)
|
||||
if err := s.debug.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
s.Logger.Errorf("debug endpoint error: %v", err)
|
||||
}
|
||||
}()
|
||||
s.startDebugEndpoint()
|
||||
|
||||
if err := s.startHealthServer(reg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Start health probe server.
|
||||
healthAddr := s.HealthAddress
|
||||
if healthAddr == "" {
|
||||
healthAddr = "localhost:8080"
|
||||
}
|
||||
s.healthServer = health.NewServer(healthAddr, s.healthChecker, s.Logger, promhttp.HandlerFor(reg, promhttp.HandlerOpts{}))
|
||||
healthListener, err := net.Listen("tcp", healthAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("health probe server listen on %s: %w", healthAddr, err)
|
||||
}
|
||||
go func() {
|
||||
if err := s.healthServer.Serve(healthListener); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
s.Logger.Errorf("health probe server: %v", err)
|
||||
}
|
||||
}()
|
||||
// Build the handler chain from inside out.
|
||||
handler := http.Handler(s.proxy)
|
||||
handler = s.auth.Protect(handler)
|
||||
handler = web.AssetHandler(handler)
|
||||
handler = accessLog.Middleware(handler)
|
||||
handler = s.meter.Middleware(handler)
|
||||
handler = s.hijackTracker.Middleware(handler)
|
||||
|
||||
// Start the reverse proxy HTTPS server.
|
||||
s.https = &http.Server{
|
||||
Addr: addr,
|
||||
Handler: s.meter.Middleware(accessLog.Middleware(web.AssetHandler(s.auth.Protect(s.proxy)))),
|
||||
Handler: handler,
|
||||
TLSConfig: tlsConfig,
|
||||
ErrorLog: newHTTPServerLogger(s.Logger, logtagValueHTTPS),
|
||||
}
|
||||
|
||||
lc := net.ListenConfig{}
|
||||
ln, err := lc.Listen(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listen on %s: %w", addr, err)
|
||||
}
|
||||
if s.ProxyProtocol {
|
||||
ln = s.wrapProxyProtocol(ln)
|
||||
}
|
||||
|
||||
httpsErr := make(chan error, 1)
|
||||
go func() {
|
||||
s.Logger.Debugf("starting reverse proxy server on %s", addr)
|
||||
httpsErr <- s.https.ListenAndServeTLS("", "")
|
||||
httpsErr <- s.https.ServeTLS(ln, "", "")
|
||||
}()
|
||||
|
||||
select {
|
||||
@@ -251,7 +235,115 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
||||
}
|
||||
}
|
||||
|
||||
// initDefaults sets fallback values for optional Server fields.
|
||||
func (s *Server) initDefaults() {
|
||||
s.startTime = time.Now()
|
||||
|
||||
// If no ID is set then one can be generated.
|
||||
if s.ID == "" {
|
||||
s.ID = "netbird-proxy-" + s.startTime.Format("20060102150405")
|
||||
}
|
||||
// Fallback version option in case it is not set.
|
||||
if s.Version == "" {
|
||||
s.Version = "dev"
|
||||
}
|
||||
|
||||
// If no logger is specified fallback to the standard logger.
|
||||
if s.Logger == nil {
|
||||
s.Logger = log.StandardLogger()
|
||||
}
|
||||
}
|
||||
|
||||
// startDebugEndpoint launches the debug HTTP server if enabled.
|
||||
func (s *Server) startDebugEndpoint() {
|
||||
if !s.DebugEndpointEnabled {
|
||||
return
|
||||
}
|
||||
debugAddr := debugEndpointAddr(s.DebugEndpointAddress)
|
||||
debugHandler := debug.NewHandler(s.netbird, s.healthChecker, s.Logger)
|
||||
if s.acme != nil {
|
||||
debugHandler.SetCertStatus(s.acme)
|
||||
}
|
||||
s.debug = &http.Server{
|
||||
Addr: debugAddr,
|
||||
Handler: debugHandler,
|
||||
ErrorLog: newHTTPServerLogger(s.Logger, logtagValueDebug),
|
||||
}
|
||||
go func() {
|
||||
s.Logger.Infof("starting debug endpoint on %s", debugAddr)
|
||||
if err := s.debug.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
s.Logger.Errorf("debug endpoint error: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// startHealthServer launches the health probe and metrics server.
|
||||
func (s *Server) startHealthServer(reg *prometheus.Registry) error {
|
||||
healthAddr := s.HealthAddress
|
||||
if healthAddr == "" {
|
||||
healthAddr = defaultHealthAddr
|
||||
}
|
||||
s.healthServer = health.NewServer(healthAddr, s.healthChecker, s.Logger, promhttp.HandlerFor(reg, promhttp.HandlerOpts{}))
|
||||
healthListener, err := net.Listen("tcp", healthAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("health probe server listen on %s: %w", healthAddr, err)
|
||||
}
|
||||
go func() {
|
||||
if err := s.healthServer.Serve(healthListener); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
s.Logger.Errorf("health probe server: %v", err)
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
// wrapProxyProtocol wraps a listener with PROXY protocol support.
|
||||
// When TrustedProxies is configured, only those sources may send PROXY headers;
|
||||
// connections from untrusted sources have any PROXY header ignored.
|
||||
func (s *Server) wrapProxyProtocol(ln net.Listener) net.Listener {
|
||||
ppListener := &proxyproto.Listener{
|
||||
Listener: ln,
|
||||
ReadHeaderTimeout: proxyProtoHeaderTimeout,
|
||||
}
|
||||
if len(s.TrustedProxies) > 0 {
|
||||
ppListener.ConnPolicy = s.proxyProtocolPolicy
|
||||
} else {
|
||||
s.Logger.Warn("PROXY protocol enabled without trusted proxies; any source may send PROXY headers")
|
||||
}
|
||||
s.Logger.Info("PROXY protocol enabled on listener")
|
||||
return ppListener
|
||||
}
|
||||
|
||||
// proxyProtocolPolicy returns whether to require, skip, or reject the PROXY
|
||||
// header based on whether the connection source is in TrustedProxies.
|
||||
func (s *Server) proxyProtocolPolicy(opts proxyproto.ConnPolicyOptions) (proxyproto.Policy, error) {
|
||||
// No logging on reject to prevent abuse
|
||||
tcpAddr, ok := opts.Upstream.(*net.TCPAddr)
|
||||
if !ok {
|
||||
return proxyproto.REJECT, nil
|
||||
}
|
||||
addr, ok := netip.AddrFromSlice(tcpAddr.IP)
|
||||
if !ok {
|
||||
return proxyproto.REJECT, nil
|
||||
}
|
||||
addr = addr.Unmap()
|
||||
|
||||
// called per accept
|
||||
for _, prefix := range s.TrustedProxies {
|
||||
if prefix.Contains(addr) {
|
||||
return proxyproto.REQUIRE, nil
|
||||
}
|
||||
}
|
||||
return proxyproto.IGNORE, nil
|
||||
}
|
||||
|
||||
const (
|
||||
defaultHealthAddr = "localhost:8080"
|
||||
defaultDebugAddr = "localhost:8444"
|
||||
|
||||
// proxyProtoHeaderTimeout is the deadline for reading the PROXY protocol
|
||||
// header after accepting a connection.
|
||||
proxyProtoHeaderTimeout = 5 * time.Second
|
||||
|
||||
// shutdownPreStopDelay is the time to wait after receiving a shutdown signal
|
||||
// before draining connections. This allows the load balancer to propagate
|
||||
// the endpoint removal.
|
||||
@@ -379,7 +471,12 @@ func (s *Server) gracefulShutdown() {
|
||||
s.Logger.Warnf("https server drain: %v", err)
|
||||
}
|
||||
|
||||
// Step 4: Stop all remaining background services.
|
||||
// Step 4: Close hijacked connections (WebSocket) that Shutdown does not handle.
|
||||
if n := s.hijackTracker.CloseAll(); n > 0 {
|
||||
s.Logger.Infof("closed %d hijacked connection(s)", n)
|
||||
}
|
||||
|
||||
// Step 5: Stop all remaining background services.
|
||||
s.shutdownServices()
|
||||
s.Logger.Info("graceful shutdown complete")
|
||||
}
|
||||
@@ -647,7 +744,7 @@ func (s *Server) protoToMapping(mapping *proto.ProxyMapping) proxy.Mapping {
|
||||
// If addr is empty, it defaults to localhost:8444 for security.
|
||||
func debugEndpointAddr(addr string) string {
|
||||
if addr == "" {
|
||||
return "localhost:8444"
|
||||
return defaultDebugAddr
|
||||
}
|
||||
return addr
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user