diff --git a/common.go b/common.go index 5fe0645..4701411 100644 --- a/common.go +++ b/common.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "fmt" + "net" "os" "os/exec" "strings" @@ -363,27 +364,62 @@ func parseTargetData(data interface{}) (TargetData, error) { return targetData, nil } +// parseTargetString parses a target string in the format "listenPort:host:targetPort" +// It properly handles IPv6 addresses which must be in brackets: "listenPort:[ipv6]:targetPort" +// Examples: +// - IPv4: "3001:192.168.1.1:80" +// - IPv6: "3001:[::1]:8080" or "3001:[fd70:1452:b736:4dd5:caca:7db9:c588:f5b3]:80" +// +// Returns listenPort, targetAddress (in host:port format suitable for net.Dial), and error +func parseTargetString(target string) (int, string, error) { + // Find the first colon to extract the listen port + firstColon := strings.Index(target, ":") + if firstColon == -1 { + return 0, "", fmt.Errorf("invalid target format, no colon found: %s", target) + } + + listenPortStr := target[:firstColon] + var listenPort int + _, err := fmt.Sscanf(listenPortStr, "%d", &listenPort) + if err != nil { + return 0, "", fmt.Errorf("invalid listen port: %s", listenPortStr) + } + if listenPort <= 0 || listenPort > 65535 { + return 0, "", fmt.Errorf("listen port out of range: %d", listenPort) + } + + // The remainder is host:targetPort - use net.SplitHostPort which handles IPv6 brackets + remainder := target[firstColon+1:] + host, targetPort, err := net.SplitHostPort(remainder) + if err != nil { + return 0, "", fmt.Errorf("invalid host:port format '%s': %w", remainder, err) + } + + // Reject empty host or target port + if host == "" { + return 0, "", fmt.Errorf("empty host in target: %s", target) + } + if targetPort == "" { + return 0, "", fmt.Errorf("empty target port in target: %s", target) + } + + // Reconstruct the target address using JoinHostPort (handles IPv6 properly) + targetAddr := net.JoinHostPort(host, targetPort) + + return listenPort, targetAddr, nil +} + func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error { for _, t := range targetData.Targets { - // Split the first number off of the target with : separator and use as the port - parts := strings.Split(t, ":") - if len(parts) != 3 { - logger.Info("Invalid target format: %s", t) - continue - } - - // Get the port as an int - port := 0 - _, err := fmt.Sscanf(parts[0], "%d", &port) + // Parse the target string, handling both IPv4 and IPv6 addresses + port, target, err := parseTargetString(t) if err != nil { - logger.Info("Invalid port: %s", parts[0]) + logger.Info("Invalid target format: %s (%v)", t, err) continue } switch action { case "add": - target := parts[1] + ":" + parts[2] - // Call updown script if provided processedTarget := target if updownScript != "" { @@ -410,8 +446,6 @@ func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto case "remove": logger.Info("Removing target with port %d", port) - target := parts[1] + ":" + parts[2] - // Call updown script if provided if updownScript != "" { _, err := executeUpdownScript(action, proto, target) @@ -420,7 +454,7 @@ func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto } } - err := pm.RemoveTarget(proto, tunnelIP, port) + err = pm.RemoveTarget(proto, tunnelIP, port) if err != nil { logger.Error("Failed to remove target: %v", err) return err diff --git a/common_test.go b/common_test.go new file mode 100644 index 0000000..a7e659a --- /dev/null +++ b/common_test.go @@ -0,0 +1,212 @@ +package main + +import ( + "net" + "testing" +) + +func TestParseTargetString(t *testing.T) { + tests := []struct { + name string + input string + wantListenPort int + wantTargetAddr string + wantErr bool + }{ + // IPv4 test cases + { + name: "valid IPv4 basic", + input: "3001:192.168.1.1:80", + wantListenPort: 3001, + wantTargetAddr: "192.168.1.1:80", + wantErr: false, + }, + { + name: "valid IPv4 localhost", + input: "8080:127.0.0.1:3000", + wantListenPort: 8080, + wantTargetAddr: "127.0.0.1:3000", + wantErr: false, + }, + { + name: "valid IPv4 same ports", + input: "443:10.0.0.1:443", + wantListenPort: 443, + wantTargetAddr: "10.0.0.1:443", + wantErr: false, + }, + + // IPv6 test cases + { + name: "valid IPv6 loopback", + input: "3001:[::1]:8080", + wantListenPort: 3001, + wantTargetAddr: "[::1]:8080", + wantErr: false, + }, + { + name: "valid IPv6 full address", + input: "80:[fd70:1452:b736:4dd5:caca:7db9:c588:f5b3]:8080", + wantListenPort: 80, + wantTargetAddr: "[fd70:1452:b736:4dd5:caca:7db9:c588:f5b3]:8080", + wantErr: false, + }, + { + name: "valid IPv6 link-local", + input: "443:[fe80::1]:443", + wantListenPort: 443, + wantTargetAddr: "[fe80::1]:443", + wantErr: false, + }, + { + name: "valid IPv6 all zeros compressed", + input: "8000:[::]:9000", + wantListenPort: 8000, + wantTargetAddr: "[::]:9000", + wantErr: false, + }, + { + name: "valid IPv6 mixed notation", + input: "5000:[::ffff:192.168.1.1]:6000", + wantListenPort: 5000, + wantTargetAddr: "[::ffff:192.168.1.1]:6000", + wantErr: false, + }, + + // Hostname test cases + { + name: "valid hostname", + input: "8080:example.com:80", + wantListenPort: 8080, + wantTargetAddr: "example.com:80", + wantErr: false, + }, + { + name: "valid hostname with subdomain", + input: "443:api.example.com:8443", + wantListenPort: 443, + wantTargetAddr: "api.example.com:8443", + wantErr: false, + }, + { + name: "valid localhost hostname", + input: "3000:localhost:3000", + wantListenPort: 3000, + wantTargetAddr: "localhost:3000", + wantErr: false, + }, + + // Error cases + { + name: "invalid - no colons", + input: "invalid", + wantErr: true, + }, + { + name: "invalid - empty string", + input: "", + wantErr: true, + }, + { + name: "invalid - non-numeric listen port", + input: "abc:192.168.1.1:80", + wantErr: true, + }, + { + name: "invalid - missing target port", + input: "3001:192.168.1.1", + wantErr: true, + }, + { + name: "invalid - IPv6 without brackets", + input: "3001:fd70:1452:b736:4dd5:caca:7db9:c588:f5b3:80", + wantErr: true, + }, + { + name: "invalid - only listen port", + input: "3001:", + wantErr: true, + }, + { + name: "invalid - missing host", + input: "3001::80", + wantErr: true, + }, + { + name: "invalid - IPv6 unclosed bracket", + input: "3001:[::1:80", + wantErr: true, + }, + { + name: "invalid - listen port zero", + input: "0:192.168.1.1:80", + wantErr: true, + }, + { + name: "invalid - listen port negative", + input: "-1:192.168.1.1:80", + wantErr: true, + }, + { + name: "invalid - listen port out of range", + input: "70000:192.168.1.1:80", + wantErr: true, + }, + { + name: "invalid - empty target port", + input: "3001:192.168.1.1:", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + listenPort, targetAddr, err := parseTargetString(tt.input) + + if (err != nil) != tt.wantErr { + t.Errorf("parseTargetString(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr) + return + } + + if tt.wantErr { + return // Don't check other values if we expected an error + } + + if listenPort != tt.wantListenPort { + t.Errorf("parseTargetString(%q) listenPort = %d, want %d", tt.input, listenPort, tt.wantListenPort) + } + + if targetAddr != tt.wantTargetAddr { + t.Errorf("parseTargetString(%q) targetAddr = %q, want %q", tt.input, targetAddr, tt.wantTargetAddr) + } + }) + } +} + +// TestParseTargetStringNetDialCompatibility verifies that the output is compatible with net.Dial +func TestParseTargetStringNetDialCompatibility(t *testing.T) { + tests := []struct { + name string + input string + }{ + {"IPv4", "8080:127.0.0.1:80"}, + {"IPv6 loopback", "8080:[::1]:80"}, + {"IPv6 full", "8080:[2001:db8::1]:80"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, targetAddr, err := parseTargetString(tt.input) + if err != nil { + t.Fatalf("parseTargetString(%q) unexpected error: %v", tt.input, err) + } + + // Verify the format is valid for net.Dial by checking it can be split back + // This doesn't actually dial, just validates the format + _, _, err = net.SplitHostPort(targetAddr) + if err != nil { + t.Errorf("parseTargetString(%q) produced invalid net.Dial format %q: %v", tt.input, targetAddr, err) + } + }) + } +}