mirror of
https://github.com/fosrl/newt.git
synced 2026-02-23 21:36:37 +00:00
add auth daemon
This commit is contained in:
committed by
Owen Schwartz
parent
d98eaa88b3
commit
2cc957d55f
269
authdaemon/host_linux.go
Normal file
269
authdaemon/host_linux.go
Normal file
@@ -0,0 +1,269 @@
|
|||||||
|
//go:build linux
|
||||||
|
|
||||||
|
package authdaemon
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"os/user"
|
||||||
|
"path/filepath"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/fosrl/newt/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// writeCACertIfNotExists writes contents to path only if the file does not exist.
|
||||||
|
func writeCACertIfNotExists(path, contents string) error {
|
||||||
|
if _, err := os.Stat(path); err == nil {
|
||||||
|
logger.Debug("auth-daemon: CA cert already exists at %s, skipping write", path)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
dir := filepath.Dir(path)
|
||||||
|
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||||
|
return fmt.Errorf("mkdir %s: %w", dir, err)
|
||||||
|
}
|
||||||
|
contents = strings.TrimSpace(contents)
|
||||||
|
if contents != "" && !strings.HasSuffix(contents, "\n") {
|
||||||
|
contents += "\n"
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(path, []byte(contents), 0644); err != nil {
|
||||||
|
return fmt.Errorf("write CA cert: %w", err)
|
||||||
|
}
|
||||||
|
logger.Info("auth-daemon: wrote CA cert to %s", path)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ensureSSHDTrustedUserCAKeys ensures sshd_config contains TrustedUserCAKeys caCertPath.
|
||||||
|
func ensureSSHDTrustedUserCAKeys(sshdConfigPath, caCertPath string) error {
|
||||||
|
if sshdConfigPath == "" {
|
||||||
|
sshdConfigPath = "/etc/ssh/sshd_config"
|
||||||
|
}
|
||||||
|
data, err := os.ReadFile(sshdConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("read sshd_config: %w", err)
|
||||||
|
}
|
||||||
|
directive := "TrustedUserCAKeys " + caCertPath
|
||||||
|
lines := strings.Split(string(data), "\n")
|
||||||
|
found := false
|
||||||
|
for i, line := range lines {
|
||||||
|
trimmed := strings.TrimSpace(line)
|
||||||
|
// strip inline comment
|
||||||
|
if idx := strings.Index(trimmed, "#"); idx >= 0 {
|
||||||
|
trimmed = strings.TrimSpace(trimmed[:idx])
|
||||||
|
}
|
||||||
|
if trimmed == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(trimmed, "TrustedUserCAKeys") {
|
||||||
|
if strings.TrimSpace(trimmed) == directive {
|
||||||
|
logger.Debug("auth-daemon: sshd_config already has TrustedUserCAKeys %s", caCertPath)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
lines[i] = directive
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
lines = append(lines, directive)
|
||||||
|
}
|
||||||
|
out := strings.Join(lines, "\n")
|
||||||
|
if !strings.HasSuffix(out, "\n") {
|
||||||
|
out += "\n"
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(sshdConfigPath, []byte(out), 0644); err != nil {
|
||||||
|
return fmt.Errorf("write sshd_config: %w", err)
|
||||||
|
}
|
||||||
|
logger.Info("auth-daemon: updated %s with TrustedUserCAKeys %s", sshdConfigPath, caCertPath)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// reloadSSHD runs the given shell command to reload sshd (e.g. "systemctl reload sshd").
|
||||||
|
func reloadSSHD(reloadCmd string) error {
|
||||||
|
if reloadCmd == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
cmd := exec.Command("sh", "-c", reloadCmd)
|
||||||
|
if out, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
return fmt.Errorf("reload sshd %q: %w (output: %s)", reloadCmd, err, string(out))
|
||||||
|
}
|
||||||
|
logger.Info("auth-daemon: reloaded sshd")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// writePrincipals updates the principals file at path: JSON object keyed by username, value is array of principals. Adds username and niceId to that user's list (deduped).
|
||||||
|
func writePrincipals(path, username, niceId string) error {
|
||||||
|
if path == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
username = strings.TrimSpace(username)
|
||||||
|
niceId = strings.TrimSpace(niceId)
|
||||||
|
if username == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
data := make(map[string][]string)
|
||||||
|
if raw, err := os.ReadFile(path); err == nil {
|
||||||
|
_ = json.Unmarshal(raw, &data)
|
||||||
|
}
|
||||||
|
list := data[username]
|
||||||
|
seen := make(map[string]struct{}, len(list)+2)
|
||||||
|
for _, p := range list {
|
||||||
|
seen[p] = struct{}{}
|
||||||
|
}
|
||||||
|
for _, p := range []string{username, niceId} {
|
||||||
|
if p == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := seen[p]; !ok {
|
||||||
|
seen[p] = struct{}{}
|
||||||
|
list = append(list, p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
data[username] = list
|
||||||
|
body, err := json.Marshal(data)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("marshal principals: %w", err)
|
||||||
|
}
|
||||||
|
dir := filepath.Dir(path)
|
||||||
|
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||||
|
return fmt.Errorf("mkdir %s: %w", dir, err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(path, body, 0644); err != nil {
|
||||||
|
return fmt.Errorf("write principals: %w", err)
|
||||||
|
}
|
||||||
|
logger.Debug("auth-daemon: wrote principals to %s", path)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// sudoGroup returns the name of the sudo group (wheel or sudo) that exists on the system. Prefers wheel.
|
||||||
|
func sudoGroup() string {
|
||||||
|
f, err := os.Open("/etc/group")
|
||||||
|
if err != nil {
|
||||||
|
return "sudo"
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
sc := bufio.NewScanner(f)
|
||||||
|
hasWheel := false
|
||||||
|
hasSudo := false
|
||||||
|
for sc.Scan() {
|
||||||
|
line := sc.Text()
|
||||||
|
if strings.HasPrefix(line, "wheel:") {
|
||||||
|
hasWheel = true
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(line, "sudo:") {
|
||||||
|
hasSudo = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if hasWheel {
|
||||||
|
return "wheel"
|
||||||
|
}
|
||||||
|
if hasSudo {
|
||||||
|
return "sudo"
|
||||||
|
}
|
||||||
|
return "sudo"
|
||||||
|
}
|
||||||
|
|
||||||
|
// ensureUser creates the system user if missing, or reconciles sudo and homedir to match meta.
|
||||||
|
func ensureUser(username string, meta ConnectionMetadata) error {
|
||||||
|
if username == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
u, err := user.Lookup(username)
|
||||||
|
if err != nil {
|
||||||
|
if _, ok := err.(user.UnknownUserError); !ok {
|
||||||
|
return fmt.Errorf("lookup user %s: %w", username, err)
|
||||||
|
}
|
||||||
|
return createUser(username, meta)
|
||||||
|
}
|
||||||
|
return reconcileUser(u, meta)
|
||||||
|
}
|
||||||
|
|
||||||
|
func createUser(username string, meta ConnectionMetadata) error {
|
||||||
|
args := []string{}
|
||||||
|
if meta.Homedir {
|
||||||
|
args = append(args, "-m")
|
||||||
|
} else {
|
||||||
|
args = append(args, "-M")
|
||||||
|
}
|
||||||
|
args = append(args, username)
|
||||||
|
cmd := exec.Command("useradd", args...)
|
||||||
|
if out, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
return fmt.Errorf("useradd %s: %w (output: %s)", username, err, string(out))
|
||||||
|
}
|
||||||
|
logger.Info("auth-daemon: created user %s (homedir=%v)", username, meta.Homedir)
|
||||||
|
if meta.Sudo {
|
||||||
|
group := sudoGroup()
|
||||||
|
cmd := exec.Command("usermod", "-aG", group, username)
|
||||||
|
if out, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
logger.Warn("auth-daemon: usermod -aG %s %s: %v (output: %s)", group, username, err, string(out))
|
||||||
|
} else {
|
||||||
|
logger.Info("auth-daemon: added %s to %s", username, group)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustAtoi(s string) int {
|
||||||
|
n, _ := strconv.Atoi(s)
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
func reconcileUser(u *user.User, meta ConnectionMetadata) error {
|
||||||
|
group := sudoGroup()
|
||||||
|
inGroup, err := userInGroup(u.Username, group)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("auth-daemon: check group %s: %v", group, err)
|
||||||
|
inGroup = false
|
||||||
|
}
|
||||||
|
if meta.Sudo && !inGroup {
|
||||||
|
cmd := exec.Command("usermod", "-aG", group, u.Username)
|
||||||
|
if out, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
logger.Warn("auth-daemon: usermod -aG %s %s: %v (output: %s)", group, u.Username, err, string(out))
|
||||||
|
} else {
|
||||||
|
logger.Info("auth-daemon: added %s to %s", u.Username, group)
|
||||||
|
}
|
||||||
|
} else if !meta.Sudo && inGroup {
|
||||||
|
cmd := exec.Command("gpasswd", "-d", u.Username, group)
|
||||||
|
if out, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
logger.Warn("auth-daemon: gpasswd -d %s %s: %v (output: %s)", u.Username, group, err, string(out))
|
||||||
|
} else {
|
||||||
|
logger.Info("auth-daemon: removed %s from %s", u.Username, group)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if meta.Homedir && u.HomeDir != "" {
|
||||||
|
if st, err := os.Stat(u.HomeDir); err != nil || !st.IsDir() {
|
||||||
|
if err := os.MkdirAll(u.HomeDir, 0755); err != nil {
|
||||||
|
logger.Warn("auth-daemon: mkdir %s: %v", u.HomeDir, err)
|
||||||
|
} else {
|
||||||
|
uid, gid := mustAtoi(u.Uid), mustAtoi(u.Gid)
|
||||||
|
_ = os.Chown(u.HomeDir, uid, gid)
|
||||||
|
logger.Info("auth-daemon: created home %s for %s", u.HomeDir, u.Username)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func userInGroup(username, groupName string) (bool, error) {
|
||||||
|
// getent group wheel returns "wheel:x:10:user1,user2"
|
||||||
|
cmd := exec.Command("getent", "group", groupName)
|
||||||
|
out, err := cmd.Output()
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
parts := strings.SplitN(strings.TrimSpace(string(out)), ":", 4)
|
||||||
|
if len(parts) < 4 {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
members := strings.Split(parts[3], ",")
|
||||||
|
for _, m := range members {
|
||||||
|
if strings.TrimSpace(m) == username {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
32
authdaemon/host_stub.go
Normal file
32
authdaemon/host_stub.go
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
//go:build !linux
|
||||||
|
|
||||||
|
package authdaemon
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
var errLinuxOnly = fmt.Errorf("auth-daemon PAM agent is only supported on Linux")
|
||||||
|
|
||||||
|
// writeCACertIfNotExists returns an error on non-Linux.
|
||||||
|
func writeCACertIfNotExists(path, contents string) error {
|
||||||
|
return errLinuxOnly
|
||||||
|
}
|
||||||
|
|
||||||
|
// ensureSSHDTrustedUserCAKeys returns an error on non-Linux.
|
||||||
|
func ensureSSHDTrustedUserCAKeys(sshdConfigPath, caCertPath string) error {
|
||||||
|
return errLinuxOnly
|
||||||
|
}
|
||||||
|
|
||||||
|
// reloadSSHD returns an error on non-Linux.
|
||||||
|
func reloadSSHD(reloadCmd string) error {
|
||||||
|
return errLinuxOnly
|
||||||
|
}
|
||||||
|
|
||||||
|
// ensureUser returns an error on non-Linux.
|
||||||
|
func ensureUser(username string, meta ConnectionMetadata) error {
|
||||||
|
return errLinuxOnly
|
||||||
|
}
|
||||||
|
|
||||||
|
// writePrincipals returns an error on non-Linux.
|
||||||
|
func writePrincipals(path, username, niceId string) error {
|
||||||
|
return errLinuxOnly
|
||||||
|
}
|
||||||
28
authdaemon/principals.go
Normal file
28
authdaemon/principals.go
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
package authdaemon
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetPrincipals reads the principals data file at path, looks up the given user, and returns that user's principals as a string slice.
|
||||||
|
// The file format is JSON: object with username keys and array-of-principals values, e.g. {"alice":["alice","usr-123"],"bob":["bob","usr-456"]}.
|
||||||
|
// If the user is not found or the file is missing, returns nil and nil.
|
||||||
|
func GetPrincipals(path, user string) ([]string, error) {
|
||||||
|
if path == "" {
|
||||||
|
return nil, fmt.Errorf("principals file path is required")
|
||||||
|
}
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("read principals file: %w", err)
|
||||||
|
}
|
||||||
|
var m map[string][]string
|
||||||
|
if err := json.Unmarshal(data, &m); err != nil {
|
||||||
|
return nil, fmt.Errorf("parse principals file: %w", err)
|
||||||
|
}
|
||||||
|
return m[user], nil
|
||||||
|
}
|
||||||
92
authdaemon/routes.go
Normal file
92
authdaemon/routes.go
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
package authdaemon
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/fosrl/newt/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// registerRoutes registers all API routes. Add new endpoints here.
|
||||||
|
func (s *Server) registerRoutes() {
|
||||||
|
s.mux.HandleFunc("/health", s.handleHealth)
|
||||||
|
s.mux.HandleFunc("/connection", s.handleConnection)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConnectionMetadata is the metadata object in POST /connection.
|
||||||
|
type ConnectionMetadata struct {
|
||||||
|
Sudo bool `json:"sudo"`
|
||||||
|
Homedir bool `json:"homedir"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConnectionRequest is the JSON body for POST /connection.
|
||||||
|
type ConnectionRequest struct {
|
||||||
|
CaCert string `json:"caCert"`
|
||||||
|
NiceId string `json:"niceId"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
Metadata ConnectionMetadata `json:"metadata"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// healthResponse is the JSON body for GET /health.
|
||||||
|
type healthResponse struct {
|
||||||
|
Status string `json:"status"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleHealth responds with 200 and {"status":"ok"}.
|
||||||
|
func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodGet {
|
||||||
|
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(w).Encode(healthResponse{Status: "ok"})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProcessConnection runs the same logic as POST /connection: CA cert, sshd config, user create/reconcile, principals.
|
||||||
|
// Use this when DisableHTTPS is true (e.g. embedded in Newt) instead of calling the API.
|
||||||
|
func (s *Server) ProcessConnection(req ConnectionRequest) {
|
||||||
|
logger.Info("connection: niceId=%q username=%q metadata.sudo=%v metadata.homedir=%v",
|
||||||
|
req.NiceId, req.Username, req.Metadata.Sudo, req.Metadata.Homedir)
|
||||||
|
|
||||||
|
cfg := &s.cfg
|
||||||
|
if cfg.CACertPath != "" {
|
||||||
|
if err := writeCACertIfNotExists(cfg.CACertPath, req.CaCert); err != nil {
|
||||||
|
logger.Warn("auth-daemon: write CA cert: %v", err)
|
||||||
|
}
|
||||||
|
sshdConfig := cfg.SSHDConfigPath
|
||||||
|
if sshdConfig == "" {
|
||||||
|
sshdConfig = "/etc/ssh/sshd_config"
|
||||||
|
}
|
||||||
|
if err := ensureSSHDTrustedUserCAKeys(sshdConfig, cfg.CACertPath); err != nil {
|
||||||
|
logger.Warn("auth-daemon: sshd_config: %v", err)
|
||||||
|
}
|
||||||
|
if cfg.ReloadSSHCommand != "" {
|
||||||
|
if err := reloadSSHD(cfg.ReloadSSHCommand); err != nil {
|
||||||
|
logger.Warn("auth-daemon: reload sshd: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := ensureUser(req.Username, req.Metadata); err != nil {
|
||||||
|
logger.Warn("auth-daemon: ensure user: %v", err)
|
||||||
|
}
|
||||||
|
if cfg.PrincipalsFilePath != "" {
|
||||||
|
if err := writePrincipals(cfg.PrincipalsFilePath, req.Username, req.NiceId); err != nil {
|
||||||
|
logger.Warn("auth-daemon: write principals: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleConnection accepts POST with connection payload and delegates to ProcessConnection.
|
||||||
|
func (s *Server) handleConnection(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var req ConnectionRequest
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
http.Error(w, "Bad Request", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.ProcessConnection(req)
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}
|
||||||
174
authdaemon/server.go
Normal file
174
authdaemon/server.go
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
package authdaemon
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/elliptic"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/subtle"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"crypto/x509/pkix"
|
||||||
|
"encoding/pem"
|
||||||
|
"fmt"
|
||||||
|
"math/big"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/fosrl/newt/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
// DisableHTTPS: when true, Run() does not start the HTTPS server (for embedded use inside Newt). Call ProcessConnection directly for connection events.
|
||||||
|
DisableHTTPS bool
|
||||||
|
Port int // Listen port for the HTTPS server. Required when DisableHTTPS is false.
|
||||||
|
PresharedKey string // Required when DisableHTTPS is false; used for HTTP auth (Authorization: Bearer <key> or X-Preshared-Key: <key>).
|
||||||
|
CACertPath string // Where to write the CA cert (e.g. /etc/ssh/ca.pem).
|
||||||
|
SSHDConfigPath string // Path to sshd_config (e.g. /etc/ssh/sshd_config). Defaults to /etc/ssh/sshd_config when CACertPath is set.
|
||||||
|
ReloadSSHCommand string // Command to reload sshd after config change (e.g. "systemctl reload sshd"). Empty = no reload.
|
||||||
|
PrincipalsFilePath string // Path to the principals data file (JSON: username -> array of principals). Empty = do not store principals.
|
||||||
|
}
|
||||||
|
|
||||||
|
type Server struct {
|
||||||
|
cfg Config
|
||||||
|
addr string
|
||||||
|
presharedKey string
|
||||||
|
mux *http.ServeMux
|
||||||
|
tlsCert tls.Certificate
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateTLSCert creates a self-signed certificate and key in memory (no disk).
|
||||||
|
func generateTLSCert() (tls.Certificate, error) {
|
||||||
|
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
return tls.Certificate{}, fmt.Errorf("generate key: %w", err)
|
||||||
|
}
|
||||||
|
serial, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
|
||||||
|
if err != nil {
|
||||||
|
return tls.Certificate{}, fmt.Errorf("serial: %w", err)
|
||||||
|
}
|
||||||
|
tmpl := &x509.Certificate{
|
||||||
|
SerialNumber: serial,
|
||||||
|
Subject: pkix.Name{
|
||||||
|
CommonName: "localhost",
|
||||||
|
},
|
||||||
|
NotBefore: time.Now(),
|
||||||
|
NotAfter: time.Now().Add(365 * 24 * time.Hour),
|
||||||
|
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||||
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||||
|
BasicConstraintsValid: true,
|
||||||
|
DNSNames: []string{"localhost", "127.0.0.1"},
|
||||||
|
}
|
||||||
|
certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, key.Public(), key)
|
||||||
|
if err != nil {
|
||||||
|
return tls.Certificate{}, fmt.Errorf("create certificate: %w", err)
|
||||||
|
}
|
||||||
|
keyDER, err := x509.MarshalECPrivateKey(key)
|
||||||
|
if err != nil {
|
||||||
|
return tls.Certificate{}, fmt.Errorf("marshal key: %w", err)
|
||||||
|
}
|
||||||
|
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
|
||||||
|
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
|
||||||
|
cert, err := tls.X509KeyPair(certPEM, keyPEM)
|
||||||
|
if err != nil {
|
||||||
|
return tls.Certificate{}, fmt.Errorf("x509 key pair: %w", err)
|
||||||
|
}
|
||||||
|
return cert, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// authMiddleware wraps next and requires a valid preshared key on every request.
|
||||||
|
// Accepts Authorization: Bearer <key> or X-Preshared-Key: <key>.
|
||||||
|
func (s *Server) authMiddleware(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
key := ""
|
||||||
|
if v := r.Header.Get("Authorization"); strings.HasPrefix(v, "Bearer ") {
|
||||||
|
key = strings.TrimSpace(strings.TrimPrefix(v, "Bearer "))
|
||||||
|
}
|
||||||
|
if key == "" {
|
||||||
|
key = strings.TrimSpace(r.Header.Get("X-Preshared-Key"))
|
||||||
|
}
|
||||||
|
if key == "" || subtle.ConstantTimeCompare([]byte(key), []byte(s.presharedKey)) != 1 {
|
||||||
|
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewServer builds a new auth-daemon server from cfg. When DisableHTTPS is false, PresharedKey and Port are required.
|
||||||
|
func NewServer(cfg Config) (*Server, error) {
|
||||||
|
if runtime.GOOS != "linux" {
|
||||||
|
return nil, fmt.Errorf("auth-daemon is only supported on Linux, not %s", runtime.GOOS)
|
||||||
|
}
|
||||||
|
if !cfg.DisableHTTPS {
|
||||||
|
if cfg.PresharedKey == "" {
|
||||||
|
return nil, fmt.Errorf("preshared key is required when HTTPS is enabled")
|
||||||
|
}
|
||||||
|
if cfg.Port <= 0 {
|
||||||
|
return nil, fmt.Errorf("port must be positive when HTTPS is enabled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s := &Server{cfg: cfg}
|
||||||
|
if !cfg.DisableHTTPS {
|
||||||
|
cert, err := generateTLSCert()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
s.addr = fmt.Sprintf(":%d", cfg.Port)
|
||||||
|
s.presharedKey = cfg.PresharedKey
|
||||||
|
s.mux = http.NewServeMux()
|
||||||
|
s.tlsCert = cert
|
||||||
|
s.registerRoutes()
|
||||||
|
}
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run starts the HTTPS server (unless DisableHTTPS) and blocks until ctx is cancelled or the server errors.
|
||||||
|
// When DisableHTTPS is true, Run() blocks on ctx only and does not listen; use ProcessConnection for connection events.
|
||||||
|
func (s *Server) Run(ctx context.Context) error {
|
||||||
|
if s.cfg.DisableHTTPS {
|
||||||
|
logger.Info("auth-daemon running (HTTPS disabled)")
|
||||||
|
<-ctx.Done()
|
||||||
|
s.cleanupPrincipalsFile()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
tcfg := &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{s.tlsCert},
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
|
}
|
||||||
|
handler := s.authMiddleware(s.mux)
|
||||||
|
srv := &http.Server{
|
||||||
|
Addr: s.addr,
|
||||||
|
Handler: handler,
|
||||||
|
TLSConfig: tcfg,
|
||||||
|
ReadTimeout: 10 * time.Second,
|
||||||
|
WriteTimeout: 10 * time.Second,
|
||||||
|
ReadHeaderTimeout: 5 * time.Second,
|
||||||
|
IdleTimeout: 60 * time.Second,
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
<-ctx.Done()
|
||||||
|
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if err := srv.Shutdown(shutdownCtx); err != nil {
|
||||||
|
logger.Warn("auth-daemon shutdown: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
logger.Info("auth-daemon listening on https://127.0.0.1%s", s.addr)
|
||||||
|
if err := srv.ListenAndServeTLS("", ""); err != nil && err != http.ErrServerClosed {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.cleanupPrincipalsFile()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) cleanupPrincipalsFile() {
|
||||||
|
if s.cfg.PrincipalsFilePath != "" {
|
||||||
|
if err := os.Remove(s.cfg.PrincipalsFilePath); err != nil && !os.IsNotExist(err) {
|
||||||
|
logger.Warn("auth-daemon: remove principals file: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user