diff --git a/clients/clients.go b/clients/clients.go index b2dca47..71c421b 100644 --- a/clients/clients.go +++ b/clients/clients.go @@ -172,6 +172,7 @@ func NewWireGuardService(interfaceName string, port uint16, mtu int, host string wsClient.RegisterHandler("newt/wg/targets/add", service.handleAddTarget) wsClient.RegisterHandler("newt/wg/targets/remove", service.handleRemoveTarget) wsClient.RegisterHandler("newt/wg/targets/update", service.handleUpdateTarget) + wsClient.RegisterHandler("newt/wg/sync", service.handleSyncConfig) return service, nil } @@ -490,6 +491,183 @@ func (s *WireGuardService) handleConfig(msg websocket.WSMessage) { } } +// SyncConfig represents the configuration sent from server for syncing +type SyncConfig struct { + Targets []Target `json:"targets"` + Peers []Peer `json:"peers"` +} + +func (s *WireGuardService) handleSyncConfig(msg websocket.WSMessage) { + var syncConfig SyncConfig + + logger.Debug("Received sync message: %v", msg) + logger.Info("Received sync configuration from remote server") + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling sync data: %v", err) + return + } + + if err := json.Unmarshal(jsonData, &syncConfig); err != nil { + logger.Error("Error unmarshaling sync data: %v", err) + return + } + + // Sync peers + if err := s.syncPeers(syncConfig.Peers); err != nil { + logger.Error("Failed to sync peers: %v", err) + } + + // Sync targets + if err := s.syncTargets(syncConfig.Targets); err != nil { + logger.Error("Failed to sync targets: %v", err) + } +} + +// syncPeers synchronizes the current peers with the desired state +// It removes peers not in the desired list and adds missing ones +func (s *WireGuardService) syncPeers(desiredPeers []Peer) error { + if s.device == nil { + return fmt.Errorf("WireGuard device is not initialized") + } + + // Get current peers from the device + currentConfig, err := s.device.IpcGet() + if err != nil { + return fmt.Errorf("failed to get current device config: %v", err) + } + + // Parse current peer public keys + lines := strings.Split(currentConfig, "\n") + currentPeerKeys := make(map[string]bool) + for _, line := range lines { + if strings.HasPrefix(line, "public_key=") { + pubKey := strings.TrimPrefix(line, "public_key=") + currentPeerKeys[pubKey] = true + } + } + + // Build a map of desired peers by their public key (normalized) + desiredPeerMap := make(map[string]Peer) + for _, peer := range desiredPeers { + // Normalize the public key for comparison + pubKey, err := wgtypes.ParseKey(peer.PublicKey) + if err != nil { + logger.Warn("Invalid public key in desired peers: %s", peer.PublicKey) + continue + } + normalizedKey := util.FixKey(pubKey.String()) + desiredPeerMap[normalizedKey] = peer + } + + // Remove peers that are not in the desired list + for currentKey := range currentPeerKeys { + if _, exists := desiredPeerMap[currentKey]; !exists { + // Parse the key back to get the original format for removal + removeConfig := fmt.Sprintf("public_key=%s\nremove=true", currentKey) + if err := s.device.IpcSet(removeConfig); err != nil { + logger.Warn("Failed to remove peer %s during sync: %v", currentKey, err) + } else { + logger.Info("Removed peer %s during sync", currentKey) + } + } + } + + // Add peers that are missing + for normalizedKey, peer := range desiredPeerMap { + if _, exists := currentPeerKeys[normalizedKey]; !exists { + if err := s.addPeerToDevice(peer); err != nil { + logger.Warn("Failed to add peer %s during sync: %v", peer.PublicKey, err) + } else { + logger.Info("Added peer %s during sync", peer.PublicKey) + } + } + } + + return nil +} + +// syncTargets synchronizes the current targets with the desired state +// It removes targets not in the desired list and adds missing ones +func (s *WireGuardService) syncTargets(desiredTargets []Target) error { + if s.tnet == nil { + // Native interface mode - proxy features not available, skip silently + logger.Debug("Skipping target sync - using native interface (no proxy support)") + return nil + } + + // Get current rules from the proxy handler + currentRules := s.tnet.GetProxySubnetRules() + + // Build a map of current rules by source+dest prefix + type ruleKey struct { + sourcePrefix string + destPrefix string + } + currentRuleMap := make(map[ruleKey]bool) + for _, rule := range currentRules { + key := ruleKey{ + sourcePrefix: rule.SourcePrefix.String(), + destPrefix: rule.DestPrefix.String(), + } + currentRuleMap[key] = true + } + + // Build a map of desired targets + desiredTargetMap := make(map[ruleKey]Target) + for _, target := range desiredTargets { + key := ruleKey{ + sourcePrefix: target.SourcePrefix, + destPrefix: target.DestPrefix, + } + desiredTargetMap[key] = target + } + + // Remove targets that are not in the desired list + for _, rule := range currentRules { + key := ruleKey{ + sourcePrefix: rule.SourcePrefix.String(), + destPrefix: rule.DestPrefix.String(), + } + if _, exists := desiredTargetMap[key]; !exists { + s.tnet.RemoveProxySubnetRule(rule.SourcePrefix, rule.DestPrefix) + logger.Info("Removed target %s -> %s during sync", rule.SourcePrefix.String(), rule.DestPrefix.String()) + } + } + + // Add targets that are missing + for key, target := range desiredTargetMap { + if _, exists := currentRuleMap[key]; !exists { + sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix) + if err != nil { + logger.Warn("Invalid source prefix %s during sync: %v", target.SourcePrefix, err) + continue + } + + destPrefix, err := netip.ParsePrefix(target.DestPrefix) + if err != nil { + logger.Warn("Invalid dest prefix %s during sync: %v", target.DestPrefix, err) + continue + } + + var portRanges []netstack2.PortRange + for _, pr := range target.PortRange { + portRanges = append(portRanges, netstack2.PortRange{ + Min: pr.Min, + Max: pr.Max, + Protocol: pr.Protocol, + }) + } + + s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp) + logger.Info("Added target %s -> %s during sync", target.SourcePrefix, target.DestPrefix) + } + } + + return nil +} + func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error { s.mu.Lock() diff --git a/healthcheck/healthcheck.go b/healthcheck/healthcheck.go index 8de3008..86bdc48 100644 --- a/healthcheck/healthcheck.go +++ b/healthcheck/healthcheck.go @@ -521,3 +521,82 @@ func (m *Monitor) DisableTarget(id int) error { return nil } + +// GetTargetIDs returns a slice of all current target IDs +func (m *Monitor) GetTargetIDs() []int { + m.mutex.RLock() + defer m.mutex.RUnlock() + + ids := make([]int, 0, len(m.targets)) + for id := range m.targets { + ids = append(ids, id) + } + return ids +} + +// SyncTargets synchronizes the current targets to match the desired set. +// It removes targets not in the desired set and adds targets that are missing. +func (m *Monitor) SyncTargets(desiredConfigs []Config) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + logger.Info("Syncing health check targets: %d desired targets", len(desiredConfigs)) + + // Build a set of desired target IDs + desiredIDs := make(map[int]Config) + for _, config := range desiredConfigs { + desiredIDs[config.ID] = config + } + + // Find targets to remove (exist but not in desired set) + var toRemove []int + for id := range m.targets { + if _, exists := desiredIDs[id]; !exists { + toRemove = append(toRemove, id) + } + } + + // Remove targets that are not in the desired set + for _, id := range toRemove { + logger.Info("Sync: removing health check target %d", id) + if target, exists := m.targets[id]; exists { + target.cancel() + delete(m.targets, id) + } + } + + // Add or update targets from the desired set + var addedCount, updatedCount int + for id, config := range desiredIDs { + if existing, exists := m.targets[id]; exists { + // Target exists - check if config changed and update if needed + // For now, we'll replace it to ensure config is up to date + logger.Debug("Sync: updating health check target %d", id) + existing.cancel() + delete(m.targets, id) + if err := m.addTargetUnsafe(config); err != nil { + logger.Error("Sync: failed to update target %d: %v", id, err) + return fmt.Errorf("failed to update target %d: %v", id, err) + } + updatedCount++ + } else { + // Target doesn't exist - add it + logger.Debug("Sync: adding health check target %d", id) + if err := m.addTargetUnsafe(config); err != nil { + logger.Error("Sync: failed to add target %d: %v", id, err) + return fmt.Errorf("failed to add target %d: %v", id, err) + } + addedCount++ + } + } + + logger.Info("Sync complete: removed %d, added %d, updated %d targets", + len(toRemove), addedCount, updatedCount) + + // Notify callback if any changes were made + if (len(toRemove) > 0 || addedCount > 0 || updatedCount > 0) && m.callback != nil { + go m.callback(m.getAllTargetsUnsafe()) + } + + return nil +} diff --git a/main.go b/main.go index c41ea35..cf4f509 100644 --- a/main.go +++ b/main.go @@ -1106,6 +1106,151 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey( } }) + // Register handler for syncing targets (TCP, UDP, and health checks) + client.RegisterHandler("newt/sync", func(msg websocket.WSMessage) { + logger.Info("Received sync message") + + // if there is no wgData or pm, we can't sync targets + if wgData.TunnelIP == "" || pm == nil { + logger.Info(msgNoTunnelOrProxy) + return + } + + // Define the sync data structure + type SyncData struct { + Targets TargetsByType `json:"targets"` + HealthCheckTargets []healthcheck.Config `json:"healthCheckTargets"` + } + + var syncData SyncData + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling sync data: %v", err) + return + } + + if err := json.Unmarshal(jsonData, &syncData); err != nil { + logger.Error("Error unmarshaling sync data: %v", err) + return + } + + logger.Debug("Sync data received: TCP targets=%d, UDP targets=%d, health check targets=%d", + len(syncData.Targets.TCP), len(syncData.Targets.UDP), len(syncData.HealthCheckTargets)) + + // Build sets of desired targets (port -> target string) + desiredTCP := make(map[int]string) + for _, t := range syncData.Targets.TCP { + parts := strings.Split(t, ":") + if len(parts) != 3 { + logger.Warn("Invalid TCP target format: %s", t) + continue + } + port := 0 + if _, err := fmt.Sscanf(parts[0], "%d", &port); err != nil { + logger.Warn("Invalid port in TCP target: %s", parts[0]) + continue + } + desiredTCP[port] = parts[1] + ":" + parts[2] + } + + desiredUDP := make(map[int]string) + for _, t := range syncData.Targets.UDP { + parts := strings.Split(t, ":") + if len(parts) != 3 { + logger.Warn("Invalid UDP target format: %s", t) + continue + } + port := 0 + if _, err := fmt.Sscanf(parts[0], "%d", &port); err != nil { + logger.Warn("Invalid port in UDP target: %s", parts[0]) + continue + } + desiredUDP[port] = parts[1] + ":" + parts[2] + } + + // Get current targets from proxy manager + currentTCP, currentUDP := pm.GetTargets() + + // Sync TCP targets + // Remove TCP targets not in desired set + if tcpForIP, ok := currentTCP[wgData.TunnelIP]; ok { + for port := range tcpForIP { + if _, exists := desiredTCP[port]; !exists { + logger.Info("Sync: removing TCP target on port %d", port) + targetStr := fmt.Sprintf("%d:%s", port, tcpForIP[port]) + updateTargets(pm, "remove", wgData.TunnelIP, "tcp", TargetData{Targets: []string{targetStr}}) + } + } + } + + // Add TCP targets that are missing + for port, target := range desiredTCP { + needsAdd := true + if tcpForIP, ok := currentTCP[wgData.TunnelIP]; ok { + if currentTarget, exists := tcpForIP[port]; exists { + // Check if target address changed + if currentTarget == target { + needsAdd = false + } else { + // Target changed, remove old one first + logger.Info("Sync: updating TCP target on port %d", port) + targetStr := fmt.Sprintf("%d:%s", port, currentTarget) + updateTargets(pm, "remove", wgData.TunnelIP, "tcp", TargetData{Targets: []string{targetStr}}) + } + } + } + if needsAdd { + logger.Info("Sync: adding TCP target on port %d -> %s", port, target) + targetStr := fmt.Sprintf("%d:%s", port, target) + updateTargets(pm, "add", wgData.TunnelIP, "tcp", TargetData{Targets: []string{targetStr}}) + } + } + + // Sync UDP targets + // Remove UDP targets not in desired set + if udpForIP, ok := currentUDP[wgData.TunnelIP]; ok { + for port := range udpForIP { + if _, exists := desiredUDP[port]; !exists { + logger.Info("Sync: removing UDP target on port %d", port) + targetStr := fmt.Sprintf("%d:%s", port, udpForIP[port]) + updateTargets(pm, "remove", wgData.TunnelIP, "udp", TargetData{Targets: []string{targetStr}}) + } + } + } + + // Add UDP targets that are missing + for port, target := range desiredUDP { + needsAdd := true + if udpForIP, ok := currentUDP[wgData.TunnelIP]; ok { + if currentTarget, exists := udpForIP[port]; exists { + // Check if target address changed + if currentTarget == target { + needsAdd = false + } else { + // Target changed, remove old one first + logger.Info("Sync: updating UDP target on port %d", port) + targetStr := fmt.Sprintf("%d:%s", port, currentTarget) + updateTargets(pm, "remove", wgData.TunnelIP, "udp", TargetData{Targets: []string{targetStr}}) + } + } + } + if needsAdd { + logger.Info("Sync: adding UDP target on port %d -> %s", port, target) + targetStr := fmt.Sprintf("%d:%s", port, target) + updateTargets(pm, "add", wgData.TunnelIP, "udp", TargetData{Targets: []string{targetStr}}) + } + } + + // Sync health check targets + if err := healthMonitor.SyncTargets(syncData.HealthCheckTargets); err != nil { + logger.Error("Failed to sync health check targets: %v", err) + } else { + logger.Info("Successfully synced health check targets") + } + + logger.Info("Sync complete") + }) + // Register handler for Docker socket check client.RegisterHandler("newt/socket/check", func(msg websocket.WSMessage) { logger.Debug("Received Docker socket check request") diff --git a/netstack2/proxy.go b/netstack2/proxy.go index fefb18d..33a232f 100644 --- a/netstack2/proxy.go +++ b/netstack2/proxy.go @@ -101,6 +101,18 @@ func (sl *SubnetLookup) RemoveSubnet(sourcePrefix, destPrefix netip.Prefix) { delete(sl.rules, key) } +// GetAllRules returns a copy of all subnet rules +func (sl *SubnetLookup) GetAllRules() []SubnetRule { + sl.mu.RLock() + defer sl.mu.RUnlock() + + rules := make([]SubnetRule, 0, len(sl.rules)) + for _, rule := range sl.rules { + rules = append(rules, *rule) + } + return rules +} + // Match checks if a source IP, destination IP, port, and protocol match any subnet rule // Returns the matched rule if ALL of these conditions are met: // - The source IP is in the rule's source prefix @@ -296,6 +308,14 @@ func (p *ProxyHandler) RemoveSubnetRule(sourcePrefix, destPrefix netip.Prefix) { p.subnetLookup.RemoveSubnet(sourcePrefix, destPrefix) } +// GetAllRules returns all subnet rules from the proxy handler +func (p *ProxyHandler) GetAllRules() []SubnetRule { + if p == nil || !p.enabled { + return nil + } + return p.subnetLookup.GetAllRules() +} + // LookupDestinationRewrite looks up the rewritten destination for a connection // This is used by TCP/UDP handlers to find the actual target address func (p *ProxyHandler) LookupDestinationRewrite(srcIP, dstIP string, dstPort uint16, proto uint8) (netip.Addr, bool) { diff --git a/netstack2/tun.go b/netstack2/tun.go index e743f1e..b00faea 100644 --- a/netstack2/tun.go +++ b/netstack2/tun.go @@ -369,6 +369,15 @@ func (net *Net) RemoveProxySubnetRule(sourcePrefix, destPrefix netip.Prefix) { } } +// GetProxySubnetRules returns all subnet rules from the proxy handler +func (net *Net) GetProxySubnetRules() []SubnetRule { + tun := (*netTun)(net) + if tun.proxyHandler != nil { + return tun.proxyHandler.GetAllRules() + } + return nil +} + // GetProxyHandler returns the proxy handler (for advanced use cases) // Returns nil if proxy is not enabled func (net *Net) GetProxyHandler() *ProxyHandler { diff --git a/proxy/manager.go b/proxy/manager.go index cef5fa6..0619e80 100644 --- a/proxy/manager.go +++ b/proxy/manager.go @@ -736,3 +736,28 @@ func (pm *ProxyManager) PrintTargets() { } } } + +// GetTargets returns a copy of the current TCP and UDP targets +// Returns map[listenIP]map[port]targetAddress for both TCP and UDP +func (pm *ProxyManager) GetTargets() (tcpTargets map[string]map[int]string, udpTargets map[string]map[int]string) { + pm.mutex.RLock() + defer pm.mutex.RUnlock() + + tcpTargets = make(map[string]map[int]string) + for listenIP, targets := range pm.tcpTargets { + tcpTargets[listenIP] = make(map[int]string) + for port, targetAddr := range targets { + tcpTargets[listenIP][port] = targetAddr + } + } + + udpTargets = make(map[string]map[int]string) + for listenIP, targets := range pm.udpTargets { + udpTargets[listenIP] = make(map[int]string) + for port, targetAddr := range targets { + udpTargets[listenIP][port] = targetAddr + } + } + + return tcpTargets, udpTargets +} diff --git a/websocket/client.go b/websocket/client.go index c0fea18..8703b51 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -671,22 +671,24 @@ func (c *Client) pingMonitor() { if c.conn == nil { return } - - // Send application-level ping with config version + + c.configVersionMux.RLock() + configVersion := c.configVersion + c.configVersionMux.RUnlock() + pingMsg := WSMessage{ - Type: "ping", - Data: map[string]interface{}{ - "configVersion": c.GetConfigVersion(), - }, + Type: "ping", + Data: map[string]interface{}{}, + ConfigVersion: configVersion, } - + c.writeMux.Lock() err := c.conn.WriteJSON(pingMsg) if err == nil { telemetry.IncWSMessage(c.metricsContext(), "out", "ping") } c.writeMux.Unlock() - + if err != nil { // Check if we're shutting down before logging error and reconnecting select {