diff --git a/clients/clients.go b/clients/clients.go index 3e062b3..ccf41aa 100644 --- a/clients/clients.go +++ b/clients/clients.go @@ -37,11 +37,12 @@ type WgConfig struct { } type Target struct { - SourcePrefix string `json:"sourcePrefix"` - DestPrefix string `json:"destPrefix"` - RewriteTo string `json:"rewriteTo,omitempty"` - DisableIcmp bool `json:"disableIcmp,omitempty"` - PortRange []PortRange `json:"portRange,omitempty"` + SourcePrefix string `json:"sourcePrefix"` + SourcePrefixes []string `json:"sourcePrefixes"` + DestPrefix string `json:"destPrefix"` + RewriteTo string `json:"rewriteTo,omitempty"` + DisableIcmp bool `json:"disableIcmp,omitempty"` + PortRange []PortRange `json:"portRange,omitempty"` } type PortRange struct { @@ -112,8 +113,6 @@ func NewWireGuardService(interfaceName string, port uint16, mtu int, host string return nil, fmt.Errorf("failed to generate private key: %v", err) } - logger.Debug("+++++++++++++++++++++++++++++++= the port is %d", port) - if port == 0 { // Find an available port portRandom, err := util.FindAvailableUDPPort(49152, 65535) @@ -173,6 +172,7 @@ func NewWireGuardService(interfaceName string, port uint16, mtu int, host string wsClient.RegisterHandler("newt/wg/targets/add", service.handleAddTarget) wsClient.RegisterHandler("newt/wg/targets/remove", service.handleRemoveTarget) wsClient.RegisterHandler("newt/wg/targets/update", service.handleUpdateTarget) + wsClient.RegisterHandler("newt/wg/sync", service.handleSyncConfig) return service, nil } @@ -278,7 +278,7 @@ func (s *WireGuardService) StartHolepunch(publicKey string, endpoint string, rel } if relayPort == 0 { - relayPort = 21820 + relayPort = 21820 } // Convert websocket.ExitNode to holepunch.ExitNode @@ -493,6 +493,183 @@ func (s *WireGuardService) handleConfig(msg websocket.WSMessage) { logger.Info("Client connectivity setup. Ready to accept connections from clients!") } +// SyncConfig represents the configuration sent from server for syncing +type SyncConfig struct { + Targets []Target `json:"targets"` + Peers []Peer `json:"peers"` +} + +func (s *WireGuardService) handleSyncConfig(msg websocket.WSMessage) { + var syncConfig SyncConfig + + logger.Debug("Received sync message: %v", msg) + logger.Info("Received sync configuration from remote server") + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling sync data: %v", err) + return + } + + if err := json.Unmarshal(jsonData, &syncConfig); err != nil { + logger.Error("Error unmarshaling sync data: %v", err) + return + } + + // Sync peers + if err := s.syncPeers(syncConfig.Peers); err != nil { + logger.Error("Failed to sync peers: %v", err) + } + + // Sync targets + if err := s.syncTargets(syncConfig.Targets); err != nil { + logger.Error("Failed to sync targets: %v", err) + } +} + +// syncPeers synchronizes the current peers with the desired state +// It removes peers not in the desired list and adds missing ones +func (s *WireGuardService) syncPeers(desiredPeers []Peer) error { + if s.device == nil { + return fmt.Errorf("WireGuard device is not initialized") + } + + // Get current peers from the device + currentConfig, err := s.device.IpcGet() + if err != nil { + return fmt.Errorf("failed to get current device config: %v", err) + } + + // Parse current peer public keys + lines := strings.Split(currentConfig, "\n") + currentPeerKeys := make(map[string]bool) + for _, line := range lines { + if strings.HasPrefix(line, "public_key=") { + pubKey := strings.TrimPrefix(line, "public_key=") + currentPeerKeys[pubKey] = true + } + } + + // Build a map of desired peers by their public key (normalized) + desiredPeerMap := make(map[string]Peer) + for _, peer := range desiredPeers { + // Normalize the public key for comparison + pubKey, err := wgtypes.ParseKey(peer.PublicKey) + if err != nil { + logger.Warn("Invalid public key in desired peers: %s", peer.PublicKey) + continue + } + normalizedKey := util.FixKey(pubKey.String()) + desiredPeerMap[normalizedKey] = peer + } + + // Remove peers that are not in the desired list + for currentKey := range currentPeerKeys { + if _, exists := desiredPeerMap[currentKey]; !exists { + // Parse the key back to get the original format for removal + removeConfig := fmt.Sprintf("public_key=%s\nremove=true", currentKey) + if err := s.device.IpcSet(removeConfig); err != nil { + logger.Warn("Failed to remove peer %s during sync: %v", currentKey, err) + } else { + logger.Info("Removed peer %s during sync", currentKey) + } + } + } + + // Add peers that are missing + for normalizedKey, peer := range desiredPeerMap { + if _, exists := currentPeerKeys[normalizedKey]; !exists { + if err := s.addPeerToDevice(peer); err != nil { + logger.Warn("Failed to add peer %s during sync: %v", peer.PublicKey, err) + } else { + logger.Info("Added peer %s during sync", peer.PublicKey) + } + } + } + + return nil +} + +// syncTargets synchronizes the current targets with the desired state +// It removes targets not in the desired list and adds missing ones +func (s *WireGuardService) syncTargets(desiredTargets []Target) error { + if s.tnet == nil { + // Native interface mode - proxy features not available, skip silently + logger.Debug("Skipping target sync - using native interface (no proxy support)") + return nil + } + + // Get current rules from the proxy handler + currentRules := s.tnet.GetProxySubnetRules() + + // Build a map of current rules by source+dest prefix + type ruleKey struct { + sourcePrefix string + destPrefix string + } + currentRuleMap := make(map[ruleKey]bool) + for _, rule := range currentRules { + key := ruleKey{ + sourcePrefix: rule.SourcePrefix.String(), + destPrefix: rule.DestPrefix.String(), + } + currentRuleMap[key] = true + } + + // Build a map of desired targets + desiredTargetMap := make(map[ruleKey]Target) + for _, target := range desiredTargets { + key := ruleKey{ + sourcePrefix: target.SourcePrefix, + destPrefix: target.DestPrefix, + } + desiredTargetMap[key] = target + } + + // Remove targets that are not in the desired list + for _, rule := range currentRules { + key := ruleKey{ + sourcePrefix: rule.SourcePrefix.String(), + destPrefix: rule.DestPrefix.String(), + } + if _, exists := desiredTargetMap[key]; !exists { + s.tnet.RemoveProxySubnetRule(rule.SourcePrefix, rule.DestPrefix) + logger.Info("Removed target %s -> %s during sync", rule.SourcePrefix.String(), rule.DestPrefix.String()) + } + } + + // Add targets that are missing + for key, target := range desiredTargetMap { + if _, exists := currentRuleMap[key]; !exists { + sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix) + if err != nil { + logger.Warn("Invalid source prefix %s during sync: %v", target.SourcePrefix, err) + continue + } + + destPrefix, err := netip.ParsePrefix(target.DestPrefix) + if err != nil { + logger.Warn("Invalid dest prefix %s during sync: %v", target.DestPrefix, err) + continue + } + + var portRanges []netstack2.PortRange + for _, pr := range target.PortRange { + portRanges = append(portRanges, netstack2.PortRange{ + Min: pr.Min, + Max: pr.Max, + Protocol: pr.Protocol, + }) + } + + s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp) + logger.Info("Added target %s -> %s during sync", target.SourcePrefix, target.DestPrefix) + } + } + + return nil +} + func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error { s.mu.Lock() @@ -696,6 +873,19 @@ func (s *WireGuardService) ensureWireguardPeers(peers []Peer) error { return nil } +// resolveSourcePrefixes returns the effective list of source prefixes for a target, +// supporting both the legacy single SourcePrefix field and the new SourcePrefixes array. +// If SourcePrefixes is non-empty it takes precedence; otherwise SourcePrefix is used. +func resolveSourcePrefixes(target Target) []string { + if len(target.SourcePrefixes) > 0 { + return target.SourcePrefixes + } + if target.SourcePrefix != "" { + return []string{target.SourcePrefix} + } + return nil +} + func (s *WireGuardService) ensureTargets(targets []Target) error { if s.tnet == nil { // Native interface mode - proxy features not available, skip silently @@ -704,11 +894,6 @@ func (s *WireGuardService) ensureTargets(targets []Target) error { } for _, target := range targets { - sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix) - if err != nil { - return fmt.Errorf("invalid CIDR %s: %v", target.SourcePrefix, err) - } - destPrefix, err := netip.ParsePrefix(target.DestPrefix) if err != nil { return fmt.Errorf("invalid CIDR %s: %v", target.DestPrefix, err) @@ -723,9 +908,14 @@ func (s *WireGuardService) ensureTargets(targets []Target) error { }) } - s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp) - - logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v disableIcmp: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange, target.DisableIcmp) + for _, sp := range resolveSourcePrefixes(target) { + sourcePrefix, err := netip.ParsePrefix(sp) + if err != nil { + return fmt.Errorf("invalid CIDR %s: %v", sp, err) + } + s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp) + logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange) + } } return nil @@ -1044,7 +1234,7 @@ func (s *WireGuardService) processPeerBandwidth(publicKey string, rxBytes, txByt BytesOut: bytesOutMB, } } - + return nil } } @@ -1095,12 +1285,6 @@ func (s *WireGuardService) handleAddTarget(msg websocket.WSMessage) { // Process all targets for _, target := range targets { - sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix) - if err != nil { - logger.Info("Invalid CIDR %s: %v", target.SourcePrefix, err) - continue - } - destPrefix, err := netip.ParsePrefix(target.DestPrefix) if err != nil { logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err) @@ -1110,15 +1294,21 @@ func (s *WireGuardService) handleAddTarget(msg websocket.WSMessage) { var portRanges []netstack2.PortRange for _, pr := range target.PortRange { portRanges = append(portRanges, netstack2.PortRange{ - Min: pr.Min, - Max: pr.Max, - Protocol: pr.Protocol, + Min: pr.Min, + Max: pr.Max, + Protocol: pr.Protocol, }) } - s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp) - - logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v disableIcmp: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange, target.DisableIcmp) + for _, sp := range resolveSourcePrefixes(target) { + sourcePrefix, err := netip.ParsePrefix(sp) + if err != nil { + logger.Info("Invalid CIDR %s: %v", sp, err) + continue + } + s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp) + logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange) + } } } @@ -1147,21 +1337,21 @@ func (s *WireGuardService) handleRemoveTarget(msg websocket.WSMessage) { // Process all targets for _, target := range targets { - sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix) - if err != nil { - logger.Info("Invalid CIDR %s: %v", target.SourcePrefix, err) - continue - } - destPrefix, err := netip.ParsePrefix(target.DestPrefix) if err != nil { logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err) continue } - s.tnet.RemoveProxySubnetRule(sourcePrefix, destPrefix) - - logger.Info("Removed target subnet %s with destination %s", target.SourcePrefix, target.DestPrefix) + for _, sp := range resolveSourcePrefixes(target) { + sourcePrefix, err := netip.ParsePrefix(sp) + if err != nil { + logger.Info("Invalid CIDR %s: %v", sp, err) + continue + } + s.tnet.RemoveProxySubnetRule(sourcePrefix, destPrefix) + logger.Info("Removed target subnet %s with destination %s", sp, target.DestPrefix) + } } } @@ -1195,30 +1385,24 @@ func (s *WireGuardService) handleUpdateTarget(msg websocket.WSMessage) { // Process all update requests for _, target := range requests.OldTargets { - sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix) - if err != nil { - logger.Info("Invalid CIDR %s: %v", target.SourcePrefix, err) - continue - } - destPrefix, err := netip.ParsePrefix(target.DestPrefix) if err != nil { logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err) continue } - s.tnet.RemoveProxySubnetRule(sourcePrefix, destPrefix) - logger.Info("Removed target subnet %s with destination %s", target.SourcePrefix, target.DestPrefix) + for _, sp := range resolveSourcePrefixes(target) { + sourcePrefix, err := netip.ParsePrefix(sp) + if err != nil { + logger.Info("Invalid CIDR %s: %v", sp, err) + continue + } + s.tnet.RemoveProxySubnetRule(sourcePrefix, destPrefix) + logger.Info("Removed target subnet %s with destination %s", sp, target.DestPrefix) + } } for _, target := range requests.NewTargets { - // Now add the new target - sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix) - if err != nil { - logger.Info("Invalid CIDR %s: %v", target.SourcePrefix, err) - continue - } - destPrefix, err := netip.ParsePrefix(target.DestPrefix) if err != nil { logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err) @@ -1228,14 +1412,21 @@ func (s *WireGuardService) handleUpdateTarget(msg websocket.WSMessage) { var portRanges []netstack2.PortRange for _, pr := range target.PortRange { portRanges = append(portRanges, netstack2.PortRange{ - Min: pr.Min, - Max: pr.Max, - Protocol: pr.Protocol, + Min: pr.Min, + Max: pr.Max, + Protocol: pr.Protocol, }) } - s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp) - logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v disableIcmp: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange, target.DisableIcmp) + for _, sp := range resolveSourcePrefixes(target) { + sourcePrefix, err := netip.ParsePrefix(sp) + if err != nil { + logger.Info("Invalid CIDR %s: %v", sp, err) + continue + } + s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp) + logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange) + } } } diff --git a/common.go b/common.go index 5fe0645..4701411 100644 --- a/common.go +++ b/common.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "fmt" + "net" "os" "os/exec" "strings" @@ -363,27 +364,62 @@ func parseTargetData(data interface{}) (TargetData, error) { return targetData, nil } +// parseTargetString parses a target string in the format "listenPort:host:targetPort" +// It properly handles IPv6 addresses which must be in brackets: "listenPort:[ipv6]:targetPort" +// Examples: +// - IPv4: "3001:192.168.1.1:80" +// - IPv6: "3001:[::1]:8080" or "3001:[fd70:1452:b736:4dd5:caca:7db9:c588:f5b3]:80" +// +// Returns listenPort, targetAddress (in host:port format suitable for net.Dial), and error +func parseTargetString(target string) (int, string, error) { + // Find the first colon to extract the listen port + firstColon := strings.Index(target, ":") + if firstColon == -1 { + return 0, "", fmt.Errorf("invalid target format, no colon found: %s", target) + } + + listenPortStr := target[:firstColon] + var listenPort int + _, err := fmt.Sscanf(listenPortStr, "%d", &listenPort) + if err != nil { + return 0, "", fmt.Errorf("invalid listen port: %s", listenPortStr) + } + if listenPort <= 0 || listenPort > 65535 { + return 0, "", fmt.Errorf("listen port out of range: %d", listenPort) + } + + // The remainder is host:targetPort - use net.SplitHostPort which handles IPv6 brackets + remainder := target[firstColon+1:] + host, targetPort, err := net.SplitHostPort(remainder) + if err != nil { + return 0, "", fmt.Errorf("invalid host:port format '%s': %w", remainder, err) + } + + // Reject empty host or target port + if host == "" { + return 0, "", fmt.Errorf("empty host in target: %s", target) + } + if targetPort == "" { + return 0, "", fmt.Errorf("empty target port in target: %s", target) + } + + // Reconstruct the target address using JoinHostPort (handles IPv6 properly) + targetAddr := net.JoinHostPort(host, targetPort) + + return listenPort, targetAddr, nil +} + func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error { for _, t := range targetData.Targets { - // Split the first number off of the target with : separator and use as the port - parts := strings.Split(t, ":") - if len(parts) != 3 { - logger.Info("Invalid target format: %s", t) - continue - } - - // Get the port as an int - port := 0 - _, err := fmt.Sscanf(parts[0], "%d", &port) + // Parse the target string, handling both IPv4 and IPv6 addresses + port, target, err := parseTargetString(t) if err != nil { - logger.Info("Invalid port: %s", parts[0]) + logger.Info("Invalid target format: %s (%v)", t, err) continue } switch action { case "add": - target := parts[1] + ":" + parts[2] - // Call updown script if provided processedTarget := target if updownScript != "" { @@ -410,8 +446,6 @@ func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto case "remove": logger.Info("Removing target with port %d", port) - target := parts[1] + ":" + parts[2] - // Call updown script if provided if updownScript != "" { _, err := executeUpdownScript(action, proto, target) @@ -420,7 +454,7 @@ func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto } } - err := pm.RemoveTarget(proto, tunnelIP, port) + err = pm.RemoveTarget(proto, tunnelIP, port) if err != nil { logger.Error("Failed to remove target: %v", err) return err diff --git a/common_test.go b/common_test.go new file mode 100644 index 0000000..a7e659a --- /dev/null +++ b/common_test.go @@ -0,0 +1,212 @@ +package main + +import ( + "net" + "testing" +) + +func TestParseTargetString(t *testing.T) { + tests := []struct { + name string + input string + wantListenPort int + wantTargetAddr string + wantErr bool + }{ + // IPv4 test cases + { + name: "valid IPv4 basic", + input: "3001:192.168.1.1:80", + wantListenPort: 3001, + wantTargetAddr: "192.168.1.1:80", + wantErr: false, + }, + { + name: "valid IPv4 localhost", + input: "8080:127.0.0.1:3000", + wantListenPort: 8080, + wantTargetAddr: "127.0.0.1:3000", + wantErr: false, + }, + { + name: "valid IPv4 same ports", + input: "443:10.0.0.1:443", + wantListenPort: 443, + wantTargetAddr: "10.0.0.1:443", + wantErr: false, + }, + + // IPv6 test cases + { + name: "valid IPv6 loopback", + input: "3001:[::1]:8080", + wantListenPort: 3001, + wantTargetAddr: "[::1]:8080", + wantErr: false, + }, + { + name: "valid IPv6 full address", + input: "80:[fd70:1452:b736:4dd5:caca:7db9:c588:f5b3]:8080", + wantListenPort: 80, + wantTargetAddr: "[fd70:1452:b736:4dd5:caca:7db9:c588:f5b3]:8080", + wantErr: false, + }, + { + name: "valid IPv6 link-local", + input: "443:[fe80::1]:443", + wantListenPort: 443, + wantTargetAddr: "[fe80::1]:443", + wantErr: false, + }, + { + name: "valid IPv6 all zeros compressed", + input: "8000:[::]:9000", + wantListenPort: 8000, + wantTargetAddr: "[::]:9000", + wantErr: false, + }, + { + name: "valid IPv6 mixed notation", + input: "5000:[::ffff:192.168.1.1]:6000", + wantListenPort: 5000, + wantTargetAddr: "[::ffff:192.168.1.1]:6000", + wantErr: false, + }, + + // Hostname test cases + { + name: "valid hostname", + input: "8080:example.com:80", + wantListenPort: 8080, + wantTargetAddr: "example.com:80", + wantErr: false, + }, + { + name: "valid hostname with subdomain", + input: "443:api.example.com:8443", + wantListenPort: 443, + wantTargetAddr: "api.example.com:8443", + wantErr: false, + }, + { + name: "valid localhost hostname", + input: "3000:localhost:3000", + wantListenPort: 3000, + wantTargetAddr: "localhost:3000", + wantErr: false, + }, + + // Error cases + { + name: "invalid - no colons", + input: "invalid", + wantErr: true, + }, + { + name: "invalid - empty string", + input: "", + wantErr: true, + }, + { + name: "invalid - non-numeric listen port", + input: "abc:192.168.1.1:80", + wantErr: true, + }, + { + name: "invalid - missing target port", + input: "3001:192.168.1.1", + wantErr: true, + }, + { + name: "invalid - IPv6 without brackets", + input: "3001:fd70:1452:b736:4dd5:caca:7db9:c588:f5b3:80", + wantErr: true, + }, + { + name: "invalid - only listen port", + input: "3001:", + wantErr: true, + }, + { + name: "invalid - missing host", + input: "3001::80", + wantErr: true, + }, + { + name: "invalid - IPv6 unclosed bracket", + input: "3001:[::1:80", + wantErr: true, + }, + { + name: "invalid - listen port zero", + input: "0:192.168.1.1:80", + wantErr: true, + }, + { + name: "invalid - listen port negative", + input: "-1:192.168.1.1:80", + wantErr: true, + }, + { + name: "invalid - listen port out of range", + input: "70000:192.168.1.1:80", + wantErr: true, + }, + { + name: "invalid - empty target port", + input: "3001:192.168.1.1:", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + listenPort, targetAddr, err := parseTargetString(tt.input) + + if (err != nil) != tt.wantErr { + t.Errorf("parseTargetString(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr) + return + } + + if tt.wantErr { + return // Don't check other values if we expected an error + } + + if listenPort != tt.wantListenPort { + t.Errorf("parseTargetString(%q) listenPort = %d, want %d", tt.input, listenPort, tt.wantListenPort) + } + + if targetAddr != tt.wantTargetAddr { + t.Errorf("parseTargetString(%q) targetAddr = %q, want %q", tt.input, targetAddr, tt.wantTargetAddr) + } + }) + } +} + +// TestParseTargetStringNetDialCompatibility verifies that the output is compatible with net.Dial +func TestParseTargetStringNetDialCompatibility(t *testing.T) { + tests := []struct { + name string + input string + }{ + {"IPv4", "8080:127.0.0.1:80"}, + {"IPv6 loopback", "8080:[::1]:80"}, + {"IPv6 full", "8080:[2001:db8::1]:80"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, targetAddr, err := parseTargetString(tt.input) + if err != nil { + t.Fatalf("parseTargetString(%q) unexpected error: %v", tt.input, err) + } + + // Verify the format is valid for net.Dial by checking it can be split back + // This doesn't actually dial, just validates the format + _, _, err = net.SplitHostPort(targetAddr) + if err != nil { + t.Errorf("parseTargetString(%q) produced invalid net.Dial format %q: %v", tt.input, targetAddr, err) + } + }) + } +} diff --git a/get-newt.sh b/get-newt.sh index d57f69a..d4ddd3f 100644 --- a/get-newt.sh +++ b/get-newt.sh @@ -1,7 +1,7 @@ -#!/bin/bash +#!/bin/sh # Get Newt - Cross-platform installation script -# Usage: curl -fsSL https://raw.githubusercontent.com/fosrl/newt/refs/heads/main/get-newt.sh | bash +# Usage: curl -fsSL https://raw.githubusercontent.com/fosrl/newt/refs/heads/main/get-newt.sh | sh set -e @@ -17,15 +17,15 @@ GITHUB_API_URL="https://api.github.com/repos/${REPO}/releases/latest" # Function to print colored output print_status() { - echo -e "${GREEN}[INFO]${NC} $1" + printf '%b[INFO]%b %s\n' "${GREEN}" "${NC}" "$1" } print_warning() { - echo -e "${YELLOW}[WARN]${NC} $1" + printf '%b[WARN]%b %s\n' "${YELLOW}" "${NC}" "$1" } print_error() { - echo -e "${RED}[ERROR]${NC} $1" + printf '%b[ERROR]%b %s\n' "${RED}" "${NC}" "$1" } # Function to get latest version from GitHub API @@ -113,16 +113,34 @@ get_install_dir() { if [ "$OS" = "windows" ]; then echo "$HOME/bin" else - # Try to use a directory in PATH, fallback to ~/.local/bin - if echo "$PATH" | grep -q "/usr/local/bin"; then - if [ -w "/usr/local/bin" ] 2>/dev/null; then - echo "/usr/local/bin" - else - echo "$HOME/.local/bin" - fi + # Prefer /usr/local/bin for system-wide installation + echo "/usr/local/bin" + fi +} + +# Check if we need sudo for installation +needs_sudo() { + local install_dir="$1" + if [ -w "$install_dir" ] 2>/dev/null; then + return 1 # No sudo needed + else + return 0 # Sudo needed + fi +} + +# Get the appropriate command prefix (sudo or empty) +get_sudo_cmd() { + local install_dir="$1" + if needs_sudo "$install_dir"; then + if command -v sudo >/dev/null 2>&1; then + echo "sudo" else - echo "$HOME/.local/bin" + print_error "Cannot write to ${install_dir} and sudo is not available." + print_error "Please run this script as root or install sudo." + exit 1 fi + else + echo "" fi } @@ -130,21 +148,24 @@ get_install_dir() { install_newt() { local platform="$1" local install_dir="$2" + local sudo_cmd="$3" local binary_name="newt_${platform}" local exe_suffix="" - + # Add .exe suffix for Windows - if [[ "$platform" == *"windows"* ]]; then - binary_name="${binary_name}.exe" - exe_suffix=".exe" - fi - + case "$platform" in + *windows*) + binary_name="${binary_name}.exe" + exe_suffix=".exe" + ;; + esac + local download_url="${BASE_URL}/${binary_name}" local temp_file="/tmp/newt${exe_suffix}" local final_path="${install_dir}/newt${exe_suffix}" - + print_status "Downloading newt from ${download_url}" - + # Download the binary if command -v curl >/dev/null 2>&1; then curl -fsSL "$download_url" -o "$temp_file" @@ -154,18 +175,22 @@ install_newt() { print_error "Neither curl nor wget is available. Please install one of them." exit 1 fi - + + # Make executable before moving + chmod +x "$temp_file" + # Create install directory if it doesn't exist - mkdir -p "$install_dir" - - # Move binary to install directory - mv "$temp_file" "$final_path" - - # Make executable (not needed on Windows, but doesn't hurt) - chmod +x "$final_path" - + if [ -n "$sudo_cmd" ]; then + $sudo_cmd mkdir -p "$install_dir" + print_status "Using sudo to install to ${install_dir}" + $sudo_cmd mv "$temp_file" "$final_path" + else + mkdir -p "$install_dir" + mv "$temp_file" "$final_path" + fi + print_status "newt installed to ${final_path}" - + # Check if install directory is in PATH if ! echo "$PATH" | grep -q "$install_dir"; then print_warning "Install directory ${install_dir} is not in your PATH." @@ -179,9 +204,9 @@ verify_installation() { local install_dir="$1" local exe_suffix="" - if [[ "$PLATFORM" == *"windows"* ]]; then - exe_suffix=".exe" - fi + case "$PLATFORM" in + *windows*) exe_suffix=".exe" ;; + esac local newt_path="${install_dir}/newt${exe_suffix}" @@ -198,34 +223,36 @@ verify_installation() { # Main installation process main() { print_status "Installing latest version of newt..." - + # Get latest version print_status "Fetching latest version from GitHub..." VERSION=$(get_latest_version) print_status "Latest version: v${VERSION}" - + # Set base URL with the fetched version BASE_URL="https://github.com/${REPO}/releases/download/${VERSION}" - + # Detect platform PLATFORM=$(detect_platform) print_status "Detected platform: ${PLATFORM}" - + # Get install directory INSTALL_DIR=$(get_install_dir) print_status "Install directory: ${INSTALL_DIR}" - + + # Check if we need sudo + SUDO_CMD=$(get_sudo_cmd "$INSTALL_DIR") + if [ -n "$SUDO_CMD" ]; then + print_status "Root privileges required for installation to ${INSTALL_DIR}" + fi + # Install newt - install_newt "$PLATFORM" "$INSTALL_DIR" - + install_newt "$PLATFORM" "$INSTALL_DIR" "$SUDO_CMD" + # Verify installation if verify_installation "$INSTALL_DIR"; then print_status "newt is ready to use!" - if [[ "$PLATFORM" == *"windows"* ]]; then - print_status "Run 'newt --help' to get started" - else - print_status "Run 'newt --help' to get started" - fi + print_status "Run 'newt --help' to get started" else exit 1 fi diff --git a/healthcheck/healthcheck.go b/healthcheck/healthcheck.go index 9b23479..9889cc6 100644 --- a/healthcheck/healthcheck.go +++ b/healthcheck/healthcheck.go @@ -521,3 +521,82 @@ func (m *Monitor) DisableTarget(id int) error { return nil } + +// GetTargetIDs returns a slice of all current target IDs +func (m *Monitor) GetTargetIDs() []int { + m.mutex.RLock() + defer m.mutex.RUnlock() + + ids := make([]int, 0, len(m.targets)) + for id := range m.targets { + ids = append(ids, id) + } + return ids +} + +// SyncTargets synchronizes the current targets to match the desired set. +// It removes targets not in the desired set and adds targets that are missing. +func (m *Monitor) SyncTargets(desiredConfigs []Config) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + logger.Info("Syncing health check targets: %d desired targets", len(desiredConfigs)) + + // Build a set of desired target IDs + desiredIDs := make(map[int]Config) + for _, config := range desiredConfigs { + desiredIDs[config.ID] = config + } + + // Find targets to remove (exist but not in desired set) + var toRemove []int + for id := range m.targets { + if _, exists := desiredIDs[id]; !exists { + toRemove = append(toRemove, id) + } + } + + // Remove targets that are not in the desired set + for _, id := range toRemove { + logger.Info("Sync: removing health check target %d", id) + if target, exists := m.targets[id]; exists { + target.cancel() + delete(m.targets, id) + } + } + + // Add or update targets from the desired set + var addedCount, updatedCount int + for id, config := range desiredIDs { + if existing, exists := m.targets[id]; exists { + // Target exists - check if config changed and update if needed + // For now, we'll replace it to ensure config is up to date + logger.Debug("Sync: updating health check target %d", id) + existing.cancel() + delete(m.targets, id) + if err := m.addTargetUnsafe(config); err != nil { + logger.Error("Sync: failed to update target %d: %v", id, err) + return fmt.Errorf("failed to update target %d: %v", id, err) + } + updatedCount++ + } else { + // Target doesn't exist - add it + logger.Debug("Sync: adding health check target %d", id) + if err := m.addTargetUnsafe(config); err != nil { + logger.Error("Sync: failed to add target %d: %v", id, err) + return fmt.Errorf("failed to add target %d: %v", id, err) + } + addedCount++ + } + } + + logger.Info("Sync complete: removed %d, added %d, updated %d targets", + len(toRemove), addedCount, updatedCount) + + // Notify callback if any changes were made + if (len(toRemove) > 0 || addedCount > 0 || updatedCount > 0) && m.callback != nil { + go m.callback(m.getAllTargetsUnsafe()) + } + + return nil +} diff --git a/main.go b/main.go index 9c373b0..c9e7d8d 100644 --- a/main.go +++ b/main.go @@ -302,10 +302,10 @@ func runNewtMain(ctx context.Context) { flag.StringVar(&dockerSocket, "docker-socket", "", "Path or address to Docker socket (typically unix:///var/run/docker.sock)") } if pingIntervalStr == "" { - flag.StringVar(&pingIntervalStr, "ping-interval", "3s", "Interval for pinging the server (default 3s)") + flag.StringVar(&pingIntervalStr, "ping-interval", "15s", "Interval for pinging the server (default 15s)") } if pingTimeoutStr == "" { - flag.StringVar(&pingTimeoutStr, "ping-timeout", "5s", " Timeout for each ping (default 5s)") + flag.StringVar(&pingTimeoutStr, "ping-timeout", "7s", " Timeout for each ping (default 7s)") } // load the prefer endpoint just as a flag flag.StringVar(&preferEndpoint, "prefer-endpoint", "", "Prefer this endpoint for the connection (if set, will override the endpoint from the server)") @@ -330,21 +330,21 @@ func runNewtMain(ctx context.Context) { if pingIntervalStr != "" { pingInterval, err = time.ParseDuration(pingIntervalStr) if err != nil { - fmt.Printf("Invalid PING_INTERVAL value: %s, using default 3 seconds\n", pingIntervalStr) - pingInterval = 3 * time.Second + fmt.Printf("Invalid PING_INTERVAL value: %s, using default 15 seconds\n", pingIntervalStr) + pingInterval = 15 * time.Second } } else { - pingInterval = 3 * time.Second + pingInterval = 15 * time.Second } if pingTimeoutStr != "" { pingTimeout, err = time.ParseDuration(pingTimeoutStr) if err != nil { - fmt.Printf("Invalid PING_TIMEOUT value: %s, using default 5 seconds\n", pingTimeoutStr) - pingTimeout = 5 * time.Second + fmt.Printf("Invalid PING_TIMEOUT value: %s, using default 7 seconds\n", pingTimeoutStr) + pingTimeout = 7 * time.Second } } else { - pingTimeout = 5 * time.Second + pingTimeout = 7 * time.Second } if dockerEnforceNetworkValidation == "" { @@ -565,8 +565,7 @@ func runNewtMain(ctx context.Context) { id, // CLI arg takes precedence secret, // CLI arg takes precedence endpoint, - pingInterval, - pingTimeout, + 30*time.Second, opt, ) if err != nil { @@ -618,8 +617,6 @@ func runNewtMain(ctx context.Context) { var connected bool var wgData WgData var dockerEventMonitor *docker.EventMonitor - - logger.Debug("++++++++++++++++++++++ the port is %d", port) if !disableClients { setupClients(client) @@ -959,7 +956,7 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey( "publicKey": publicKey.String(), "pingResults": pingResults, "newtVersion": newtVersion, - }, 1*time.Second) + }, 2*time.Second) return } @@ -1062,7 +1059,7 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey( "publicKey": publicKey.String(), "pingResults": pingResults, "newtVersion": newtVersion, - }, 1*time.Second) + }, 2*time.Second) logger.Debug("Sent exit node ping results to cloud for selection: pingResults=%+v", pingResults) }) @@ -1167,6 +1164,153 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey( } }) + // Register handler for syncing targets (TCP, UDP, and health checks) + client.RegisterHandler("newt/sync", func(msg websocket.WSMessage) { + logger.Info("Received sync message") + + // if there is no wgData or pm, we can't sync targets + if wgData.TunnelIP == "" || pm == nil { + logger.Info(msgNoTunnelOrProxy) + return + } + + // Define the sync data structure + type SyncData struct { + Targets TargetsByType `json:"targets"` + HealthCheckTargets []healthcheck.Config `json:"healthCheckTargets"` + } + + var syncData SyncData + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling sync data: %v", err) + return + } + + if err := json.Unmarshal(jsonData, &syncData); err != nil { + logger.Error("Error unmarshaling sync data: %v", err) + return + } + + logger.Debug("Sync data received: TCP targets=%d, UDP targets=%d, health check targets=%d", + len(syncData.Targets.TCP), len(syncData.Targets.UDP), len(syncData.HealthCheckTargets)) + + //TODO: TEST AND IMPLEMENT THIS + + // // Build sets of desired targets (port -> target string) + // desiredTCP := make(map[int]string) + // for _, t := range syncData.Targets.TCP { + // parts := strings.Split(t, ":") + // if len(parts) != 3 { + // logger.Warn("Invalid TCP target format: %s", t) + // continue + // } + // port := 0 + // if _, err := fmt.Sscanf(parts[0], "%d", &port); err != nil { + // logger.Warn("Invalid port in TCP target: %s", parts[0]) + // continue + // } + // desiredTCP[port] = parts[1] + ":" + parts[2] + // } + + // desiredUDP := make(map[int]string) + // for _, t := range syncData.Targets.UDP { + // parts := strings.Split(t, ":") + // if len(parts) != 3 { + // logger.Warn("Invalid UDP target format: %s", t) + // continue + // } + // port := 0 + // if _, err := fmt.Sscanf(parts[0], "%d", &port); err != nil { + // logger.Warn("Invalid port in UDP target: %s", parts[0]) + // continue + // } + // desiredUDP[port] = parts[1] + ":" + parts[2] + // } + + // // Get current targets from proxy manager + // currentTCP, currentUDP := pm.GetTargets() + + // // Sync TCP targets + // // Remove TCP targets not in desired set + // if tcpForIP, ok := currentTCP[wgData.TunnelIP]; ok { + // for port := range tcpForIP { + // if _, exists := desiredTCP[port]; !exists { + // logger.Info("Sync: removing TCP target on port %d", port) + // targetStr := fmt.Sprintf("%d:%s", port, tcpForIP[port]) + // updateTargets(pm, "remove", wgData.TunnelIP, "tcp", TargetData{Targets: []string{targetStr}}) + // } + // } + // } + + // // Add TCP targets that are missing + // for port, target := range desiredTCP { + // needsAdd := true + // if tcpForIP, ok := currentTCP[wgData.TunnelIP]; ok { + // if currentTarget, exists := tcpForIP[port]; exists { + // // Check if target address changed + // if currentTarget == target { + // needsAdd = false + // } else { + // // Target changed, remove old one first + // logger.Info("Sync: updating TCP target on port %d", port) + // targetStr := fmt.Sprintf("%d:%s", port, currentTarget) + // updateTargets(pm, "remove", wgData.TunnelIP, "tcp", TargetData{Targets: []string{targetStr}}) + // } + // } + // } + // if needsAdd { + // logger.Info("Sync: adding TCP target on port %d -> %s", port, target) + // targetStr := fmt.Sprintf("%d:%s", port, target) + // updateTargets(pm, "add", wgData.TunnelIP, "tcp", TargetData{Targets: []string{targetStr}}) + // } + // } + + // // Sync UDP targets + // // Remove UDP targets not in desired set + // if udpForIP, ok := currentUDP[wgData.TunnelIP]; ok { + // for port := range udpForIP { + // if _, exists := desiredUDP[port]; !exists { + // logger.Info("Sync: removing UDP target on port %d", port) + // targetStr := fmt.Sprintf("%d:%s", port, udpForIP[port]) + // updateTargets(pm, "remove", wgData.TunnelIP, "udp", TargetData{Targets: []string{targetStr}}) + // } + // } + // } + + // // Add UDP targets that are missing + // for port, target := range desiredUDP { + // needsAdd := true + // if udpForIP, ok := currentUDP[wgData.TunnelIP]; ok { + // if currentTarget, exists := udpForIP[port]; exists { + // // Check if target address changed + // if currentTarget == target { + // needsAdd = false + // } else { + // // Target changed, remove old one first + // logger.Info("Sync: updating UDP target on port %d", port) + // targetStr := fmt.Sprintf("%d:%s", port, currentTarget) + // updateTargets(pm, "remove", wgData.TunnelIP, "udp", TargetData{Targets: []string{targetStr}}) + // } + // } + // } + // if needsAdd { + // logger.Info("Sync: adding UDP target on port %d -> %s", port, target) + // targetStr := fmt.Sprintf("%d:%s", port, target) + // updateTargets(pm, "add", wgData.TunnelIP, "udp", TargetData{Targets: []string{targetStr}}) + // } + // } + + // // Sync health check targets + // if err := healthMonitor.SyncTargets(syncData.HealthCheckTargets); err != nil { + // logger.Error("Failed to sync health check targets: %v", err) + // } else { + // logger.Info("Successfully synced health check targets") + // } + + logger.Info("Sync complete") + }) + // Register handler for Docker socket check client.RegisterHandler("newt/socket/check", func(msg websocket.WSMessage) { logger.Debug("Received Docker socket check request") @@ -1649,6 +1793,8 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey( pm.Stop() } + client.SendMessage("newt/disconnecting", map[string]any{}) + if client != nil { client.Close() } diff --git a/netstack2/proxy.go b/netstack2/proxy.go index 388a3d1..1b34818 100644 --- a/netstack2/proxy.go +++ b/netstack2/proxy.go @@ -48,6 +48,23 @@ type SubnetRule struct { PortRanges []PortRange // empty slice means all ports allowed } +// GetAllRules returns a copy of all subnet rules +func (sl *SubnetLookup) GetAllRules() []SubnetRule { + sl.mu.RLock() + defer sl.mu.RUnlock() + + var rules []SubnetRule + for _, destTriePtr := range sl.sourceTrie.All() { + if destTriePtr == nil { + continue + } + for _, rule := range destTriePtr.rules { + rules = append(rules, *rule) + } + } + return rules +} + // connKey uniquely identifies a connection for NAT tracking type connKey struct { srcIP string @@ -200,6 +217,14 @@ func (p *ProxyHandler) RemoveSubnetRule(sourcePrefix, destPrefix netip.Prefix) { p.subnetLookup.RemoveSubnet(sourcePrefix, destPrefix) } +// GetAllRules returns all subnet rules from the proxy handler +func (p *ProxyHandler) GetAllRules() []SubnetRule { + if p == nil || !p.enabled { + return nil + } + return p.subnetLookup.GetAllRules() +} + // LookupDestinationRewrite looks up the rewritten destination for a connection // This is used by TCP/UDP handlers to find the actual target address func (p *ProxyHandler) LookupDestinationRewrite(srcIP, dstIP string, dstPort uint16, proto uint8) (netip.Addr, bool) { diff --git a/netstack2/tun.go b/netstack2/tun.go index e743f1e..b00faea 100644 --- a/netstack2/tun.go +++ b/netstack2/tun.go @@ -369,6 +369,15 @@ func (net *Net) RemoveProxySubnetRule(sourcePrefix, destPrefix netip.Prefix) { } } +// GetProxySubnetRules returns all subnet rules from the proxy handler +func (net *Net) GetProxySubnetRules() []SubnetRule { + tun := (*netTun)(net) + if tun.proxyHandler != nil { + return tun.proxyHandler.GetAllRules() + } + return nil +} + // GetProxyHandler returns the proxy handler (for advanced use cases) // Returns nil if proxy is not enabled func (net *Net) GetProxyHandler() *ProxyHandler { diff --git a/proxy/manager.go b/proxy/manager.go index cef5fa6..0619e80 100644 --- a/proxy/manager.go +++ b/proxy/manager.go @@ -736,3 +736,28 @@ func (pm *ProxyManager) PrintTargets() { } } } + +// GetTargets returns a copy of the current TCP and UDP targets +// Returns map[listenIP]map[port]targetAddress for both TCP and UDP +func (pm *ProxyManager) GetTargets() (tcpTargets map[string]map[int]string, udpTargets map[string]map[int]string) { + pm.mutex.RLock() + defer pm.mutex.RUnlock() + + tcpTargets = make(map[string]map[int]string) + for listenIP, targets := range pm.tcpTargets { + tcpTargets[listenIP] = make(map[int]string) + for port, targetAddr := range targets { + tcpTargets[listenIP][port] = targetAddr + } + } + + udpTargets = make(map[string]map[int]string) + for listenIP, targets := range pm.udpTargets { + udpTargets[listenIP] = make(map[int]string) + for port, targetAddr := range targets { + udpTargets[listenIP][port] = targetAddr + } + } + + return tcpTargets, udpTargets +} diff --git a/websocket/client.go b/websocket/client.go index da1fa88..533771b 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -2,6 +2,7 @@ package websocket import ( "bytes" + "compress/gzip" "crypto/tls" "crypto/x509" "encoding/json" @@ -37,7 +38,6 @@ type Client struct { isConnected bool reconnectMux sync.RWMutex pingInterval time.Duration - pingTimeout time.Duration onConnect func() error onTokenUpdate func(token string) writeMux sync.Mutex @@ -47,6 +47,11 @@ type Client struct { metricsCtx context.Context configNeedsSave bool // Flag to track if config needs to be saved serverVersion string + configVersion int64 // Latest config version received from server + configVersionMux sync.RWMutex + processingMessage bool // Flag to track if a message is currently being processed + processingMux sync.RWMutex // Protects processingMessage + processingWg sync.WaitGroup // WaitGroup to wait for message processing to complete } type ClientOption func(*Client) @@ -111,7 +116,7 @@ func (c *Client) MetricsContext() context.Context { } // NewClient creates a new websocket client -func NewClient(clientType string, ID, secret string, endpoint string, pingInterval time.Duration, pingTimeout time.Duration, opts ...ClientOption) (*Client, error) { +func NewClient(clientType string, ID, secret string, endpoint string, pingInterval time.Duration, opts ...ClientOption) (*Client, error) { config := &Config{ ID: ID, Secret: secret, @@ -126,7 +131,6 @@ func NewClient(clientType string, ID, secret string, endpoint string, pingInterv reconnectInterval: 3 * time.Second, isConnected: false, pingInterval: pingInterval, - pingTimeout: pingTimeout, clientType: clientType, } @@ -154,6 +158,20 @@ func (c *Client) GetServerVersion() string { return c.serverVersion } +// GetConfigVersion returns the latest config version received from server +func (c *Client) GetConfigVersion() int64 { + c.configVersionMux.RLock() + defer c.configVersionMux.RUnlock() + return c.configVersion +} + +// setConfigVersion updates the config version +func (c *Client) setConfigVersion(version int64) { + c.configVersionMux.Lock() + defer c.configVersionMux.Unlock() + c.configVersion = version +} + // Connect establishes the WebSocket connection func (c *Client) Connect() error { go c.connectWithRetry() @@ -641,7 +659,57 @@ func (c *Client) setupPKCS12TLS() (*tls.Config, error) { } // pingMonitor sends pings at a short interval and triggers reconnect on failure +func (c *Client) sendPing() { + if c.conn == nil { + return + } + + // Skip ping if a message is currently being processed + c.processingMux.RLock() + isProcessing := c.processingMessage + c.processingMux.RUnlock() + if isProcessing { + logger.Debug("Skipping ping, message is being processed") + return + } + + c.configVersionMux.RLock() + configVersion := c.configVersion + c.configVersionMux.RUnlock() + + pingMsg := WSMessage{ + Type: "newt/ping", + Data: map[string]interface{}{}, + ConfigVersion: configVersion, + } + + c.writeMux.Lock() + err := c.conn.WriteJSON(pingMsg) + if err == nil { + telemetry.IncWSMessage(c.metricsContext(), "out", "ping") + } + c.writeMux.Unlock() + + if err != nil { + // Check if we're shutting down before logging error and reconnecting + select { + case <-c.done: + // Expected during shutdown + return + default: + logger.Error("Ping failed: %v", err) + telemetry.IncWSKeepaliveFailure(c.metricsContext(), "ping_write") + telemetry.IncWSReconnect(c.metricsContext(), "ping_write") + c.reconnect() + return + } + } +} + func (c *Client) pingMonitor() { + // Send an immediate ping as soon as we connect + c.sendPing() + ticker := time.NewTicker(c.pingInterval) defer ticker.Stop() @@ -650,29 +718,7 @@ func (c *Client) pingMonitor() { case <-c.done: return case <-ticker.C: - if c.conn == nil { - return - } - c.writeMux.Lock() - err := c.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(c.pingTimeout)) - if err == nil { - telemetry.IncWSMessage(c.metricsContext(), "out", "ping") - } - c.writeMux.Unlock() - if err != nil { - // Check if we're shutting down before logging error and reconnecting - select { - case <-c.done: - // Expected during shutdown - return - default: - logger.Error("Ping failed: %v", err) - telemetry.IncWSKeepaliveFailure(c.metricsContext(), "ping_write") - telemetry.IncWSReconnect(c.metricsContext(), "ping_write") - c.reconnect() - return - } - } + c.sendPing() } } } @@ -709,10 +755,13 @@ func (c *Client) readPumpWithDisconnectDetection(started time.Time) { disconnectResult = "success" return default: - var msg WSMessage - err := c.conn.ReadJSON(&msg) + msgType, p, err := c.conn.ReadMessage() if err == nil { - telemetry.IncWSMessage(c.metricsContext(), "in", "text") + if msgType == websocket.BinaryMessage { + telemetry.IncWSMessage(c.metricsContext(), "in", "binary") + } else { + telemetry.IncWSMessage(c.metricsContext(), "in", "text") + } } if err != nil { // Check if we're shutting down before logging error @@ -737,9 +786,47 @@ func (c *Client) readPumpWithDisconnectDetection(started time.Time) { } } + // Update config version from incoming message + var data []byte + if msgType == websocket.BinaryMessage { + gr, err := gzip.NewReader(bytes.NewReader(p)) + if err != nil { + logger.Error("WebSocket failed to create gzip reader: %v", err) + continue + } + data, err = io.ReadAll(gr) + gr.Close() + if err != nil { + logger.Error("WebSocket failed to decompress message: %v", err) + continue + } + } else { + data = p + } + + var msg WSMessage + if err = json.Unmarshal(data, &msg); err != nil { + logger.Error("WebSocket failed to parse message: %v", err) + continue + } + + c.setConfigVersion(msg.ConfigVersion) + c.handlersMux.RLock() if handler, ok := c.handlers[msg.Type]; ok { + // Mark that we're processing a message + c.processingMux.Lock() + c.processingMessage = true + c.processingMux.Unlock() + c.processingWg.Add(1) + handler(msg) + + // Mark that we're done processing + c.processingWg.Done() + c.processingMux.Lock() + c.processingMessage = false + c.processingMux.Unlock() } c.handlersMux.RUnlock() } diff --git a/websocket/types.go b/websocket/types.go index 1196d64..381f7a1 100644 --- a/websocket/types.go +++ b/websocket/types.go @@ -17,6 +17,7 @@ type TokenResponse struct { } type WSMessage struct { - Type string `json:"type"` - Data interface{} `json:"data"` + Type string `json:"type"` + Data interface{} `json:"data"` + ConfigVersion int64 `json:"configVersion,omitempty"` }