[proxy] add pre-shared key support (#5377)

This commit is contained in:
Pascal Fischer
2026-02-23 16:31:29 +01:00
committed by GitHub
parent 5d171f181a
commit 9d123ec059
4 changed files with 39 additions and 12 deletions

View File

@@ -53,6 +53,7 @@ var (
certLockMethod string certLockMethod string
wgPort int wgPort int
proxyProtocol bool proxyProtocol bool
preSharedKey string
) )
var rootCmd = &cobra.Command{ var rootCmd = &cobra.Command{
@@ -84,6 +85,7 @@ func init() {
rootCmd.Flags().StringVar(&certLockMethod, "cert-lock-method", envStringOrDefault("NB_PROXY_CERT_LOCK_METHOD", "auto"), "Certificate lock method for cross-replica coordination: auto, flock, or k8s-lease") rootCmd.Flags().StringVar(&certLockMethod, "cert-lock-method", envStringOrDefault("NB_PROXY_CERT_LOCK_METHOD", "auto"), "Certificate lock method for cross-replica coordination: auto, flock, or k8s-lease")
rootCmd.Flags().IntVar(&wgPort, "wg-port", envIntOrDefault("NB_PROXY_WG_PORT", 0), "WireGuard listen port (0 = random). Fixed port only works with single-account deployments") rootCmd.Flags().IntVar(&wgPort, "wg-port", envIntOrDefault("NB_PROXY_WG_PORT", 0), "WireGuard listen port (0 = random). Fixed port only works with single-account deployments")
rootCmd.Flags().BoolVar(&proxyProtocol, "proxy-protocol", envBoolOrDefault("NB_PROXY_PROXY_PROTOCOL", false), "Enable PROXY protocol on TCP listeners to preserve client IPs behind L4 proxies") rootCmd.Flags().BoolVar(&proxyProtocol, "proxy-protocol", envBoolOrDefault("NB_PROXY_PROXY_PROTOCOL", false), "Enable PROXY protocol on TCP listeners to preserve client IPs behind L4 proxies")
rootCmd.Flags().StringVar(&preSharedKey, "preshared-key", envStringOrDefault("NB_PROXY_PRESHARED_KEY", ""), "Define a pre-shared key for the tunnel between proxy and peers")
} }
// Execute runs the root command. // Execute runs the root command.
@@ -156,6 +158,7 @@ func runServer(cmd *cobra.Command, args []string) error {
CertLockMethod: nbacme.CertLockMethod(certLockMethod), CertLockMethod: nbacme.CertLockMethod(certLockMethod),
WireguardPort: wgPort, WireguardPort: wgPort,
ProxyProtocol: proxyProtocol, ProxyProtocol: proxyProtocol,
PreSharedKey: preSharedKey,
} }
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT) ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT)

View File

