diff --git a/main.go b/main.go index cf02e27..ecf84eb 100644 --- a/main.go +++ b/main.go @@ -1,7 +1,9 @@ package main import ( + "bytes" "context" + "crypto/tls" "encoding/json" "errors" "flag" @@ -11,7 +13,6 @@ import ( "net/netip" "os" "os/signal" - "path/filepath" "strconv" "strings" "syscall" @@ -130,6 +131,7 @@ var ( preferEndpoint string healthMonitor *healthcheck.Monitor enforceHealthcheckCert bool + authDaemonKey string // Build/version (can be overridden via -ldflags "-X main.newtVersion=...") newtVersion = "version_replaceme" @@ -183,6 +185,7 @@ func runNewtMain(ctx context.Context) { updownScript = os.Getenv("UPDOWN_SCRIPT") interfaceName = os.Getenv("INTERFACE") portStr := os.Getenv("PORT") + authDaemonKey = os.Getenv("AUTH_DAEMON_KEY") // Metrics/observability env mirrors metricsEnabledEnv := os.Getenv("NEWT_METRICS_PROMETHEUS_ENABLED") @@ -371,6 +374,11 @@ func runNewtMain(ctx context.Context) { region = regionEnv } + // Auth daemon key flag + if authDaemonKey == "" { + flag.StringVar(&authDaemonKey, "auth-daemon-key", "", "Preshared key for auth daemon authentication") + } + // do a --version check version := flag.Bool("version", false, "Print the version") @@ -686,8 +694,8 @@ func runNewtMain(ctx context.Context) { relayPort := wgData.RelayPort if relayPort == 0 { - relayPort = 21820 - } + relayPort = 21820 + } clientsHandleNewtConnection(wgData.PublicKey, endpoint, relayPort) @@ -1315,6 +1323,140 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey( } }) + // Register handler for SSH certificate issued events + client.RegisterHandler("newt/pam/connection", func(msg websocket.WSMessage) { + logger.Debug("Received SSH certificate issued message") + + // Define the structure of the incoming message + type SSHCertData struct { + TraceID string `json:"traceId"` + AgentPort int `json:"agentPort"` + AgentHost string `json:"agentHost"` + CACert string `json:"caCert"` + Username string `json:"username"` + NiceID string `json:"niceId"` + Metadata struct { + Sudo bool `json:"sudo"` + Homedir bool `json:"homedir"` + } `json:"metadata"` + } + + var certData SSHCertData + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling SSH cert data: %v", err) + return + } + + if err := json.Unmarshal(jsonData, &certData); err != nil { + logger.Error("Error unmarshaling SSH cert data: %v", err) + return + } + + // Check if auth daemon key is configured + if authDaemonKey == "" { + logger.Error("Auth daemon key not configured, cannot process SSH certificate") + // Send failure response back to cloud + err := client.SendMessage("newt/pam/connection/response", map[string]interface{}{ + "traceId": certData.TraceID, + "success": false, + "error": "auth daemon key not configured", + }) + if err != nil { + logger.Error("Failed to send SSH cert failure response: %v", err) + } + return + } + + // Prepare the request body for the auth daemon + requestBody := map[string]interface{}{ + "caCert": certData.CACert, + "niceId": certData.NiceID, + "username": certData.Username, + "metadata": map[string]interface{}{ + "sudo": certData.Metadata.Sudo, + "homedir": certData.Metadata.Homedir, + }, + } + + requestJSON, err := json.Marshal(requestBody) + if err != nil { + logger.Error("Failed to marshal auth daemon request: %v", err) + // Send failure response + client.SendMessage("newt/pam/ssh-cert-response", map[string]interface{}{ + "traceId": certData.TraceID, + "success": false, + "error": fmt.Sprintf("failed to marshal request: %v", err), + }) + return + } + + // Create HTTPS client that skips certificate verification + // (auth daemon uses self-signed cert) + httpClient := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + }, + Timeout: 10 * time.Second, + } + + // Make the request to the auth daemon + url := fmt.Sprintf("https://%s:%d/connection", certData.AgentHost, certData.AgentPort) + req, err := http.NewRequest("POST", url, bytes.NewBuffer(requestJSON)) + if err != nil { + logger.Error("Failed to create auth daemon request: %v", err) + client.SendMessage("newt/pam/connection/response", map[string]interface{}{ + "traceId": certData.TraceID, + "success": false, + "error": fmt.Sprintf("failed to create request: %v", err), + }) + return + } + + // Set headers + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+authDaemonKey) + + logger.Debug("Sending SSH cert to auth daemon at %s", url) + + // Send the request + resp, err := httpClient.Do(req) + if err != nil { + logger.Error("Failed to connect to auth daemon: %v", err) + client.SendMessage("newt/pam/connection/response", map[string]interface{}{ + "traceId": certData.TraceID, + "success": false, + "error": fmt.Sprintf("failed to connect to auth daemon: %v", err), + }) + return + } + defer resp.Body.Close() + + // Check response status + if resp.StatusCode != http.StatusOK { + logger.Error("Auth daemon returned non-OK status: %d", resp.StatusCode) + client.SendMessage("newt/pam/connection/response", map[string]interface{}{ + "traceId": certData.TraceID, + "success": false, + "error": fmt.Sprintf("auth daemon returned status %d", resp.StatusCode), + }) + return + } + + logger.Info("Successfully registered SSH certificate with auth daemon for user %s", certData.Username) + + // Send success response back to cloud + err = client.SendMessage("newt/pam/connection/response", map[string]interface{}{ + "traceId": certData.TraceID, + "success": true, + }) + if err != nil { + logger.Error("Failed to send SSH cert success response: %v", err) + } + }) + client.OnConnect(func() error { publicKey = privateKey.PublicKey() logger.Debug("Public key: %s", publicKey)