mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 16:26:38 +00:00
Merge branch 'main' into feature/disable-legacy-port
This commit is contained in:
@@ -358,9 +358,9 @@ func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
|
|||||||
// Fast path for IPv4 addresses (4 bytes) - most common case
|
// Fast path for IPv4 addresses (4 bytes) - most common case
|
||||||
if len(oldBytes) == 4 && len(newBytes) == 4 {
|
if len(oldBytes) == 4 && len(newBytes) == 4 {
|
||||||
sum += uint32(^binary.BigEndian.Uint16(oldBytes[0:2]))
|
sum += uint32(^binary.BigEndian.Uint16(oldBytes[0:2]))
|
||||||
sum += uint32(^binary.BigEndian.Uint16(oldBytes[2:4]))
|
sum += uint32(^binary.BigEndian.Uint16(oldBytes[2:4])) //nolint:gosec // length checked above
|
||||||
sum += uint32(binary.BigEndian.Uint16(newBytes[0:2]))
|
sum += uint32(binary.BigEndian.Uint16(newBytes[0:2]))
|
||||||
sum += uint32(binary.BigEndian.Uint16(newBytes[2:4]))
|
sum += uint32(binary.BigEndian.Uint16(newBytes[2:4])) //nolint:gosec // length checked above
|
||||||
} else {
|
} else {
|
||||||
// Fallback for other lengths
|
// Fallback for other lengths
|
||||||
for i := 0; i < len(oldBytes)-1; i += 2 {
|
for i := 0; i < len(oldBytes)-1; i += 2 {
|
||||||
|
|||||||
@@ -410,7 +410,7 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
|
|||||||
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
|
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) onICEStateDisconnected() {
|
func (conn *Conn) onICEStateDisconnected(sessionChanged bool) {
|
||||||
conn.mu.Lock()
|
conn.mu.Lock()
|
||||||
defer conn.mu.Unlock()
|
defer conn.mu.Unlock()
|
||||||
|
|
||||||
@@ -430,6 +430,10 @@ func (conn *Conn) onICEStateDisconnected() {
|
|||||||
if conn.isReadyToUpgrade() {
|
if conn.isReadyToUpgrade() {
|
||||||
conn.Log.Infof("ICE disconnected, set Relay to active connection")
|
conn.Log.Infof("ICE disconnected, set Relay to active connection")
|
||||||
conn.dumpState.SwitchToRelay()
|
conn.dumpState.SwitchToRelay()
|
||||||
|
if sessionChanged {
|
||||||
|
conn.resetEndpoint()
|
||||||
|
}
|
||||||
|
|
||||||
conn.wgProxyRelay.Work()
|
conn.wgProxyRelay.Work()
|
||||||
|
|
||||||
presharedKey := conn.presharedKey(conn.rosenpassRemoteKey)
|
presharedKey := conn.presharedKey(conn.rosenpassRemoteKey)
|
||||||
@@ -757,6 +761,17 @@ func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
|
|||||||
return wgProxy, nil
|
return wgProxy, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (conn *Conn) resetEndpoint() {
|
||||||
|
if !isController(conn.config) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
conn.Log.Infof("reset wg endpoint")
|
||||||
|
conn.wgWatcher.Reset()
|
||||||
|
if err := conn.endpointUpdater.RemoveEndpointAddress(); err != nil {
|
||||||
|
conn.Log.Warnf("failed to remove endpoint address before update: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (conn *Conn) isReadyToUpgrade() bool {
|
func (conn *Conn) isReadyToUpgrade() bool {
|
||||||
return conn.wgProxyRelay != nil && conn.currentConnPriority != conntype.Relay
|
return conn.wgProxyRelay != nil && conn.currentConnPriority != conntype.Relay
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -66,6 +66,10 @@ func (e *EndpointUpdater) RemoveWgPeer() error {
|
|||||||
return e.wgConfig.WgInterface.RemovePeer(e.wgConfig.RemoteKey)
|
return e.wgConfig.WgInterface.RemovePeer(e.wgConfig.RemoteKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *EndpointUpdater) RemoveEndpointAddress() error {
|
||||||
|
return e.wgConfig.WgInterface.RemoveEndpointAddress(e.wgConfig.RemoteKey)
|
||||||
|
}
|
||||||
|
|
||||||
func (e *EndpointUpdater) waitForCloseTheDelayedUpdate() {
|
func (e *EndpointUpdater) waitForCloseTheDelayedUpdate() {
|
||||||
if e.cancelFunc == nil {
|
if e.cancelFunc == nil {
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -32,6 +32,8 @@ type WGWatcher struct {
|
|||||||
|
|
||||||
enabled bool
|
enabled bool
|
||||||
muEnabled sync.RWMutex
|
muEnabled sync.RWMutex
|
||||||
|
|
||||||
|
resetCh chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher {
|
func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher {
|
||||||
@@ -40,6 +42,7 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin
|
|||||||
wgIfaceStater: wgIfaceStater,
|
wgIfaceStater: wgIfaceStater,
|
||||||
peerKey: peerKey,
|
peerKey: peerKey,
|
||||||
stateDump: stateDump,
|
stateDump: stateDump,
|
||||||
|
resetCh: make(chan struct{}, 1),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -76,6 +79,15 @@ func (w *WGWatcher) IsEnabled() bool {
|
|||||||
return w.enabled
|
return w.enabled
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Reset signals the watcher that the WireGuard peer has been reset and a new
|
||||||
|
// handshake is expected. This restarts the handshake timeout from scratch.
|
||||||
|
func (w *WGWatcher) Reset() {
|
||||||
|
select {
|
||||||
|
case w.resetCh <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// wgStateCheck help to check the state of the WireGuard handshake and relay connection
|
// wgStateCheck help to check the state of the WireGuard handshake and relay connection
|
||||||
func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn func(), enabledTime time.Time, initialHandshake time.Time) {
|
func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn func(), enabledTime time.Time, initialHandshake time.Time) {
|
||||||
w.log.Infof("WireGuard watcher started")
|
w.log.Infof("WireGuard watcher started")
|
||||||
@@ -105,6 +117,12 @@ func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn
|
|||||||
w.stateDump.WGcheckSuccess()
|
w.stateDump.WGcheckSuccess()
|
||||||
|
|
||||||
w.log.Debugf("WireGuard watcher reset timer: %v", resetTime)
|
w.log.Debugf("WireGuard watcher reset timer: %v", resetTime)
|
||||||
|
case <-w.resetCh:
|
||||||
|
w.log.Infof("WireGuard watcher received peer reset, restarting handshake timeout")
|
||||||
|
lastHandshake = time.Time{}
|
||||||
|
enabledTime = time.Now()
|
||||||
|
timer.Stop()
|
||||||
|
timer.Reset(wgHandshakeOvertime)
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
w.log.Infof("WireGuard watcher stopped")
|
w.log.Infof("WireGuard watcher stopped")
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -52,8 +52,9 @@ type WorkerICE struct {
|
|||||||
// increase by one when disconnecting the agent
|
// increase by one when disconnecting the agent
|
||||||
// with it the remote peer can discard the already deprecated offer/answer
|
// with it the remote peer can discard the already deprecated offer/answer
|
||||||
// Without it the remote peer may recreate a workable ICE connection
|
// Without it the remote peer may recreate a workable ICE connection
|
||||||
sessionID ICESessionID
|
sessionID ICESessionID
|
||||||
muxAgent sync.Mutex
|
remoteSessionChanged bool
|
||||||
|
muxAgent sync.Mutex
|
||||||
|
|
||||||
localUfrag string
|
localUfrag string
|
||||||
localPwd string
|
localPwd string
|
||||||
@@ -106,6 +107,7 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
w.log.Debugf("agent already exists, recreate the connection")
|
w.log.Debugf("agent already exists, recreate the connection")
|
||||||
|
w.remoteSessionChanged = true
|
||||||
w.agentDialerCancel()
|
w.agentDialerCancel()
|
||||||
if w.agent != nil {
|
if w.agent != nil {
|
||||||
if err := w.agent.Close(); err != nil {
|
if err := w.agent.Close(); err != nil {
|
||||||
@@ -306,13 +308,17 @@ func (w *WorkerICE) connect(ctx context.Context, agent *icemaker.ThreadSafeAgent
|
|||||||
w.conn.onICEConnectionIsReady(selectedPriority(pair), ci)
|
w.conn.onICEConnectionIsReady(selectedPriority(pair), ci)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.CancelFunc) {
|
func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.CancelFunc) bool {
|
||||||
cancel()
|
cancel()
|
||||||
if err := agent.Close(); err != nil {
|
if err := agent.Close(); err != nil {
|
||||||
w.log.Warnf("failed to close ICE agent: %s", err)
|
w.log.Warnf("failed to close ICE agent: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
w.muxAgent.Lock()
|
w.muxAgent.Lock()
|
||||||
|
defer w.muxAgent.Unlock()
|
||||||
|
|
||||||
|
sessionChanged := w.remoteSessionChanged
|
||||||
|
w.remoteSessionChanged = false
|
||||||
|
|
||||||
if w.agent == agent {
|
if w.agent == agent {
|
||||||
// consider to remove from here and move to the OnNewOffer
|
// consider to remove from here and move to the OnNewOffer
|
||||||
@@ -325,7 +331,7 @@ func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.C
|
|||||||
w.agentConnecting = false
|
w.agentConnecting = false
|
||||||
w.remoteSessionID = ""
|
w.remoteSessionID = ""
|
||||||
}
|
}
|
||||||
w.muxAgent.Unlock()
|
return sessionChanged
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
|
func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
|
||||||
@@ -426,11 +432,11 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia
|
|||||||
// ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to
|
// ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to
|
||||||
// notify the conn.onICEStateDisconnected changes to update the current used priority
|
// notify the conn.onICEStateDisconnected changes to update the current used priority
|
||||||
|
|
||||||
w.closeAgent(agent, dialerCancel)
|
sessionChanged := w.closeAgent(agent, dialerCancel)
|
||||||
|
|
||||||
if w.lastKnownState == ice.ConnectionStateConnected {
|
if w.lastKnownState == ice.ConnectionStateConnected {
|
||||||
w.lastKnownState = ice.ConnectionStateDisconnected
|
w.lastKnownState = ice.ConnectionStateDisconnected
|
||||||
w.conn.onICEStateDisconnected()
|
w.conn.onICEStateDisconnected(sessionChanged)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return
|
return
|
||||||
|
|||||||
1
go.mod
1
go.mod
@@ -83,6 +83,7 @@ require (
|
|||||||
github.com/pion/stun/v3 v3.1.0
|
github.com/pion/stun/v3 v3.1.0
|
||||||
github.com/pion/transport/v3 v3.1.1
|
github.com/pion/transport/v3 v3.1.1
|
||||||
github.com/pion/turn/v3 v3.0.1
|
github.com/pion/turn/v3 v3.0.1
|
||||||
|
github.com/pires/go-proxyproto v0.11.0
|
||||||
github.com/pkg/sftp v1.13.9
|
github.com/pkg/sftp v1.13.9
|
||||||
github.com/prometheus/client_golang v1.23.2
|
github.com/prometheus/client_golang v1.23.2
|
||||||
github.com/quic-go/quic-go v0.55.0
|
github.com/quic-go/quic-go v0.55.0
|
||||||
|
|||||||
2
go.sum
2
go.sum
@@ -474,6 +474,8 @@ github.com/pion/turn/v3 v3.0.1 h1:wLi7BTQr6/Q20R0vt/lHbjv6y4GChFtC33nkYbasoT8=
|
|||||||
github.com/pion/turn/v3 v3.0.1/go.mod h1:MrJDKgqryDyWy1/4NT9TWfXWGMC7UHT6pJIv1+gMeNE=
|
github.com/pion/turn/v3 v3.0.1/go.mod h1:MrJDKgqryDyWy1/4NT9TWfXWGMC7UHT6pJIv1+gMeNE=
|
||||||
github.com/pion/turn/v4 v4.1.1 h1:9UnY2HB99tpDyz3cVVZguSxcqkJ1DsTSZ+8TGruh4fc=
|
github.com/pion/turn/v4 v4.1.1 h1:9UnY2HB99tpDyz3cVVZguSxcqkJ1DsTSZ+8TGruh4fc=
|
||||||
github.com/pion/turn/v4 v4.1.1/go.mod h1:2123tHk1O++vmjI5VSD0awT50NywDAq5A2NNNU4Jjs8=
|
github.com/pion/turn/v4 v4.1.1/go.mod h1:2123tHk1O++vmjI5VSD0awT50NywDAq5A2NNNU4Jjs8=
|
||||||
|
github.com/pires/go-proxyproto v0.11.0 h1:gUQpS85X/VJMdUsYyEgyn59uLJvGqPhJV5YvG68wXH4=
|
||||||
|
github.com/pires/go-proxyproto v0.11.0/go.mod h1:ZKAAyp3cgy5Y5Mo4n9AlScrkCZwUy0g3Jf+slqQVcuU=
|
||||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
|
|||||||
@@ -99,15 +99,16 @@ func NewProvider(ctx context.Context, config *Config) (*Provider, error) {
|
|||||||
|
|
||||||
// Build Dex server config - use Dex's types directly
|
// Build Dex server config - use Dex's types directly
|
||||||
dexConfig := server.Config{
|
dexConfig := server.Config{
|
||||||
Issuer: issuer,
|
Issuer: issuer,
|
||||||
Storage: stor,
|
Storage: stor,
|
||||||
SkipApprovalScreen: true,
|
SkipApprovalScreen: true,
|
||||||
SupportedResponseTypes: []string{"code"},
|
SupportedResponseTypes: []string{"code"},
|
||||||
Logger: logger,
|
ContinueOnConnectorFailure: true,
|
||||||
PrometheusRegistry: prometheus.NewRegistry(),
|
Logger: logger,
|
||||||
RotateKeysAfter: 6 * time.Hour,
|
PrometheusRegistry: prometheus.NewRegistry(),
|
||||||
IDTokensValidFor: 24 * time.Hour,
|
RotateKeysAfter: 6 * time.Hour,
|
||||||
RefreshTokenPolicy: refreshPolicy,
|
IDTokensValidFor: 24 * time.Hour,
|
||||||
|
RefreshTokenPolicy: refreshPolicy,
|
||||||
Web: server.WebConfig{
|
Web: server.WebConfig{
|
||||||
Issuer: "NetBird",
|
Issuer: "NetBird",
|
||||||
},
|
},
|
||||||
@@ -260,6 +261,7 @@ func buildDexConfig(yamlConfig *YAMLConfig, stor storage.Storage, logger *slog.L
|
|||||||
if len(cfg.SupportedResponseTypes) == 0 {
|
if len(cfg.SupportedResponseTypes) == 0 {
|
||||||
cfg.SupportedResponseTypes = []string{"code"}
|
cfg.SupportedResponseTypes = []string{"code"}
|
||||||
}
|
}
|
||||||
|
cfg.ContinueOnConnectorFailure = true
|
||||||
return cfg
|
return cfg
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package dex
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -195,3 +196,64 @@ enablePasswordDB: true
|
|||||||
|
|
||||||
t.Logf("User lookup successful: rawID=%s, connectorID=%s", rawID, connID)
|
t.Logf("User lookup successful: rawID=%s, connectorID=%s", rawID, connID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNewProvider_ContinueOnConnectorFailure(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
tmpDir, err := os.MkdirTemp("", "dex-connector-failure-*")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer os.RemoveAll(tmpDir)
|
||||||
|
|
||||||
|
config := &Config{
|
||||||
|
Issuer: "http://localhost:5556/dex",
|
||||||
|
Port: 5556,
|
||||||
|
DataDir: tmpDir,
|
||||||
|
}
|
||||||
|
|
||||||
|
provider, err := NewProvider(ctx, config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() { _ = provider.Stop(ctx) }()
|
||||||
|
|
||||||
|
// The provider should have started successfully even though
|
||||||
|
// ContinueOnConnectorFailure is an internal Dex config field.
|
||||||
|
// We verify the provider is functional by performing a basic operation.
|
||||||
|
assert.NotNil(t, provider.dexServer)
|
||||||
|
assert.NotNil(t, provider.storage)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildDexConfig_ContinueOnConnectorFailure(t *testing.T) {
|
||||||
|
tmpDir, err := os.MkdirTemp("", "dex-build-config-*")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer os.RemoveAll(tmpDir)
|
||||||
|
|
||||||
|
yamlContent := `
|
||||||
|
issuer: http://localhost:5556/dex
|
||||||
|
storage:
|
||||||
|
type: sqlite3
|
||||||
|
config:
|
||||||
|
file: ` + filepath.Join(tmpDir, "dex.db") + `
|
||||||
|
web:
|
||||||
|
http: 127.0.0.1:5556
|
||||||
|
enablePasswordDB: true
|
||||||
|
`
|
||||||
|
configPath := filepath.Join(tmpDir, "config.yaml")
|
||||||
|
err = os.WriteFile(configPath, []byte(yamlContent), 0644)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
yamlConfig, err := LoadConfig(configPath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
stor, err := yamlConfig.Storage.OpenStorage(slog.New(slog.NewTextHandler(os.Stderr, nil)))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer stor.Close()
|
||||||
|
|
||||||
|
err = initializeStorage(ctx, stor, yamlConfig)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
||||||
|
cfg := buildDexConfig(yamlConfig, stor, logger)
|
||||||
|
|
||||||
|
assert.True(t, cfg.ContinueOnConnectorFailure,
|
||||||
|
"buildDexConfig must set ContinueOnConnectorFailure to true so management starts even if an external IdP is down")
|
||||||
|
}
|
||||||
|
|||||||
@@ -329,6 +329,9 @@ initialize_default_values() {
|
|||||||
BIND_LOCALHOST_ONLY="true"
|
BIND_LOCALHOST_ONLY="true"
|
||||||
EXTERNAL_PROXY_NETWORK=""
|
EXTERNAL_PROXY_NETWORK=""
|
||||||
|
|
||||||
|
# Traefik static IP within the internal bridge network
|
||||||
|
TRAEFIK_IP="172.30.0.10"
|
||||||
|
|
||||||
# NetBird Proxy configuration
|
# NetBird Proxy configuration
|
||||||
ENABLE_PROXY="false"
|
ENABLE_PROXY="false"
|
||||||
PROXY_DOMAIN=""
|
PROXY_DOMAIN=""
|
||||||
@@ -393,7 +396,7 @@ check_existing_installation() {
|
|||||||
echo "Generated files already exist, if you want to reinitialize the environment, please remove them first."
|
echo "Generated files already exist, if you want to reinitialize the environment, please remove them first."
|
||||||
echo "You can use the following commands:"
|
echo "You can use the following commands:"
|
||||||
echo " $DOCKER_COMPOSE_COMMAND down --volumes # to remove all containers and volumes"
|
echo " $DOCKER_COMPOSE_COMMAND down --volumes # to remove all containers and volumes"
|
||||||
echo " rm -f docker-compose.yml dashboard.env config.yaml proxy.env nginx-netbird.conf caddyfile-netbird.txt npm-advanced-config.txt"
|
echo " rm -f docker-compose.yml dashboard.env config.yaml proxy.env traefik-dynamic.yaml nginx-netbird.conf caddyfile-netbird.txt npm-advanced-config.txt"
|
||||||
echo "Be aware that this will remove all data from the database, and you will have to reconfigure the dashboard."
|
echo "Be aware that this will remove all data from the database, and you will have to reconfigure the dashboard."
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
@@ -412,6 +415,8 @@ generate_configuration_files() {
|
|||||||
# This will be overwritten with the actual token after netbird-server starts
|
# This will be overwritten with the actual token after netbird-server starts
|
||||||
echo "# Placeholder - will be updated with token after netbird-server starts" > proxy.env
|
echo "# Placeholder - will be updated with token after netbird-server starts" > proxy.env
|
||||||
echo "NB_PROXY_TOKEN=placeholder" >> proxy.env
|
echo "NB_PROXY_TOKEN=placeholder" >> proxy.env
|
||||||
|
# TCP ServersTransport for PROXY protocol v2 to the proxy backend
|
||||||
|
render_traefik_dynamic > traefik-dynamic.yaml
|
||||||
fi
|
fi
|
||||||
;;
|
;;
|
||||||
1)
|
1)
|
||||||
@@ -559,10 +564,14 @@ init_environment() {
|
|||||||
############################################
|
############################################
|
||||||
|
|
||||||
render_docker_compose_traefik_builtin() {
|
render_docker_compose_traefik_builtin() {
|
||||||
# Generate proxy service section if enabled
|
# Generate proxy service section and Traefik dynamic config if enabled
|
||||||
local proxy_service=""
|
local proxy_service=""
|
||||||
local proxy_volumes=""
|
local proxy_volumes=""
|
||||||
|
local traefik_file_provider=""
|
||||||
|
local traefik_dynamic_volume=""
|
||||||
if [[ "$ENABLE_PROXY" == "true" ]]; then
|
if [[ "$ENABLE_PROXY" == "true" ]]; then
|
||||||
|
traefik_file_provider=' - "--providers.file.filename=/etc/traefik/dynamic.yaml"'
|
||||||
|
traefik_dynamic_volume=" - ./traefik-dynamic.yaml:/etc/traefik/dynamic.yaml:ro"
|
||||||
proxy_service="
|
proxy_service="
|
||||||
# NetBird Proxy - exposes internal resources to the internet
|
# NetBird Proxy - exposes internal resources to the internet
|
||||||
proxy:
|
proxy:
|
||||||
@@ -570,7 +579,7 @@ render_docker_compose_traefik_builtin() {
|
|||||||
container_name: netbird-proxy
|
container_name: netbird-proxy
|
||||||
# Hairpin NAT fix: route domain back to traefik's static IP within Docker
|
# Hairpin NAT fix: route domain back to traefik's static IP within Docker
|
||||||
extra_hosts:
|
extra_hosts:
|
||||||
- \"$NETBIRD_DOMAIN:172.30.0.10\"
|
- \"$NETBIRD_DOMAIN:$TRAEFIK_IP\"
|
||||||
ports:
|
ports:
|
||||||
- 51820:51820/udp
|
- 51820:51820/udp
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
@@ -590,6 +599,7 @@ render_docker_compose_traefik_builtin() {
|
|||||||
- traefik.tcp.routers.proxy-passthrough.service=proxy-tls
|
- traefik.tcp.routers.proxy-passthrough.service=proxy-tls
|
||||||
- traefik.tcp.routers.proxy-passthrough.priority=1
|
- traefik.tcp.routers.proxy-passthrough.priority=1
|
||||||
- traefik.tcp.services.proxy-tls.loadbalancer.server.port=8443
|
- traefik.tcp.services.proxy-tls.loadbalancer.server.port=8443
|
||||||
|
- traefik.tcp.services.proxy-tls.loadbalancer.serverstransport=pp-v2@file
|
||||||
logging:
|
logging:
|
||||||
driver: \"json-file\"
|
driver: \"json-file\"
|
||||||
options:
|
options:
|
||||||
@@ -609,7 +619,7 @@ services:
|
|||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
networks:
|
networks:
|
||||||
netbird:
|
netbird:
|
||||||
ipv4_address: 172.30.0.10
|
ipv4_address: $TRAEFIK_IP
|
||||||
command:
|
command:
|
||||||
# Logging
|
# Logging
|
||||||
- "--log.level=INFO"
|
- "--log.level=INFO"
|
||||||
@@ -636,12 +646,14 @@ services:
|
|||||||
# gRPC transport settings
|
# gRPC transport settings
|
||||||
- "--serverstransport.forwardingtimeouts.responseheadertimeout=0s"
|
- "--serverstransport.forwardingtimeouts.responseheadertimeout=0s"
|
||||||
- "--serverstransport.forwardingtimeouts.idleconntimeout=0s"
|
- "--serverstransport.forwardingtimeouts.idleconntimeout=0s"
|
||||||
|
$traefik_file_provider
|
||||||
ports:
|
ports:
|
||||||
- '443:443'
|
- '443:443'
|
||||||
- '80:80'
|
- '80:80'
|
||||||
volumes:
|
volumes:
|
||||||
- /var/run/docker.sock:/var/run/docker.sock:ro
|
- /var/run/docker.sock:/var/run/docker.sock:ro
|
||||||
- netbird_traefik_letsencrypt:/letsencrypt
|
- netbird_traefik_letsencrypt:/letsencrypt
|
||||||
|
$traefik_dynamic_volume
|
||||||
logging:
|
logging:
|
||||||
driver: "json-file"
|
driver: "json-file"
|
||||||
options:
|
options:
|
||||||
@@ -751,6 +763,10 @@ server:
|
|||||||
cliRedirectURIs:
|
cliRedirectURIs:
|
||||||
- "http://localhost:53000/"
|
- "http://localhost:53000/"
|
||||||
|
|
||||||
|
reverseProxy:
|
||||||
|
trustedHTTPProxies:
|
||||||
|
- "$TRAEFIK_IP/32"
|
||||||
|
|
||||||
store:
|
store:
|
||||||
engine: "sqlite"
|
engine: "sqlite"
|
||||||
encryptionKey: "$DATASTORE_ENCRYPTION_KEY"
|
encryptionKey: "$DATASTORE_ENCRYPTION_KEY"
|
||||||
@@ -780,6 +796,17 @@ EOF
|
|||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
render_traefik_dynamic() {
|
||||||
|
cat <<'EOF'
|
||||||
|
tcp:
|
||||||
|
serversTransports:
|
||||||
|
pp-v2:
|
||||||
|
proxyProtocol:
|
||||||
|
version: 2
|
||||||
|
EOF
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
render_proxy_env() {
|
render_proxy_env() {
|
||||||
cat <<EOF
|
cat <<EOF
|
||||||
# NetBird Proxy Configuration
|
# NetBird Proxy Configuration
|
||||||
@@ -799,6 +826,10 @@ NB_PROXY_OIDC_CLIENT_ID=netbird-proxy
|
|||||||
NB_PROXY_OIDC_ENDPOINT=$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/oauth2
|
NB_PROXY_OIDC_ENDPOINT=$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/oauth2
|
||||||
NB_PROXY_OIDC_SCOPES=openid,profile,email
|
NB_PROXY_OIDC_SCOPES=openid,profile,email
|
||||||
NB_PROXY_FORWARDED_PROTO=https
|
NB_PROXY_FORWARDED_PROTO=https
|
||||||
|
# Enable PROXY protocol to preserve client IPs through L4 proxies (Traefik TCP passthrough)
|
||||||
|
NB_PROXY_PROXY_PROTOCOL=true
|
||||||
|
# Trust Traefik's IP for PROXY protocol headers
|
||||||
|
NB_PROXY_TRUSTED_PROXIES=$TRAEFIK_IP
|
||||||
EOF
|
EOF
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
@@ -1176,8 +1207,9 @@ print_builtin_traefik_instructions() {
|
|||||||
echo " The proxy service is enabled and running."
|
echo " The proxy service is enabled and running."
|
||||||
echo " Any domain NOT matching $NETBIRD_DOMAIN will be passed through to the proxy."
|
echo " Any domain NOT matching $NETBIRD_DOMAIN will be passed through to the proxy."
|
||||||
echo " The proxy handles its own TLS certificates via ACME TLS-ALPN-01 challenge."
|
echo " The proxy handles its own TLS certificates via ACME TLS-ALPN-01 challenge."
|
||||||
echo " Point your proxy domain to this server's domain address like in the example below:"
|
echo " Point your proxy domain to this server's domain address like in the examples below:"
|
||||||
echo ""
|
echo ""
|
||||||
|
echo " $PROXY_DOMAIN CNAME $NETBIRD_DOMAIN"
|
||||||
echo " *.$PROXY_DOMAIN CNAME $NETBIRD_DOMAIN"
|
echo " *.$PROXY_DOMAIN CNAME $NETBIRD_DOMAIN"
|
||||||
echo ""
|
echo ""
|
||||||
fi
|
fi
|
||||||
|
|||||||
@@ -224,6 +224,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
|||||||
s.syncSem.Add(1)
|
s.syncSem.Add(1)
|
||||||
|
|
||||||
reqStart := time.Now()
|
reqStart := time.Now()
|
||||||
|
syncStart := reqStart.UTC()
|
||||||
|
|
||||||
ctx := srv.Context()
|
ctx := srv.Context()
|
||||||
|
|
||||||
@@ -300,7 +301,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
|||||||
metahash := metaHash(peerMeta, realIP.String())
|
metahash := metaHash(peerMeta, realIP.String())
|
||||||
s.loginFilter.addLogin(peerKey.String(), metahash)
|
s.loginFilter.addLogin(peerKey.String(), metahash)
|
||||||
|
|
||||||
peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP, reqStart)
|
peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP, syncStart)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err)
|
log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err)
|
||||||
s.syncSem.Add(-1)
|
s.syncSem.Add(-1)
|
||||||
@@ -311,7 +312,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
|
log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
|
||||||
s.syncSem.Add(-1)
|
s.syncSem.Add(-1)
|
||||||
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, reqStart)
|
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, syncStart)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -319,7 +320,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Debugf("error while notify peer connected for %s: %v", peerKey.String(), err)
|
log.WithContext(ctx).Debugf("error while notify peer connected for %s: %v", peerKey.String(), err)
|
||||||
s.syncSem.Add(-1)
|
s.syncSem.Add(-1)
|
||||||
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, reqStart)
|
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, syncStart)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -336,7 +337,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
|||||||
|
|
||||||
s.syncSem.Add(-1)
|
s.syncSem.Add(-1)
|
||||||
|
|
||||||
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv, reqStart)
|
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv, syncStart)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleHandshake(ctx context.Context, srv proto.ManagementService_JobServer) (wgtypes.Key, error) {
|
func (s *Server) handleHandshake(ctx context.Context, srv proto.ManagementService_JobServer) (wgtypes.Key, error) {
|
||||||
|
|||||||
@@ -56,6 +56,7 @@ var (
|
|||||||
certKeyFile string
|
certKeyFile string
|
||||||
certLockMethod string
|
certLockMethod string
|
||||||
wgPort int
|
wgPort int
|
||||||
|
proxyProtocol bool
|
||||||
)
|
)
|
||||||
|
|
||||||
var rootCmd = &cobra.Command{
|
var rootCmd = &cobra.Command{
|
||||||
@@ -90,6 +91,7 @@ func init() {
|
|||||||
rootCmd.Flags().StringVar(&certKeyFile, "cert-key-file", envStringOrDefault("NB_PROXY_CERTIFICATE_KEY_FILE", "tls.key"), "TLS certificate key filename within the certificate directory")
|
rootCmd.Flags().StringVar(&certKeyFile, "cert-key-file", envStringOrDefault("NB_PROXY_CERTIFICATE_KEY_FILE", "tls.key"), "TLS certificate key filename within the certificate directory")
|
||||||
rootCmd.Flags().StringVar(&certLockMethod, "cert-lock-method", envStringOrDefault("NB_PROXY_CERT_LOCK_METHOD", "auto"), "Certificate lock method for cross-replica coordination: auto, flock, or k8s-lease")
|
rootCmd.Flags().StringVar(&certLockMethod, "cert-lock-method", envStringOrDefault("NB_PROXY_CERT_LOCK_METHOD", "auto"), "Certificate lock method for cross-replica coordination: auto, flock, or k8s-lease")
|
||||||
rootCmd.Flags().IntVar(&wgPort, "wg-port", envIntOrDefault("NB_PROXY_WG_PORT", 0), "WireGuard listen port (0 = random). Fixed port only works with single-account deployments")
|
rootCmd.Flags().IntVar(&wgPort, "wg-port", envIntOrDefault("NB_PROXY_WG_PORT", 0), "WireGuard listen port (0 = random). Fixed port only works with single-account deployments")
|
||||||
|
rootCmd.Flags().BoolVar(&proxyProtocol, "proxy-protocol", envBoolOrDefault("NB_PROXY_PROXY_PROTOCOL", false), "Enable PROXY protocol on TCP listeners to preserve client IPs behind L4 proxies")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute runs the root command.
|
// Execute runs the root command.
|
||||||
@@ -165,6 +167,7 @@ func runServer(cmd *cobra.Command, args []string) error {
|
|||||||
TrustedProxies: parsedTrustedProxies,
|
TrustedProxies: parsedTrustedProxies,
|
||||||
CertLockMethod: nbacme.CertLockMethod(certLockMethod),
|
CertLockMethod: nbacme.CertLockMethod(certLockMethod),
|
||||||
WireguardPort: wgPort,
|
WireguardPort: wgPort,
|
||||||
|
ProxyProtocol: proxyProtocol,
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT)
|
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT)
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
||||||
|
"github.com/netbirdio/netbird/proxy/internal/responsewriter"
|
||||||
"github.com/netbirdio/netbird/proxy/web"
|
"github.com/netbirdio/netbird/proxy/web"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -27,8 +28,8 @@ func (l *Logger) Middleware(next http.Handler) http.Handler {
|
|||||||
|
|
||||||
// Use a response writer wrapper so we can access the status code later.
|
// Use a response writer wrapper so we can access the status code later.
|
||||||
sw := &statusWriter{
|
sw := &statusWriter{
|
||||||
w: w,
|
PassthroughWriter: responsewriter.New(w),
|
||||||
status: http.StatusOK,
|
status: http.StatusOK,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resolve the source IP using trusted proxy configuration before passing
|
// Resolve the source IP using trusted proxy configuration before passing
|
||||||
|
|||||||
@@ -1,26 +1,18 @@
|
|||||||
package accesslog
|
package accesslog
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"github.com/netbirdio/netbird/proxy/internal/responsewriter"
|
||||||
)
|
)
|
||||||
|
|
||||||
// statusWriter is a simple wrapper around an http.ResponseWriter
|
// statusWriter captures the HTTP status code from WriteHeader calls.
|
||||||
// that captures the setting of the status code via the WriteHeader
|
// It embeds responsewriter.PassthroughWriter which handles all the optional
|
||||||
// function and stores it so that it can be retrieved later.
|
// interfaces (Hijacker, Flusher, Pusher) automatically.
|
||||||
type statusWriter struct {
|
type statusWriter struct {
|
||||||
w http.ResponseWriter
|
*responsewriter.PassthroughWriter
|
||||||
status int
|
status int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *statusWriter) Header() http.Header {
|
|
||||||
return w.w.Header()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *statusWriter) Write(data []byte) (int, error) {
|
|
||||||
return w.w.Write(data)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *statusWriter) WriteHeader(status int) {
|
func (w *statusWriter) WriteHeader(status int) {
|
||||||
w.status = status
|
w.status = status
|
||||||
w.w.WriteHeader(status)
|
w.PassthroughWriter.WriteHeader(status)
|
||||||
}
|
}
|
||||||
|
|||||||
49
proxy/internal/conntrack/conn.go
Normal file
49
proxy/internal/conntrack/conn.go
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
package conntrack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// trackedConn wraps a net.Conn and removes itself from the tracker on Close.
|
||||||
|
type trackedConn struct {
|
||||||
|
net.Conn
|
||||||
|
tracker *HijackTracker
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *trackedConn) Close() error {
|
||||||
|
c.tracker.conns.Delete(c)
|
||||||
|
return c.Conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// trackingWriter wraps an http.ResponseWriter and intercepts Hijack calls
|
||||||
|
// to replace the raw connection with a trackedConn that auto-deregisters.
|
||||||
|
type trackingWriter struct {
|
||||||
|
http.ResponseWriter
|
||||||
|
tracker *HijackTracker
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *trackingWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
|
hijacker, ok := w.ResponseWriter.(http.Hijacker)
|
||||||
|
if !ok {
|
||||||
|
return nil, nil, http.ErrNotSupported
|
||||||
|
}
|
||||||
|
conn, buf, err := hijacker.Hijack()
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
tc := &trackedConn{Conn: conn, tracker: w.tracker}
|
||||||
|
w.tracker.conns.Store(tc, struct{}{})
|
||||||
|
return tc, buf, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *trackingWriter) Flush() {
|
||||||
|
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *trackingWriter) Unwrap() http.ResponseWriter {
|
||||||
|
return w.ResponseWriter
|
||||||
|
}
|
||||||
41
proxy/internal/conntrack/hijacked.go
Normal file
41
proxy/internal/conntrack/hijacked.go
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
package conntrack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HijackTracker tracks connections that have been hijacked (e.g. WebSocket
|
||||||
|
// upgrades). http.Server.Shutdown does not close hijacked connections, so
|
||||||
|
// they must be tracked and closed explicitly during graceful shutdown.
|
||||||
|
//
|
||||||
|
// Use Middleware as the outermost HTTP middleware to ensure hijacked
|
||||||
|
// connections are tracked and automatically deregistered when closed.
|
||||||
|
type HijackTracker struct {
|
||||||
|
conns sync.Map // net.Conn → struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Middleware returns an HTTP middleware that wraps the ResponseWriter so that
|
||||||
|
// hijacked connections are tracked and automatically deregistered from the
|
||||||
|
// tracker when closed. This should be the outermost middleware in the chain.
|
||||||
|
func (t *HijackTracker) Middleware(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
next.ServeHTTP(&trackingWriter{ResponseWriter: w, tracker: t}, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseAll closes all tracked hijacked connections and returns the number
|
||||||
|
// of connections that were closed.
|
||||||
|
func (t *HijackTracker) CloseAll() int {
|
||||||
|
var count int
|
||||||
|
t.conns.Range(func(key, _ any) bool {
|
||||||
|
if conn, ok := key.(net.Conn); ok {
|
||||||
|
_ = conn.Close()
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
t.conns.Delete(key)
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
return count
|
||||||
|
}
|
||||||
@@ -5,9 +5,11 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
||||||
|
"github.com/netbirdio/netbird/proxy/internal/responsewriter"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Metrics struct {
|
type Metrics struct {
|
||||||
@@ -60,18 +62,18 @@ func New(reg prometheus.Registerer) *Metrics {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type responseInterceptor struct {
|
type responseInterceptor struct {
|
||||||
http.ResponseWriter
|
*responsewriter.PassthroughWriter
|
||||||
status int
|
status int
|
||||||
size int
|
size int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *responseInterceptor) WriteHeader(status int) {
|
func (w *responseInterceptor) WriteHeader(status int) {
|
||||||
w.status = status
|
w.status = status
|
||||||
w.ResponseWriter.WriteHeader(status)
|
w.PassthroughWriter.WriteHeader(status)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *responseInterceptor) Write(b []byte) (int, error) {
|
func (w *responseInterceptor) Write(b []byte) (int, error) {
|
||||||
size, err := w.ResponseWriter.Write(b)
|
size, err := w.PassthroughWriter.Write(b)
|
||||||
w.size += size
|
w.size += size
|
||||||
return size, err
|
return size, err
|
||||||
}
|
}
|
||||||
@@ -81,7 +83,7 @@ func (m *Metrics) Middleware(next http.Handler) http.Handler {
|
|||||||
m.requestsTotal.Inc()
|
m.requestsTotal.Inc()
|
||||||
m.activeRequests.Inc()
|
m.activeRequests.Inc()
|
||||||
|
|
||||||
interceptor := &responseInterceptor{ResponseWriter: w}
|
interceptor := &responseInterceptor{PassthroughWriter: responsewriter.New(w)}
|
||||||
|
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
next.ServeHTTP(interceptor, r)
|
next.ServeHTTP(interceptor, r)
|
||||||
|
|||||||
53
proxy/internal/responsewriter/responsewriter.go
Normal file
53
proxy/internal/responsewriter/responsewriter.go
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
package responsewriter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PassthroughWriter wraps an http.ResponseWriter and preserves optional
|
||||||
|
// interfaces like Hijacker, Flusher, and Pusher by delegating to the underlying
|
||||||
|
// ResponseWriter if it supports them.
|
||||||
|
//
|
||||||
|
// This is the standard pattern for Go middleware that needs to wrap ResponseWriter
|
||||||
|
// while maintaining support for protocol upgrades (WebSocket), streaming (Flusher),
|
||||||
|
// and HTTP/2 server push.
|
||||||
|
type PassthroughWriter struct {
|
||||||
|
http.ResponseWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a new wrapper around the given ResponseWriter.
|
||||||
|
func New(w http.ResponseWriter) *PassthroughWriter {
|
||||||
|
return &PassthroughWriter{ResponseWriter: w}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hijack implements http.Hijacker interface if the underlying ResponseWriter supports it.
|
||||||
|
// This is required for WebSocket connections and other protocol upgrades.
|
||||||
|
func (w *PassthroughWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
|
if hijacker, ok := w.ResponseWriter.(http.Hijacker); ok {
|
||||||
|
return hijacker.Hijack()
|
||||||
|
}
|
||||||
|
return nil, nil, http.ErrNotSupported
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush implements http.Flusher interface if the underlying ResponseWriter supports it.
|
||||||
|
func (w *PassthroughWriter) Flush() {
|
||||||
|
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Push implements http.Pusher interface if the underlying ResponseWriter supports it.
|
||||||
|
func (w *PassthroughWriter) Push(target string, opts *http.PushOptions) error {
|
||||||
|
if pusher, ok := w.ResponseWriter.(http.Pusher); ok {
|
||||||
|
return pusher.Push(target, opts)
|
||||||
|
}
|
||||||
|
return http.ErrNotSupported
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unwrap returns the underlying ResponseWriter.
|
||||||
|
// This is required for http.ResponseController (Go 1.20+) to work correctly.
|
||||||
|
func (w *PassthroughWriter) Unwrap() http.ResponseWriter {
|
||||||
|
return w.ResponseWriter
|
||||||
|
}
|
||||||
106
proxy/proxyprotocol_test.go
Normal file
106
proxy/proxyprotocol_test.go
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
proxyproto "github.com/pires/go-proxyproto"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestWrapProxyProtocol_OverridesRemoteAddr(t *testing.T) {
|
||||||
|
srv := &Server{
|
||||||
|
Logger: log.StandardLogger(),
|
||||||
|
TrustedProxies: []netip.Prefix{netip.MustParsePrefix("127.0.0.1/32")},
|
||||||
|
ProxyProtocol: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
raw, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer raw.Close()
|
||||||
|
|
||||||
|
ln := srv.wrapProxyProtocol(raw)
|
||||||
|
|
||||||
|
realClientIP := "203.0.113.50"
|
||||||
|
realClientPort := uint16(54321)
|
||||||
|
|
||||||
|
accepted := make(chan net.Conn, 1)
|
||||||
|
go func() {
|
||||||
|
conn, err := ln.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
accepted <- conn
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Connect and send a PROXY v2 header.
|
||||||
|
conn, err := net.Dial("tcp", ln.Addr().String())
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
header := &proxyproto.Header{
|
||||||
|
Version: 2,
|
||||||
|
Command: proxyproto.PROXY,
|
||||||
|
TransportProtocol: proxyproto.TCPv4,
|
||||||
|
SourceAddr: &net.TCPAddr{IP: net.ParseIP(realClientIP), Port: int(realClientPort)},
|
||||||
|
DestinationAddr: &net.TCPAddr{IP: net.ParseIP("10.0.0.1"), Port: 443},
|
||||||
|
}
|
||||||
|
_, err = header.WriteTo(conn)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case accepted := <-accepted:
|
||||||
|
defer accepted.Close()
|
||||||
|
host, _, err := net.SplitHostPort(accepted.RemoteAddr().String())
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, realClientIP, host, "RemoteAddr should reflect the PROXY header source IP")
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("timed out waiting for connection")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyProtocolPolicy_TrustedRequires(t *testing.T) {
|
||||||
|
srv := &Server{
|
||||||
|
Logger: log.StandardLogger(),
|
||||||
|
TrustedProxies: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
|
||||||
|
}
|
||||||
|
|
||||||
|
opts := proxyproto.ConnPolicyOptions{
|
||||||
|
Upstream: &net.TCPAddr{IP: net.ParseIP("10.0.0.1"), Port: 1234},
|
||||||
|
}
|
||||||
|
policy, err := srv.proxyProtocolPolicy(opts)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, proxyproto.REQUIRE, policy, "trusted source should require PROXY header")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyProtocolPolicy_UntrustedIgnores(t *testing.T) {
|
||||||
|
srv := &Server{
|
||||||
|
Logger: log.StandardLogger(),
|
||||||
|
TrustedProxies: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
|
||||||
|
}
|
||||||
|
|
||||||
|
opts := proxyproto.ConnPolicyOptions{
|
||||||
|
Upstream: &net.TCPAddr{IP: net.ParseIP("203.0.113.50"), Port: 1234},
|
||||||
|
}
|
||||||
|
policy, err := srv.proxyProtocolPolicy(opts)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, proxyproto.IGNORE, policy, "untrusted source should have PROXY header ignored")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyProtocolPolicy_InvalidIPRejects(t *testing.T) {
|
||||||
|
srv := &Server{
|
||||||
|
Logger: log.StandardLogger(),
|
||||||
|
TrustedProxies: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
|
||||||
|
}
|
||||||
|
|
||||||
|
opts := proxyproto.ConnPolicyOptions{
|
||||||
|
Upstream: &net.UnixAddr{Name: "/tmp/test.sock", Net: "unix"},
|
||||||
|
}
|
||||||
|
policy, err := srv.proxyProtocolPolicy(opts)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, proxyproto.REJECT, policy, "unparsable address should be rejected")
|
||||||
|
}
|
||||||
203
proxy/server.go
203
proxy/server.go
@@ -23,6 +23,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/cenkalti/backoff/v4"
|
"github.com/cenkalti/backoff/v4"
|
||||||
|
proxyproto "github.com/pires/go-proxyproto"
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -36,6 +37,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/proxy/internal/acme"
|
"github.com/netbirdio/netbird/proxy/internal/acme"
|
||||||
"github.com/netbirdio/netbird/proxy/internal/auth"
|
"github.com/netbirdio/netbird/proxy/internal/auth"
|
||||||
"github.com/netbirdio/netbird/proxy/internal/certwatch"
|
"github.com/netbirdio/netbird/proxy/internal/certwatch"
|
||||||
|
"github.com/netbirdio/netbird/proxy/internal/conntrack"
|
||||||
"github.com/netbirdio/netbird/proxy/internal/debug"
|
"github.com/netbirdio/netbird/proxy/internal/debug"
|
||||||
proxygrpc "github.com/netbirdio/netbird/proxy/internal/grpc"
|
proxygrpc "github.com/netbirdio/netbird/proxy/internal/grpc"
|
||||||
"github.com/netbirdio/netbird/proxy/internal/health"
|
"github.com/netbirdio/netbird/proxy/internal/health"
|
||||||
@@ -63,6 +65,11 @@ type Server struct {
|
|||||||
healthChecker *health.Checker
|
healthChecker *health.Checker
|
||||||
meter *metrics.Metrics
|
meter *metrics.Metrics
|
||||||
|
|
||||||
|
// hijackTracker tracks hijacked connections (e.g. WebSocket upgrades)
|
||||||
|
// so they can be closed during graceful shutdown, since http.Server.Shutdown
|
||||||
|
// does not handle them.
|
||||||
|
hijackTracker conntrack.HijackTracker
|
||||||
|
|
||||||
// Mostly used for debugging on management.
|
// Mostly used for debugging on management.
|
||||||
startTime time.Time
|
startTime time.Time
|
||||||
|
|
||||||
@@ -92,7 +99,7 @@ type Server struct {
|
|||||||
DebugEndpointEnabled bool
|
DebugEndpointEnabled bool
|
||||||
// DebugEndpointAddress is the address for the debug HTTP endpoint (default: ":8444").
|
// DebugEndpointAddress is the address for the debug HTTP endpoint (default: ":8444").
|
||||||
DebugEndpointAddress string
|
DebugEndpointAddress string
|
||||||
// HealthAddress is the address for the health probe endpoint (default: "localhost:8080").
|
// HealthAddress is the address for the health probe endpoint.
|
||||||
HealthAddress string
|
HealthAddress string
|
||||||
// ProxyToken is the access token for authenticating with the management server.
|
// ProxyToken is the access token for authenticating with the management server.
|
||||||
ProxyToken string
|
ProxyToken string
|
||||||
@@ -107,6 +114,10 @@ type Server struct {
|
|||||||
// random OS-assigned port. A fixed port only works with single-account
|
// random OS-assigned port. A fixed port only works with single-account
|
||||||
// deployments; multiple accounts will fail to bind the same port.
|
// deployments; multiple accounts will fail to bind the same port.
|
||||||
WireguardPort int
|
WireguardPort int
|
||||||
|
// ProxyProtocol enables PROXY protocol (v1/v2) on TCP listeners.
|
||||||
|
// When enabled, the real client IP is extracted from the PROXY header
|
||||||
|
// sent by upstream L4 proxies that support PROXY protocol.
|
||||||
|
ProxyProtocol bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NotifyStatus sends a status update to management about tunnel connectivity
|
// NotifyStatus sends a status update to management about tunnel connectivity
|
||||||
@@ -137,23 +148,8 @@ func (s *Server) NotifyCertificateIssued(ctx context.Context, accountID, service
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
||||||
s.startTime = time.Now()
|
s.initDefaults()
|
||||||
|
|
||||||
// If no ID is set then one can be generated.
|
|
||||||
if s.ID == "" {
|
|
||||||
s.ID = "netbird-proxy-" + s.startTime.Format("20060102150405")
|
|
||||||
}
|
|
||||||
// Fallback version option in case it is not set.
|
|
||||||
if s.Version == "" {
|
|
||||||
s.Version = "dev"
|
|
||||||
}
|
|
||||||
|
|
||||||
// If no logger is specified fallback to the standard logger.
|
|
||||||
if s.Logger == nil {
|
|
||||||
s.Logger = log.StandardLogger()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start up metrics gathering
|
|
||||||
reg := prometheus.NewRegistry()
|
reg := prometheus.NewRegistry()
|
||||||
s.meter = metrics.New(reg)
|
s.meter = metrics.New(reg)
|
||||||
|
|
||||||
@@ -189,53 +185,41 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
|||||||
|
|
||||||
s.healthChecker = health.NewChecker(s.Logger, s.netbird)
|
s.healthChecker = health.NewChecker(s.Logger, s.netbird)
|
||||||
|
|
||||||
if s.DebugEndpointEnabled {
|
s.startDebugEndpoint()
|
||||||
debugAddr := debugEndpointAddr(s.DebugEndpointAddress)
|
|
||||||
debugHandler := debug.NewHandler(s.netbird, s.healthChecker, s.Logger)
|
if err := s.startHealthServer(reg); err != nil {
|
||||||
if s.acme != nil {
|
return err
|
||||||
debugHandler.SetCertStatus(s.acme)
|
|
||||||
}
|
|
||||||
s.debug = &http.Server{
|
|
||||||
Addr: debugAddr,
|
|
||||||
Handler: debugHandler,
|
|
||||||
ErrorLog: newHTTPServerLogger(s.Logger, logtagValueDebug),
|
|
||||||
}
|
|
||||||
go func() {
|
|
||||||
s.Logger.Infof("starting debug endpoint on %s", debugAddr)
|
|
||||||
if err := s.debug.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
|
||||||
s.Logger.Errorf("debug endpoint error: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start health probe server.
|
// Build the handler chain from inside out.
|
||||||
healthAddr := s.HealthAddress
|
handler := http.Handler(s.proxy)
|
||||||
if healthAddr == "" {
|
handler = s.auth.Protect(handler)
|
||||||
healthAddr = "localhost:8080"
|
handler = web.AssetHandler(handler)
|
||||||
}
|
handler = accessLog.Middleware(handler)
|
||||||
s.healthServer = health.NewServer(healthAddr, s.healthChecker, s.Logger, promhttp.HandlerFor(reg, promhttp.HandlerOpts{}))
|
handler = s.meter.Middleware(handler)
|
||||||
healthListener, err := net.Listen("tcp", healthAddr)
|
handler = s.hijackTracker.Middleware(handler)
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("health probe server listen on %s: %w", healthAddr, err)
|
|
||||||
}
|
|
||||||
go func() {
|
|
||||||
if err := s.healthServer.Serve(healthListener); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
|
||||||
s.Logger.Errorf("health probe server: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Start the reverse proxy HTTPS server.
|
// Start the reverse proxy HTTPS server.
|
||||||
s.https = &http.Server{
|
s.https = &http.Server{
|
||||||
Addr: addr,
|
Addr: addr,
|
||||||
Handler: s.meter.Middleware(accessLog.Middleware(web.AssetHandler(s.auth.Protect(s.proxy)))),
|
Handler: handler,
|
||||||
TLSConfig: tlsConfig,
|
TLSConfig: tlsConfig,
|
||||||
ErrorLog: newHTTPServerLogger(s.Logger, logtagValueHTTPS),
|
ErrorLog: newHTTPServerLogger(s.Logger, logtagValueHTTPS),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
lc := net.ListenConfig{}
|
||||||
|
ln, err := lc.Listen(ctx, "tcp", addr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("listen on %s: %w", addr, err)
|
||||||
|
}
|
||||||
|
if s.ProxyProtocol {
|
||||||
|
ln = s.wrapProxyProtocol(ln)
|
||||||
|
}
|
||||||
|
|
||||||
httpsErr := make(chan error, 1)
|
httpsErr := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
s.Logger.Debugf("starting reverse proxy server on %s", addr)
|
s.Logger.Debugf("starting reverse proxy server on %s", addr)
|
||||||
httpsErr <- s.https.ListenAndServeTLS("", "")
|
httpsErr <- s.https.ServeTLS(ln, "", "")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@@ -251,7 +235,115 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// initDefaults sets fallback values for optional Server fields.
|
||||||
|
func (s *Server) initDefaults() {
|
||||||
|
s.startTime = time.Now()
|
||||||
|
|
||||||
|
// If no ID is set then one can be generated.
|
||||||
|
if s.ID == "" {
|
||||||
|
s.ID = "netbird-proxy-" + s.startTime.Format("20060102150405")
|
||||||
|
}
|
||||||
|
// Fallback version option in case it is not set.
|
||||||
|
if s.Version == "" {
|
||||||
|
s.Version = "dev"
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no logger is specified fallback to the standard logger.
|
||||||
|
if s.Logger == nil {
|
||||||
|
s.Logger = log.StandardLogger()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// startDebugEndpoint launches the debug HTTP server if enabled.
|
||||||
|
func (s *Server) startDebugEndpoint() {
|
||||||
|
if !s.DebugEndpointEnabled {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
debugAddr := debugEndpointAddr(s.DebugEndpointAddress)
|
||||||
|
debugHandler := debug.NewHandler(s.netbird, s.healthChecker, s.Logger)
|
||||||
|
if s.acme != nil {
|
||||||
|
debugHandler.SetCertStatus(s.acme)
|
||||||
|
}
|
||||||
|
s.debug = &http.Server{
|
||||||
|
Addr: debugAddr,
|
||||||
|
Handler: debugHandler,
|
||||||
|
ErrorLog: newHTTPServerLogger(s.Logger, logtagValueDebug),
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
s.Logger.Infof("starting debug endpoint on %s", debugAddr)
|
||||||
|
if err := s.debug.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||||
|
s.Logger.Errorf("debug endpoint error: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// startHealthServer launches the health probe and metrics server.
|
||||||
|
func (s *Server) startHealthServer(reg *prometheus.Registry) error {
|
||||||
|
healthAddr := s.HealthAddress
|
||||||
|
if healthAddr == "" {
|
||||||
|
healthAddr = defaultHealthAddr
|
||||||
|
}
|
||||||
|
s.healthServer = health.NewServer(healthAddr, s.healthChecker, s.Logger, promhttp.HandlerFor(reg, promhttp.HandlerOpts{}))
|
||||||
|
healthListener, err := net.Listen("tcp", healthAddr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("health probe server listen on %s: %w", healthAddr, err)
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
if err := s.healthServer.Serve(healthListener); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||||
|
s.Logger.Errorf("health probe server: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// wrapProxyProtocol wraps a listener with PROXY protocol support.
|
||||||
|
// When TrustedProxies is configured, only those sources may send PROXY headers;
|
||||||
|
// connections from untrusted sources have any PROXY header ignored.
|
||||||
|
func (s *Server) wrapProxyProtocol(ln net.Listener) net.Listener {
|
||||||
|
ppListener := &proxyproto.Listener{
|
||||||
|
Listener: ln,
|
||||||
|
ReadHeaderTimeout: proxyProtoHeaderTimeout,
|
||||||
|
}
|
||||||
|
if len(s.TrustedProxies) > 0 {
|
||||||
|
ppListener.ConnPolicy = s.proxyProtocolPolicy
|
||||||
|
} else {
|
||||||
|
s.Logger.Warn("PROXY protocol enabled without trusted proxies; any source may send PROXY headers")
|
||||||
|
}
|
||||||
|
s.Logger.Info("PROXY protocol enabled on listener")
|
||||||
|
return ppListener
|
||||||
|
}
|
||||||
|
|
||||||
|
// proxyProtocolPolicy returns whether to require, skip, or reject the PROXY
|
||||||
|
// header based on whether the connection source is in TrustedProxies.
|
||||||
|
func (s *Server) proxyProtocolPolicy(opts proxyproto.ConnPolicyOptions) (proxyproto.Policy, error) {
|
||||||
|
// No logging on reject to prevent abuse
|
||||||
|
tcpAddr, ok := opts.Upstream.(*net.TCPAddr)
|
||||||
|
if !ok {
|
||||||
|
return proxyproto.REJECT, nil
|
||||||
|
}
|
||||||
|
addr, ok := netip.AddrFromSlice(tcpAddr.IP)
|
||||||
|
if !ok {
|
||||||
|
return proxyproto.REJECT, nil
|
||||||
|
}
|
||||||
|
addr = addr.Unmap()
|
||||||
|
|
||||||
|
// called per accept
|
||||||
|
for _, prefix := range s.TrustedProxies {
|
||||||
|
if prefix.Contains(addr) {
|
||||||
|
return proxyproto.REQUIRE, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return proxyproto.IGNORE, nil
|
||||||
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
defaultHealthAddr = "localhost:8080"
|
||||||
|
defaultDebugAddr = "localhost:8444"
|
||||||
|
|
||||||
|
// proxyProtoHeaderTimeout is the deadline for reading the PROXY protocol
|
||||||
|
// header after accepting a connection.
|
||||||
|
proxyProtoHeaderTimeout = 5 * time.Second
|
||||||
|
|
||||||
// shutdownPreStopDelay is the time to wait after receiving a shutdown signal
|
// shutdownPreStopDelay is the time to wait after receiving a shutdown signal
|
||||||
// before draining connections. This allows the load balancer to propagate
|
// before draining connections. This allows the load balancer to propagate
|
||||||
// the endpoint removal.
|
// the endpoint removal.
|
||||||
@@ -379,7 +471,12 @@ func (s *Server) gracefulShutdown() {
|
|||||||
s.Logger.Warnf("https server drain: %v", err)
|
s.Logger.Warnf("https server drain: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Step 4: Stop all remaining background services.
|
// Step 4: Close hijacked connections (WebSocket) that Shutdown does not handle.
|
||||||
|
if n := s.hijackTracker.CloseAll(); n > 0 {
|
||||||
|
s.Logger.Infof("closed %d hijacked connection(s)", n)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 5: Stop all remaining background services.
|
||||||
s.shutdownServices()
|
s.shutdownServices()
|
||||||
s.Logger.Info("graceful shutdown complete")
|
s.Logger.Info("graceful shutdown complete")
|
||||||
}
|
}
|
||||||
@@ -647,7 +744,7 @@ func (s *Server) protoToMapping(mapping *proto.ProxyMapping) proxy.Mapping {
|
|||||||
// If addr is empty, it defaults to localhost:8444 for security.
|
// If addr is empty, it defaults to localhost:8444 for security.
|
||||||
func debugEndpointAddr(addr string) string {
|
func debugEndpointAddr(addr string) string {
|
||||||
if addr == "" {
|
if addr == "" {
|
||||||
return "localhost:8444"
|
return defaultDebugAddr
|
||||||
}
|
}
|
||||||
return addr
|
return addr
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user