diff --git a/authdaemon/host_linux.go b/authdaemon/host_linux.go new file mode 100644 index 0000000..4416dd1 --- /dev/null +++ b/authdaemon/host_linux.go @@ -0,0 +1,269 @@ +//go:build linux + +package authdaemon + +import ( + "bufio" + "encoding/json" + "fmt" + "os" + "os/exec" + "os/user" + "path/filepath" + "strconv" + "strings" + + "github.com/fosrl/newt/logger" +) + +// writeCACertIfNotExists writes contents to path only if the file does not exist. +func writeCACertIfNotExists(path, contents string) error { + if _, err := os.Stat(path); err == nil { + logger.Debug("auth-daemon: CA cert already exists at %s, skipping write", path) + return nil + } + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0755); err != nil { + return fmt.Errorf("mkdir %s: %w", dir, err) + } + contents = strings.TrimSpace(contents) + if contents != "" && !strings.HasSuffix(contents, "\n") { + contents += "\n" + } + if err := os.WriteFile(path, []byte(contents), 0644); err != nil { + return fmt.Errorf("write CA cert: %w", err) + } + logger.Info("auth-daemon: wrote CA cert to %s", path) + return nil +} + +// ensureSSHDTrustedUserCAKeys ensures sshd_config contains TrustedUserCAKeys caCertPath. +func ensureSSHDTrustedUserCAKeys(sshdConfigPath, caCertPath string) error { + if sshdConfigPath == "" { + sshdConfigPath = "/etc/ssh/sshd_config" + } + data, err := os.ReadFile(sshdConfigPath) + if err != nil { + return fmt.Errorf("read sshd_config: %w", err) + } + directive := "TrustedUserCAKeys " + caCertPath + lines := strings.Split(string(data), "\n") + found := false + for i, line := range lines { + trimmed := strings.TrimSpace(line) + // strip inline comment + if idx := strings.Index(trimmed, "#"); idx >= 0 { + trimmed = strings.TrimSpace(trimmed[:idx]) + } + if trimmed == "" { + continue + } + if strings.HasPrefix(trimmed, "TrustedUserCAKeys") { + if strings.TrimSpace(trimmed) == directive { + logger.Debug("auth-daemon: sshd_config already has TrustedUserCAKeys %s", caCertPath) + return nil + } + lines[i] = directive + found = true + break + } + } + if !found { + lines = append(lines, directive) + } + out := strings.Join(lines, "\n") + if !strings.HasSuffix(out, "\n") { + out += "\n" + } + if err := os.WriteFile(sshdConfigPath, []byte(out), 0644); err != nil { + return fmt.Errorf("write sshd_config: %w", err) + } + logger.Info("auth-daemon: updated %s with TrustedUserCAKeys %s", sshdConfigPath, caCertPath) + return nil +} + +// reloadSSHD runs the given shell command to reload sshd (e.g. "systemctl reload sshd"). +func reloadSSHD(reloadCmd string) error { + if reloadCmd == "" { + return nil + } + cmd := exec.Command("sh", "-c", reloadCmd) + if out, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("reload sshd %q: %w (output: %s)", reloadCmd, err, string(out)) + } + logger.Info("auth-daemon: reloaded sshd") + return nil +} + +// writePrincipals updates the principals file at path: JSON object keyed by username, value is array of principals. Adds username and niceId to that user's list (deduped). +func writePrincipals(path, username, niceId string) error { + if path == "" { + return nil + } + username = strings.TrimSpace(username) + niceId = strings.TrimSpace(niceId) + if username == "" { + return nil + } + data := make(map[string][]string) + if raw, err := os.ReadFile(path); err == nil { + _ = json.Unmarshal(raw, &data) + } + list := data[username] + seen := make(map[string]struct{}, len(list)+2) + for _, p := range list { + seen[p] = struct{}{} + } + for _, p := range []string{username, niceId} { + if p == "" { + continue + } + if _, ok := seen[p]; !ok { + seen[p] = struct{}{} + list = append(list, p) + } + } + data[username] = list + body, err := json.Marshal(data) + if err != nil { + return fmt.Errorf("marshal principals: %w", err) + } + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0755); err != nil { + return fmt.Errorf("mkdir %s: %w", dir, err) + } + if err := os.WriteFile(path, body, 0644); err != nil { + return fmt.Errorf("write principals: %w", err) + } + logger.Debug("auth-daemon: wrote principals to %s", path) + return nil +} + +// sudoGroup returns the name of the sudo group (wheel or sudo) that exists on the system. Prefers wheel. +func sudoGroup() string { + f, err := os.Open("/etc/group") + if err != nil { + return "sudo" + } + defer f.Close() + sc := bufio.NewScanner(f) + hasWheel := false + hasSudo := false + for sc.Scan() { + line := sc.Text() + if strings.HasPrefix(line, "wheel:") { + hasWheel = true + } + if strings.HasPrefix(line, "sudo:") { + hasSudo = true + } + } + if hasWheel { + return "wheel" + } + if hasSudo { + return "sudo" + } + return "sudo" +} + +// ensureUser creates the system user if missing, or reconciles sudo and homedir to match meta. +func ensureUser(username string, meta ConnectionMetadata) error { + if username == "" { + return nil + } + u, err := user.Lookup(username) + if err != nil { + if _, ok := err.(user.UnknownUserError); !ok { + return fmt.Errorf("lookup user %s: %w", username, err) + } + return createUser(username, meta) + } + return reconcileUser(u, meta) +} + +func createUser(username string, meta ConnectionMetadata) error { + args := []string{} + if meta.Homedir { + args = append(args, "-m") + } else { + args = append(args, "-M") + } + args = append(args, username) + cmd := exec.Command("useradd", args...) + if out, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("useradd %s: %w (output: %s)", username, err, string(out)) + } + logger.Info("auth-daemon: created user %s (homedir=%v)", username, meta.Homedir) + if meta.Sudo { + group := sudoGroup() + cmd := exec.Command("usermod", "-aG", group, username) + if out, err := cmd.CombinedOutput(); err != nil { + logger.Warn("auth-daemon: usermod -aG %s %s: %v (output: %s)", group, username, err, string(out)) + } else { + logger.Info("auth-daemon: added %s to %s", username, group) + } + } + return nil +} + +func mustAtoi(s string) int { + n, _ := strconv.Atoi(s) + return n +} + +func reconcileUser(u *user.User, meta ConnectionMetadata) error { + group := sudoGroup() + inGroup, err := userInGroup(u.Username, group) + if err != nil { + logger.Warn("auth-daemon: check group %s: %v", group, err) + inGroup = false + } + if meta.Sudo && !inGroup { + cmd := exec.Command("usermod", "-aG", group, u.Username) + if out, err := cmd.CombinedOutput(); err != nil { + logger.Warn("auth-daemon: usermod -aG %s %s: %v (output: %s)", group, u.Username, err, string(out)) + } else { + logger.Info("auth-daemon: added %s to %s", u.Username, group) + } + } else if !meta.Sudo && inGroup { + cmd := exec.Command("gpasswd", "-d", u.Username, group) + if out, err := cmd.CombinedOutput(); err != nil { + logger.Warn("auth-daemon: gpasswd -d %s %s: %v (output: %s)", u.Username, group, err, string(out)) + } else { + logger.Info("auth-daemon: removed %s from %s", u.Username, group) + } + } + if meta.Homedir && u.HomeDir != "" { + if st, err := os.Stat(u.HomeDir); err != nil || !st.IsDir() { + if err := os.MkdirAll(u.HomeDir, 0755); err != nil { + logger.Warn("auth-daemon: mkdir %s: %v", u.HomeDir, err) + } else { + uid, gid := mustAtoi(u.Uid), mustAtoi(u.Gid) + _ = os.Chown(u.HomeDir, uid, gid) + logger.Info("auth-daemon: created home %s for %s", u.HomeDir, u.Username) + } + } + } + return nil +} + +func userInGroup(username, groupName string) (bool, error) { + // getent group wheel returns "wheel:x:10:user1,user2" + cmd := exec.Command("getent", "group", groupName) + out, err := cmd.Output() + if err != nil { + return false, err + } + parts := strings.SplitN(strings.TrimSpace(string(out)), ":", 4) + if len(parts) < 4 { + return false, nil + } + members := strings.Split(parts[3], ",") + for _, m := range members { + if strings.TrimSpace(m) == username { + return true, nil + } + } + return false, nil +} diff --git a/authdaemon/host_stub.go b/authdaemon/host_stub.go new file mode 100644 index 0000000..dfd09a5 --- /dev/null +++ b/authdaemon/host_stub.go @@ -0,0 +1,32 @@ +//go:build !linux + +package authdaemon + +import "fmt" + +var errLinuxOnly = fmt.Errorf("auth-daemon PAM agent is only supported on Linux") + +// writeCACertIfNotExists returns an error on non-Linux. +func writeCACertIfNotExists(path, contents string) error { + return errLinuxOnly +} + +// ensureSSHDTrustedUserCAKeys returns an error on non-Linux. +func ensureSSHDTrustedUserCAKeys(sshdConfigPath, caCertPath string) error { + return errLinuxOnly +} + +// reloadSSHD returns an error on non-Linux. +func reloadSSHD(reloadCmd string) error { + return errLinuxOnly +} + +// ensureUser returns an error on non-Linux. +func ensureUser(username string, meta ConnectionMetadata) error { + return errLinuxOnly +} + +// writePrincipals returns an error on non-Linux. +func writePrincipals(path, username, niceId string) error { + return errLinuxOnly +} diff --git a/authdaemon/principals.go b/authdaemon/principals.go new file mode 100644 index 0000000..cbfca80 --- /dev/null +++ b/authdaemon/principals.go @@ -0,0 +1,28 @@ +package authdaemon + +import ( + "encoding/json" + "fmt" + "os" +) + +// GetPrincipals reads the principals data file at path, looks up the given user, and returns that user's principals as a string slice. +// The file format is JSON: object with username keys and array-of-principals values, e.g. {"alice":["alice","usr-123"],"bob":["bob","usr-456"]}. +// If the user is not found or the file is missing, returns nil and nil. +func GetPrincipals(path, user string) ([]string, error) { + if path == "" { + return nil, fmt.Errorf("principals file path is required") + } + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return nil, nil + } + return nil, fmt.Errorf("read principals file: %w", err) + } + var m map[string][]string + if err := json.Unmarshal(data, &m); err != nil { + return nil, fmt.Errorf("parse principals file: %w", err) + } + return m[user], nil +} diff --git a/authdaemon/routes.go b/authdaemon/routes.go new file mode 100644 index 0000000..d7ce880 --- /dev/null +++ b/authdaemon/routes.go @@ -0,0 +1,92 @@ +package authdaemon + +import ( + "encoding/json" + "net/http" + + "github.com/fosrl/newt/logger" +) + +// registerRoutes registers all API routes. Add new endpoints here. +func (s *Server) registerRoutes() { + s.mux.HandleFunc("/health", s.handleHealth) + s.mux.HandleFunc("/connection", s.handleConnection) +} + +// ConnectionMetadata is the metadata object in POST /connection. +type ConnectionMetadata struct { + Sudo bool `json:"sudo"` + Homedir bool `json:"homedir"` +} + +// ConnectionRequest is the JSON body for POST /connection. +type ConnectionRequest struct { + CaCert string `json:"caCert"` + NiceId string `json:"niceId"` + Username string `json:"username"` + Metadata ConnectionMetadata `json:"metadata"` +} + +// healthResponse is the JSON body for GET /health. +type healthResponse struct { + Status string `json:"status"` +} + +// handleHealth responds with 200 and {"status":"ok"}. +func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed) + return + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(healthResponse{Status: "ok"}) +} + +// ProcessConnection runs the same logic as POST /connection: CA cert, sshd config, user create/reconcile, principals. +// Use this when DisableHTTPS is true (e.g. embedded in Newt) instead of calling the API. +func (s *Server) ProcessConnection(req ConnectionRequest) { + logger.Info("connection: niceId=%q username=%q metadata.sudo=%v metadata.homedir=%v", + req.NiceId, req.Username, req.Metadata.Sudo, req.Metadata.Homedir) + + cfg := &s.cfg + if cfg.CACertPath != "" { + if err := writeCACertIfNotExists(cfg.CACertPath, req.CaCert); err != nil { + logger.Warn("auth-daemon: write CA cert: %v", err) + } + sshdConfig := cfg.SSHDConfigPath + if sshdConfig == "" { + sshdConfig = "/etc/ssh/sshd_config" + } + if err := ensureSSHDTrustedUserCAKeys(sshdConfig, cfg.CACertPath); err != nil { + logger.Warn("auth-daemon: sshd_config: %v", err) + } + if cfg.ReloadSSHCommand != "" { + if err := reloadSSHD(cfg.ReloadSSHCommand); err != nil { + logger.Warn("auth-daemon: reload sshd: %v", err) + } + } + } + if err := ensureUser(req.Username, req.Metadata); err != nil { + logger.Warn("auth-daemon: ensure user: %v", err) + } + if cfg.PrincipalsFilePath != "" { + if err := writePrincipals(cfg.PrincipalsFilePath, req.Username, req.NiceId); err != nil { + logger.Warn("auth-daemon: write principals: %v", err) + } + } +} + +// handleConnection accepts POST with connection payload and delegates to ProcessConnection. +func (s *Server) handleConnection(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed) + return + } + var req ConnectionRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "Bad Request", http.StatusBadRequest) + return + } + s.ProcessConnection(req) + w.WriteHeader(http.StatusOK) +} diff --git a/authdaemon/server.go b/authdaemon/server.go new file mode 100644 index 0000000..83cb480 --- /dev/null +++ b/authdaemon/server.go @@ -0,0 +1,174 @@ +package authdaemon + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/subtle" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "math/big" + "net/http" + "os" + "runtime" + "strings" + "time" + + "github.com/fosrl/newt/logger" +) + +type Config struct { + // DisableHTTPS: when true, Run() does not start the HTTPS server (for embedded use inside Newt). Call ProcessConnection directly for connection events. + DisableHTTPS bool + Port int // Listen port for the HTTPS server. Required when DisableHTTPS is false. + PresharedKey string // Required when DisableHTTPS is false; used for HTTP auth (Authorization: Bearer or X-Preshared-Key: ). + CACertPath string // Where to write the CA cert (e.g. /etc/ssh/ca.pem). + SSHDConfigPath string // Path to sshd_config (e.g. /etc/ssh/sshd_config). Defaults to /etc/ssh/sshd_config when CACertPath is set. + ReloadSSHCommand string // Command to reload sshd after config change (e.g. "systemctl reload sshd"). Empty = no reload. + PrincipalsFilePath string // Path to the principals data file (JSON: username -> array of principals). Empty = do not store principals. +} + +type Server struct { + cfg Config + addr string + presharedKey string + mux *http.ServeMux + tlsCert tls.Certificate +} + +// generateTLSCert creates a self-signed certificate and key in memory (no disk). +func generateTLSCert() (tls.Certificate, error) { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return tls.Certificate{}, fmt.Errorf("generate key: %w", err) + } + serial, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + if err != nil { + return tls.Certificate{}, fmt.Errorf("serial: %w", err) + } + tmpl := &x509.Certificate{ + SerialNumber: serial, + Subject: pkix.Name{ + CommonName: "localhost", + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(365 * 24 * time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + DNSNames: []string{"localhost", "127.0.0.1"}, + } + certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, key.Public(), key) + if err != nil { + return tls.Certificate{}, fmt.Errorf("create certificate: %w", err) + } + keyDER, err := x509.MarshalECPrivateKey(key) + if err != nil { + return tls.Certificate{}, fmt.Errorf("marshal key: %w", err) + } + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}) + cert, err := tls.X509KeyPair(certPEM, keyPEM) + if err != nil { + return tls.Certificate{}, fmt.Errorf("x509 key pair: %w", err) + } + return cert, nil +} + +// authMiddleware wraps next and requires a valid preshared key on every request. +// Accepts Authorization: Bearer or X-Preshared-Key: . +func (s *Server) authMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + key := "" + if v := r.Header.Get("Authorization"); strings.HasPrefix(v, "Bearer ") { + key = strings.TrimSpace(strings.TrimPrefix(v, "Bearer ")) + } + if key == "" { + key = strings.TrimSpace(r.Header.Get("X-Preshared-Key")) + } + if key == "" || subtle.ConstantTimeCompare([]byte(key), []byte(s.presharedKey)) != 1 { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + next.ServeHTTP(w, r) + }) +} + +// NewServer builds a new auth-daemon server from cfg. When DisableHTTPS is false, PresharedKey and Port are required. +func NewServer(cfg Config) (*Server, error) { + if runtime.GOOS != "linux" { + return nil, fmt.Errorf("auth-daemon is only supported on Linux, not %s", runtime.GOOS) + } + if !cfg.DisableHTTPS { + if cfg.PresharedKey == "" { + return nil, fmt.Errorf("preshared key is required when HTTPS is enabled") + } + if cfg.Port <= 0 { + return nil, fmt.Errorf("port must be positive when HTTPS is enabled") + } + } + s := &Server{cfg: cfg} + if !cfg.DisableHTTPS { + cert, err := generateTLSCert() + if err != nil { + return nil, err + } + s.addr = fmt.Sprintf(":%d", cfg.Port) + s.presharedKey = cfg.PresharedKey + s.mux = http.NewServeMux() + s.tlsCert = cert + s.registerRoutes() + } + return s, nil +} + +// Run starts the HTTPS server (unless DisableHTTPS) and blocks until ctx is cancelled or the server errors. +// When DisableHTTPS is true, Run() blocks on ctx only and does not listen; use ProcessConnection for connection events. +func (s *Server) Run(ctx context.Context) error { + if s.cfg.DisableHTTPS { + logger.Info("auth-daemon running (HTTPS disabled)") + <-ctx.Done() + s.cleanupPrincipalsFile() + return nil + } + tcfg := &tls.Config{ + Certificates: []tls.Certificate{s.tlsCert}, + MinVersion: tls.VersionTLS12, + } + handler := s.authMiddleware(s.mux) + srv := &http.Server{ + Addr: s.addr, + Handler: handler, + TLSConfig: tcfg, + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + ReadHeaderTimeout: 5 * time.Second, + IdleTimeout: 60 * time.Second, + } + go func() { + <-ctx.Done() + shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if err := srv.Shutdown(shutdownCtx); err != nil { + logger.Warn("auth-daemon shutdown: %v", err) + } + }() + logger.Info("auth-daemon listening on https://127.0.0.1%s", s.addr) + if err := srv.ListenAndServeTLS("", ""); err != nil && err != http.ErrServerClosed { + return err + } + s.cleanupPrincipalsFile() + return nil +} + +func (s *Server) cleanupPrincipalsFile() { + if s.cfg.PrincipalsFilePath != "" { + if err := os.Remove(s.cfg.PrincipalsFilePath); err != nil && !os.IsNotExist(err) { + logger.Warn("auth-daemon: remove principals file: %v", err) + } + } +}