@@ -86,6 +86,13 @@ func (e *clientEntry) acquireInflight(backend backendKey) (release func(), ok bo
} }
} }
// ClientConfig holds configuration for the embedded NetBird client.
type ClientConfig struct {
MgmtAddr string
WGPort int
PreSharedKey string
}
type statusNotifier interface { type statusNotifier interface {
NotifyStatus(ctx context.Context, accountID, serviceID, domain string, connected bool) error NotifyStatus(ctx context.Context, accountID, serviceID, domain string, connected bool) error
} }
@@ -98,10 +105,9 @@ type managementClient interface {
// backed by underlying NetBird connections. // backed by underlying NetBird connections.
// Clients are keyed by AccountID, allowing multiple domains to share the same connection. // Clients are keyed by AccountID, allowing multiple domains to share the same connection.
type NetBird struct { type NetBird struct {
mgmtAddr string
proxyID string proxyID string
proxyAddr string proxyAddr string
wgPort int clientCfg ClientConfig
logger *log.Logger logger *log.Logger
mgmtClient managementClient mgmtClient managementClient
transportCfg transportConfig transportCfg transportConfig
@@ -229,11 +235,12 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
// The peer has already been created via CreateProxyPeer RPC with the public key. // The peer has already been created via CreateProxyPeer RPC with the public key.
client, err := embed.New(embed.Options{ client, err := embed.New(embed.Options{
DeviceName: deviceNamePrefix + n.proxyID, DeviceName: deviceNamePrefix + n.proxyID,
ManagementURL: n.mgmtAddr, ManagementURL: n.clientCfg.MgmtAddr,
PrivateKey: privateKey.String(), PrivateKey: privateKey.String(),
LogLevel: log.WarnLevel.String(), LogLevel: log.WarnLevel.String(),
BlockInbound: true, BlockInbound: true,
WireguardPort: &n.wgPort, WireguardPort: &n.clientCfg.WGPort,
PreSharedKey: n.clientCfg.PreSharedKey,
}) })
if err != nil { if err != nil {
return nil, fmt.Errorf("create netbird client: %w", err) return nil, fmt.Errorf("create netbird client: %w", err)
@@ -536,18 +543,17 @@ func (n *NetBird) ListClientsForStartup() map[types.AccountID]*embed.Client {
return result return result
} }
// NewNetBird creates a new NetBird transport. Set wgPort to 0 for a random // NewNetBird creates a new NetBird transport. Set clientCfg.WGPort to 0 for a random
// OS-assigned port. A fixed port only works with single-account deployments; // OS-assigned port. A fixed port only works with single-account deployments;
// multiple accounts will fail to bind the same port. // multiple accounts will fail to bind the same port.
func NewNetBird(mgmtAddr, proxyID, proxyAddr string, wgPort int, logger *log.Logger, notifier statusNotifier, mgmtClient managementClient) *NetBird { func NewNetBird(proxyID, proxyAddr string, clientCfg ClientConfig, logger *log.Logger, notifier statusNotifier, mgmtClient managementClient) *NetBird {
if logger == nil { if logger == nil {
logger = log.StandardLogger() logger = log.StandardLogger()
} }
return &NetBird{ return &NetBird{
mgmtAddr: mgmtAddr,
proxyID: proxyID, proxyID: proxyID,
proxyAddr: proxyAddr, proxyAddr: proxyAddr,
wgPort: wgPort, clientCfg: clientCfg,
logger: logger, logger: logger,
clients: make(map[types.AccountID]*clientEntry), clients: make(map[types.AccountID]*clientEntry),
statusNotifier: notifier, statusNotifier: notifier,

View File

@@ -49,7 +49,11 @@ func (m *mockStatusNotifier) calls() []statusCall {
// mockNetBird creates a NetBird instance for testing without actually connecting. // mockNetBird creates a NetBird instance for testing without actually connecting.
// It uses an invalid management URL to prevent real connections. // It uses an invalid management URL to prevent real connections.
func mockNetBird() *NetBird { func mockNetBird() *NetBird {
return NewNetBird("http://invalid.test:9999", "test-proxy", "invalid.test", 0, nil, nil, &mockMgmtClient{}) return NewNetBird("test-proxy", "invalid.test", ClientConfig{
MgmtAddr: "http://invalid.test:9999",
WGPort: 0,
PreSharedKey: "",
}, nil, nil, &mockMgmtClient{})
} }
func TestNetBird_AddPeer_CreatesClientForNewAccount(t *testing.T) { func TestNetBird_AddPeer_CreatesClientForNewAccount(t *testing.T) {
@@ -282,7 +286,11 @@ func TestNetBird_RoundTrip_RequiresExistingClient(t *testing.T) {
func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) { func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) {
notifier := &mockStatusNotifier{} notifier := &mockStatusNotifier{}
nb := NewNetBird("http://invalid.test:9999", "test-proxy", "invalid.test", 0, nil, notifier, &mockMgmtClient{}) nb := NewNetBird("test-proxy", "invalid.test", ClientConfig{
MgmtAddr: "http://invalid.test:9999",
WGPort: 0,
PreSharedKey: "",
}, nil, notifier, &mockMgmtClient{})
accountID := types.AccountID("account-1") accountID := types.AccountID("account-1")
// Add first domain — creates a new client entry. // Add first domain — creates a new client entry.
@@ -308,7 +316,11 @@ func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) {
func TestNetBird_RemovePeer_NotifiesDisconnection(t *testing.T) { func TestNetBird_RemovePeer_NotifiesDisconnection(t *testing.T) {
notifier := &mockStatusNotifier{} notifier := &mockStatusNotifier{}
nb := NewNetBird("http://invalid.test:9999", "test-proxy", "invalid.test", 0, nil, notifier, &mockMgmtClient{}) nb := NewNetBird("test-proxy", "invalid.test", ClientConfig{
MgmtAddr: "http://invalid.test:9999",
WGPort: 0,
PreSharedKey: "",
}, nil, notifier, &mockMgmtClient{})
accountID := types.AccountID("account-1") accountID := types.AccountID("account-1")
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "key-1", "svc-1") err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "key-1", "svc-1")

View File

@@ -114,6 +114,8 @@ type Server struct {
// When enabled, the real client IP is extracted from the PROXY header // When enabled, the real client IP is extracted from the PROXY header
// sent by upstream L4 proxies that support PROXY protocol. // sent by upstream L4 proxies that support PROXY protocol.
ProxyProtocol bool ProxyProtocol bool
// PreSharedKey used for tunnel between proxy and peers (set globally not per account)
PreSharedKey string
} }
// NotifyStatus sends a status update to management about tunnel connectivity // NotifyStatus sends a status update to management about tunnel connectivity
@@ -163,7 +165,11 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
// Initialize the netbird client, this is required to build peer connections // Initialize the netbird client, this is required to build peer connections
// to proxy over. // to proxy over.
s.netbird = roundtrip.NewNetBird(s.ManagementAddress, s.ID, s.ProxyURL, s.WireguardPort, s.Logger, s, s.mgmtClient) s.netbird = roundtrip.NewNetBird(s.ID, s.ProxyURL, roundtrip.ClientConfig{
MgmtAddr: s.ManagementAddress,
WGPort: s.WireguardPort,
PreSharedKey: s.PreSharedKey,
}, s.Logger, s, s.mgmtClient)
tlsConfig, err := s.configureTLS(ctx) tlsConfig, err := s.configureTLS(ctx)
if err != nil { if err != nil {