Merge branch 'feat/Split-mTLS-client-and-CA-certificates' of github.com:Pallavikumarimdb/newt into Pallavikumarimdb-feat/Split-mTLS-client-and-CA-certificates

This commit is contained in:
Owen
2025-08-30 18:06:18 -07:00
3 changed files with 286 additions and 24 deletions

158
main.go
View File

@@ -74,6 +74,18 @@ type ExitNodePingResult struct {
WasPreviouslyConnected bool `json:"wasPreviouslyConnected"`
}
// Custom flag type for multiple CA files
type stringSlice []string
func (s *stringSlice) String() string {
return strings.Join(*s, ",")
}
func (s *stringSlice) Set(value string) error {
*s = append(*s, value)
return nil
}
var (
endpoint string
id string
@@ -89,7 +101,6 @@ var (
keepInterface bool
acceptClients bool
updownScript string
tlsPrivateKey string
dockerSocket string
dockerEnforceNetworkValidation string
dockerEnforceNetworkValidationBool bool
@@ -103,6 +114,14 @@ var (
authorizedKeysFile string
preferEndpoint string
healthMonitor *healthcheck.Monitor
// New mTLS configuration variables
tlsClientCert string
tlsClientKey string
tlsClientCAs []string
// Legacy PKCS12 support (deprecated)
tlsPrivateKey string
)
func main() {
@@ -124,7 +143,6 @@ func main() {
acceptClients = acceptClientsEnv == "true"
useNativeInterface = useNativeInterfaceEnv == "true"
tlsPrivateKey = os.Getenv("TLS_CLIENT_CERT")
dockerSocket = os.Getenv("DOCKER_SOCKET")
pingIntervalStr := os.Getenv("PING_INTERVAL")
pingTimeoutStr := os.Getenv("PING_TIMEOUT")
@@ -133,6 +151,25 @@ func main() {
// authorizedKeysFile = os.Getenv("AUTHORIZED_KEYS_FILE")
authorizedKeysFile = ""
// Read new mTLS environment variables
tlsClientCert = os.Getenv("TLS_CLIENT_CERT")
tlsClientKey = os.Getenv("TLS_CLIENT_KEY")
tlsClientCAsEnv := os.Getenv("TLS_CLIENT_CAS")
if tlsClientCAsEnv != "" {
tlsClientCAs = strings.Split(tlsClientCAsEnv, ",")
// Trim spaces from each CA file path
for i, ca := range tlsClientCAs {
tlsClientCAs[i] = strings.TrimSpace(ca)
}
}
// Legacy PKCS12 support (deprecated)
tlsPrivateKey = os.Getenv("TLS_CLIENT_CERT_PKCS12")
// Keep backward compatibility with old environment variable name
if tlsPrivateKey == "" {
tlsPrivateKey = os.Getenv("TLS_CLIENT_CERT")
}
if endpoint == "" {
flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server")
}
@@ -188,6 +225,23 @@ func main() {
// flag.StringVar(&authorizedKeysFile, "authorized-keys-file", "~/.ssh/authorized_keys", "Path to authorized keys file (if unset, no keys will be authorized)")
// }
// Add new mTLS flags
if tlsClientCert == "" {
flag.StringVar(&tlsClientCert, "tls-client-cert-file", "", "Path to client certificate file (PEM/DER format)")
}
if tlsClientKey == "" {
flag.StringVar(&tlsClientKey, "tls-client-key", "", "Path to client private key file (PEM/DER format)")
}
// Handle multiple CA files
var tlsClientCAsFlag stringSlice
flag.Var(&tlsClientCAsFlag, "tls-client-ca", "Path to CA certificate file for validating remote certificates (can be specified multiple times)")
// Legacy PKCS12 flag (deprecated)
if tlsPrivateKey == "" {
flag.StringVar(&tlsPrivateKey, "tls-client-cert", "", "Path to client certificate (PKCS12 format) - DEPRECATED: use --tls-client-cert-file and --tls-client-key instead")
}
if pingIntervalStr != "" {
pingInterval, err = time.ParseDuration(pingIntervalStr)
if err != nil {
@@ -212,7 +266,7 @@ func main() {
flag.StringVar(&dockerEnforceNetworkValidation, "docker-enforce-network-validation", "false", "Enforce validation of container on newt network (true or false)")
}
if healthFile == "" {
flag.StringVar(&healthFile, "health-file", "", "Path to health file (if unset, health file wont be written)")
flag.StringVar(&healthFile, "health-file", "", "Path to health file (if unset, health file won't be written)")
}
// do a --version check
@@ -220,6 +274,11 @@ func main() {
flag.Parse()
// Merge command line CA flags with environment variable CAs
if len(tlsClientCAsFlag) > 0 {
tlsClientCAs = append(tlsClientCAs, tlsClientCAsFlag...)
}
logger.Init()
loggerLevel := parseLogLevel(logLevel)
logger.GetLogger().SetLevel(parseLogLevel(logLevel))
@@ -249,14 +308,42 @@ func main() {
dockerEnforceNetworkValidationBool = false
}
// Add TLS configuration validation
if err := validateTLSConfig(); err != nil {
logger.Fatal("TLS configuration error: %v", err)
}
// Show deprecation warning if using PKCS12
if tlsPrivateKey != "" {
logger.Warn("Using deprecated PKCS12 format for mTLS. Consider migrating to separate certificate files using --tls-client-cert-file, --tls-client-key, and --tls-client-ca")
}
privateKey, err = wgtypes.GeneratePrivateKey()
if err != nil {
logger.Fatal("Failed to generate private key: %v", err)
}
// Create client option based on TLS configuration
var opt websocket.ClientOption
if tlsPrivateKey != "" {
opt = websocket.WithTLSConfig(tlsPrivateKey)
if tlsClientCert != "" && tlsClientKey != "" {
// Use new separate certificate configuration
opt = websocket.WithTLSConfig(websocket.TLSConfig{
ClientCertFile: tlsClientCert,
ClientKeyFile: tlsClientKey,
CAFiles: tlsClientCAs,
})
logger.Debug("Using separate certificate files for mTLS")
logger.Debug("Client cert: %s", tlsClientCert)
logger.Debug("Client key: %s", tlsClientKey)
logger.Debug("CA files: %v", tlsClientCAs)
} else if tlsPrivateKey != "" {
// Use existing PKCS12 configuration for backward compatibility
opt = websocket.WithTLSConfig(websocket.TLSConfig{
PKCS12File: tlsPrivateKey,
})
logger.Debug("Using PKCS12 file for mTLS: %s", tlsPrivateKey)
}
// Create a new client
client, err := websocket.NewClient(
"newt",
@@ -277,7 +364,21 @@ func main() {
logger.Debug("Endpoint: %v", endpoint)
logger.Debug("Log Level: %v", logLevel)
logger.Debug("Docker Network Validation Enabled: %v", dockerEnforceNetworkValidationBool)
logger.Debug("TLS Private Key Set: %v", tlsPrivateKey != "")
// Add new TLS debug logging
if tlsClientCert != "" {
logger.Debug("TLS Client Cert File: %v", tlsClientCert)
}
if tlsClientKey != "" {
logger.Debug("TLS Client Key File: %v", tlsClientKey)
}
if len(tlsClientCAs) > 0 {
logger.Debug("TLS CA Files: %v", tlsClientCAs)
}
if tlsPrivateKey != "" {
logger.Debug("TLS PKCS12 File: %v", tlsPrivateKey)
}
if dns != "" {
logger.Debug("Dns: %v", dns)
}
@@ -1147,3 +1248,48 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
logger.Info("Exiting...")
os.Exit(0)
}
// validateTLSConfig validates the TLS configuration
func validateTLSConfig() error {
// Check for conflicting configurations
pkcs12Specified := tlsPrivateKey != ""
separateFilesSpecified := tlsClientCert != "" || tlsClientKey != "" || len(tlsClientCAs) > 0
if pkcs12Specified && separateFilesSpecified {
return fmt.Errorf("cannot use both PKCS12 format (--tls-client-cert) and separate certificate files (--tls-client-cert-file, --tls-client-key, --tls-client-ca)")
}
// If using separate files, both cert and key are required
if (tlsClientCert != "" && tlsClientKey == "") || (tlsClientCert == "" && tlsClientKey != "") {
return fmt.Errorf("both --tls-client-cert-file and --tls-client-key must be specified together")
}
// Validate certificate files exist
if tlsClientCert != "" {
if _, err := os.Stat(tlsClientCert); os.IsNotExist(err) {
return fmt.Errorf("client certificate file does not exist: %s", tlsClientCert)
}
}
if tlsClientKey != "" {
if _, err := os.Stat(tlsClientKey); os.IsNotExist(err) {
return fmt.Errorf("client key file does not exist: %s", tlsClientKey)
}
}
// Validate CA files exist
for _, caFile := range tlsClientCAs {
if _, err := os.Stat(caFile); os.IsNotExist(err) {
return fmt.Errorf("CA certificate file does not exist: %s", caFile)
}
}
// Validate PKCS12 file exists if specified
if tlsPrivateKey != "" {
if _, err := os.Stat(tlsPrivateKey); os.IsNotExist(err) {
return fmt.Errorf("PKCS12 certificate file does not exist: %s", tlsPrivateKey)
}
}
return nil
}