mirror of
https://github.com/fosrl/newt.git
synced 2026-03-06 18:56:41 +00:00
Add basic newt command relay to auth daemon
This commit is contained in:
148
main.go
148
main.go
@@ -1,7 +1,9 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"flag"
|
"flag"
|
||||||
@@ -11,7 +13,6 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"path/filepath"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
@@ -130,6 +131,7 @@ var (
|
|||||||
preferEndpoint string
|
preferEndpoint string
|
||||||
healthMonitor *healthcheck.Monitor
|
healthMonitor *healthcheck.Monitor
|
||||||
enforceHealthcheckCert bool
|
enforceHealthcheckCert bool
|
||||||
|
authDaemonKey string
|
||||||
// Build/version (can be overridden via -ldflags "-X main.newtVersion=...")
|
// Build/version (can be overridden via -ldflags "-X main.newtVersion=...")
|
||||||
newtVersion = "version_replaceme"
|
newtVersion = "version_replaceme"
|
||||||
|
|
||||||
@@ -183,6 +185,7 @@ func runNewtMain(ctx context.Context) {
|
|||||||
updownScript = os.Getenv("UPDOWN_SCRIPT")
|
updownScript = os.Getenv("UPDOWN_SCRIPT")
|
||||||
interfaceName = os.Getenv("INTERFACE")
|
interfaceName = os.Getenv("INTERFACE")
|
||||||
portStr := os.Getenv("PORT")
|
portStr := os.Getenv("PORT")
|
||||||
|
authDaemonKey = os.Getenv("AUTH_DAEMON_KEY")
|
||||||
|
|
||||||
// Metrics/observability env mirrors
|
// Metrics/observability env mirrors
|
||||||
metricsEnabledEnv := os.Getenv("NEWT_METRICS_PROMETHEUS_ENABLED")
|
metricsEnabledEnv := os.Getenv("NEWT_METRICS_PROMETHEUS_ENABLED")
|
||||||
@@ -371,6 +374,11 @@ func runNewtMain(ctx context.Context) {
|
|||||||
region = regionEnv
|
region = regionEnv
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Auth daemon key flag
|
||||||
|
if authDaemonKey == "" {
|
||||||
|
flag.StringVar(&authDaemonKey, "auth-daemon-key", "", "Preshared key for auth daemon authentication")
|
||||||
|
}
|
||||||
|
|
||||||
// do a --version check
|
// do a --version check
|
||||||
version := flag.Bool("version", false, "Print the version")
|
version := flag.Bool("version", false, "Print the version")
|
||||||
|
|
||||||
@@ -686,8 +694,8 @@ func runNewtMain(ctx context.Context) {
|
|||||||
|
|
||||||
relayPort := wgData.RelayPort
|
relayPort := wgData.RelayPort
|
||||||
if relayPort == 0 {
|
if relayPort == 0 {
|
||||||
relayPort = 21820
|
relayPort = 21820
|
||||||
}
|
}
|
||||||
|
|
||||||
clientsHandleNewtConnection(wgData.PublicKey, endpoint, relayPort)
|
clientsHandleNewtConnection(wgData.PublicKey, endpoint, relayPort)
|
||||||
|
|
||||||
@@ -1315,6 +1323,140 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Register handler for SSH certificate issued events
|
||||||
|
client.RegisterHandler("newt/pam/connection", func(msg websocket.WSMessage) {
|
||||||
|
logger.Debug("Received SSH certificate issued message")
|
||||||
|
|
||||||
|
// Define the structure of the incoming message
|
||||||
|
type SSHCertData struct {
|
||||||
|
TraceID string `json:"traceId"`
|
||||||
|
AgentPort int `json:"agentPort"`
|
||||||
|
AgentHost string `json:"agentHost"`
|
||||||
|
CACert string `json:"caCert"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
NiceID string `json:"niceId"`
|
||||||
|
Metadata struct {
|
||||||
|
Sudo bool `json:"sudo"`
|
||||||
|
Homedir bool `json:"homedir"`
|
||||||
|
} `json:"metadata"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var certData SSHCertData
|
||||||
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Error marshaling SSH cert data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(jsonData, &certData); err != nil {
|
||||||
|
logger.Error("Error unmarshaling SSH cert data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if auth daemon key is configured
|
||||||
|
if authDaemonKey == "" {
|
||||||
|
logger.Error("Auth daemon key not configured, cannot process SSH certificate")
|
||||||
|
// Send failure response back to cloud
|
||||||
|
err := client.SendMessage("newt/pam/connection/response", map[string]interface{}{
|
||||||
|
"traceId": certData.TraceID,
|
||||||
|
"success": false,
|
||||||
|
"error": "auth daemon key not configured",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to send SSH cert failure response: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare the request body for the auth daemon
|
||||||
|
requestBody := map[string]interface{}{
|
||||||
|
"caCert": certData.CACert,
|
||||||
|
"niceId": certData.NiceID,
|
||||||
|
"username": certData.Username,
|
||||||
|
"metadata": map[string]interface{}{
|
||||||
|
"sudo": certData.Metadata.Sudo,
|
||||||
|
"homedir": certData.Metadata.Homedir,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
requestJSON, err := json.Marshal(requestBody)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to marshal auth daemon request: %v", err)
|
||||||
|
// Send failure response
|
||||||
|
client.SendMessage("newt/pam/ssh-cert-response", map[string]interface{}{
|
||||||
|
"traceId": certData.TraceID,
|
||||||
|
"success": false,
|
||||||
|
"error": fmt.Sprintf("failed to marshal request: %v", err),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create HTTPS client that skips certificate verification
|
||||||
|
// (auth daemon uses self-signed cert)
|
||||||
|
httpClient := &http.Client{
|
||||||
|
Transport: &http.Transport{
|
||||||
|
TLSClientConfig: &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Timeout: 10 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make the request to the auth daemon
|
||||||
|
url := fmt.Sprintf("https://%s:%d/connection", certData.AgentHost, certData.AgentPort)
|
||||||
|
req, err := http.NewRequest("POST", url, bytes.NewBuffer(requestJSON))
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to create auth daemon request: %v", err)
|
||||||
|
client.SendMessage("newt/pam/connection/response", map[string]interface{}{
|
||||||
|
"traceId": certData.TraceID,
|
||||||
|
"success": false,
|
||||||
|
"error": fmt.Sprintf("failed to create request: %v", err),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set headers
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Authorization", "Bearer "+authDaemonKey)
|
||||||
|
|
||||||
|
logger.Debug("Sending SSH cert to auth daemon at %s", url)
|
||||||
|
|
||||||
|
// Send the request
|
||||||
|
resp, err := httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to connect to auth daemon: %v", err)
|
||||||
|
client.SendMessage("newt/pam/connection/response", map[string]interface{}{
|
||||||
|
"traceId": certData.TraceID,
|
||||||
|
"success": false,
|
||||||
|
"error": fmt.Sprintf("failed to connect to auth daemon: %v", err),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
// Check response status
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
logger.Error("Auth daemon returned non-OK status: %d", resp.StatusCode)
|
||||||
|
client.SendMessage("newt/pam/connection/response", map[string]interface{}{
|
||||||
|
"traceId": certData.TraceID,
|
||||||
|
"success": false,
|
||||||
|
"error": fmt.Sprintf("auth daemon returned status %d", resp.StatusCode),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Successfully registered SSH certificate with auth daemon for user %s", certData.Username)
|
||||||
|
|
||||||
|
// Send success response back to cloud
|
||||||
|
err = client.SendMessage("newt/pam/connection/response", map[string]interface{}{
|
||||||
|
"traceId": certData.TraceID,
|
||||||
|
"success": true,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to send SSH cert success response: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
client.OnConnect(func() error {
|
client.OnConnect(func() error {
|
||||||
publicKey = privateKey.PublicKey()
|
publicKey = privateKey.PublicKey()
|
||||||
logger.Debug("Public key: %s", publicKey)
|
logger.Debug("Public key: %s", publicKey)
|
||||||
|
|||||||
Reference in New Issue
Block a user