From e8612c7e6bb82722e58f91f60bc9cc4ebeb2cf45 Mon Sep 17 00:00:00 2001 From: Owen Date: Sun, 3 Aug 2025 17:02:15 -0700 Subject: [PATCH] Handle adding and removing healthchecks --- healthcheck/healthcheck.go | 79 ++++++++++++++++++++++++++++++++++---- main.go | 45 +++++++++++++++------- 2 files changed, 102 insertions(+), 22 deletions(-) diff --git a/healthcheck/healthcheck.go b/healthcheck/healthcheck.go index dafa19b..092025d 100644 --- a/healthcheck/healthcheck.go +++ b/healthcheck/healthcheck.go @@ -108,6 +108,30 @@ func (m *Monitor) AddTarget(config Config) error { m.mutex.Lock() defer m.mutex.Unlock() + return m.addTargetUnsafe(config) +} + +// AddTargets adds multiple health check targets in bulk +func (m *Monitor) AddTargets(configs []Config) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + for _, config := range configs { + if err := m.addTargetUnsafe(config); err != nil { + return fmt.Errorf("failed to add target %s: %v", config.ID, err) + } + } + + // Notify callback once after all targets are added + if m.callback != nil { + go m.callback(m.getAllTargetsUnsafe()) + } + + return nil +} + +// addTargetUnsafe adds a target without acquiring the mutex (internal method) +func (m *Monitor) addTargetUnsafe(config Config) error { // Set defaults if config.Scheme == "" { config.Scheme = "http" @@ -173,22 +197,56 @@ func (m *Monitor) RemoveTarget(id string) error { // Notify callback of status change if m.callback != nil { - go m.callback(m.getAllTargets()) + go m.callback(m.GetTargets()) } return nil } -// GetTargets returns a copy of all targets -func (m *Monitor) GetTargets() map[string]*Target { - return m.getAllTargets() +// RemoveTargets removes multiple health check targets +func (m *Monitor) RemoveTargets(ids []string) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + var notFound []string + + for _, id := range ids { + target, exists := m.targets[id] + if !exists { + notFound = append(notFound, id) + continue + } + + target.cancel() + delete(m.targets, id) + } + + // Notify callback of status change if any targets were removed + if len(notFound) != len(ids) && m.callback != nil { + go m.callback(m.GetTargets()) + } + + if len(notFound) > 0 { + return fmt.Errorf("targets not found: %v", notFound) + } + + return nil } -// getAllTargets returns a copy of all targets (internal method) -func (m *Monitor) getAllTargets() map[string]*Target { +// RemoveTargetsByID is a convenience method that accepts either a single ID or multiple IDs +func (m *Monitor) RemoveTargetsByID(ids ...string) error { + return m.RemoveTargets(ids) +} + +// GetTargets returns a copy of all targets +func (m *Monitor) GetTargets() map[string]*Target { m.mutex.RLock() defer m.mutex.RUnlock() + return m.getAllTargetsUnsafe() +} +// getAllTargetsUnsafe returns a copy of all targets without acquiring the mutex (internal method) +func (m *Monitor) getAllTargetsUnsafe() map[string]*Target { targets := make(map[string]*Target) for id, target := range m.targets { // Create a copy to avoid race conditions @@ -198,6 +256,11 @@ func (m *Monitor) getAllTargets() map[string]*Target { return targets } +// getAllTargets returns a copy of all targets (deprecated, use GetTargets) +func (m *Monitor) getAllTargets() map[string]*Target { + return m.GetTargets() +} + // monitorTarget monitors a single target func (m *Monitor) monitorTarget(target *Target) { // Initial check @@ -234,7 +297,7 @@ func (m *Monitor) monitorTarget(target *Target) { // Notify callback if status changed if oldStatus != target.Status && m.callback != nil { - go m.callback(m.getAllTargets()) + go m.callback(m.GetTargets()) } } } @@ -344,7 +407,7 @@ func (m *Monitor) DisableTarget(id string) error { // Notify callback of status change if m.callback != nil { - go m.callback(m.getAllTargets()) + go m.callback(m.GetTargets()) } } diff --git a/main.go b/main.go index e8e8896..440d802 100644 --- a/main.go +++ b/main.go @@ -30,11 +30,12 @@ import ( ) type WgData struct { - Endpoint string `json:"endpoint"` - PublicKey string `json:"publicKey"` - ServerIP string `json:"serverIP"` - TunnelIP string `json:"tunnelIP"` - Targets TargetsByType `json:"targets"` + Endpoint string `json:"endpoint"` + PublicKey string `json:"publicKey"` + ServerIP string `json:"serverIP"` + TunnelIP string `json:"tunnelIP"` + Targets TargetsByType `json:"targets"` + HealthCheckTargets []healthcheck.Config `json:"healthCheckTargets"` } type TargetsByType struct { @@ -449,6 +450,12 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub clientsAddProxyTarget(pm, wgData.TunnelIP) + if err := healthMonitor.AddTargets(wgData.HealthCheckTargets); err != nil { + logger.Error("Failed to bulk add health check targets: %v", err) + } else { + logger.Info("Successfully added %d health check targets", len(wgData.HealthCheckTargets)) + } + err = pm.Start() if err != nil { logger.Error("Failed to start proxy manager: %v", err) @@ -925,7 +932,12 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub client.RegisterHandler("newt/healthcheck/add", func(msg websocket.WSMessage) { logger.Debug("Received health check add request: %+v", msg) - var config healthcheck.Config + type HealthCheckConfig struct { + Targets []healthcheck.Config `json:"targets"` + } + + var config HealthCheckConfig + // add a bunch of targets at once jsonData, err := json.Marshal(msg.Data) if err != nil { logger.Error("Error marshaling health check data: %v", err) @@ -937,20 +949,24 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub return } - if err := healthMonitor.AddTarget(config); err != nil { - logger.Error("Failed to add health check target %s: %v", config.ID, err) + if err := healthMonitor.AddTargets(config.Targets); err != nil { + logger.Error("Failed to add health check targets: %v", err) } else { - logger.Info("Added health check target: %s", config.ID) + logger.Info("Added %d health check targets", len(config.Targets)) } + + logger.Debug("Health check targets added: %+v", config.Targets) }) // Register handler for removing health check targets client.RegisterHandler("newt/healthcheck/remove", func(msg websocket.WSMessage) { logger.Debug("Received health check remove request: %+v", msg) - var requestData struct { - ID string `json:"id"` + type HealthCheckConfig struct { + IDs []string `json:"ids"` } + + var requestData HealthCheckConfig jsonData, err := json.Marshal(msg.Data) if err != nil { logger.Error("Error marshaling health check remove data: %v", err) @@ -962,10 +978,11 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub return } - if err := healthMonitor.RemoveTarget(requestData.ID); err != nil { - logger.Error("Failed to remove health check target %s: %v", requestData.ID, err) + // Multiple target removal + if err := healthMonitor.RemoveTargets(requestData.IDs); err != nil { + logger.Error("Failed to remove health check targets %v: %v", requestData.IDs, err) } else { - logger.Info("Removed health check target: %s", requestData.ID) + logger.Info("Removed %d health check targets: %v", len(requestData.IDs), requestData.IDs) } })