diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 0000000..c5f1403 --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1 @@ +* @oschwartz10612 @miloschwartz diff --git a/Makefile b/Makefile index c35bbbf..53c4bb2 100644 --- a/Makefile +++ b/Makefile @@ -6,7 +6,7 @@ VERSION ?= dev LDFLAGS = -X main.newtVersion=$(VERSION) local: - CGO_ENABLED=0 go build -o ./bin/newt + CGO_ENABLED=0 go build -ldflags "$(LDFLAGS)" -o ./bin/newt docker-build: docker build -t fosrl/newt:latest . diff --git a/blueprint.yaml b/blueprint.yaml deleted file mode 100644 index 0465f00..0000000 --- a/blueprint.yaml +++ /dev/null @@ -1,37 +0,0 @@ -resources: - resource-nice-id: - name: this is my resource - protocol: http - full-domain: level1.test3.example.com - host-header: example.com - tls-server-name: example.com - auth: - pincode: 123456 - password: sadfasdfadsf - sso-enabled: true - sso-roles: - - Member - sso-users: - - owen@pangolin.net - whitelist-users: - - owen@pangolin.net - targets: - # - site: glossy-plains-viscacha-rat - - hostname: localhost - method: http - port: 8000 - healthcheck: - port: 8000 - hostname: localhost - # - site: glossy-plains-viscacha-rat - - hostname: localhost - method: http - port: 8001 - resource-nice-id2: - name: this is other resource - protocol: tcp - proxy-port: 3000 - targets: - # - site: glossy-plains-viscacha-rat - - hostname: localhost - port: 3000 \ No newline at end of file diff --git a/clients/clients.go b/clients/clients.go index 78bc0c3..3862160 100644 --- a/clients/clients.go +++ b/clients/clients.go @@ -40,13 +40,17 @@ type WgConfig struct { } type Target struct { - SourcePrefix string `json:"sourcePrefix"` - SourcePrefixes []string `json:"sourcePrefixes"` - DestPrefix string `json:"destPrefix"` - RewriteTo string `json:"rewriteTo,omitempty"` - DisableIcmp bool `json:"disableIcmp,omitempty"` - PortRange []PortRange `json:"portRange,omitempty"` - ResourceId int `json:"resourceId,omitempty"` + SourcePrefix string `json:"sourcePrefix"` + SourcePrefixes []string `json:"sourcePrefixes"` + DestPrefix string `json:"destPrefix"` + RewriteTo string `json:"rewriteTo,omitempty"` + DisableIcmp bool `json:"disableIcmp,omitempty"` + PortRange []PortRange `json:"portRange,omitempty"` + ResourceId int `json:"resourceId,omitempty"` + Protocol string `json:"protocol,omitempty"` // for now practicably either http or https + HTTPTargets []netstack2.HTTPTarget `json:"httpTargets,omitempty"` // for http protocol, list of downstream services to load balance across + TLSCert string `json:"tlsCert,omitempty"` // PEM-encoded certificate for incoming HTTPS termination + TLSKey string `json:"tlsKey,omitempty"` // PEM-encoded private key for incoming HTTPS termination } type PortRange struct { @@ -74,18 +78,18 @@ type PeerReading struct { } type WireGuardService struct { - interfaceName string - mtu int - client *websocket.Client - config WgConfig - key wgtypes.Key - newtId string - lastReadings map[string]PeerReading - mu sync.Mutex - Port uint16 - host string - serverPubKey string - token string + interfaceName string + mtu int + client *websocket.Client + config WgConfig + key wgtypes.Key + newtId string + lastReadings map[string]PeerReading + mu sync.Mutex + Port uint16 + host string + serverPubKey string + token string stopGetConfig func() pendingConfigChainId string // Netstack fields @@ -697,7 +701,18 @@ func (s *WireGuardService) syncTargets(desiredTargets []Target) error { }) } - s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp, target.ResourceId) + s.tnet.AddProxySubnetRule(netstack2.SubnetRule{ + SourcePrefix: sourcePrefix, + DestPrefix: destPrefix, + RewriteTo: target.RewriteTo, + PortRanges: portRanges, + DisableIcmp: target.DisableIcmp, + ResourceId: target.ResourceId, + Protocol: target.Protocol, + HTTPTargets: target.HTTPTargets, + TLSCert: target.TLSCert, + TLSKey: target.TLSKey, + }) logger.Info("Added target %s -> %s during sync", target.SourcePrefix, target.DestPrefix) } } @@ -835,6 +850,13 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error { }) }) + // Configure the HTTP request log sender to ship compressed request logs via websocket + s.tnet.SetHTTPRequestLogSender(func(data string) error { + return s.client.SendMessageNoLog("newt/request-log", map[string]interface{}{ + "compressed": data, + }) + }) + // Create WireGuard device using the shared bind s.device = device.NewDevice(s.tun, s.sharedBind, device.NewLogger( device.LogLevelSilent, // Use silent logging by default - could be made configurable @@ -955,7 +977,18 @@ func (s *WireGuardService) ensureTargets(targets []Target) error { if err != nil { return fmt.Errorf("invalid CIDR %s: %v", sp, err) } - s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp, target.ResourceId) + s.tnet.AddProxySubnetRule(netstack2.SubnetRule{ + SourcePrefix: sourcePrefix, + DestPrefix: destPrefix, + RewriteTo: target.RewriteTo, + PortRanges: portRanges, + DisableIcmp: target.DisableIcmp, + ResourceId: target.ResourceId, + Protocol: target.Protocol, + HTTPTargets: target.HTTPTargets, + TLSCert: target.TLSCert, + TLSKey: target.TLSKey, + }) logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange) } } @@ -1348,7 +1381,18 @@ func (s *WireGuardService) handleAddTarget(msg websocket.WSMessage) { logger.Info("Invalid CIDR %s: %v", sp, err) continue } - s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp, target.ResourceId) + s.tnet.AddProxySubnetRule(netstack2.SubnetRule{ + SourcePrefix: sourcePrefix, + DestPrefix: destPrefix, + RewriteTo: target.RewriteTo, + PortRanges: portRanges, + DisableIcmp: target.DisableIcmp, + ResourceId: target.ResourceId, + Protocol: target.Protocol, + HTTPTargets: target.HTTPTargets, + TLSCert: target.TLSCert, + TLSKey: target.TLSKey, + }) logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange) } } @@ -1466,7 +1510,18 @@ func (s *WireGuardService) handleUpdateTarget(msg websocket.WSMessage) { logger.Info("Invalid CIDR %s: %v", sp, err) continue } - s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp, target.ResourceId) + s.tnet.AddProxySubnetRule(netstack2.SubnetRule{ + SourcePrefix: sourcePrefix, + DestPrefix: destPrefix, + RewriteTo: target.RewriteTo, + PortRanges: portRanges, + DisableIcmp: target.DisableIcmp, + ResourceId: target.ResourceId, + Protocol: target.Protocol, + HTTPTargets: target.HTTPTargets, + TLSCert: target.TLSCert, + TLSKey: target.TLSKey, + }) logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange) } } diff --git a/get-newt.sh b/get-newt.sh index d4ddd3f..77df9ed 100644 --- a/get-newt.sh +++ b/get-newt.sh @@ -30,41 +30,38 @@ print_error() { # Function to get latest version from GitHub API get_latest_version() { - local latest_info - + latest_info="" + if command -v curl >/dev/null 2>&1; then latest_info=$(curl -fsSL "$GITHUB_API_URL" 2>/dev/null) elif command -v wget >/dev/null 2>&1; then latest_info=$(wget -qO- "$GITHUB_API_URL" 2>/dev/null) else - print_error "Neither curl nor wget is available. Please install one of them." >&2 + print_error "Neither curl nor wget is available." exit 1 fi - + if [ -z "$latest_info" ]; then - print_error "Failed to fetch latest version information" >&2 + print_error "Failed to fetch latest version info" exit 1 fi - - # Extract version from JSON response (works without jq) - local version=$(echo "$latest_info" | grep '"tag_name"' | head -1 | sed 's/.*"tag_name": *"\([^"]*\)".*/\1/') - + + version=$(printf '%s' "$latest_info" | grep '"tag_name"' | head -1 | sed 's/.*"tag_name": *"\([^"]*\)".*/\1/') + if [ -z "$version" ]; then - print_error "Could not parse version from GitHub API response" >&2 + print_error "Could not parse version from GitHub API response" exit 1 fi - - # Remove 'v' prefix if present - version=$(echo "$version" | sed 's/^v//') - - echo "$version" + + version=$(printf '%s' "$version" | sed 's/^v//') + printf '%s' "$version" } # Detect OS and architecture detect_platform() { - local os arch - - # Detect OS + os="" + arch="" + case "$(uname -s)" in Linux*) os="linux" ;; Darwin*) os="darwin" ;; @@ -75,12 +72,11 @@ detect_platform() { exit 1 ;; esac - - # Detect architecture + case "$(uname -m)" in x86_64|amd64) arch="amd64" ;; arm64|aarch64) arch="arm64" ;; - armv7l|armv6l) + armv7l|armv6l) if [ "$os" = "linux" ]; then if [ "$(uname -m)" = "armv6l" ]; then arch="arm32v6" @@ -88,10 +84,10 @@ detect_platform() { arch="arm32" fi else - arch="arm64" # Default for non-Linux ARM + arch="arm64" fi ;; - riscv64) + riscv64) if [ "$os" = "linux" ]; then arch="riscv64" else @@ -104,23 +100,68 @@ detect_platform() { exit 1 ;; esac - - echo "${os}_${arch}" + + printf '%s_%s' "$os" "$arch" } -# Get installation directory +# Determine installation directory (default fallback) get_install_dir() { - if [ "$OS" = "windows" ]; then - echo "$HOME/bin" - else - # Prefer /usr/local/bin for system-wide installation - echo "/usr/local/bin" + case "$PLATFORM" in + *windows*) + echo "$HOME/bin" + ;; + *) + echo "/usr/local/bin" + ;; + esac +} + +# Parse --path argument from args +# Returns the value after --path, or empty string if not provided +parse_path_arg() { + while [ $# -gt 0 ]; do + case "$1" in + --path) + if [ -n "$2" ]; then + printf '%s' "$2" + return + fi + ;; + --path=*) + printf '%s' "${1#--path=}" + return + ;; + esac + shift + done +} + +# Detect an existing newt binary location. +# Tries unprivileged which first, then sudo which (for binaries only visible to root). +# Returns the full path of the binary, or empty string if not found. +detect_existing_binary() { + existing="" + + # Try unprivileged which first + existing=$(command -v newt 2>/dev/null || true) + if [ -n "$existing" ]; then + printf '%s' "$existing" + return + fi + + # Try sudo which — some installations land in paths only root can see in $PATH + if command -v sudo >/dev/null 2>&1; then + existing=$(sudo which newt 2>/dev/null || true) + if [ -n "$existing" ]; then + printf '%s' "$existing" + return + fi fi } # Check if we need sudo for installation needs_sudo() { - local install_dir="$1" + install_dir="$1" if [ -w "$install_dir" ] 2>/dev/null; then return 1 # No sudo needed else @@ -130,7 +171,7 @@ needs_sudo() { # Get the appropriate command prefix (sudo or empty) get_sudo_cmd() { - local install_dir="$1" + install_dir="$1" if needs_sudo "$install_dir"; then if command -v sudo >/dev/null 2>&1; then echo "sudo" @@ -146,40 +187,46 @@ get_sudo_cmd() { # Download and install newt install_newt() { - local platform="$1" - local install_dir="$2" - local sudo_cmd="$3" - local binary_name="newt_${platform}" - local exe_suffix="" + platform="$1" + install_dir="$2" + sudo_cmd="$3" + custom_path="$4" + binary_name="newt_${platform}" + final_name="newt" - # Add .exe suffix for Windows case "$platform" in *windows*) binary_name="${binary_name}.exe" - exe_suffix=".exe" + final_name="newt.exe" ;; esac - local download_url="${BASE_URL}/${binary_name}" - local temp_file="/tmp/newt${exe_suffix}" - local final_path="${install_dir}/newt${exe_suffix}" + download_url="${BASE_URL}/${binary_name}" + temp_file="/tmp/${final_name}" + + # If a custom path is provided, use it directly; otherwise use install_dir/final_name + if [ -n "$custom_path" ]; then + final_path="$custom_path" + install_dir=$(dirname "$final_path") + else + final_path="${install_dir}/${final_name}" + fi print_status "Downloading newt from ${download_url}" - # Download the binary if command -v curl >/dev/null 2>&1; then curl -fsSL "$download_url" -o "$temp_file" elif command -v wget >/dev/null 2>&1; then wget -q "$download_url" -O "$temp_file" else - print_error "Neither curl nor wget is available. Please install one of them." + print_error "Neither curl nor wget is available." exit 1 fi # Make executable before moving chmod +x "$temp_file" - # Create install directory if it doesn't exist + # Create install directory if it doesn't exist and move binary if [ -n "$sudo_cmd" ]; then $sudo_cmd mkdir -p "$install_dir" print_status "Using sudo to install to ${install_dir}" @@ -194,25 +241,25 @@ install_newt() { # Check if install directory is in PATH if ! echo "$PATH" | grep -q "$install_dir"; then print_warning "Install directory ${install_dir} is not in your PATH." - print_warning "Add it to your PATH by adding this line to your shell profile:" + print_warning "Add it with:" print_warning " export PATH=\"${install_dir}:\$PATH\"" fi } # Verify installation verify_installation() { - local install_dir="$1" - local exe_suffix="" - + install_dir="$1" + exe_suffix="" + case "$PLATFORM" in *windows*) exe_suffix=".exe" ;; esac - - local newt_path="${install_dir}/newt${exe_suffix}" - - if [ -f "$newt_path" ] && [ -x "$newt_path" ]; then + + newt_path="${install_dir}/newt${exe_suffix}" + + if [ -x "$newt_path" ]; then print_status "Installation successful!" - print_status "newt version: $("$newt_path" --version 2>/dev/null || echo "unknown")" + print_status "newt version: $("$newt_path" --version 2>/dev/null || printf 'unknown')" return 0 else print_error "Installation failed. Binary not found or not executable." @@ -222,22 +269,40 @@ verify_installation() { # Main installation process main() { - print_status "Installing latest version of newt..." + # --path explicitly overrides everything + CUSTOM_PATH=$(parse_path_arg "$@") - # Get latest version - print_status "Fetching latest version from GitHub..." + if [ -n "$CUSTOM_PATH" ]; then + print_status "Installing latest version of newt to ${CUSTOM_PATH} (--path override)..." + else + print_status "Installing latest version of newt..." + fi + + print_status "Fetching latest version..." VERSION=$(get_latest_version) print_status "Latest version: v${VERSION}" - # Set base URL with the fetched version BASE_URL="https://github.com/${REPO}/releases/download/${VERSION}" - # Detect platform PLATFORM=$(detect_platform) print_status "Detected platform: ${PLATFORM}" - # Get install directory - INSTALL_DIR=$(get_install_dir) + if [ -n "$CUSTOM_PATH" ]; then + # --path wins; derive INSTALL_DIR from it + INSTALL_DIR=$(dirname "$CUSTOM_PATH") + else + # Try to find an existing installation so we update the right place + EXISTING_BINARY=$(detect_existing_binary) + if [ -n "$EXISTING_BINARY" ]; then + print_status "Found existing newt binary at ${EXISTING_BINARY}" + CUSTOM_PATH="$EXISTING_BINARY" + INSTALL_DIR=$(dirname "$EXISTING_BINARY") + print_status "Will update existing installation at ${INSTALL_DIR}" + else + INSTALL_DIR=$(get_install_dir) + fi + fi + print_status "Install directory: ${INSTALL_DIR}" # Check if we need sudo @@ -246,13 +311,20 @@ main() { print_status "Root privileges required for installation to ${INSTALL_DIR}" fi - # Install newt - install_newt "$PLATFORM" "$INSTALL_DIR" "$SUDO_CMD" + install_newt "$PLATFORM" "$INSTALL_DIR" "$SUDO_CMD" "$CUSTOM_PATH" - # Verify installation - if verify_installation "$INSTALL_DIR"; then + if [ -n "$CUSTOM_PATH" ]; then + if [ -x "$CUSTOM_PATH" ]; then + print_status "Installation successful!" + print_status "newt version: $("$CUSTOM_PATH" --version 2>/dev/null || printf 'unknown')" + print_status "newt is ready to use!" + else + print_error "Installation failed. Binary not found or not executable at ${CUSTOM_PATH}." + exit 1 + fi + elif verify_installation "$INSTALL_DIR"; then print_status "newt is ready to use!" - print_status "Run 'newt --help' to get started" + print_status "Run 'newt --help' to get started." else exit 1 fi diff --git a/healthcheck/healthcheck.go b/healthcheck/healthcheck.go index f618803..a7f0b6a 100644 --- a/healthcheck/healthcheck.go +++ b/healthcheck/healthcheck.go @@ -37,33 +37,38 @@ func (s Health) String() string { // Config holds the health check configuration for a target type Config struct { - ID int `json:"id"` - Enabled bool `json:"hcEnabled"` - Path string `json:"hcPath"` - Scheme string `json:"hcScheme"` - Mode string `json:"hcMode"` - Hostname string `json:"hcHostname"` - Port int `json:"hcPort"` - Interval int `json:"hcInterval"` // in seconds - UnhealthyInterval int `json:"hcUnhealthyInterval"` // in seconds - Timeout int `json:"hcTimeout"` // in seconds - Headers map[string]string `json:"hcHeaders"` - Method string `json:"hcMethod"` - Status int `json:"hcStatus"` // HTTP status code - TLSServerName string `json:"hcTlsServerName"` + ID int `json:"id"` + Enabled bool `json:"hcEnabled"` + Path string `json:"hcPath"` + Scheme string `json:"hcScheme"` + Mode string `json:"hcMode"` + Hostname string `json:"hcHostname"` + Port int `json:"hcPort"` + Interval int `json:"hcInterval"` // in seconds + UnhealthyInterval int `json:"hcUnhealthyInterval"` // in seconds + Timeout int `json:"hcTimeout"` // in seconds + FollowRedirects bool `json:"hcFollowRedirects"` + Headers map[string]string `json:"hcHeaders"` + Method string `json:"hcMethod"` + Status int `json:"hcStatus"` // HTTP status code + TLSServerName string `json:"hcTlsServerName"` + HealthyThreshold int `json:"hcHealthyThreshold"` // consecutive successes required to become healthy + UnhealthyThreshold int `json:"hcUnhealthyThreshold"` // consecutive failures required to become unhealthy } // Target represents a health check target with its current status type Target struct { - Config Config `json:"config"` - Status Health `json:"status"` - LastCheck time.Time `json:"lastCheck"` - LastError string `json:"lastError,omitempty"` - CheckCount int `json:"checkCount"` - timer *time.Timer - ctx context.Context - cancel context.CancelFunc - client *http.Client + Config Config `json:"config"` + Status Health `json:"status"` + LastCheck time.Time `json:"lastCheck"` + LastError string `json:"lastError,omitempty"` + CheckCount int `json:"checkCount"` + timer *time.Timer + ctx context.Context + cancel context.CancelFunc + client *http.Client + consecutiveSuccesses int + consecutiveFailures int } // StatusChangeCallback is called when any target's status changes @@ -165,9 +170,16 @@ func (m *Monitor) addTargetUnsafe(config Config) error { if config.Timeout == 0 { config.Timeout = 5 } + if config.HealthyThreshold == 0 { + config.HealthyThreshold = 1 + } + if config.UnhealthyThreshold == 0 { + config.UnhealthyThreshold = 1 + } - logger.Debug("Target %d configuration: scheme=%s, method=%s, interval=%ds, timeout=%ds", - config.ID, config.Scheme, config.Method, config.Interval, config.Timeout) + logger.Debug("Target %d configuration: mode=%s, scheme=%s, method=%s, interval=%ds, timeout=%ds, healthyThreshold=%d, unhealthyThreshold=%d", + config.ID, config.Mode, config.Scheme, config.Method, config.Interval, config.Timeout, + config.HealthyThreshold, config.UnhealthyThreshold) // Parse headers if provided as string if len(config.Headers) == 0 && config.Path != "" { @@ -189,6 +201,14 @@ func (m *Monitor) addTargetUnsafe(config Config) error { ctx: ctx, cancel: cancel, client: &http.Client{ + CheckRedirect: func() func(*http.Request, []*http.Request) error { + if !config.FollowRedirects { + return func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + } + } + return nil + }(), Transport: &http.Transport{ TLSClientConfig: &tls.Config{ // Configure TLS settings based on certificate enforcement @@ -361,12 +381,69 @@ func (m *Monitor) monitorTarget(target *Target) { } } -// performHealthCheck performs a health check on a target +// performHealthCheck performs a health check on a target and applies threshold logic func (m *Monitor) performHealthCheck(target *Target) { target.CheckCount++ target.LastCheck = time.Now() - target.LastError = "" + var passed bool + var checkErr string + + switch strings.ToLower(target.Config.Mode) { + case "tcp": + passed, checkErr = m.performTCPCheck(target) + default: + // "http", "https", or anything else falls through to HTTP + passed, checkErr = m.performHTTPCheck(target) + } + + if passed { + target.consecutiveFailures = 0 + target.consecutiveSuccesses++ + + logger.Debug("Target %d: check passed (consecutive successes: %d / threshold: %d)", + target.Config.ID, target.consecutiveSuccesses, target.Config.HealthyThreshold) + + if target.consecutiveSuccesses >= target.Config.HealthyThreshold { + target.Status = StatusHealthy + target.LastError = "" + } + } else { + target.consecutiveSuccesses = 0 + target.consecutiveFailures++ + target.LastError = checkErr + + logger.Debug("Target %d: check failed (consecutive failures: %d / threshold: %d): %s", + target.Config.ID, target.consecutiveFailures, target.Config.UnhealthyThreshold, checkErr) + + if target.consecutiveFailures >= target.Config.UnhealthyThreshold { + target.Status = StatusUnhealthy + } + } +} + +// performTCPCheck dials the target's host:port over TCP and returns whether it succeeded +func (m *Monitor) performTCPCheck(target *Target) (bool, string) { + address := net.JoinHostPort(target.Config.Hostname, strconv.Itoa(target.Config.Port)) + timeout := time.Duration(target.Config.Timeout) * time.Second + + logger.Debug("Target %d: performing TCP health check to %s (timeout: %v)", + target.Config.ID, address, timeout) + + conn, err := net.DialTimeout("tcp", address, timeout) + if err != nil { + msg := fmt.Sprintf("TCP dial failed: %v", err) + logger.Warn("Target %d: %s", target.Config.ID, msg) + return false, msg + } + conn.Close() + + logger.Debug("Target %d: TCP health check passed", target.Config.ID) + return true, "" +} + +// performHTTPCheck performs an HTTP/HTTPS health check and returns whether it succeeded +func (m *Monitor) performHTTPCheck(target *Target) (bool, string) { // Build URL (use net.JoinHostPort to properly handle IPv6 addresses with ports) host := target.Config.Hostname if target.Config.Port > 0 { @@ -380,7 +457,7 @@ func (m *Monitor) performHealthCheck(target *Target) { url += target.Config.Path } - logger.Debug("Target %d: performing health check %d to %s", + logger.Debug("Target %d: performing HTTP health check %d to %s", target.Config.ID, target.CheckCount, url) if target.Config.Scheme == "https" { @@ -388,16 +465,15 @@ func (m *Monitor) performHealthCheck(target *Target) { target.Config.ID, m.enforceCert) } - // Create request + // Create request with timeout context ctx, cancel := context.WithTimeout(context.Background(), time.Duration(target.Config.Timeout)*time.Second) defer cancel() req, err := http.NewRequestWithContext(ctx, target.Config.Method, url, nil) if err != nil { - target.Status = StatusUnhealthy - target.LastError = fmt.Sprintf("failed to create request: %v", err) - logger.Warn("Target %d: failed to create request: %v", target.Config.ID, err) - return + msg := fmt.Sprintf("failed to create request: %v", err) + logger.Warn("Target %d: %s", target.Config.ID, msg) + return false, msg } // Add headers @@ -413,43 +489,34 @@ func (m *Monitor) performHealthCheck(target *Target) { // Perform request resp, err := target.client.Do(req) if err != nil { - target.Status = StatusUnhealthy - target.LastError = fmt.Sprintf("request failed: %v", err) + msg := fmt.Sprintf("request failed: %v", err) logger.Warn("Target %d: health check failed: %v", target.Config.ID, err) - return + return false, msg } defer resp.Body.Close() // Check response status - var expectedStatus int if target.Config.Status > 0 { - expectedStatus = target.Config.Status - } else { - expectedStatus = 0 // Use range check for 200-299 + // Check for specific status code + logger.Debug("Target %d: checking status against expected code %d", target.Config.ID, target.Config.Status) + if resp.StatusCode == target.Config.Status { + logger.Debug("Target %d: health check passed (status: %d)", target.Config.ID, resp.StatusCode) + return true, "" + } + msg := fmt.Sprintf("unexpected status code: %d (expected: %d)", resp.StatusCode, target.Config.Status) + logger.Warn("Target %d: %s", target.Config.ID, msg) + return false, msg } - if expectedStatus > 0 { - logger.Debug("Target %d: checking health status against expected code %d", target.Config.ID, expectedStatus) - // Check for specific status code - if resp.StatusCode == expectedStatus { - target.Status = StatusHealthy - logger.Debug("Target %d: health check passed (status: %d, expected: %d)", target.Config.ID, resp.StatusCode, expectedStatus) - } else { - target.Status = StatusUnhealthy - target.LastError = fmt.Sprintf("unexpected status code: %d (expected: %d)", resp.StatusCode, expectedStatus) - logger.Warn("Target %d: health check failed with status code %d (expected: %d)", target.Config.ID, resp.StatusCode, expectedStatus) - } - } else { - // Check for 2xx range - if resp.StatusCode >= 200 && resp.StatusCode < 300 { - target.Status = StatusHealthy - logger.Debug("Target %d: health check passed (status: %d)", target.Config.ID, resp.StatusCode) - } else { - target.Status = StatusUnhealthy - target.LastError = fmt.Sprintf("unhealthy status code: %d", resp.StatusCode) - logger.Warn("Target %d: health check failed with status code %d", target.Config.ID, resp.StatusCode) - } + // Default: check for 2xx range + if resp.StatusCode >= 200 && resp.StatusCode < 300 { + logger.Debug("Target %d: health check passed (status: %d)", target.Config.ID, resp.StatusCode) + return true, "" } + + msg := fmt.Sprintf("unhealthy status code: %d", resp.StatusCode) + logger.Warn("Target %d: health check failed with status code %d", target.Config.ID, resp.StatusCode) + return false, msg } // Stop stops monitoring all targets diff --git a/main.go b/main.go index d5f2a96..7718c5d 100644 --- a/main.go +++ b/main.go @@ -129,6 +129,7 @@ var ( dockerEnforceNetworkValidationBool bool pingInterval time.Duration pingTimeout time.Duration + udpProxyIdleTimeout time.Duration publicKey wgtypes.Key pingStopChan chan struct{} stopFunc func() @@ -261,6 +262,7 @@ func runNewtMain(ctx context.Context) { dockerSocket = os.Getenv("DOCKER_SOCKET") pingIntervalStr := os.Getenv("PING_INTERVAL") pingTimeoutStr := os.Getenv("PING_TIMEOUT") + udpProxyIdleTimeoutStr := os.Getenv("NEWT_UDP_PROXY_IDLE_TIMEOUT") dockerEnforceNetworkValidation = os.Getenv("DOCKER_ENFORCE_NETWORK_VALIDATION") healthFile = os.Getenv("HEALTH_FILE") // authorizedKeysFile = os.Getenv("AUTHORIZED_KEYS_FILE") @@ -337,6 +339,9 @@ func runNewtMain(ctx context.Context) { if pingTimeoutStr == "" { flag.StringVar(&pingTimeoutStr, "ping-timeout", "7s", " Timeout for each ping (default 7s)") } + if udpProxyIdleTimeoutStr == "" { + flag.StringVar(&udpProxyIdleTimeoutStr, "udp-proxy-idle-timeout", "90s", "Idle timeout for UDP proxied client flows before cleanup") + } // load the prefer endpoint just as a flag flag.StringVar(&preferEndpoint, "prefer-endpoint", "", "Prefer this endpoint for the connection (if set, will override the endpoint from the server)") if provisioningKey == "" { @@ -386,6 +391,16 @@ func runNewtMain(ctx context.Context) { pingTimeout = 7 * time.Second } + if udpProxyIdleTimeoutStr != "" { + udpProxyIdleTimeout, err = time.ParseDuration(udpProxyIdleTimeoutStr) + if err != nil || udpProxyIdleTimeout <= 0 { + fmt.Printf("Invalid NEWT_UDP_PROXY_IDLE_TIMEOUT/--udp-proxy-idle-timeout value: %s, using default 90 seconds\n", udpProxyIdleTimeoutStr) + udpProxyIdleTimeout = 90 * time.Second + } + } else { + udpProxyIdleTimeout = 90 * time.Second + } + if dockerEnforceNetworkValidation == "" { flag.StringVar(&dockerEnforceNetworkValidation, "docker-enforce-network-validation", "false", "Enforce validation of container on newt network (true or false)") } @@ -896,6 +911,7 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey( // Create proxy manager pm = proxy.NewProxyManager(tnet) pm.SetAsyncBytes(metricsAsyncBytes) + pm.SetUDPIdleTimeout(udpProxyIdleTimeout) // Set tunnel_id for metrics (WireGuard peer public key) pm.SetTunnelID(wgData.PublicKey) diff --git a/netstack2/handlers.go b/netstack2/handlers.go index 07c235f..dabfee9 100644 --- a/netstack2/handlers.go +++ b/netstack2/handlers.go @@ -137,14 +137,31 @@ func (h *TCPHandler) InstallTCPHandler() error { // handleTCPConn handles a TCP connection by proxying it to the actual target func (h *TCPHandler) handleTCPConn(netstackConn *gonet.TCPConn, id stack.TransportEndpointID) { - defer netstackConn.Close() - - // Extract source and target address from the connection ID + // Extract source and target address from the connection ID first so they + // are available for HTTP routing before any defer is set up. srcIP := id.RemoteAddress.String() srcPort := id.RemotePort dstIP := id.LocalAddress.String() dstPort := id.LocalPort + // For HTTP/HTTPS ports, look up the matching subnet rule. If the rule has + // Protocol configured, hand the connection off to the HTTP handler which + // takes full ownership of the lifecycle (the defer close must not be + // installed before this point). + if (dstPort == 80 || dstPort == 443) && h.proxyHandler != nil && h.proxyHandler.httpHandler != nil { + srcAddr, _ := netip.ParseAddr(srcIP) + dstAddr, _ := netip.ParseAddr(dstIP) + rule := h.proxyHandler.subnetLookup.Match(srcAddr, dstAddr, dstPort, tcp.ProtocolNumber) + if rule != nil && rule.Protocol != "" { + logger.Info("TCP Forwarder: Routing %s:%d -> %s:%d to HTTP handler (%s)", + srcIP, srcPort, dstIP, dstPort, rule.Protocol) + h.proxyHandler.httpHandler.HandleConn(netstackConn, rule) + return + } + } + + defer netstackConn.Close() + logger.Info("TCP Forwarder: Handling connection %s:%d -> %s:%d", srcIP, srcPort, dstIP, dstPort) // Check if there's a destination rewrite for this connection (e.g., localhost targets) diff --git a/netstack2/http_handler.go b/netstack2/http_handler.go new file mode 100644 index 0000000..4ff7ed9 --- /dev/null +++ b/netstack2/http_handler.go @@ -0,0 +1,363 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. + */ + +package netstack2 + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "net/http" + "net/http/httputil" + "net/url" + "sync" + "time" + + "github.com/fosrl/newt/logger" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// --------------------------------------------------------------------------- +// HTTPTarget +// --------------------------------------------------------------------------- + +// HTTPTarget describes a single downstream HTTP or HTTPS service that the +// proxy should forward requests to. +type HTTPTarget struct { + DestAddr string `json:"destAddr"` // IP address or hostname of the downstream service + DestPort uint16 `json:"destPort"` // TCP port of the downstream service + Scheme string `json:"scheme"` // When true the outbound leg uses HTTPS +} + +// --------------------------------------------------------------------------- +// HTTPHandler +// --------------------------------------------------------------------------- + +// HTTPHandler intercepts TCP connections from the netstack forwarder on ports +// 80 and 443 and services them as HTTP or HTTPS, reverse-proxying each request +// to downstream targets specified by the matching SubnetRule. +// +// HTTP and raw TCP are fully separate: a connection is only routed here when +// its SubnetRule has Protocol set ("http" or "https"). All other connections +// on those ports fall through to the normal raw-TCP path. +// +// Incoming TLS termination (Protocol == "https") is performed per-connection +// using the certificate and key stored in the rule, so different subnet rules +// can present different certificates without sharing any state. +// +// Outbound connections to downstream targets honour HTTPTarget.UseHTTPS +// independently of the incoming protocol. +type HTTPHandler struct { + stack *stack.Stack + proxyHandler *ProxyHandler + requestLogger *HTTPRequestLogger + + listener *chanListener + server *http.Server + + // proxyCache holds pre-built *httputil.ReverseProxy values keyed by the + // canonical target URL string ("scheme://host:port"). Building a proxy is + // cheap, but reusing one preserves the underlying http.Transport connection + // pool, which matters for throughput. + proxyCache sync.Map // map[string]*httputil.ReverseProxy + + // tlsCache holds pre-parsed *tls.Config values keyed by the concatenation + // of the PEM certificate and key. Parsing a keypair is relatively expensive + // and the same cert is likely reused across many connections. + tlsCache sync.Map // map[string]*tls.Config +} + +// --------------------------------------------------------------------------- +// chanListener – net.Listener backed by a channel +// --------------------------------------------------------------------------- + +// chanListener implements net.Listener by receiving net.Conn values over a +// buffered channel. This lets the netstack TCP forwarder hand off connections +// directly to a running http.Server without any real OS socket. +type chanListener struct { + connCh chan net.Conn + closed chan struct{} + once sync.Once +} + +func newChanListener() *chanListener { + return &chanListener{ + connCh: make(chan net.Conn, 128), + closed: make(chan struct{}), + } +} + +// Accept blocks until a connection is available or the listener is closed. +func (l *chanListener) Accept() (net.Conn, error) { + select { + case conn, ok := <-l.connCh: + if !ok { + return nil, net.ErrClosed + } + return conn, nil + case <-l.closed: + return nil, net.ErrClosed + } +} + +// Close shuts down the listener; subsequent Accept calls return net.ErrClosed. +func (l *chanListener) Close() error { + l.once.Do(func() { close(l.closed) }) + return nil +} + +// Addr returns a placeholder address (the listener has no real OS socket). +func (l *chanListener) Addr() net.Addr { + return &net.TCPAddr{} +} + +// send delivers conn to the listener. Returns false if the listener is already +// closed, in which case the caller is responsible for closing conn. +func (l *chanListener) send(conn net.Conn) bool { + select { + case l.connCh <- conn: + return true + case <-l.closed: + return false + } +} + +// --------------------------------------------------------------------------- +// httpConnCtx – conn wrapper that carries a SubnetRule through the listener +// --------------------------------------------------------------------------- + +// httpConnCtx wraps a net.Conn so the matching SubnetRule can be passed +// through the chanListener into the http.Server's ConnContext callback, +// making it available to request handlers via the request context. +type httpConnCtx struct { + net.Conn + rule *SubnetRule +} + +// connCtxKey is the unexported context key used to store a *SubnetRule on the +// per-connection context created by http.Server.ConnContext. +type connCtxKey struct{} + +// --------------------------------------------------------------------------- +// Constructor and lifecycle +// --------------------------------------------------------------------------- + +// NewHTTPHandler creates an HTTPHandler attached to the given stack and +// ProxyHandler. Call Start to begin serving connections. +func NewHTTPHandler(s *stack.Stack, ph *ProxyHandler) *HTTPHandler { + return &HTTPHandler{ + stack: s, + proxyHandler: ph, + } +} + +// SetRequestLogger attaches an HTTPRequestLogger so that every proxied request +// is recorded and periodically shipped to the server. +func (h *HTTPHandler) SetRequestLogger(rl *HTTPRequestLogger) { + h.requestLogger = rl +} + +// Start launches the internal http.Server that services connections delivered +// via HandleConn. The server runs for the lifetime of the HTTPHandler; call +// Close to stop it. +func (h *HTTPHandler) Start() error { + h.listener = newChanListener() + + h.server = &http.Server{ + Handler: http.HandlerFunc(h.handleRequest), + // ConnContext runs once per accepted connection and attaches the + // SubnetRule carried by httpConnCtx to the connection's context so + // that handleRequest can retrieve it without any global state. + ConnContext: func(ctx context.Context, c net.Conn) context.Context { + if cc, ok := c.(*httpConnCtx); ok { + return context.WithValue(ctx, connCtxKey{}, cc.rule) + } + return ctx + }, + } + + go func() { + if err := h.server.Serve(h.listener); err != nil && err != http.ErrServerClosed { + logger.Error("HTTP handler: server exited unexpectedly: %v", err) + } + }() + + logger.Info("HTTP handler: ready — routing determined per SubnetRule on ports 80/443") + return nil +} + +// HandleConn accepts a TCP connection from the netstack forwarder together +// with the SubnetRule that matched it. The HTTP handler takes full ownership +// of the connection's lifecycle; the caller must NOT close conn after this call. +// +// When rule.Protocol is "https", TLS termination is performed on conn using +// the certificate and key stored in rule.TLSCert and rule.TLSKey before the +// connection is passed to the HTTP server. The HTTP server itself is always +// plain-HTTP; TLS is fully unwrapped at this layer. +func (h *HTTPHandler) HandleConn(conn net.Conn, rule *SubnetRule) { + var effectiveConn net.Conn = conn + + if rule.Protocol == "https" { + tlsCfg, err := h.getTLSConfig(rule) + if err != nil { + logger.Error("HTTP handler: cannot build TLS config for connection from %s: %v", + conn.RemoteAddr(), err) + conn.Close() + return + } + // tls.Server wraps the raw conn; the TLS handshake is deferred until + // the first Read, which the http.Server will trigger naturally. + effectiveConn = tls.Server(conn, tlsCfg) + } + + wrapped := &httpConnCtx{Conn: effectiveConn, rule: rule} + if !h.listener.send(wrapped) { + // Listener is already closed — clean up the orphaned connection. + effectiveConn.Close() + } +} + +// Close gracefully shuts down the HTTP server and the underlying channel +// listener, causing the goroutine started in Start to exit. +func (h *HTTPHandler) Close() error { + if h.server != nil { + if err := h.server.Close(); err != nil { + return err + } + } + if h.listener != nil { + h.listener.Close() + } + return nil +} + +// --------------------------------------------------------------------------- +// Internal helpers +// --------------------------------------------------------------------------- + +// getTLSConfig returns a *tls.Config for the cert/key pair in rule, using a +// cache to avoid re-parsing the same keypair on every connection. +// The cache key is the concatenation of the PEM cert and key strings, so +// different rules that happen to share the same material hit the same entry. +func (h *HTTPHandler) getTLSConfig(rule *SubnetRule) (*tls.Config, error) { + cacheKey := rule.TLSCert + "|" + rule.TLSKey + if v, ok := h.tlsCache.Load(cacheKey); ok { + return v.(*tls.Config), nil + } + + cert, err := tls.X509KeyPair([]byte(rule.TLSCert), []byte(rule.TLSKey)) + if err != nil { + return nil, fmt.Errorf("failed to parse TLS keypair: %w", err) + } + cfg := &tls.Config{ + Certificates: []tls.Certificate{cert}, + } + // LoadOrStore is safe under concurrent calls: if two goroutines race here + // both will produce a valid config; the loser's work is discarded. + actual, _ := h.tlsCache.LoadOrStore(cacheKey, cfg) + return actual.(*tls.Config), nil +} + +// getProxy returns a cached *httputil.ReverseProxy for the given target, +// creating one on first use. Reusing the proxy preserves its http.Transport +// connection pool, avoiding repeated TCP/TLS handshakes to the downstream. +func (h *HTTPHandler) getProxy(target HTTPTarget) *httputil.ReverseProxy { + scheme := target.Scheme + cacheKey := fmt.Sprintf("%s://%s:%d", scheme, target.DestAddr, target.DestPort) + + if v, ok := h.proxyCache.Load(cacheKey); ok { + return v.(*httputil.ReverseProxy) + } + + targetURL := &url.URL{ + Scheme: scheme, + Host: fmt.Sprintf("%s:%d", target.DestAddr, target.DestPort), + } + insecureTransport := (*http.Transport)(nil) + if target.Scheme == "https" { + // Allow self-signed certificates on downstream HTTPS targets. + insecureTransport = &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, //nolint:gosec // downstream self-signed certs are a supported configuration + }, + } + } + + proxy := &httputil.ReverseProxy{ + Rewrite: func(pr *httputil.ProxyRequest) { + pr.SetURL(targetURL) + // SetXForwarded sets X-Forwarded-For from the inbound request's + // RemoteAddr (the WireGuard/netstack client address), along with + // X-Forwarded-Host and X-Forwarded-Proto. Using Rewrite instead of + // Director means the proxy does not append its own automatic + // X-Forwarded-For entry, so the header is set exactly once. + pr.SetXForwarded() + }, + Transport: insecureTransport, + } + + proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { + logger.Error("HTTP handler: upstream error (%s %s -> %s): %v", + r.Method, r.URL.RequestURI(), cacheKey, err) + http.Error(w, "Bad Gateway", http.StatusBadGateway) + } + + actual, _ := h.proxyCache.LoadOrStore(cacheKey, proxy) + return actual.(*httputil.ReverseProxy) +} + +// statusCapture wraps an http.ResponseWriter and records the HTTP status code +// written by the upstream handler. If WriteHeader is never called the status +// defaults to 200 (http.StatusOK), matching net/http semantics. +type statusCapture struct { + http.ResponseWriter + status int +} + +func (sc *statusCapture) WriteHeader(code int) { + sc.status = code + sc.ResponseWriter.WriteHeader(code) +} + +// handleRequest is the http.Handler entry point. It retrieves the SubnetRule +// attached to the connection by ConnContext, selects the first configured +// downstream target, and forwards the request via the cached ReverseProxy. +// +// TODO: add host/path-based routing across multiple HTTPTargets once the +// configuration model evolves beyond a single target per rule. +func (h *HTTPHandler) handleRequest(w http.ResponseWriter, r *http.Request) { + rule, _ := r.Context().Value(connCtxKey{}).(*SubnetRule) + if rule == nil || len(rule.HTTPTargets) == 0 { + logger.Error("HTTP handler: no downstream targets for request %s %s", r.Method, r.URL.RequestURI()) + http.Error(w, "no targets configured", http.StatusBadGateway) + return + } + + target := rule.HTTPTargets[0] + scheme := target.Scheme + logger.Info("HTTP handler: %s %s -> %s://%s:%d", + r.Method, r.URL.RequestURI(), scheme, target.DestAddr, target.DestPort) + + timestamp := time.Now() + sc := &statusCapture{ResponseWriter: w, status: http.StatusOK} + + h.getProxy(target).ServeHTTP(sc, r) + + if h.requestLogger != nil && rule.ResourceId != 0 { + h.requestLogger.LogRequest(HTTPRequestLog{ + ResourceID: rule.ResourceId, + Timestamp: timestamp, + Method: r.Method, + Scheme: rule.Protocol, + Host: r.Host, + Path: r.URL.Path, + RawQuery: r.URL.RawQuery, + UserAgent: r.UserAgent(), + SourceAddr: r.RemoteAddr, + TLS: rule.Protocol == "https", + }) + } +} diff --git a/netstack2/http_request_log.go b/netstack2/http_request_log.go new file mode 100644 index 0000000..85ab5db --- /dev/null +++ b/netstack2/http_request_log.go @@ -0,0 +1,175 @@ +package netstack2 + +import ( + "bytes" + "compress/zlib" + "encoding/base64" + "encoding/json" + "sync" + "time" + + "github.com/fosrl/newt/logger" +) + +// HTTPRequestLog represents a single HTTP/HTTPS request proxied through the handler. +type HTTPRequestLog struct { + RequestID string `json:"requestId"` + ResourceID int `json:"resourceId"` + Timestamp time.Time `json:"timestamp"` + Method string `json:"method"` + Scheme string `json:"scheme"` + Host string `json:"host"` + Path string `json:"path"` + RawQuery string `json:"rawQuery,omitempty"` + UserAgent string `json:"userAgent,omitempty"` + SourceAddr string `json:"sourceAddr"` + TLS bool `json:"tls"` +} + +// HTTPRequestLogger buffers HTTP request logs and periodically flushes them +// to the server via a configurable SendFunc. +type HTTPRequestLogger struct { + mu sync.Mutex + pending []HTTPRequestLog + sendFn SendFunc + stopCh chan struct{} + flushDone chan struct{} +} + +// NewHTTPRequestLogger creates a new HTTPRequestLogger and starts its background flush loop. +func NewHTTPRequestLogger() *HTTPRequestLogger { + rl := &HTTPRequestLogger{ + pending: make([]HTTPRequestLog, 0), + stopCh: make(chan struct{}), + flushDone: make(chan struct{}), + } + go rl.backgroundLoop() + return rl +} + +// SetSendFunc sets the callback used to send compressed HTTP request log batches +// to the server. This can be called after construction once the websocket +// client is available. +func (rl *HTTPRequestLogger) SetSendFunc(fn SendFunc) { + rl.mu.Lock() + defer rl.mu.Unlock() + rl.sendFn = fn +} + +// LogRequest adds an HTTP request log entry to the buffer. If the buffer +// reaches maxBufferedSessions entries a flush is triggered immediately. +func (rl *HTTPRequestLogger) LogRequest(log HTTPRequestLog) { + if log.RequestID == "" { + log.RequestID = generateSessionID() + } + + rl.mu.Lock() + rl.pending = append(rl.pending, log) + shouldFlush := len(rl.pending) >= maxBufferedSessions + rl.mu.Unlock() + + if shouldFlush { + rl.flush() + } +} + +// backgroundLoop handles periodic flushing of buffered request logs. +func (rl *HTTPRequestLogger) backgroundLoop() { + defer close(rl.flushDone) + + ticker := time.NewTicker(flushInterval) + defer ticker.Stop() + + for { + select { + case <-rl.stopCh: + return + case <-ticker.C: + rl.flush() + } + } +} + +// flush drains the pending buffer, compresses with zlib, and sends via the SendFunc. +// On send failure the batch is re-queued, capped at maxBufferedSessions*5 entries +// to prevent unbounded memory growth when the server is unreachable. +func (rl *HTTPRequestLogger) flush() { + rl.mu.Lock() + if len(rl.pending) == 0 { + rl.mu.Unlock() + return + } + batch := rl.pending + rl.pending = make([]HTTPRequestLog, 0) + sendFn := rl.sendFn + rl.mu.Unlock() + + if sendFn == nil { + logger.Debug("HTTP request logger: no send function configured, discarding %d requests", len(batch)) + return + } + + compressed, err := compressRequestLogs(batch) + if err != nil { + logger.Error("HTTP request logger: failed to compress %d requests: %v", len(batch), err) + return + } + + if err := sendFn(compressed); err != nil { + logger.Error("HTTP request logger: failed to send %d requests: %v", len(batch), err) + // Re-queue the batch so we don't lose data + rl.mu.Lock() + rl.pending = append(batch, rl.pending...) + // Cap re-queued data to prevent unbounded growth if server is unreachable + if len(rl.pending) > maxBufferedSessions*5 { + dropped := len(rl.pending) - maxBufferedSessions*5 + rl.pending = rl.pending[:maxBufferedSessions*5] + logger.Warn("HTTP request logger: buffer overflow, dropped %d oldest requests", dropped) + } + rl.mu.Unlock() + return + } + + logger.Info("HTTP request logger: sent %d requests to server", len(batch)) +} + +// compressRequestLogs JSON-encodes the request logs, compresses with zlib, and +// returns a base64-encoded string suitable for embedding in a JSON message. +func compressRequestLogs(logs []HTTPRequestLog) (string, error) { + jsonData, err := json.Marshal(logs) + if err != nil { + return "", err + } + + var buf bytes.Buffer + w, err := zlib.NewWriterLevel(&buf, zlib.BestCompression) + if err != nil { + return "", err + } + if _, err := w.Write(jsonData); err != nil { + w.Close() + return "", err + } + if err := w.Close(); err != nil { + return "", err + } + + return base64.StdEncoding.EncodeToString(buf.Bytes()), nil +} + +// Close shuts down the background loop and performs one final flush to send +// any remaining buffered requests to the server. +func (rl *HTTPRequestLogger) Close() { + select { + case <-rl.stopCh: + // Already closed + return + default: + close(rl.stopCh) + } + + // Wait for the background loop to exit so we don't race on flush + <-rl.flushDone + + rl.flush() +} \ No newline at end of file diff --git a/netstack2/proxy.go b/netstack2/proxy.go index e383fc0..b08eea3 100644 --- a/netstack2/proxy.go +++ b/netstack2/proxy.go @@ -53,6 +53,14 @@ type SubnetRule struct { RewriteTo string // Optional rewrite address for DNAT - can be IP/CIDR or domain name PortRanges []PortRange // empty slice means all ports allowed ResourceId int // Optional resource ID from the server for access logging + + // HTTP proxy configuration (optional). + // When Protocol is non-empty the TCP connection is handled by HTTPHandler + // instead of the raw TCP forwarder. + Protocol string // "", "http", or "https" — controls the incoming (client-facing) protocol + HTTPTargets []HTTPTarget // downstream services to proxy requests to + TLSCert string // PEM-encoded certificate for incoming HTTPS termination + TLSKey string // PEM-encoded private key for incoming HTTPS termination } // GetAllRules returns a copy of all subnet rules @@ -114,6 +122,7 @@ type ProxyHandler struct { tcpHandler *TCPHandler udpHandler *UDPHandler icmpHandler *ICMPHandler + httpHandler *HTTPHandler subnetLookup *SubnetLookup natTable map[connKey]*natState reverseNatTable map[reverseConnKey]*natState // Reverse lookup map for O(1) reply packet NAT @@ -124,6 +133,7 @@ type ProxyHandler struct { icmpReplies chan []byte // Channel for ICMP reply packets to be sent back through the tunnel notifiable channel.Notification // Notification handler for triggering reads accessLogger *AccessLogger // Access logger for tracking sessions + httpRequestLogger *HTTPRequestLogger // HTTP request logger for proxied HTTP/HTTPS requests } // ProxyHandlerOptions configures the proxy handler @@ -164,12 +174,24 @@ func NewProxyHandler(options ProxyHandlerOptions) (*ProxyHandler, error) { }), } - // Initialize TCP handler if enabled + // Initialize TCP handler if enabled. The HTTP handler piggybacks on the + // TCP forwarder — TCPHandler.handleTCPConn checks the subnet rule for + // ports 80/443 and routes matching connections to the HTTP handler, so + // the HTTP handler is always initialised alongside TCP. if options.EnableTCP { handler.tcpHandler = NewTCPHandler(handler.proxyStack, handler) if err := handler.tcpHandler.InstallTCPHandler(); err != nil { return nil, fmt.Errorf("failed to install TCP handler: %v", err) } + + handler.httpHandler = NewHTTPHandler(handler.proxyStack, handler) + if err := handler.httpHandler.Start(); err != nil { + return nil, fmt.Errorf("failed to start HTTP handler: %v", err) + } + + handler.httpRequestLogger = NewHTTPRequestLogger() + handler.httpHandler.SetRequestLogger(handler.httpRequestLogger) + logger.Debug("ProxyHandler: HTTP handler enabled") } // Initialize UDP handler if enabled @@ -208,16 +230,14 @@ func NewProxyHandler(options ProxyHandlerOptions) (*ProxyHandler, error) { return handler, nil } -// AddSubnetRule adds a subnet with optional port restrictions to the proxy handler -// sourcePrefix: The IP prefix of the peer sending the data -// destPrefix: The IP prefix of the destination -// rewriteTo: Optional address to rewrite destination to - can be IP/CIDR or domain name -// If portRanges is nil or empty, all ports are allowed for this subnet -func (p *ProxyHandler) AddSubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool, resourceId int) { +// AddSubnetRule adds a subnet rule to the proxy handler. +// HTTP proxy behaviour is configured via rule.Protocol, rule.HTTPTargets, +// rule.TLSCert, and rule.TLSKey; leave Protocol empty for raw TCP/UDP. +func (p *ProxyHandler) AddSubnetRule(rule SubnetRule) { if p == nil || !p.enabled { return } - p.subnetLookup.AddSubnet(sourcePrefix, destPrefix, rewriteTo, portRanges, disableIcmp, resourceId) + p.subnetLookup.AddSubnet(rule) } // RemoveSubnetRule removes a subnet from the proxy handler @@ -273,6 +293,24 @@ func (p *ProxyHandler) SetAccessLogSender(fn SendFunc) { p.accessLogger.SetSendFunc(fn) } +// GetHTTPRequestLogger returns the HTTP request logger. +func (p *ProxyHandler) GetHTTPRequestLogger() *HTTPRequestLogger { + if p == nil { + return nil + } + return p.httpRequestLogger +} + +// SetHTTPRequestLogSender configures the function used to send compressed HTTP +// request log batches to the server. This should be called once the websocket +// client is available. +func (p *ProxyHandler) SetHTTPRequestLogSender(fn SendFunc) { + if p == nil || !p.enabled || p.httpRequestLogger == nil { + return + } + p.httpRequestLogger.SetSendFunc(fn) +} + // LookupDestinationRewrite looks up the rewritten destination for a connection // This is used by TCP/UDP handlers to find the actual target address func (p *ProxyHandler) LookupDestinationRewrite(srcIP, dstIP string, dstPort uint16, proto uint8) (netip.Addr, bool) { @@ -794,6 +832,16 @@ func (p *ProxyHandler) Close() error { p.accessLogger.Close() } + // Shut down HTTP request logger + if p.httpRequestLogger != nil { + p.httpRequestLogger.Close() + } + + // Shut down HTTP handler + if p.httpHandler != nil { + p.httpHandler.Close() + } + // Close ICMP replies channel if p.icmpReplies != nil { close(p.icmpReplies) diff --git a/netstack2/subnet_lookup.go b/netstack2/subnet_lookup.go index 317f85c..757908a 100644 --- a/netstack2/subnet_lookup.go +++ b/netstack2/subnet_lookup.go @@ -44,24 +44,18 @@ func prefixEqual(a, b netip.Prefix) bool { return a.Masked() == b.Masked() } -// AddSubnet adds a subnet rule with source and destination prefixes and optional port restrictions -// If portRanges is nil or empty, all ports are allowed for this subnet -// rewriteTo can be either an IP/CIDR (e.g., "192.168.1.1/32") or a domain name (e.g., "example.com") -func (sl *SubnetLookup) AddSubnet(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool, resourceId int) { +// AddSubnet adds a subnet rule to the lookup table. +// If rule.PortRanges is nil or empty, all ports are allowed. +// rule.RewriteTo can be either an IP/CIDR (e.g., "192.168.1.1/32") or a domain name (e.g., "example.com"). +// HTTP proxy behaviour is driven by rule.Protocol, rule.HTTPTargets, rule.TLSCert, and rule.TLSKey. +func (sl *SubnetLookup) AddSubnet(rule SubnetRule) { sl.mu.Lock() defer sl.mu.Unlock() - rule := &SubnetRule{ - SourcePrefix: sourcePrefix, - DestPrefix: destPrefix, - DisableIcmp: disableIcmp, - RewriteTo: rewriteTo, - PortRanges: portRanges, - ResourceId: resourceId, - } + rulePtr := &rule // Canonicalize source prefix to handle host bits correctly - canonicalSourcePrefix := sourcePrefix.Masked() + canonicalSourcePrefix := rule.SourcePrefix.Masked() // Get or create destination trie for this source prefix destTriePtr, exists := sl.sourceTrie.Get(canonicalSourcePrefix) @@ -76,12 +70,12 @@ func (sl *SubnetLookup) AddSubnet(sourcePrefix, destPrefix netip.Prefix, rewrite // Canonicalize destination prefix to handle host bits correctly // BART masks prefixes internally, so we need to match that behavior in our bookkeeping - canonicalDestPrefix := destPrefix.Masked() + canonicalDestPrefix := rule.DestPrefix.Masked() // Add rule to destination trie // Original behavior: overwrite if same (sourcePrefix, destPrefix) exists // Store as single-element slice to match original overwrite behavior - destTriePtr.trie.Insert(canonicalDestPrefix, []*SubnetRule{rule}) + destTriePtr.trie.Insert(canonicalDestPrefix, []*SubnetRule{rulePtr}) // Update destTriePtr.rules - remove old rule with same canonical prefix if exists, then add new one // Use canonical comparison to handle cases like 10.0.0.5/24 vs 10.0.0.0/24 @@ -91,7 +85,7 @@ func (sl *SubnetLookup) AddSubnet(sourcePrefix, destPrefix netip.Prefix, rewrite newRules = append(newRules, r) } } - newRules = append(newRules, rule) + newRules = append(newRules, rulePtr) destTriePtr.rules = newRules } diff --git a/netstack2/tun.go b/netstack2/tun.go index 3183c36..fae90dd 100644 --- a/netstack2/tun.go +++ b/netstack2/tun.go @@ -351,13 +351,13 @@ func (net *Net) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) { return net.DialUDP(laddr, nil) } -// AddProxySubnetRule adds a subnet rule to the proxy handler -// If portRanges is nil or empty, all ports are allowed for this subnet -// rewriteTo can be either an IP/CIDR (e.g., "192.168.1.1/32") or a domain name (e.g., "example.com") -func (net *Net) AddProxySubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool, resourceId int) { +// AddProxySubnetRule adds a subnet rule to the proxy handler. +// HTTP proxy behaviour is configured via rule.Protocol, rule.HTTPTargets, +// rule.TLSCert, and rule.TLSKey; leave Protocol empty for raw TCP/UDP. +func (net *Net) AddProxySubnetRule(rule SubnetRule) { tun := (*netTun)(net) if tun.proxyHandler != nil { - tun.proxyHandler.AddSubnetRule(sourcePrefix, destPrefix, rewriteTo, portRanges, disableIcmp, resourceId) + tun.proxyHandler.AddSubnetRule(rule) } } @@ -394,6 +394,16 @@ func (net *Net) SetAccessLogSender(fn SendFunc) { } } +// SetHTTPRequestLogSender configures the function used to send compressed HTTP +// request log batches to the server. This should be called once the websocket +// client is available. +func (net *Net) SetHTTPRequestLogSender(fn SendFunc) { + tun := (*netTun)(net) + if tun.proxyHandler != nil { + tun.proxyHandler.SetHTTPRequestLogSender(fn) + } +} + type PingConn struct { laddr PingAddr raddr PingAddr diff --git a/proxy/manager.go b/proxy/manager.go index 5566589..0d1f750 100644 --- a/proxy/manager.go +++ b/proxy/manager.go @@ -23,9 +23,31 @@ import ( const ( errUnsupportedProtoFmt = "unsupported protocol: %s" - maxUDPPacketSize = 65507 + maxUDPPacketSize = 65507 // Maximum UDP packet size + defaultUDPIdleTimeout = 90 * time.Second ) +// udpBufferPool provides reusable buffers for UDP packet handling. +// This reduces GC pressure from frequent large allocations. +var udpBufferPool = sync.Pool{ + New: func() any { + buf := make([]byte, maxUDPPacketSize) + return &buf + }, +} + +// getUDPBuffer retrieves a buffer from the pool. +func getUDPBuffer() *[]byte { + return udpBufferPool.Get().(*[]byte) +} + +// putUDPBuffer clears and returns a buffer to the pool. +func putUDPBuffer(buf *[]byte) { + // Clear the buffer to prevent data leakage + clear(*buf) + udpBufferPool.Put(buf) +} + // Target represents a proxy target with its address and port type Target struct { Address string @@ -47,6 +69,7 @@ type ProxyManager struct { tunnels map[string]*tunnelEntry asyncBytes bool flushStop chan struct{} + udpIdleTimeout time.Duration } // tunnelEntry holds per-tunnel attributes and (optional) async counters. @@ -132,6 +155,7 @@ func NewProxyManager(tnet *netstack.Net) *ProxyManager { listeners: make([]*gonet.TCPListener, 0), udpConns: make([]*gonet.UDPConn, 0), tunnels: make(map[string]*tunnelEntry), + udpIdleTimeout: defaultUDPIdleTimeout, } } @@ -209,6 +233,7 @@ func NewProxyManagerWithoutTNet() *ProxyManager { udpTargets: make(map[string]map[int]string), listeners: make([]*gonet.TCPListener, 0), udpConns: make([]*gonet.UDPConn, 0), + udpIdleTimeout: defaultUDPIdleTimeout, } } @@ -345,6 +370,17 @@ func (pm *ProxyManager) SetAsyncBytes(b bool) { go pm.flushLoop() } } + +// SetUDPIdleTimeout configures when idle UDP client flows are reclaimed. +func (pm *ProxyManager) SetUDPIdleTimeout(d time.Duration) { + pm.mutex.Lock() + defer pm.mutex.Unlock() + if d <= 0 { + pm.udpIdleTimeout = defaultUDPIdleTimeout + return + } + pm.udpIdleTimeout = d +} func (pm *ProxyManager) flushLoop() { flushInterval := 2 * time.Second if v := os.Getenv("OTEL_METRIC_EXPORT_INTERVAL"); v != "" { @@ -555,7 +591,9 @@ func (pm *ProxyManager) handleTCPProxy(listener net.Listener, targetAddr string) } func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) { - buffer := make([]byte, maxUDPPacketSize) // Max UDP packet size + bufPtr := getUDPBuffer() + defer putUDPBuffer(bufPtr) + buffer := *bufPtr clientConns := make(map[string]*net.UDPConn) var clientsMutex sync.RWMutex @@ -623,6 +661,9 @@ func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) { telemetry.IncProxyAccept(context.Background(), pm.currentTunnelID, "udp", "failure", classifyProxyError(err)) continue } + // Prevent idle UDP client goroutines from living forever and + // retaining large per-connection buffers. + _ = targetConn.SetReadDeadline(time.Now().Add(pm.udpIdleTimeout)) tunnelID := pm.currentTunnelID telemetry.IncProxyAccept(context.Background(), tunnelID, "udp", "success", "") telemetry.IncProxyConnectionEvent(context.Background(), tunnelID, "udp", telemetry.ProxyConnectionOpened) @@ -638,7 +679,10 @@ func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) { go func(clientKey string, targetConn *net.UDPConn, remoteAddr net.Addr, tunnelID string) { start := time.Now() result := "success" + bufPtr := getUDPBuffer() defer func() { + // Return buffer to pool first + putUDPBuffer(bufPtr) // Always clean up when this goroutine exits clientsMutex.Lock() if storedConn, exists := clientConns[clientKey]; exists && storedConn == targetConn { @@ -653,10 +697,14 @@ func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) { telemetry.IncProxyConnectionEvent(context.Background(), tunnelID, "udp", telemetry.ProxyConnectionClosed) }() - buffer := make([]byte, maxUDPPacketSize) + buffer := *bufPtr for { n, _, err := targetConn.ReadFromUDP(buffer) if err != nil { + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + return + } // Connection closed is normal during cleanup if errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) { return // defer will handle cleanup, result stays "success" @@ -699,6 +747,8 @@ func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) { delete(clientConns, clientKey) clientsMutex.Unlock() } else if pm.currentTunnelID != "" && written > 0 { + // Extend idle timeout whenever client traffic is observed. + _ = targetConn.SetReadDeadline(time.Now().Add(pm.udpIdleTimeout)) if pm.asyncBytes { if e := pm.getEntry(pm.currentTunnelID); e != nil { e.bytesInUDP.Add(uint64(written)) diff --git a/websocket/client.go b/websocket/client.go index 6990bd2..67e23ec 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -707,6 +707,10 @@ func (c *Client) sendPing() { } c.writeMux.Lock() + if c.conn == nil { + c.writeMux.Unlock() + return + } err := c.conn.WriteJSON(pingMsg) if err == nil { telemetry.IncWSMessage(c.metricsContext(), "out", "ping") @@ -859,10 +863,12 @@ func (c *Client) readPumpWithDisconnectDetection(started time.Time) { func (c *Client) reconnect() { c.setConnected(false) telemetry.SetWSConnectionState(false) + c.writeMux.Lock() if c.conn != nil { c.conn.Close() c.conn = nil } + c.writeMux.Unlock() // Only reconnect if we're not shutting down select { diff --git a/websocket/config.go b/websocket/config.go index 39f1bd2..f24f65a 100644 --- a/websocket/config.go +++ b/websocket/config.go @@ -71,6 +71,11 @@ func (c *Client) loadConfig() error { } return err } + if len(bytes.TrimSpace(data)) == 0 { + logger.Info("Config file at %s is empty, will initialize it with provided values", configPath) + c.configNeedsSave = true + return nil + } var config Config if err := json.Unmarshal(data, &config); err != nil { diff --git a/websocket/config_test.go b/websocket/config_test.go new file mode 100644 index 0000000..b2d8a24 --- /dev/null +++ b/websocket/config_test.go @@ -0,0 +1,35 @@ +package websocket + +import ( + "os" + "path/filepath" + "testing" +) + +func TestLoadConfig_EmptyFileMarksConfigForSave(t *testing.T) { + t.Setenv("CONFIG_FILE", "") + + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.json") + if err := os.WriteFile(configPath, []byte(""), 0o644); err != nil { + t.Fatalf("failed to create empty config file: %v", err) + } + + client := &Client{ + config: &Config{ + Endpoint: "https://example.com", + ProvisioningKey: "spk-test", + }, + clientType: "newt", + configFilePath: configPath, + } + + if err := client.loadConfig(); err != nil { + t.Fatalf("loadConfig returned error for empty file: %v", err) + } + + if !client.configNeedsSave { + t.Fatal("expected empty config file to mark configNeedsSave") + } +} +