Add basic newt command relay to auth daemon

This commit is contained in:
Owen
2026-02-16 20:04:24 -08:00
committed by Owen Schwartz
parent 2265b61381
commit 5b884042cd

148
main.go
View File

@@ -1,7 +1,9 @@
package main
import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"errors"
"flag"
@@ -11,7 +13,6 @@ import (
"net/netip"
"os"
"os/signal"
"path/filepath"
"strconv"
"strings"
"syscall"
@@ -130,6 +131,7 @@ var (
preferEndpoint string
healthMonitor *healthcheck.Monitor
enforceHealthcheckCert bool
authDaemonKey string
// Build/version (can be overridden via -ldflags "-X main.newtVersion=...")
newtVersion = "version_replaceme"
@@ -183,6 +185,7 @@ func runNewtMain(ctx context.Context) {
updownScript = os.Getenv("UPDOWN_SCRIPT")
interfaceName = os.Getenv("INTERFACE")
portStr := os.Getenv("PORT")
authDaemonKey = os.Getenv("AUTH_DAEMON_KEY")
// Metrics/observability env mirrors
metricsEnabledEnv := os.Getenv("NEWT_METRICS_PROMETHEUS_ENABLED")
@@ -371,6 +374,11 @@ func runNewtMain(ctx context.Context) {
region = regionEnv
}
// Auth daemon key flag
if authDaemonKey == "" {
flag.StringVar(&authDaemonKey, "auth-daemon-key", "", "Preshared key for auth daemon authentication")
}
// do a --version check
version := flag.Bool("version", false, "Print the version")
@@ -686,8 +694,8 @@ func runNewtMain(ctx context.Context) {
relayPort := wgData.RelayPort
if relayPort == 0 {
relayPort = 21820
}
relayPort = 21820
}
clientsHandleNewtConnection(wgData.PublicKey, endpoint, relayPort)
@@ -1315,6 +1323,140 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
}
})
// Register handler for SSH certificate issued events
client.RegisterHandler("newt/pam/connection", func(msg websocket.WSMessage) {
logger.Debug("Received SSH certificate issued message")
// Define the structure of the incoming message
type SSHCertData struct {
TraceID string `json:"traceId"`
AgentPort int `json:"agentPort"`
AgentHost string `json:"agentHost"`
CACert string `json:"caCert"`
Username string `json:"username"`
NiceID string `json:"niceId"`
Metadata struct {
Sudo bool `json:"sudo"`
Homedir bool `json:"homedir"`
} `json:"metadata"`
}
var certData SSHCertData
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling SSH cert data: %v", err)
return
}
if err := json.Unmarshal(jsonData, &certData); err != nil {
logger.Error("Error unmarshaling SSH cert data: %v", err)
return
}
// Check if auth daemon key is configured
if authDaemonKey == "" {
logger.Error("Auth daemon key not configured, cannot process SSH certificate")
// Send failure response back to cloud
err := client.SendMessage("newt/pam/connection/response", map[string]interface{}{
"traceId": certData.TraceID,
"success": false,
"error": "auth daemon key not configured",
})
if err != nil {
logger.Error("Failed to send SSH cert failure response: %v", err)
}
return
}
// Prepare the request body for the auth daemon
requestBody := map[string]interface{}{
"caCert": certData.CACert,
"niceId": certData.NiceID,
"username": certData.Username,
"metadata": map[string]interface{}{
"sudo": certData.Metadata.Sudo,
"homedir": certData.Metadata.Homedir,
},
}
requestJSON, err := json.Marshal(requestBody)
if err != nil {
logger.Error("Failed to marshal auth daemon request: %v", err)
// Send failure response
client.SendMessage("newt/pam/ssh-cert-response", map[string]interface{}{
"traceId": certData.TraceID,
"success": false,
"error": fmt.Sprintf("failed to marshal request: %v", err),
})
return
}
// Create HTTPS client that skips certificate verification
// (auth daemon uses self-signed cert)
httpClient := &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
},
Timeout: 10 * time.Second,
}
// Make the request to the auth daemon
url := fmt.Sprintf("https://%s:%d/connection", certData.AgentHost, certData.AgentPort)
req, err := http.NewRequest("POST", url, bytes.NewBuffer(requestJSON))
if err != nil {
logger.Error("Failed to create auth daemon request: %v", err)
client.SendMessage("newt/pam/connection/response", map[string]interface{}{
"traceId": certData.TraceID,
"success": false,
"error": fmt.Sprintf("failed to create request: %v", err),
})
return
}
// Set headers
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+authDaemonKey)
logger.Debug("Sending SSH cert to auth daemon at %s", url)
// Send the request
resp, err := httpClient.Do(req)
if err != nil {
logger.Error("Failed to connect to auth daemon: %v", err)
client.SendMessage("newt/pam/connection/response", map[string]interface{}{
"traceId": certData.TraceID,
"success": false,
"error": fmt.Sprintf("failed to connect to auth daemon: %v", err),
})
return
}
defer resp.Body.Close()
// Check response status
if resp.StatusCode != http.StatusOK {
logger.Error("Auth daemon returned non-OK status: %d", resp.StatusCode)
client.SendMessage("newt/pam/connection/response", map[string]interface{}{
"traceId": certData.TraceID,
"success": false,
"error": fmt.Sprintf("auth daemon returned status %d", resp.StatusCode),
})
return
}
logger.Info("Successfully registered SSH certificate with auth daemon for user %s", certData.Username)
// Send success response back to cloud
err = client.SendMessage("newt/pam/connection/response", map[string]interface{}{
"traceId": certData.TraceID,
"success": true,
})
if err != nil {
logger.Error("Failed to send SSH cert success response: %v", err)
}
})
client.OnConnect(func() error {
publicKey = privateKey.PublicKey()
logger.Debug("Public key: %s", publicKey)