From 6820f8d23e1e6b7cce939cdb56f8040c65ecded4 Mon Sep 17 00:00:00 2001 From: Owen Date: Sun, 3 Aug 2025 16:12:00 -0700 Subject: [PATCH] Add basic heathchecks --- healthcheck/healthcheck.go | 352 +++++++++++++++++++++++++++++++++++++ main.go | 152 ++++++++++++++++ 2 files changed, 504 insertions(+) create mode 100644 healthcheck/healthcheck.go diff --git a/healthcheck/healthcheck.go b/healthcheck/healthcheck.go new file mode 100644 index 0000000..dafa19b --- /dev/null +++ b/healthcheck/healthcheck.go @@ -0,0 +1,352 @@ +package healthcheck + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strings" + "sync" + "time" +) + +// Status represents the health status of a target +type Status int + +const ( + StatusUnknown Status = iota + StatusHealthy + StatusUnhealthy +) + +func (s Status) String() string { + switch s { + case StatusHealthy: + return "healthy" + case StatusUnhealthy: + return "unhealthy" + default: + return "unknown" + } +} + +// Config holds the health check configuration for a target +type Config struct { + ID string `json:"id"` + Enabled bool `json:"hcEnabled"` + Path string `json:"hcPath"` + Scheme string `json:"hcScheme"` + Mode string `json:"hcMode"` + Hostname string `json:"hcHostname"` + Port int `json:"hcPort"` + Interval int `json:"hcInterval"` // in seconds + UnhealthyInterval int `json:"hcUnhealthyInterval"` // in seconds + Timeout int `json:"hcTimeout"` // in seconds + Headers map[string]string `json:"hcHeaders"` + Method string `json:"hcMethod"` +} + +// Target represents a health check target with its current status +type Target struct { + Config Config `json:"config"` + Status Status `json:"status"` + LastCheck time.Time `json:"lastCheck"` + LastError string `json:"lastError,omitempty"` + CheckCount int `json:"checkCount"` + ticker *time.Ticker + ctx context.Context + cancel context.CancelFunc +} + +// StatusChangeCallback is called when any target's status changes +type StatusChangeCallback func(targets map[string]*Target) + +// Monitor manages health check targets and their monitoring +type Monitor struct { + targets map[string]*Target + mutex sync.RWMutex + callback StatusChangeCallback + client *http.Client +} + +// NewMonitor creates a new health check monitor +func NewMonitor(callback StatusChangeCallback) *Monitor { + return &Monitor{ + targets: make(map[string]*Target), + callback: callback, + client: &http.Client{ + Timeout: 30 * time.Second, + }, + } +} + +// parseHeaders parses the headers string into a map +func parseHeaders(headersStr string) map[string]string { + headers := make(map[string]string) + if headersStr == "" { + return headers + } + + // Try to parse as JSON first + if err := json.Unmarshal([]byte(headersStr), &headers); err == nil { + return headers + } + + // Fallback to simple key:value parsing + pairs := strings.Split(headersStr, ",") + for _, pair := range pairs { + kv := strings.SplitN(strings.TrimSpace(pair), ":", 2) + if len(kv) == 2 { + headers[strings.TrimSpace(kv[0])] = strings.TrimSpace(kv[1]) + } + } + return headers +} + +// AddTarget adds a new health check target +func (m *Monitor) AddTarget(config Config) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + // Set defaults + if config.Scheme == "" { + config.Scheme = "http" + } + if config.Mode == "" { + config.Mode = "http" + } + if config.Method == "" { + config.Method = "GET" + } + if config.Interval == 0 { + config.Interval = 30 + } + if config.UnhealthyInterval == 0 { + config.UnhealthyInterval = 30 + } + if config.Timeout == 0 { + config.Timeout = 5 + } + + // Parse headers if provided as string + if len(config.Headers) == 0 && config.Path != "" { + // This is a simplified header parsing - in real use you might want more robust parsing + config.Headers = make(map[string]string) + } + + // Remove existing target if it exists + if existing, exists := m.targets[config.ID]; exists { + existing.cancel() + } + + // Create new target + ctx, cancel := context.WithCancel(context.Background()) + target := &Target{ + Config: config, + Status: StatusUnknown, + ctx: ctx, + cancel: cancel, + } + + m.targets[config.ID] = target + + // Start monitoring if enabled + if config.Enabled { + go m.monitorTarget(target) + } + + return nil +} + +// RemoveTarget removes a health check target +func (m *Monitor) RemoveTarget(id string) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + target, exists := m.targets[id] + if !exists { + return fmt.Errorf("target with id %s not found", id) + } + + target.cancel() + delete(m.targets, id) + + // Notify callback of status change + if m.callback != nil { + go m.callback(m.getAllTargets()) + } + + return nil +} + +// GetTargets returns a copy of all targets +func (m *Monitor) GetTargets() map[string]*Target { + return m.getAllTargets() +} + +// getAllTargets returns a copy of all targets (internal method) +func (m *Monitor) getAllTargets() map[string]*Target { + m.mutex.RLock() + defer m.mutex.RUnlock() + + targets := make(map[string]*Target) + for id, target := range m.targets { + // Create a copy to avoid race conditions + targetCopy := *target + targets[id] = &targetCopy + } + return targets +} + +// monitorTarget monitors a single target +func (m *Monitor) monitorTarget(target *Target) { + // Initial check + m.performHealthCheck(target) + + // Set up ticker based on current status + interval := time.Duration(target.Config.Interval) * time.Second + if target.Status == StatusUnhealthy { + interval = time.Duration(target.Config.UnhealthyInterval) * time.Second + } + + target.ticker = time.NewTicker(interval) + defer target.ticker.Stop() + + for { + select { + case <-target.ctx.Done(): + return + case <-target.ticker.C: + oldStatus := target.Status + m.performHealthCheck(target) + + // Update ticker interval if status changed + newInterval := time.Duration(target.Config.Interval) * time.Second + if target.Status == StatusUnhealthy { + newInterval = time.Duration(target.Config.UnhealthyInterval) * time.Second + } + + if newInterval != interval { + target.ticker.Stop() + target.ticker = time.NewTicker(newInterval) + interval = newInterval + } + + // Notify callback if status changed + if oldStatus != target.Status && m.callback != nil { + go m.callback(m.getAllTargets()) + } + } + } +} + +// performHealthCheck performs a health check on a target +func (m *Monitor) performHealthCheck(target *Target) { + target.CheckCount++ + target.LastCheck = time.Now() + target.LastError = "" + + // Build URL + url := fmt.Sprintf("%s://%s", target.Config.Scheme, target.Config.Hostname) + if target.Config.Port > 0 { + url = fmt.Sprintf("%s:%d", url, target.Config.Port) + } + if target.Config.Path != "" { + if !strings.HasPrefix(target.Config.Path, "/") { + url += "/" + } + url += target.Config.Path + } + + // Create request + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(target.Config.Timeout)*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, target.Config.Method, url, nil) + if err != nil { + target.Status = StatusUnhealthy + target.LastError = fmt.Sprintf("failed to create request: %v", err) + return + } + + // Add headers + for key, value := range target.Config.Headers { + req.Header.Set(key, value) + } + + // Perform request + resp, err := m.client.Do(req) + if err != nil { + target.Status = StatusUnhealthy + target.LastError = fmt.Sprintf("request failed: %v", err) + return + } + defer resp.Body.Close() + + // Check response status + if resp.StatusCode >= 200 && resp.StatusCode < 300 { + target.Status = StatusHealthy + } else { + target.Status = StatusUnhealthy + target.LastError = fmt.Sprintf("unhealthy status code: %d", resp.StatusCode) + } +} + +// Stop stops monitoring all targets +func (m *Monitor) Stop() { + m.mutex.Lock() + defer m.mutex.Unlock() + + for _, target := range m.targets { + target.cancel() + } + m.targets = make(map[string]*Target) +} + +// EnableTarget enables monitoring for a specific target +func (m *Monitor) EnableTarget(id string) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + target, exists := m.targets[id] + if !exists { + return fmt.Errorf("target with id %s not found", id) + } + + if !target.Config.Enabled { + target.Config.Enabled = true + target.cancel() // Stop existing monitoring + + ctx, cancel := context.WithCancel(context.Background()) + target.ctx = ctx + target.cancel = cancel + + go m.monitorTarget(target) + } + + return nil +} + +// DisableTarget disables monitoring for a specific target +func (m *Monitor) DisableTarget(id string) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + target, exists := m.targets[id] + if !exists { + return fmt.Errorf("target with id %s not found", id) + } + + if target.Config.Enabled { + target.Config.Enabled = false + target.cancel() + target.Status = StatusUnknown + + // Notify callback of status change + if m.callback != nil { + go m.callback(m.getAllTargets()) + } + } + + return nil +} diff --git a/main.go b/main.go index 483aa23..e8e8896 100644 --- a/main.go +++ b/main.go @@ -16,6 +16,7 @@ import ( "time" "github.com/fosrl/newt/docker" + "github.com/fosrl/newt/healthcheck" "github.com/fosrl/newt/logger" "github.com/fosrl/newt/proxy" "github.com/fosrl/newt/updates" @@ -99,6 +100,7 @@ var ( healthFile string useNativeInterface bool authorizedKeysFile string + healthMonitor *healthcheck.Monitor ) func main() { @@ -895,6 +897,152 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub logger.Info("SSH public key appended to authorized keys file") }) + // Initialize health check monitor with status change callback + healthMonitor = healthcheck.NewMonitor(func(targets map[string]*healthcheck.Target) { + logger.Debug("Health check status update for %d targets", len(targets)) + + // Send health status update to the server + healthStatuses := make(map[string]interface{}) + for id, target := range targets { + healthStatuses[id] = map[string]interface{}{ + "status": target.Status.String(), + "lastCheck": target.LastCheck.Format(time.RFC3339), + "checkCount": target.CheckCount, + "lastError": target.LastError, + "config": target.Config, + } + } + + err := client.SendMessage("newt/healthcheck/status", map[string]interface{}{ + "targets": healthStatuses, + }) + if err != nil { + logger.Error("Failed to send health check status update: %v", err) + } + }) + + // Register handler for adding health check targets + client.RegisterHandler("newt/healthcheck/add", func(msg websocket.WSMessage) { + logger.Debug("Received health check add request: %+v", msg) + + var config healthcheck.Config + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling health check data: %v", err) + return + } + + if err := json.Unmarshal(jsonData, &config); err != nil { + logger.Error("Error unmarshaling health check config: %v", err) + return + } + + if err := healthMonitor.AddTarget(config); err != nil { + logger.Error("Failed to add health check target %s: %v", config.ID, err) + } else { + logger.Info("Added health check target: %s", config.ID) + } + }) + + // Register handler for removing health check targets + client.RegisterHandler("newt/healthcheck/remove", func(msg websocket.WSMessage) { + logger.Debug("Received health check remove request: %+v", msg) + + var requestData struct { + ID string `json:"id"` + } + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling health check remove data: %v", err) + return + } + + if err := json.Unmarshal(jsonData, &requestData); err != nil { + logger.Error("Error unmarshaling health check remove request: %v", err) + return + } + + if err := healthMonitor.RemoveTarget(requestData.ID); err != nil { + logger.Error("Failed to remove health check target %s: %v", requestData.ID, err) + } else { + logger.Info("Removed health check target: %s", requestData.ID) + } + }) + + // Register handler for enabling health check targets + client.RegisterHandler("newt/healthcheck/enable", func(msg websocket.WSMessage) { + logger.Debug("Received health check enable request: %+v", msg) + + var requestData struct { + ID string `json:"id"` + } + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling health check enable data: %v", err) + return + } + + if err := json.Unmarshal(jsonData, &requestData); err != nil { + logger.Error("Error unmarshaling health check enable request: %v", err) + return + } + + if err := healthMonitor.EnableTarget(requestData.ID); err != nil { + logger.Error("Failed to enable health check target %s: %v", requestData.ID, err) + } else { + logger.Info("Enabled health check target: %s", requestData.ID) + } + }) + + // Register handler for disabling health check targets + client.RegisterHandler("newt/healthcheck/disable", func(msg websocket.WSMessage) { + logger.Debug("Received health check disable request: %+v", msg) + + var requestData struct { + ID string `json:"id"` + } + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling health check disable data: %v", err) + return + } + + if err := json.Unmarshal(jsonData, &requestData); err != nil { + logger.Error("Error unmarshaling health check disable request: %v", err) + return + } + + if err := healthMonitor.DisableTarget(requestData.ID); err != nil { + logger.Error("Failed to disable health check target %s: %v", requestData.ID, err) + } else { + logger.Info("Disabled health check target: %s", requestData.ID) + } + }) + + // Register handler for getting health check status + client.RegisterHandler("newt/healthcheck/status/request", func(msg websocket.WSMessage) { + logger.Debug("Received health check status request") + + targets := healthMonitor.GetTargets() + healthStatuses := make(map[string]interface{}) + for id, target := range targets { + healthStatuses[id] = map[string]interface{}{ + "status": target.Status.String(), + "lastCheck": target.LastCheck.Format(time.RFC3339), + "checkCount": target.CheckCount, + "lastError": target.LastError, + "config": target.Config, + } + } + + err := client.SendMessage("newt/healthcheck/status", map[string]interface{}{ + "targets": healthStatuses, + }) + if err != nil { + logger.Error("Failed to send health check status response: %v", err) + } + }) + client.OnConnect(func() error { publicKey = privateKey.PublicKey() logger.Debug("Public key: %s", publicKey) @@ -936,6 +1084,10 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub // Close clients first (including WGTester) closeClients() + if healthMonitor != nil { + healthMonitor.Stop() + } + if dev != nil { dev.Close() }