diff --git a/authdaemon.go b/authdaemon.go new file mode 100644 index 0000000..5de92ad --- /dev/null +++ b/authdaemon.go @@ -0,0 +1,151 @@ +package main + +import ( + "context" + "errors" + "fmt" + "os" + "runtime" + + "github.com/fosrl/newt/authdaemon" + "github.com/fosrl/newt/logger" +) + +const ( + defaultPrincipalsPath = "/var/run/auth-daemon/principals" + defaultCACertPath = "/etc/ssh/ca.pem" +) + +var ( + errPresharedKeyRequired = errors.New("auth-daemon-key is required when --auth-daemon is enabled") + errRootRequired = errors.New("auth-daemon must be run as root (use sudo)") + authDaemonServer *authdaemon.Server // Global auth daemon server instance +) + +// startAuthDaemon initializes and starts the auth daemon in the background. +// It validates requirements (Linux, root, preshared key) and starts the server +// in a goroutine so it runs alongside normal newt operation. +func startAuthDaemon(ctx context.Context) error { + // Validation + if runtime.GOOS != "linux" { + return fmt.Errorf("auth-daemon is only supported on Linux, not %s", runtime.GOOS) + } + if os.Geteuid() != 0 { + return errRootRequired + } + + // Use defaults if not set + principalsFile := authDaemonPrincipalsFile + if principalsFile == "" { + principalsFile = defaultPrincipalsPath + } + caCertPath := authDaemonCACertPath + if caCertPath == "" { + caCertPath = defaultCACertPath + } + + // Create auth daemon server + cfg := authdaemon.Config{ + DisableHTTPS: true, // We run without HTTP server in newt + PresharedKey: "this-key-is-not-used", // Not used in embedded mode, but set to non-empty to satisfy validation + PrincipalsFilePath: principalsFile, + CACertPath: caCertPath, + Force: true, + } + + srv, err := authdaemon.NewServer(cfg) + if err != nil { + return fmt.Errorf("create auth daemon server: %w", err) + } + + authDaemonServer = srv + + // Start the auth daemon in a goroutine so it runs alongside newt + go func() { + logger.Info("Auth daemon starting (native mode, no HTTP server)") + if err := srv.Run(ctx); err != nil { + logger.Error("Auth daemon error: %v", err) + } + logger.Info("Auth daemon stopped") + }() + + return nil +} + + + +// runPrincipalsCmd executes the principals subcommand logic +func runPrincipalsCmd(args []string) { + opts := struct { + PrincipalsFile string + Username string + }{ + PrincipalsFile: defaultPrincipalsPath, + } + + // Parse flags manually + for i := 0; i < len(args); i++ { + switch args[i] { + case "--principals-file": + if i+1 >= len(args) { + fmt.Fprintf(os.Stderr, "Error: --principals-file requires a value\n") + os.Exit(1) + } + opts.PrincipalsFile = args[i+1] + i++ + case "--username": + if i+1 >= len(args) { + fmt.Fprintf(os.Stderr, "Error: --username requires a value\n") + os.Exit(1) + } + opts.Username = args[i+1] + i++ + case "--help", "-h": + printPrincipalsHelp() + os.Exit(0) + default: + fmt.Fprintf(os.Stderr, "Error: unknown flag: %s\n", args[i]) + printPrincipalsHelp() + os.Exit(1) + } + } + + // Validation + if opts.Username == "" { + fmt.Fprintf(os.Stderr, "Error: username is required\n") + printPrincipalsHelp() + os.Exit(1) + } + + // Get principals + list, err := authdaemon.GetPrincipals(opts.PrincipalsFile, opts.Username) + if err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } + if len(list) == 0 { + fmt.Println("") + return + } + for _, principal := range list { + fmt.Println(principal) + } +} + +func printPrincipalsHelp() { + fmt.Fprintf(os.Stderr, `Usage: newt principals [flags] + +Output principals for a username (for AuthorizedPrincipalsCommand in sshd_config). +Read the principals file and print principals that match the given username, one per line. +Configure in sshd_config with AuthorizedPrincipalsCommand and %%u for the username. + +Flags: + --principals-file string Path to the principals file (default "%s") + --username string Username to look up (required) + --help, -h Show this help message + +Example: + newt principals --username alice + +`, defaultPrincipalsPath) +} \ No newline at end of file diff --git a/authdaemon/host_linux.go b/authdaemon/host_linux.go index 82834f3..76f8712 100644 --- a/authdaemon/host_linux.go +++ b/authdaemon/host_linux.go @@ -138,7 +138,7 @@ func ensureUser(username string, meta ConnectionMetadata) error { } func createUser(username string, meta ConnectionMetadata) error { - args := []string{} + args := []string{"-s", "/bin/bash"} if meta.Homedir { args = append(args, "-m") } else { diff --git a/main.go b/main.go index 24fd8bb..3ea52d4 100644 --- a/main.go +++ b/main.go @@ -18,6 +18,7 @@ import ( "syscall" "time" + "github.com/fosrl/newt/authdaemon" "github.com/fosrl/newt/docker" "github.com/fosrl/newt/healthcheck" "github.com/fosrl/newt/logger" @@ -132,6 +133,9 @@ var ( healthMonitor *healthcheck.Monitor enforceHealthcheckCert bool authDaemonKey string + authDaemonPrincipalsFile string + authDaemonCACertPath string + authDaemonEnabled bool // Build/version (can be overridden via -ldflags "-X main.newtVersion=...") newtVersion = "version_replaceme" @@ -154,6 +158,28 @@ var ( ) func main() { + // Check for subcommands first (only principals exits early) + if len(os.Args) > 1 { + switch os.Args[1] { + case "auth-daemon": + // Run principals subcommand only if the next argument is "principals" + if len(os.Args) > 2 && os.Args[2] == "principals" { + runPrincipalsCmd(os.Args[3:]) + return + } + + // auth-daemon subcommand without "principals" - show help + fmt.Println("Error: auth-daemon subcommand requires 'principals' argument") + fmt.Println() + fmt.Println("Usage:") + fmt.Println(" newt auth-daemon principals [options]") + fmt.Println() + + // If not "principals", exit the switch to continue with normal execution + return + } + } + // Check if we're running as a Windows service if isWindowsService() { runService("NewtWireguardService", false, os.Args[1:]) @@ -185,7 +211,10 @@ func runNewtMain(ctx context.Context) { updownScript = os.Getenv("UPDOWN_SCRIPT") interfaceName = os.Getenv("INTERFACE") portStr := os.Getenv("PORT") - authDaemonKey = os.Getenv("AUTH_DAEMON_KEY") + authDaemonKey = os.Getenv("AD_KEY") + authDaemonPrincipalsFile = os.Getenv("AD_PRINCIPALS_FILE") + authDaemonCACertPath = os.Getenv("AD_CA_CERT_PATH") + authDaemonEnabledEnv := os.Getenv("AUTH_DAEMON_ENABLED") // Metrics/observability env mirrors metricsEnabledEnv := os.Getenv("NEWT_METRICS_PROMETHEUS_ENABLED") @@ -374,9 +403,22 @@ func runNewtMain(ctx context.Context) { region = regionEnv } - // Auth daemon key flag + // Auth daemon flags if authDaemonKey == "" { - flag.StringVar(&authDaemonKey, "auth-daemon-key", "", "Preshared key for auth daemon authentication") + flag.StringVar(&authDaemonKey, "ad-preshared-key", "", "Preshared key for auth daemon authentication (required when --auth-daemon is true)") + } + if authDaemonPrincipalsFile == "" { + flag.StringVar(&authDaemonPrincipalsFile, "ad-principals-file", "/var/run/auth-daemon/principals", "Path to the principals file for auth daemon") + } + if authDaemonCACertPath == "" { + flag.StringVar(&authDaemonCACertPath, "ad-ca-cert-path", "/etc/ssh/ca.pem", "Path to the CA certificate file for auth daemon") + } + if authDaemonEnabledEnv == "" { + flag.BoolVar(&authDaemonEnabled, "auth-daemon", false, "Enable auth daemon mode (runs alongside normal newt operation)") + } else { + if v, err := strconv.ParseBool(authDaemonEnabledEnv); err == nil { + authDaemonEnabled = v + } } // do a --version check @@ -398,6 +440,13 @@ func runNewtMain(ctx context.Context) { logger.Init(nil) loggerLevel := util.ParseLogLevel(logLevel) + + // Start auth daemon if enabled + if authDaemonEnabled { + if err := startAuthDaemon(ctx); err != nil { + logger.Fatal("Failed to start auth daemon: %v", err) + } + } logger.GetLogger().SetLevel(loggerLevel) // Initialize telemetry after flags are parsed (so flags override env) @@ -1329,7 +1378,7 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey( // Define the structure of the incoming message type SSHCertData struct { - MessageId string `json:"messageId"` + MessageId int `json:"messageId"` AgentPort int `json:"agentPort"` AgentHost string `json:"agentHost"` CACert string `json:"caCert"` @@ -1348,109 +1397,137 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey( return } + // print the received data for debugging + logger.Debug("Received SSH cert data: %s", string(jsonData)) + if err := json.Unmarshal(jsonData, &certData); err != nil { logger.Error("Error unmarshaling SSH cert data: %v", err) return } - // Check if auth daemon key is configured - if authDaemonKey == "" { - logger.Error("Auth daemon key not configured, cannot process SSH certificate") - // Send failure response back to cloud - err := client.SendMessage("ws/round-trip/complete", map[string]interface{}{ - "messageId": certData.MessageId, - "complete": true, - "error": "auth daemon key not configured", - }) - if err != nil { - logger.Error("Failed to send SSH cert failure response: %v", err) - } - return - } + // Check if we're running the auth daemon internally + if authDaemonServer != nil { + // Call ProcessConnection directly when running internally + logger.Debug("Calling internal auth daemon ProcessConnection for user %s", certData.Username) - // Prepare the request body for the auth daemon - requestBody := map[string]interface{}{ - "caCert": certData.CACert, - "niceId": certData.NiceID, - "username": certData.Username, - "metadata": map[string]interface{}{ - "sudo": certData.Metadata.Sudo, - "homedir": certData.Metadata.Homedir, - }, - } - - requestJSON, err := json.Marshal(requestBody) - if err != nil { - logger.Error("Failed to marshal auth daemon request: %v", err) - // Send failure response - client.SendMessage("ws/round-trip/complete", map[string]interface{}{ - "messageId": certData.MessageId, - "complete": true, - "error": fmt.Sprintf("failed to marshal request: %v", err), - }) - return - } - - // Create HTTPS client that skips certificate verification - // (auth daemon uses self-signed cert) - httpClient := &http.Client{ - Transport: &http.Transport{ - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: true, + authDaemonServer.ProcessConnection(authdaemon.ConnectionRequest{ + CaCert: certData.CACert, + NiceId: certData.NiceID, + Username: certData.Username, + Metadata: authdaemon.ConnectionMetadata{ + Sudo: certData.Metadata.Sudo, + Homedir: certData.Metadata.Homedir, }, - }, - Timeout: 10 * time.Second, - } - - // Make the request to the auth daemon - url := fmt.Sprintf("https://%s:%d/connection", certData.AgentHost, certData.AgentPort) - req, err := http.NewRequest("POST", url, bytes.NewBuffer(requestJSON)) - if err != nil { - logger.Error("Failed to create auth daemon request: %v", err) - client.SendMessage("ws/round-trip/complete", map[string]interface{}{ - "messageId": certData.MessageId, - "complete": true, - "error": fmt.Sprintf("failed to create request: %v", err), }) - return - } - // Set headers - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+authDaemonKey) - - logger.Debug("Sending SSH cert to auth daemon at %s", url) - - // Send the request - resp, err := httpClient.Do(req) - if err != nil { - logger.Error("Failed to connect to auth daemon: %v", err) - client.SendMessage("ws/round-trip/complete", map[string]interface{}{ + // Send success response back to cloud + err = client.SendMessage("ws/round-trip/complete", map[string]interface{}{ "messageId": certData.MessageId, - "complete": true, - "error": fmt.Sprintf("failed to connect to auth daemon: %v", err), + "complete": true, }) - return - } - defer resp.Body.Close() - // Check response status - if resp.StatusCode != http.StatusOK { - logger.Error("Auth daemon returned non-OK status: %d", resp.StatusCode) - client.SendMessage("ws/round-trip/complete", map[string]interface{}{ - "messageId": certData.MessageId, - "complete": true, - "error": fmt.Sprintf("auth daemon returned status %d", resp.StatusCode), - }) - return - } + logger.Info("Successfully processed connection via internal auth daemon for user %s", certData.Username) + } else { + // External auth daemon mode - make HTTP request + // Check if auth daemon key is configured + if authDaemonKey == "" { + logger.Error("Auth daemon key not configured, cannot communicate with daemon") + // Send failure response back to cloud + err := client.SendMessage("ws/round-trip/complete", map[string]interface{}{ + "messageId": certData.MessageId, + "complete": true, + "error": "auth daemon key not configured", + }) + if err != nil { + logger.Error("Failed to send SSH cert failure response: %v", err) + } + return + } - logger.Info("Successfully registered SSH certificate with auth daemon for user %s", certData.Username) + // Prepare the request body for the auth daemon + requestBody := map[string]interface{}{ + "caCert": certData.CACert, + "niceId": certData.NiceID, + "username": certData.Username, + "metadata": map[string]interface{}{ + "sudo": certData.Metadata.Sudo, + "homedir": certData.Metadata.Homedir, + }, + } + + requestJSON, err := json.Marshal(requestBody) + if err != nil { + logger.Error("Failed to marshal auth daemon request: %v", err) + // Send failure response + client.SendMessage("ws/round-trip/complete", map[string]interface{}{ + "messageId": certData.MessageId, + "complete": true, + "error": fmt.Sprintf("failed to marshal request: %v", err), + }) + return + } + + // Create HTTPS client that skips certificate verification + // (auth daemon uses self-signed cert) + httpClient := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + }, + Timeout: 10 * time.Second, + } + + // Make the request to the auth daemon + url := fmt.Sprintf("https://%s:%d/connection", certData.AgentHost, certData.AgentPort) + req, err := http.NewRequest("POST", url, bytes.NewBuffer(requestJSON)) + if err != nil { + logger.Error("Failed to create auth daemon request: %v", err) + client.SendMessage("ws/round-trip/complete", map[string]interface{}{ + "messageId": certData.MessageId, + "complete": true, + "error": fmt.Sprintf("failed to create request: %v", err), + }) + return + } + + // Set headers + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+authDaemonKey) + + logger.Debug("Sending SSH cert to auth daemon at %s", url) + + // Send the request + resp, err := httpClient.Do(req) + if err != nil { + logger.Error("Failed to connect to auth daemon: %v", err) + client.SendMessage("ws/round-trip/complete", map[string]interface{}{ + "messageId": certData.MessageId, + "complete": true, + "error": fmt.Sprintf("failed to connect to auth daemon: %v", err), + }) + return + } + defer resp.Body.Close() + + // Check response status + if resp.StatusCode != http.StatusOK { + logger.Error("Auth daemon returned non-OK status: %d", resp.StatusCode) + client.SendMessage("ws/round-trip/complete", map[string]interface{}{ + "messageId": certData.MessageId, + "complete": true, + "error": fmt.Sprintf("auth daemon returned status %d", resp.StatusCode), + }) + return + } + + logger.Info("Successfully registered SSH certificate with external auth daemon for user %s", certData.Username) + } // Send success response back to cloud err = client.SendMessage("ws/round-trip/complete", map[string]interface{}{ "messageId": certData.MessageId, - "complete": true, + "complete": true, }) if err != nil { logger.Error("Failed to send SSH cert success response: %v", err)