Split mTLS client and CA certificates

This commit is contained in:
Pallavi
2025-08-05 01:08:29 +05:30
parent 151d0e38e6
commit d52f89f629
3 changed files with 287 additions and 28 deletions

View File

@@ -42,6 +42,9 @@ When Newt receives WireGuard control messages, it will use the information encod
- `ping-timeout` (optional): Timeout for each ping. Default: 5s
- `updown` (optional): A script to be called when targets are added or removed.
- `tls-client-cert` (optional): Client certificate (p12 or pfx) for mTLS. See [mTLS](#mtls)
- `tls-client-cert` (optional): Path to client certificate (PEM format, optional if using PKCS12). See [mTLS](#mtls)
- `tls-client-key` (optional): Path to private key for mTLS (PEM format, optional if using PKCS12)
- `tls-ca-cert` (optional): Path to CA certificate to verify server (PEM format, optional if using PKCS12)
- `docker-enforce-network-validation` (optional): Validate the container target is on the same network as the newt process. Default: false
- `health-file` (optional): Check if connection to WG server (pangolin) is ok. creates a file if ok, removes it if not ok. Can be used with docker healtcheck to restart newt
- `accept-clients` (optional): Enable WireGuard server mode to accept incoming olm client connections. Default: false
@@ -65,6 +68,9 @@ All CLI arguments can be set using environment variables as an alternative to co
- `PING_TIMEOUT`: Timeout for each ping. Default: 5s (equivalent to `--ping-timeout`)
- `UPDOWN_SCRIPT`: Path to updown script for target add/remove events (equivalent to `--updown`)
- `TLS_CLIENT_CERT`: Path to client certificate for mTLS (equivalent to `--tls-client-cert`)
- `TLS_CLIENT_CERT`: Path to client certificate for mTLS (equivalent to `--tls-client-cert`)
- `TLS_CLIENT_KEY`: Path to private key for mTLS (equivalent to `--tls-client-key`)
- `TLS_CA_CERT`: Path to CA certificate to verify server (equivalent to `--tls-ca-cert`)
- `DOCKER_ENFORCE_NETWORK_VALIDATION`: Validate container targets are on same network. Default: false (equivalent to `--docker-enforce-network-validation`)
- `HEALTH_FILE`: Path to health file for connection monitoring (equivalent to `--health-file`)
- `ACCEPT_CLIENTS`: Enable WireGuard server mode. Default: false (equivalent to `--accept-clients`)
@@ -259,16 +265,20 @@ You can look at updown.py as a reference script to get started!
### mTLS
Newt supports mutual TLS (mTLS) authentication, if the server has been configured to request a client certificate.
Newt supports mutual TLS (mTLS) authentication if the server is configured to request a client certificate. You can use either a PKCS12 (.p12/.pfx) file or split PEM files for the client cert, private key, and CA.
- Only PKCS12 (.p12 or .pfx) file format is accepted
- The PKCS12 file must contain:
- Private key
- Public certificate
- CA certificate
- Encrypted PKCS12 files are currently not supported
#### Option 1: PKCS12 (Legacy)
Examples:
> This is the original method and still supported.
* File must contain:
* Client private key
* Public certificate
* CA certificate
* Encrypted `.p12` files are **not supported**
Example:
```bash
newt \
@@ -278,6 +288,27 @@ newt \
--tls-client-cert ./client.p12
```
#### Option 2: Split PEM Files (Preferred)
You can now provide separate files for:
* `--tls-client-cert`: client certificate (`.crt` or `.pem`)
* `--tls-client-key`: client private key (`.key` or `.pem`)
* `--tls-ca-cert`: CA cert to verify the server
Example:
```bash
newt \
--id 31frd0uzbjvp721 \
--secret h51mmlknrvrwv8s4r1i210azhumt6isgbpyavxodibx1k2d6 \
--endpoint https://example.com \
--tls-client-cert ./client.crt \
--tls-client-key ./client.key \
--tls-ca-cert ./ca.crt
```
```yaml
services:
newt:

161
main.go
View File

@@ -72,6 +72,18 @@ type ExitNodePingResult struct {
WasPreviouslyConnected bool `json:"wasPreviouslyConnected"`
}
// Custom flag type for multiple CA files
type stringSlice []string
func (s *stringSlice) String() string {
return strings.Join(*s, ",")
}
func (s *stringSlice) Set(value string) error {
*s = append(*s, value)
return nil
}
var (
endpoint string
id string
@@ -87,7 +99,6 @@ var (
keepInterface bool
acceptClients bool
updownScript string
tlsPrivateKey string
dockerSocket string
dockerEnforceNetworkValidation string
dockerEnforceNetworkValidationBool bool
@@ -99,6 +110,14 @@ var (
healthFile string
useNativeInterface bool
authorizedKeysFile string
// New mTLS configuration variables
tlsClientCert string
tlsClientKey string
tlsClientCAs []string
// Legacy PKCS12 support (deprecated)
tlsPrivateKey string
)
func main() {
@@ -114,7 +133,6 @@ func main() {
generateAndSaveKeyTo = os.Getenv("GENERATE_AND_SAVE_KEY_TO")
keepInterface = os.Getenv("KEEP_INTERFACE") == "true"
acceptClients = os.Getenv("ACCEPT_CLIENTS") == "true"
tlsPrivateKey = os.Getenv("TLS_CLIENT_CERT")
dockerSocket = os.Getenv("DOCKER_SOCKET")
pingIntervalStr := os.Getenv("PING_INTERVAL")
pingTimeoutStr := os.Getenv("PING_TIMEOUT")
@@ -123,6 +141,25 @@ func main() {
useNativeInterface = os.Getenv("USE_NATIVE_INTERFACE") == "true"
// authorizedKeysFile = os.Getenv("AUTHORIZED_KEYS_FILE")
authorizedKeysFile = ""
// Read new mTLS environment variables
tlsClientCert = os.Getenv("TLS_CLIENT_CERT")
tlsClientKey = os.Getenv("TLS_CLIENT_KEY")
tlsClientCAsEnv := os.Getenv("TLS_CLIENT_CAS")
if tlsClientCAsEnv != "" {
tlsClientCAs = strings.Split(tlsClientCAsEnv, ",")
// Trim spaces from each CA file path
for i, ca := range tlsClientCAs {
tlsClientCAs[i] = strings.TrimSpace(ca)
}
}
// Legacy PKCS12 support (deprecated)
tlsPrivateKey = os.Getenv("TLS_CLIENT_CERT_PKCS12")
// Keep backward compatibility with old environment variable name
if tlsPrivateKey == "" {
tlsPrivateKey = os.Getenv("TLS_CLIENT_CERT")
}
if endpoint == "" {
flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server")
@@ -154,9 +191,6 @@ func main() {
flag.BoolVar(&keepInterface, "keep-interface", false, "Keep the WireGuard interface")
flag.BoolVar(&useNativeInterface, "native", false, "Use native WireGuard interface (requires WireGuard kernel module) and linux")
flag.BoolVar(&acceptClients, "accept-clients", false, "Accept clients on the WireGuard interface")
if tlsPrivateKey == "" {
flag.StringVar(&tlsPrivateKey, "tls-client-cert", "", "Path to client certificate used for mTLS")
}
if dockerSocket == "" {
flag.StringVar(&dockerSocket, "docker-socket", "", "Path to Docker socket (typically /var/run/docker.sock)")
}
@@ -173,6 +207,23 @@ func main() {
// flag.StringVar(&authorizedKeysFile, "authorized-keys-file", "~/.ssh/authorized_keys", "Path to authorized keys file (if unset, no keys will be authorized)")
// }
// Add new mTLS flags
if tlsClientCert == "" {
flag.StringVar(&tlsClientCert, "tls-client-cert-file", "", "Path to client certificate file (PEM/DER format)")
}
if tlsClientKey == "" {
flag.StringVar(&tlsClientKey, "tls-client-key", "", "Path to client private key file (PEM/DER format)")
}
// Handle multiple CA files
var tlsClientCAsFlag stringSlice
flag.Var(&tlsClientCAsFlag, "tls-client-ca", "Path to CA certificate file for validating remote certificates (can be specified multiple times)")
// Legacy PKCS12 flag (deprecated)
if tlsPrivateKey == "" {
flag.StringVar(&tlsPrivateKey, "tls-client-cert", "", "Path to client certificate (PKCS12 format) - DEPRECATED: use --tls-client-cert-file and --tls-client-key instead")
}
if pingIntervalStr != "" {
pingInterval, err = time.ParseDuration(pingIntervalStr)
if err != nil {
@@ -197,7 +248,7 @@ func main() {
flag.StringVar(&dockerEnforceNetworkValidation, "docker-enforce-network-validation", "false", "Enforce validation of container on newt network (true or false)")
}
if healthFile == "" {
flag.StringVar(&healthFile, "health-file", "", "Path to health file (if unset, health file wont be written)")
flag.StringVar(&healthFile, "health-file", "", "Path to health file (if unset, health file won't be written)")
}
// do a --version check
@@ -205,6 +256,11 @@ func main() {
flag.Parse()
// Merge command line CA flags with environment variable CAs
if len(tlsClientCAsFlag) > 0 {
tlsClientCAs = append(tlsClientCAs, tlsClientCAsFlag...)
}
logger.Init()
loggerLevel := parseLogLevel(logLevel)
logger.GetLogger().SetLevel(parseLogLevel(logLevel))
@@ -234,14 +290,42 @@ func main() {
dockerEnforceNetworkValidationBool = false
}
// Add TLS configuration validation
if err := validateTLSConfig(); err != nil {
logger.Fatal("TLS configuration error: %v", err)
}
// Show deprecation warning if using PKCS12
if tlsPrivateKey != "" {
logger.Warn("Using deprecated PKCS12 format for mTLS. Consider migrating to separate certificate files using --tls-client-cert-file, --tls-client-key, and --tls-client-ca")
}
privateKey, err = wgtypes.GeneratePrivateKey()
if err != nil {
logger.Fatal("Failed to generate private key: %v", err)
}
// Create client option based on TLS configuration
var opt websocket.ClientOption
if tlsPrivateKey != "" {
opt = websocket.WithTLSConfig(tlsPrivateKey)
if tlsClientCert != "" && tlsClientKey != "" {
// Use new separate certificate configuration
opt = websocket.WithTLSConfig(websocket.TLSConfig{
ClientCertFile: tlsClientCert,
ClientKeyFile: tlsClientKey,
CAFiles: tlsClientCAs,
})
logger.Debug("Using separate certificate files for mTLS")
logger.Debug("Client cert: %s", tlsClientCert)
logger.Debug("Client key: %s", tlsClientKey)
logger.Debug("CA files: %v", tlsClientCAs)
} else if tlsPrivateKey != "" {
// Use existing PKCS12 configuration for backward compatibility
opt = websocket.WithTLSConfig(websocket.TLSConfig{
PKCS12File: tlsPrivateKey,
})
logger.Debug("Using PKCS12 file for mTLS: %s", tlsPrivateKey)
}
// Create a new client
client, err := websocket.NewClient(
"newt",
@@ -262,7 +346,21 @@ func main() {
logger.Debug("Endpoint: %v", endpoint)
logger.Debug("Log Level: %v", logLevel)
logger.Debug("Docker Network Validation Enabled: %v", dockerEnforceNetworkValidationBool)
logger.Debug("TLS Private Key Set: %v", tlsPrivateKey != "")
// Add new TLS debug logging
if tlsClientCert != "" {
logger.Debug("TLS Client Cert File: %v", tlsClientCert)
}
if tlsClientKey != "" {
logger.Debug("TLS Client Key File: %v", tlsClientKey)
}
if len(tlsClientCAs) > 0 {
logger.Debug("TLS CA Files: %v", tlsClientCAs)
}
if tlsPrivateKey != "" {
logger.Debug("TLS PKCS12 File: %v", tlsPrivateKey)
}
if dns != "" {
logger.Debug("Dns: %v", dns)
}
@@ -950,3 +1048,48 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
logger.Info("Exiting...")
os.Exit(0)
}
// validateTLSConfig validates the TLS configuration
func validateTLSConfig() error {
// Check for conflicting configurations
pkcs12Specified := tlsPrivateKey != ""
separateFilesSpecified := tlsClientCert != "" || tlsClientKey != "" || len(tlsClientCAs) > 0
if pkcs12Specified && separateFilesSpecified {
return fmt.Errorf("cannot use both PKCS12 format (--tls-client-cert) and separate certificate files (--tls-client-cert-file, --tls-client-key, --tls-client-ca)")
}
// If using separate files, both cert and key are required
if (tlsClientCert != "" && tlsClientKey == "") || (tlsClientCert == "" && tlsClientKey != "") {
return fmt.Errorf("both --tls-client-cert-file and --tls-client-key must be specified together")
}
// Validate certificate files exist
if tlsClientCert != "" {
if _, err := os.Stat(tlsClientCert); os.IsNotExist(err) {
return fmt.Errorf("client certificate file does not exist: %s", tlsClientCert)
}
}
if tlsClientKey != "" {
if _, err := os.Stat(tlsClientKey); os.IsNotExist(err) {
return fmt.Errorf("client key file does not exist: %s", tlsClientKey)
}
}
// Validate CA files exist
for _, caFile := range tlsClientCAs {
if _, err := os.Stat(caFile); os.IsNotExist(err) {
return fmt.Errorf("CA certificate file does not exist: %s", caFile)
}
}
// Validate PKCS12 file exists if specified
if tlsPrivateKey != "" {
if _, err := os.Stat(tlsPrivateKey); os.IsNotExist(err) {
return fmt.Errorf("PKCS12 certificate file does not exist: %s", tlsPrivateKey)
}
}
return nil
}

View File

@@ -35,12 +35,24 @@ type Client struct {
onTokenUpdate func(token string)
writeMux sync.Mutex
clientType string // Type of client (e.g., "newt", "olm")
tlsConfig TLSConfig
}
type ClientOption func(*Client)
type MessageHandler func(message WSMessage)
// TLSConfig holds TLS configuration options
type TLSConfig struct {
// New separate certificate support
ClientCertFile string
ClientKeyFile string
CAFiles []string
// Existing PKCS12 support (deprecated)
PKCS12File string
}
// WithBaseURL sets the base URL for the client
func WithBaseURL(url string) ClientOption {
return func(c *Client) {
@@ -48,9 +60,14 @@ func WithBaseURL(url string) ClientOption {
}
}
func WithTLSConfig(tlsClientCertPath string) ClientOption {
// WithTLSConfig sets the TLS configuration for the client
func WithTLSConfig(config TLSConfig) ClientOption {
return func(c *Client) {
c.config.TlsClientCert = tlsClientCertPath
c.tlsConfig = config
// For backward compatibility, also set the legacy field
if config.PKCS12File != "" {
c.config.TlsClientCert = config.PKCS12File
}
}
}
@@ -198,10 +215,12 @@ func (c *Client) getToken() (string, error) {
baseEndpoint := strings.TrimRight(baseURL.String(), "/")
var tlsConfig *tls.Config = nil
if c.config.TlsClientCert != "" {
tlsConfig, err = loadClientCertificate(c.config.TlsClientCert)
// Use new TLS configuration method
if c.tlsConfig.ClientCertFile != "" || c.tlsConfig.ClientKeyFile != "" || len(c.tlsConfig.CAFiles) > 0 || c.tlsConfig.PKCS12File != "" {
tlsConfig, err = c.setupTLS()
if err != nil {
return "", fmt.Errorf("failed to load certificate %s: %w", c.config.TlsClientCert, err)
return "", fmt.Errorf("failed to setup TLS configuration: %w", err)
}
}
@@ -331,14 +350,17 @@ func (c *Client) establishConnection() error {
// Connect to WebSocket
dialer := websocket.DefaultDialer
if c.config.TlsClientCert != "" {
logger.Info("Adding tls to req")
tlsConfig, err := loadClientCertificate(c.config.TlsClientCert)
// Use new TLS configuration method
if c.tlsConfig.ClientCertFile != "" || c.tlsConfig.ClientKeyFile != "" || len(c.tlsConfig.CAFiles) > 0 || c.tlsConfig.PKCS12File != "" {
logger.Info("Setting up TLS configuration for WebSocket connection")
tlsConfig, err := c.setupTLS()
if err != nil {
return fmt.Errorf("failed to load certificate %s: %w", c.config.TlsClientCert, err)
return fmt.Errorf("failed to setup TLS configuration: %w", err)
}
dialer.TLSClientConfig = tlsConfig
}
conn, _, err := dialer.Dial(u.String(), nil)
if err != nil {
return fmt.Errorf("failed to connect to WebSocket: %w", err)
@@ -365,6 +387,69 @@ func (c *Client) establishConnection() error {
return nil
}
// setupTLS configures TLS based on the TLS configuration
func (c *Client) setupTLS() (*tls.Config, error) {
tlsConfig := &tls.Config{}
// Handle new separate certificate configuration
if c.tlsConfig.ClientCertFile != "" && c.tlsConfig.ClientKeyFile != "" {
logger.Info("Loading separate certificate files for mTLS")
logger.Debug("Client cert: %s", c.tlsConfig.ClientCertFile)
logger.Debug("Client key: %s", c.tlsConfig.ClientKeyFile)
// Load client certificate and key
cert, err := tls.LoadX509KeyPair(c.tlsConfig.ClientCertFile, c.tlsConfig.ClientKeyFile)
if err != nil {
return nil, fmt.Errorf("failed to load client certificate pair: %w", err)
}
tlsConfig.Certificates = []tls.Certificate{cert}
// Load CA certificates for remote validation if specified
if len(c.tlsConfig.CAFiles) > 0 {
logger.Debug("Loading CA certificates: %v", c.tlsConfig.CAFiles)
caCertPool := x509.NewCertPool()
for _, caFile := range c.tlsConfig.CAFiles {
caCert, err := os.ReadFile(caFile)
if err != nil {
return nil, fmt.Errorf("failed to read CA file %s: %w", caFile, err)
}
// Try to parse as PEM first, then DER
if !caCertPool.AppendCertsFromPEM(caCert) {
// If PEM parsing failed, try DER
cert, err := x509.ParseCertificate(caCert)
if err != nil {
return nil, fmt.Errorf("failed to parse CA certificate from %s: %w", caFile, err)
}
caCertPool.AddCert(cert)
}
}
tlsConfig.RootCAs = caCertPool
}
return tlsConfig, nil
}
// Fallback to existing PKCS12 implementation for backward compatibility
if c.tlsConfig.PKCS12File != "" {
logger.Info("Loading PKCS12 certificate for mTLS (deprecated)")
return c.setupPKCS12TLS()
}
// Legacy fallback using config.TlsClientCert
if c.config.TlsClientCert != "" {
logger.Info("Loading legacy PKCS12 certificate for mTLS (deprecated)")
return loadClientCertificate(c.config.TlsClientCert)
}
return nil, nil
}
// setupPKCS12TLS loads TLS configuration from PKCS12 file
func (c *Client) setupPKCS12TLS() (*tls.Config, error) {
return loadClientCertificate(c.tlsConfig.PKCS12File)
}
// pingMonitor sends pings at a short interval and triggers reconnect on failure
func (c *Client) pingMonitor() {
ticker := time.NewTicker(c.pingInterval)
@@ -469,7 +554,7 @@ func (c *Client) setConnected(status bool) {
c.isConnected = status
}
// LoadClientCertificate Helper method to load client certificates
// LoadClientCertificate Helper method to load client certificates (PKCS12 format)
func loadClientCertificate(p12Path string) (*tls.Config, error) {
logger.Info("Loading tls-client-cert %s", p12Path)
// Read the PKCS12 file
@@ -506,4 +591,4 @@ func loadClientCertificate(p12Path string) (*tls.Config, error) {
Certificates: []tls.Certificate{cert},
RootCAs: rootCAs,
}, nil
}
}