diff --git a/README.md b/README.md index 69a4fa8..ee91356 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,9 @@ When Newt receives WireGuard control messages, it will use the information encod - `ping-timeout` (optional): Timeout for each ping. Default: 5s - `updown` (optional): A script to be called when targets are added or removed. - `tls-client-cert` (optional): Client certificate (p12 or pfx) for mTLS. See [mTLS](#mtls) +- `tls-client-cert` (optional): Path to client certificate (PEM format, optional if using PKCS12). See [mTLS](#mtls) +- `tls-client-key` (optional): Path to private key for mTLS (PEM format, optional if using PKCS12) +- `tls-ca-cert` (optional): Path to CA certificate to verify server (PEM format, optional if using PKCS12) - `docker-enforce-network-validation` (optional): Validate the container target is on the same network as the newt process. Default: false - `health-file` (optional): Check if connection to WG server (pangolin) is ok. creates a file if ok, removes it if not ok. Can be used with docker healtcheck to restart newt - `accept-clients` (optional): Enable WireGuard server mode to accept incoming olm client connections. Default: false @@ -65,6 +68,9 @@ All CLI arguments can be set using environment variables as an alternative to co - `PING_TIMEOUT`: Timeout for each ping. Default: 5s (equivalent to `--ping-timeout`) - `UPDOWN_SCRIPT`: Path to updown script for target add/remove events (equivalent to `--updown`) - `TLS_CLIENT_CERT`: Path to client certificate for mTLS (equivalent to `--tls-client-cert`) +- `TLS_CLIENT_CERT`: Path to client certificate for mTLS (equivalent to `--tls-client-cert`) +- `TLS_CLIENT_KEY`: Path to private key for mTLS (equivalent to `--tls-client-key`) +- `TLS_CA_CERT`: Path to CA certificate to verify server (equivalent to `--tls-ca-cert`) - `DOCKER_ENFORCE_NETWORK_VALIDATION`: Validate container targets are on same network. Default: false (equivalent to `--docker-enforce-network-validation`) - `HEALTH_FILE`: Path to health file for connection monitoring (equivalent to `--health-file`) - `ACCEPT_CLIENTS`: Enable WireGuard server mode. Default: false (equivalent to `--accept-clients`) @@ -259,16 +265,20 @@ You can look at updown.py as a reference script to get started! ### mTLS -Newt supports mutual TLS (mTLS) authentication, if the server has been configured to request a client certificate. +Newt supports mutual TLS (mTLS) authentication if the server is configured to request a client certificate. You can use either a PKCS12 (.p12/.pfx) file or split PEM files for the client cert, private key, and CA. -- Only PKCS12 (.p12 or .pfx) file format is accepted -- The PKCS12 file must contain: - - Private key - - Public certificate - - CA certificate -- Encrypted PKCS12 files are currently not supported +#### Option 1: PKCS12 (Legacy) -Examples: +> This is the original method and still supported. + +* File must contain: + + * Client private key + * Public certificate + * CA certificate +* Encrypted `.p12` files are **not supported** + +Example: ```bash newt \ @@ -278,6 +288,27 @@ newt \ --tls-client-cert ./client.p12 ``` +#### Option 2: Split PEM Files (Preferred) + +You can now provide separate files for: + +* `--tls-client-cert`: client certificate (`.crt` or `.pem`) +* `--tls-client-key`: client private key (`.key` or `.pem`) +* `--tls-ca-cert`: CA cert to verify the server + +Example: + +```bash +newt \ +--id 31frd0uzbjvp721 \ +--secret h51mmlknrvrwv8s4r1i210azhumt6isgbpyavxodibx1k2d6 \ +--endpoint https://example.com \ +--tls-client-cert ./client.crt \ +--tls-client-key ./client.key \ +--tls-ca-cert ./ca.crt +``` + + ```yaml services: newt: diff --git a/main.go b/main.go index 483aa23..bd73263 100644 --- a/main.go +++ b/main.go @@ -72,6 +72,18 @@ type ExitNodePingResult struct { WasPreviouslyConnected bool `json:"wasPreviouslyConnected"` } +// Custom flag type for multiple CA files +type stringSlice []string + +func (s *stringSlice) String() string { + return strings.Join(*s, ",") +} + +func (s *stringSlice) Set(value string) error { + *s = append(*s, value) + return nil +} + var ( endpoint string id string @@ -87,7 +99,6 @@ var ( keepInterface bool acceptClients bool updownScript string - tlsPrivateKey string dockerSocket string dockerEnforceNetworkValidation string dockerEnforceNetworkValidationBool bool @@ -99,6 +110,14 @@ var ( healthFile string useNativeInterface bool authorizedKeysFile string + + // New mTLS configuration variables + tlsClientCert string + tlsClientKey string + tlsClientCAs []string + + // Legacy PKCS12 support (deprecated) + tlsPrivateKey string ) func main() { @@ -114,7 +133,6 @@ func main() { generateAndSaveKeyTo = os.Getenv("GENERATE_AND_SAVE_KEY_TO") keepInterface = os.Getenv("KEEP_INTERFACE") == "true" acceptClients = os.Getenv("ACCEPT_CLIENTS") == "true" - tlsPrivateKey = os.Getenv("TLS_CLIENT_CERT") dockerSocket = os.Getenv("DOCKER_SOCKET") pingIntervalStr := os.Getenv("PING_INTERVAL") pingTimeoutStr := os.Getenv("PING_TIMEOUT") @@ -123,6 +141,25 @@ func main() { useNativeInterface = os.Getenv("USE_NATIVE_INTERFACE") == "true" // authorizedKeysFile = os.Getenv("AUTHORIZED_KEYS_FILE") authorizedKeysFile = "" + + // Read new mTLS environment variables + tlsClientCert = os.Getenv("TLS_CLIENT_CERT") + tlsClientKey = os.Getenv("TLS_CLIENT_KEY") + tlsClientCAsEnv := os.Getenv("TLS_CLIENT_CAS") + if tlsClientCAsEnv != "" { + tlsClientCAs = strings.Split(tlsClientCAsEnv, ",") + // Trim spaces from each CA file path + for i, ca := range tlsClientCAs { + tlsClientCAs[i] = strings.TrimSpace(ca) + } + } + + // Legacy PKCS12 support (deprecated) + tlsPrivateKey = os.Getenv("TLS_CLIENT_CERT_PKCS12") + // Keep backward compatibility with old environment variable name + if tlsPrivateKey == "" { + tlsPrivateKey = os.Getenv("TLS_CLIENT_CERT") + } if endpoint == "" { flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server") @@ -154,9 +191,6 @@ func main() { flag.BoolVar(&keepInterface, "keep-interface", false, "Keep the WireGuard interface") flag.BoolVar(&useNativeInterface, "native", false, "Use native WireGuard interface (requires WireGuard kernel module) and linux") flag.BoolVar(&acceptClients, "accept-clients", false, "Accept clients on the WireGuard interface") - if tlsPrivateKey == "" { - flag.StringVar(&tlsPrivateKey, "tls-client-cert", "", "Path to client certificate used for mTLS") - } if dockerSocket == "" { flag.StringVar(&dockerSocket, "docker-socket", "", "Path to Docker socket (typically /var/run/docker.sock)") } @@ -173,6 +207,23 @@ func main() { // flag.StringVar(&authorizedKeysFile, "authorized-keys-file", "~/.ssh/authorized_keys", "Path to authorized keys file (if unset, no keys will be authorized)") // } + // Add new mTLS flags + if tlsClientCert == "" { + flag.StringVar(&tlsClientCert, "tls-client-cert-file", "", "Path to client certificate file (PEM/DER format)") + } + if tlsClientKey == "" { + flag.StringVar(&tlsClientKey, "tls-client-key", "", "Path to client private key file (PEM/DER format)") + } + + // Handle multiple CA files + var tlsClientCAsFlag stringSlice + flag.Var(&tlsClientCAsFlag, "tls-client-ca", "Path to CA certificate file for validating remote certificates (can be specified multiple times)") + + // Legacy PKCS12 flag (deprecated) + if tlsPrivateKey == "" { + flag.StringVar(&tlsPrivateKey, "tls-client-cert", "", "Path to client certificate (PKCS12 format) - DEPRECATED: use --tls-client-cert-file and --tls-client-key instead") + } + if pingIntervalStr != "" { pingInterval, err = time.ParseDuration(pingIntervalStr) if err != nil { @@ -197,7 +248,7 @@ func main() { flag.StringVar(&dockerEnforceNetworkValidation, "docker-enforce-network-validation", "false", "Enforce validation of container on newt network (true or false)") } if healthFile == "" { - flag.StringVar(&healthFile, "health-file", "", "Path to health file (if unset, health file won’t be written)") + flag.StringVar(&healthFile, "health-file", "", "Path to health file (if unset, health file won't be written)") } // do a --version check @@ -205,6 +256,11 @@ func main() { flag.Parse() + // Merge command line CA flags with environment variable CAs + if len(tlsClientCAsFlag) > 0 { + tlsClientCAs = append(tlsClientCAs, tlsClientCAsFlag...) + } + logger.Init() loggerLevel := parseLogLevel(logLevel) logger.GetLogger().SetLevel(parseLogLevel(logLevel)) @@ -234,14 +290,42 @@ func main() { dockerEnforceNetworkValidationBool = false } + // Add TLS configuration validation + if err := validateTLSConfig(); err != nil { + logger.Fatal("TLS configuration error: %v", err) + } + + // Show deprecation warning if using PKCS12 + if tlsPrivateKey != "" { + logger.Warn("Using deprecated PKCS12 format for mTLS. Consider migrating to separate certificate files using --tls-client-cert-file, --tls-client-key, and --tls-client-ca") + } + privateKey, err = wgtypes.GeneratePrivateKey() if err != nil { logger.Fatal("Failed to generate private key: %v", err) } + + // Create client option based on TLS configuration var opt websocket.ClientOption - if tlsPrivateKey != "" { - opt = websocket.WithTLSConfig(tlsPrivateKey) + if tlsClientCert != "" && tlsClientKey != "" { + // Use new separate certificate configuration + opt = websocket.WithTLSConfig(websocket.TLSConfig{ + ClientCertFile: tlsClientCert, + ClientKeyFile: tlsClientKey, + CAFiles: tlsClientCAs, + }) + logger.Debug("Using separate certificate files for mTLS") + logger.Debug("Client cert: %s", tlsClientCert) + logger.Debug("Client key: %s", tlsClientKey) + logger.Debug("CA files: %v", tlsClientCAs) + } else if tlsPrivateKey != "" { + // Use existing PKCS12 configuration for backward compatibility + opt = websocket.WithTLSConfig(websocket.TLSConfig{ + PKCS12File: tlsPrivateKey, + }) + logger.Debug("Using PKCS12 file for mTLS: %s", tlsPrivateKey) } + // Create a new client client, err := websocket.NewClient( "newt", @@ -262,7 +346,21 @@ func main() { logger.Debug("Endpoint: %v", endpoint) logger.Debug("Log Level: %v", logLevel) logger.Debug("Docker Network Validation Enabled: %v", dockerEnforceNetworkValidationBool) - logger.Debug("TLS Private Key Set: %v", tlsPrivateKey != "") + + // Add new TLS debug logging + if tlsClientCert != "" { + logger.Debug("TLS Client Cert File: %v", tlsClientCert) + } + if tlsClientKey != "" { + logger.Debug("TLS Client Key File: %v", tlsClientKey) + } + if len(tlsClientCAs) > 0 { + logger.Debug("TLS CA Files: %v", tlsClientCAs) + } + if tlsPrivateKey != "" { + logger.Debug("TLS PKCS12 File: %v", tlsPrivateKey) + } + if dns != "" { logger.Debug("Dns: %v", dns) } @@ -950,3 +1048,48 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub logger.Info("Exiting...") os.Exit(0) } + +// validateTLSConfig validates the TLS configuration +func validateTLSConfig() error { + // Check for conflicting configurations + pkcs12Specified := tlsPrivateKey != "" + separateFilesSpecified := tlsClientCert != "" || tlsClientKey != "" || len(tlsClientCAs) > 0 + + if pkcs12Specified && separateFilesSpecified { + return fmt.Errorf("cannot use both PKCS12 format (--tls-client-cert) and separate certificate files (--tls-client-cert-file, --tls-client-key, --tls-client-ca)") + } + + // If using separate files, both cert and key are required + if (tlsClientCert != "" && tlsClientKey == "") || (tlsClientCert == "" && tlsClientKey != "") { + return fmt.Errorf("both --tls-client-cert-file and --tls-client-key must be specified together") + } + + // Validate certificate files exist + if tlsClientCert != "" { + if _, err := os.Stat(tlsClientCert); os.IsNotExist(err) { + return fmt.Errorf("client certificate file does not exist: %s", tlsClientCert) + } + } + + if tlsClientKey != "" { + if _, err := os.Stat(tlsClientKey); os.IsNotExist(err) { + return fmt.Errorf("client key file does not exist: %s", tlsClientKey) + } + } + + // Validate CA files exist + for _, caFile := range tlsClientCAs { + if _, err := os.Stat(caFile); os.IsNotExist(err) { + return fmt.Errorf("CA certificate file does not exist: %s", caFile) + } + } + + // Validate PKCS12 file exists if specified + if tlsPrivateKey != "" { + if _, err := os.Stat(tlsPrivateKey); os.IsNotExist(err) { + return fmt.Errorf("PKCS12 certificate file does not exist: %s", tlsPrivateKey) + } + } + + return nil +} \ No newline at end of file diff --git a/websocket/client.go b/websocket/client.go index 8aea1f7..f85d81f 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -35,12 +35,24 @@ type Client struct { onTokenUpdate func(token string) writeMux sync.Mutex clientType string // Type of client (e.g., "newt", "olm") + tlsConfig TLSConfig } type ClientOption func(*Client) type MessageHandler func(message WSMessage) +// TLSConfig holds TLS configuration options +type TLSConfig struct { + // New separate certificate support + ClientCertFile string + ClientKeyFile string + CAFiles []string + + // Existing PKCS12 support (deprecated) + PKCS12File string +} + // WithBaseURL sets the base URL for the client func WithBaseURL(url string) ClientOption { return func(c *Client) { @@ -48,9 +60,14 @@ func WithBaseURL(url string) ClientOption { } } -func WithTLSConfig(tlsClientCertPath string) ClientOption { +// WithTLSConfig sets the TLS configuration for the client +func WithTLSConfig(config TLSConfig) ClientOption { return func(c *Client) { - c.config.TlsClientCert = tlsClientCertPath + c.tlsConfig = config + // For backward compatibility, also set the legacy field + if config.PKCS12File != "" { + c.config.TlsClientCert = config.PKCS12File + } } } @@ -198,10 +215,12 @@ func (c *Client) getToken() (string, error) { baseEndpoint := strings.TrimRight(baseURL.String(), "/") var tlsConfig *tls.Config = nil - if c.config.TlsClientCert != "" { - tlsConfig, err = loadClientCertificate(c.config.TlsClientCert) + + // Use new TLS configuration method + if c.tlsConfig.ClientCertFile != "" || c.tlsConfig.ClientKeyFile != "" || len(c.tlsConfig.CAFiles) > 0 || c.tlsConfig.PKCS12File != "" { + tlsConfig, err = c.setupTLS() if err != nil { - return "", fmt.Errorf("failed to load certificate %s: %w", c.config.TlsClientCert, err) + return "", fmt.Errorf("failed to setup TLS configuration: %w", err) } } @@ -331,14 +350,17 @@ func (c *Client) establishConnection() error { // Connect to WebSocket dialer := websocket.DefaultDialer - if c.config.TlsClientCert != "" { - logger.Info("Adding tls to req") - tlsConfig, err := loadClientCertificate(c.config.TlsClientCert) + + // Use new TLS configuration method + if c.tlsConfig.ClientCertFile != "" || c.tlsConfig.ClientKeyFile != "" || len(c.tlsConfig.CAFiles) > 0 || c.tlsConfig.PKCS12File != "" { + logger.Info("Setting up TLS configuration for WebSocket connection") + tlsConfig, err := c.setupTLS() if err != nil { - return fmt.Errorf("failed to load certificate %s: %w", c.config.TlsClientCert, err) + return fmt.Errorf("failed to setup TLS configuration: %w", err) } dialer.TLSClientConfig = tlsConfig } + conn, _, err := dialer.Dial(u.String(), nil) if err != nil { return fmt.Errorf("failed to connect to WebSocket: %w", err) @@ -365,6 +387,69 @@ func (c *Client) establishConnection() error { return nil } +// setupTLS configures TLS based on the TLS configuration +func (c *Client) setupTLS() (*tls.Config, error) { + tlsConfig := &tls.Config{} + + // Handle new separate certificate configuration + if c.tlsConfig.ClientCertFile != "" && c.tlsConfig.ClientKeyFile != "" { + logger.Info("Loading separate certificate files for mTLS") + logger.Debug("Client cert: %s", c.tlsConfig.ClientCertFile) + logger.Debug("Client key: %s", c.tlsConfig.ClientKeyFile) + + // Load client certificate and key + cert, err := tls.LoadX509KeyPair(c.tlsConfig.ClientCertFile, c.tlsConfig.ClientKeyFile) + if err != nil { + return nil, fmt.Errorf("failed to load client certificate pair: %w", err) + } + tlsConfig.Certificates = []tls.Certificate{cert} + + // Load CA certificates for remote validation if specified + if len(c.tlsConfig.CAFiles) > 0 { + logger.Debug("Loading CA certificates: %v", c.tlsConfig.CAFiles) + caCertPool := x509.NewCertPool() + for _, caFile := range c.tlsConfig.CAFiles { + caCert, err := os.ReadFile(caFile) + if err != nil { + return nil, fmt.Errorf("failed to read CA file %s: %w", caFile, err) + } + + // Try to parse as PEM first, then DER + if !caCertPool.AppendCertsFromPEM(caCert) { + // If PEM parsing failed, try DER + cert, err := x509.ParseCertificate(caCert) + if err != nil { + return nil, fmt.Errorf("failed to parse CA certificate from %s: %w", caFile, err) + } + caCertPool.AddCert(cert) + } + } + tlsConfig.RootCAs = caCertPool + } + + return tlsConfig, nil + } + + // Fallback to existing PKCS12 implementation for backward compatibility + if c.tlsConfig.PKCS12File != "" { + logger.Info("Loading PKCS12 certificate for mTLS (deprecated)") + return c.setupPKCS12TLS() + } + + // Legacy fallback using config.TlsClientCert + if c.config.TlsClientCert != "" { + logger.Info("Loading legacy PKCS12 certificate for mTLS (deprecated)") + return loadClientCertificate(c.config.TlsClientCert) + } + + return nil, nil +} + +// setupPKCS12TLS loads TLS configuration from PKCS12 file +func (c *Client) setupPKCS12TLS() (*tls.Config, error) { + return loadClientCertificate(c.tlsConfig.PKCS12File) +} + // pingMonitor sends pings at a short interval and triggers reconnect on failure func (c *Client) pingMonitor() { ticker := time.NewTicker(c.pingInterval) @@ -469,7 +554,7 @@ func (c *Client) setConnected(status bool) { c.isConnected = status } -// LoadClientCertificate Helper method to load client certificates +// LoadClientCertificate Helper method to load client certificates (PKCS12 format) func loadClientCertificate(p12Path string) (*tls.Config, error) { logger.Info("Loading tls-client-cert %s", p12Path) // Read the PKCS12 file @@ -506,4 +591,4 @@ func loadClientCertificate(p12Path string) (*tls.Config, error) { Certificates: []tls.Certificate{cert}, RootCAs: rootCAs, }, nil -} +} \ No newline at end of file