Compare commits

..

2 Commits

Author SHA1 Message Date
Owen
44ca592a5c Set newt version in dockerfile 2026-03-08 11:28:56 -07:00
Owen
e1edbcea07 Make sure to set version and fix prepare issue 2026-03-07 12:34:55 -08:00
18 changed files with 203 additions and 2737 deletions

View File

@@ -38,12 +38,10 @@ type WgConfig struct {
type Target struct { type Target struct {
SourcePrefix string `json:"sourcePrefix"` SourcePrefix string `json:"sourcePrefix"`
SourcePrefixes []string `json:"sourcePrefixes"`
DestPrefix string `json:"destPrefix"` DestPrefix string `json:"destPrefix"`
RewriteTo string `json:"rewriteTo,omitempty"` RewriteTo string `json:"rewriteTo,omitempty"`
DisableIcmp bool `json:"disableIcmp,omitempty"` DisableIcmp bool `json:"disableIcmp,omitempty"`
PortRange []PortRange `json:"portRange,omitempty"` PortRange []PortRange `json:"portRange,omitempty"`
ResourceId int `json:"resourceId,omitempty"`
} }
type PortRange struct { type PortRange struct {
@@ -114,6 +112,8 @@ func NewWireGuardService(interfaceName string, port uint16, mtu int, host string
return nil, fmt.Errorf("failed to generate private key: %v", err) return nil, fmt.Errorf("failed to generate private key: %v", err)
} }
logger.Debug("+++++++++++++++++++++++++++++++= the port is %d", port)
if port == 0 { if port == 0 {
// Find an available port // Find an available port
portRandom, err := util.FindAvailableUDPPort(49152, 65535) portRandom, err := util.FindAvailableUDPPort(49152, 65535)
@@ -174,7 +174,6 @@ func NewWireGuardService(interfaceName string, port uint16, mtu int, host string
wsClient.RegisterHandler("newt/wg/targets/add", service.handleAddTarget) wsClient.RegisterHandler("newt/wg/targets/add", service.handleAddTarget)
wsClient.RegisterHandler("newt/wg/targets/remove", service.handleRemoveTarget) wsClient.RegisterHandler("newt/wg/targets/remove", service.handleRemoveTarget)
wsClient.RegisterHandler("newt/wg/targets/update", service.handleUpdateTarget) wsClient.RegisterHandler("newt/wg/targets/update", service.handleUpdateTarget)
wsClient.RegisterHandler("newt/wg/sync", service.handleSyncConfig)
return service, nil return service, nil
} }
@@ -197,15 +196,6 @@ func (s *WireGuardService) Close() {
s.stopGetConfig = nil s.stopGetConfig = nil
} }
// Flush access logs before tearing down the tunnel
if s.tnet != nil {
if ph := s.tnet.GetProxyHandler(); ph != nil {
if al := ph.GetAccessLogger(); al != nil {
al.Close()
}
}
}
// Stop the direct UDP relay first // Stop the direct UDP relay first
s.StopDirectUDPRelay() s.StopDirectUDPRelay()
@@ -504,183 +494,6 @@ func (s *WireGuardService) handleConfig(msg websocket.WSMessage) {
logger.Info("Client connectivity setup. Ready to accept connections from clients!") logger.Info("Client connectivity setup. Ready to accept connections from clients!")
} }
// SyncConfig represents the configuration sent from server for syncing
type SyncConfig struct {
Targets []Target `json:"targets"`
Peers []Peer `json:"peers"`
}
func (s *WireGuardService) handleSyncConfig(msg websocket.WSMessage) {
var syncConfig SyncConfig
logger.Debug("Received sync message: %v", msg)
logger.Info("Received sync configuration from remote server")
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling sync data: %v", err)
return
}
if err := json.Unmarshal(jsonData, &syncConfig); err != nil {
logger.Error("Error unmarshaling sync data: %v", err)
return
}
// Sync peers
if err := s.syncPeers(syncConfig.Peers); err != nil {
logger.Error("Failed to sync peers: %v", err)
}
// Sync targets
if err := s.syncTargets(syncConfig.Targets); err != nil {
logger.Error("Failed to sync targets: %v", err)
}
}
// syncPeers synchronizes the current peers with the desired state
// It removes peers not in the desired list and adds missing ones
func (s *WireGuardService) syncPeers(desiredPeers []Peer) error {
if s.device == nil {
return fmt.Errorf("WireGuard device is not initialized")
}
// Get current peers from the device
currentConfig, err := s.device.IpcGet()
if err != nil {
return fmt.Errorf("failed to get current device config: %v", err)
}
// Parse current peer public keys
lines := strings.Split(currentConfig, "\n")
currentPeerKeys := make(map[string]bool)
for _, line := range lines {
if strings.HasPrefix(line, "public_key=") {
pubKey := strings.TrimPrefix(line, "public_key=")
currentPeerKeys[pubKey] = true
}
}
// Build a map of desired peers by their public key (normalized)
desiredPeerMap := make(map[string]Peer)
for _, peer := range desiredPeers {
// Normalize the public key for comparison
pubKey, err := wgtypes.ParseKey(peer.PublicKey)
if err != nil {
logger.Warn("Invalid public key in desired peers: %s", peer.PublicKey)
continue
}
normalizedKey := util.FixKey(pubKey.String())
desiredPeerMap[normalizedKey] = peer
}
// Remove peers that are not in the desired list
for currentKey := range currentPeerKeys {
if _, exists := desiredPeerMap[currentKey]; !exists {
// Parse the key back to get the original format for removal
removeConfig := fmt.Sprintf("public_key=%s\nremove=true", currentKey)
if err := s.device.IpcSet(removeConfig); err != nil {
logger.Warn("Failed to remove peer %s during sync: %v", currentKey, err)
} else {
logger.Info("Removed peer %s during sync", currentKey)
}
}
}
// Add peers that are missing
for normalizedKey, peer := range desiredPeerMap {
if _, exists := currentPeerKeys[normalizedKey]; !exists {
if err := s.addPeerToDevice(peer); err != nil {
logger.Warn("Failed to add peer %s during sync: %v", peer.PublicKey, err)
} else {
logger.Info("Added peer %s during sync", peer.PublicKey)
}
}
}
return nil
}
// syncTargets synchronizes the current targets with the desired state
// It removes targets not in the desired list and adds missing ones
func (s *WireGuardService) syncTargets(desiredTargets []Target) error {
if s.tnet == nil {
// Native interface mode - proxy features not available, skip silently
logger.Debug("Skipping target sync - using native interface (no proxy support)")
return nil
}
// Get current rules from the proxy handler
currentRules := s.tnet.GetProxySubnetRules()
// Build a map of current rules by source+dest prefix
type ruleKey struct {
sourcePrefix string
destPrefix string
}
currentRuleMap := make(map[ruleKey]bool)
for _, rule := range currentRules {
key := ruleKey{
sourcePrefix: rule.SourcePrefix.String(),
destPrefix: rule.DestPrefix.String(),
}
currentRuleMap[key] = true
}
// Build a map of desired targets
desiredTargetMap := make(map[ruleKey]Target)
for _, target := range desiredTargets {
key := ruleKey{
sourcePrefix: target.SourcePrefix,
destPrefix: target.DestPrefix,
}
desiredTargetMap[key] = target
}
// Remove targets that are not in the desired list
for _, rule := range currentRules {
key := ruleKey{
sourcePrefix: rule.SourcePrefix.String(),
destPrefix: rule.DestPrefix.String(),
}
if _, exists := desiredTargetMap[key]; !exists {
s.tnet.RemoveProxySubnetRule(rule.SourcePrefix, rule.DestPrefix)
logger.Info("Removed target %s -> %s during sync", rule.SourcePrefix.String(), rule.DestPrefix.String())
}
}
// Add targets that are missing
for key, target := range desiredTargetMap {
if _, exists := currentRuleMap[key]; !exists {
sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix)
if err != nil {
logger.Warn("Invalid source prefix %s during sync: %v", target.SourcePrefix, err)
continue
}
destPrefix, err := netip.ParsePrefix(target.DestPrefix)
if err != nil {
logger.Warn("Invalid dest prefix %s during sync: %v", target.DestPrefix, err)
continue
}
var portRanges []netstack2.PortRange
for _, pr := range target.PortRange {
portRanges = append(portRanges, netstack2.PortRange{
Min: pr.Min,
Max: pr.Max,
Protocol: pr.Protocol,
})
}
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp, target.ResourceId)
logger.Info("Added target %s -> %s during sync", target.SourcePrefix, target.DestPrefix)
}
}
return nil
}
func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error { func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
s.mu.Lock() s.mu.Lock()
@@ -804,13 +617,6 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
s.TunnelIP = tunnelIP.String() s.TunnelIP = tunnelIP.String()
// Configure the access log sender to ship compressed session logs via websocket
s.tnet.SetAccessLogSender(func(data string) error {
return s.client.SendMessageNoLog("newt/access-log", map[string]interface{}{
"compressed": data,
})
})
// Create WireGuard device using the shared bind // Create WireGuard device using the shared bind
s.device = device.NewDevice(s.tun, s.sharedBind, device.NewLogger( s.device = device.NewDevice(s.tun, s.sharedBind, device.NewLogger(
device.LogLevelSilent, // Use silent logging by default - could be made configurable device.LogLevelSilent, // Use silent logging by default - could be made configurable
@@ -891,19 +697,6 @@ func (s *WireGuardService) ensureWireguardPeers(peers []Peer) error {
return nil return nil
} }
// resolveSourcePrefixes returns the effective list of source prefixes for a target,
// supporting both the legacy single SourcePrefix field and the new SourcePrefixes array.
// If SourcePrefixes is non-empty it takes precedence; otherwise SourcePrefix is used.
func resolveSourcePrefixes(target Target) []string {
if len(target.SourcePrefixes) > 0 {
return target.SourcePrefixes
}
if target.SourcePrefix != "" {
return []string{target.SourcePrefix}
}
return nil
}
func (s *WireGuardService) ensureTargets(targets []Target) error { func (s *WireGuardService) ensureTargets(targets []Target) error {
if s.tnet == nil { if s.tnet == nil {
// Native interface mode - proxy features not available, skip silently // Native interface mode - proxy features not available, skip silently
@@ -912,6 +705,11 @@ func (s *WireGuardService) ensureTargets(targets []Target) error {
} }
for _, target := range targets { for _, target := range targets {
sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix)
if err != nil {
return fmt.Errorf("invalid CIDR %s: %v", target.SourcePrefix, err)
}
destPrefix, err := netip.ParsePrefix(target.DestPrefix) destPrefix, err := netip.ParsePrefix(target.DestPrefix)
if err != nil { if err != nil {
return fmt.Errorf("invalid CIDR %s: %v", target.DestPrefix, err) return fmt.Errorf("invalid CIDR %s: %v", target.DestPrefix, err)
@@ -926,14 +724,9 @@ func (s *WireGuardService) ensureTargets(targets []Target) error {
}) })
} }
for _, sp := range resolveSourcePrefixes(target) { s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
sourcePrefix, err := netip.ParsePrefix(sp)
if err != nil { logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v disableIcmp: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange, target.DisableIcmp)
return fmt.Errorf("invalid CIDR %s: %v", sp, err)
}
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp, target.ResourceId)
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange)
}
} }
return nil return nil
@@ -1303,6 +1096,12 @@ func (s *WireGuardService) handleAddTarget(msg websocket.WSMessage) {
// Process all targets // Process all targets
for _, target := range targets { for _, target := range targets {
sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix)
if err != nil {
logger.Info("Invalid CIDR %s: %v", target.SourcePrefix, err)
continue
}
destPrefix, err := netip.ParsePrefix(target.DestPrefix) destPrefix, err := netip.ParsePrefix(target.DestPrefix)
if err != nil { if err != nil {
logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err) logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err)
@@ -1318,15 +1117,9 @@ func (s *WireGuardService) handleAddTarget(msg websocket.WSMessage) {
}) })
} }
for _, sp := range resolveSourcePrefixes(target) { s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
sourcePrefix, err := netip.ParsePrefix(sp)
if err != nil { logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v disableIcmp: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange, target.DisableIcmp)
logger.Info("Invalid CIDR %s: %v", sp, err)
continue
}
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp, target.ResourceId)
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange)
}
} }
} }
@@ -1355,21 +1148,21 @@ func (s *WireGuardService) handleRemoveTarget(msg websocket.WSMessage) {
// Process all targets // Process all targets
for _, target := range targets { for _, target := range targets {
sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix)
if err != nil {
logger.Info("Invalid CIDR %s: %v", target.SourcePrefix, err)
continue
}
destPrefix, err := netip.ParsePrefix(target.DestPrefix) destPrefix, err := netip.ParsePrefix(target.DestPrefix)
if err != nil { if err != nil {
logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err) logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err)
continue continue
} }
for _, sp := range resolveSourcePrefixes(target) {
sourcePrefix, err := netip.ParsePrefix(sp)
if err != nil {
logger.Info("Invalid CIDR %s: %v", sp, err)
continue
}
s.tnet.RemoveProxySubnetRule(sourcePrefix, destPrefix) s.tnet.RemoveProxySubnetRule(sourcePrefix, destPrefix)
logger.Info("Removed target subnet %s with destination %s", sp, target.DestPrefix)
} logger.Info("Removed target subnet %s with destination %s", target.SourcePrefix, target.DestPrefix)
} }
} }
@@ -1403,24 +1196,30 @@ func (s *WireGuardService) handleUpdateTarget(msg websocket.WSMessage) {
// Process all update requests // Process all update requests
for _, target := range requests.OldTargets { for _, target := range requests.OldTargets {
sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix)
if err != nil {
logger.Info("Invalid CIDR %s: %v", target.SourcePrefix, err)
continue
}
destPrefix, err := netip.ParsePrefix(target.DestPrefix) destPrefix, err := netip.ParsePrefix(target.DestPrefix)
if err != nil { if err != nil {
logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err) logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err)
continue continue
} }
for _, sp := range resolveSourcePrefixes(target) {
sourcePrefix, err := netip.ParsePrefix(sp)
if err != nil {
logger.Info("Invalid CIDR %s: %v", sp, err)
continue
}
s.tnet.RemoveProxySubnetRule(sourcePrefix, destPrefix) s.tnet.RemoveProxySubnetRule(sourcePrefix, destPrefix)
logger.Info("Removed target subnet %s with destination %s", sp, target.DestPrefix) logger.Info("Removed target subnet %s with destination %s", target.SourcePrefix, target.DestPrefix)
}
} }
for _, target := range requests.NewTargets { for _, target := range requests.NewTargets {
// Now add the new target
sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix)
if err != nil {
logger.Info("Invalid CIDR %s: %v", target.SourcePrefix, err)
continue
}
destPrefix, err := netip.ParsePrefix(target.DestPrefix) destPrefix, err := netip.ParsePrefix(target.DestPrefix)
if err != nil { if err != nil {
logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err) logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err)
@@ -1436,15 +1235,8 @@ func (s *WireGuardService) handleUpdateTarget(msg websocket.WSMessage) {
}) })
} }
for _, sp := range resolveSourcePrefixes(target) { s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
sourcePrefix, err := netip.ParsePrefix(sp) logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v disableIcmp: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange, target.DisableIcmp)
if err != nil {
logger.Info("Invalid CIDR %s: %v", sp, err)
continue
}
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp, target.ResourceId)
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange)
}
} }
} }

View File

@@ -5,10 +5,8 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net"
"os" "os"
"os/exec" "os/exec"
"regexp"
"strings" "strings"
"time" "time"
@@ -365,62 +363,27 @@ func parseTargetData(data interface{}) (TargetData, error) {
return targetData, nil return targetData, nil
} }
// parseTargetString parses a target string in the format "listenPort:host:targetPort"
// It properly handles IPv6 addresses which must be in brackets: "listenPort:[ipv6]:targetPort"
// Examples:
// - IPv4: "3001:192.168.1.1:80"
// - IPv6: "3001:[::1]:8080" or "3001:[fd70:1452:b736:4dd5:caca:7db9:c588:f5b3]:80"
//
// Returns listenPort, targetAddress (in host:port format suitable for net.Dial), and error
func parseTargetString(target string) (int, string, error) {
// Find the first colon to extract the listen port
firstColon := strings.Index(target, ":")
if firstColon == -1 {
return 0, "", fmt.Errorf("invalid target format, no colon found: %s", target)
}
listenPortStr := target[:firstColon]
var listenPort int
_, err := fmt.Sscanf(listenPortStr, "%d", &listenPort)
if err != nil {
return 0, "", fmt.Errorf("invalid listen port: %s", listenPortStr)
}
if listenPort <= 0 || listenPort > 65535 {
return 0, "", fmt.Errorf("listen port out of range: %d", listenPort)
}
// The remainder is host:targetPort - use net.SplitHostPort which handles IPv6 brackets
remainder := target[firstColon+1:]
host, targetPort, err := net.SplitHostPort(remainder)
if err != nil {
return 0, "", fmt.Errorf("invalid host:port format '%s': %w", remainder, err)
}
// Reject empty host or target port
if host == "" {
return 0, "", fmt.Errorf("empty host in target: %s", target)
}
if targetPort == "" {
return 0, "", fmt.Errorf("empty target port in target: %s", target)
}
// Reconstruct the target address using JoinHostPort (handles IPv6 properly)
targetAddr := net.JoinHostPort(host, targetPort)
return listenPort, targetAddr, nil
}
func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error { func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error {
for _, t := range targetData.Targets { for _, t := range targetData.Targets {
// Parse the target string, handling both IPv4 and IPv6 addresses // Split the first number off of the target with : separator and use as the port
port, target, err := parseTargetString(t) parts := strings.Split(t, ":")
if len(parts) != 3 {
logger.Info("Invalid target format: %s", t)
continue
}
// Get the port as an int
port := 0
_, err := fmt.Sscanf(parts[0], "%d", &port)
if err != nil { if err != nil {
logger.Info("Invalid target format: %s (%v)", t, err) logger.Info("Invalid port: %s", parts[0])
continue continue
} }
switch action { switch action {
case "add": case "add":
target := parts[1] + ":" + parts[2]
// Call updown script if provided // Call updown script if provided
processedTarget := target processedTarget := target
if updownScript != "" { if updownScript != "" {
@@ -447,6 +410,8 @@ func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto
case "remove": case "remove":
logger.Info("Removing target with port %d", port) logger.Info("Removing target with port %d", port)
target := parts[1] + ":" + parts[2]
// Call updown script if provided // Call updown script if provided
if updownScript != "" { if updownScript != "" {
_, err := executeUpdownScript(action, proto, target) _, err := executeUpdownScript(action, proto, target)
@@ -455,7 +420,7 @@ func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto
} }
} }
err = pm.RemoveTarget(proto, tunnelIP, port) err := pm.RemoveTarget(proto, tunnelIP, port)
if err != nil { if err != nil {
logger.Error("Failed to remove target: %v", err) logger.Error("Failed to remove target: %v", err)
return err return err
@@ -510,29 +475,6 @@ func executeUpdownScript(action, proto, target string) (string, error) {
return target, nil return target, nil
} }
// interpolateBlueprint finds all {{...}} tokens in the raw blueprint bytes and
// replaces recognised schemes with their resolved values. Currently supported:
//
// - env.<VAR> replaced with the value of the named environment variable
//
// Any token that does not match a supported scheme is left as-is so that
// future schemes (e.g. tag., api.) are preserved rather than silently dropped.
func interpolateBlueprint(data []byte) []byte {
re := regexp.MustCompile(`\{\{([^}]+)\}\}`)
return re.ReplaceAllFunc(data, func(match []byte) []byte {
// strip the surrounding {{ }}
inner := strings.TrimSpace(string(match[2 : len(match)-2]))
if strings.HasPrefix(inner, "env.") {
varName := strings.TrimPrefix(inner, "env.")
return []byte(os.Getenv(varName))
}
// unrecognised scheme leave the token untouched
return match
})
}
func sendBlueprint(client *websocket.Client) error { func sendBlueprint(client *websocket.Client) error {
if blueprintFile == "" { if blueprintFile == "" {
return nil return nil
@@ -542,9 +484,6 @@ func sendBlueprint(client *websocket.Client) error {
if err != nil { if err != nil {
logger.Error("Failed to read blueprint file: %v", err) logger.Error("Failed to read blueprint file: %v", err)
} else { } else {
// interpolate {{env.VAR}} (and any future schemes) before parsing
blueprintData = interpolateBlueprint(blueprintData)
// first we should convert the yaml to json and error if the yaml is bad // first we should convert the yaml to json and error if the yaml is bad
var yamlObj interface{} var yamlObj interface{}
var blueprintJsonData string var blueprintJsonData string

View File

@@ -1,212 +0,0 @@
package main
import (
"net"
"testing"
)
func TestParseTargetString(t *testing.T) {
tests := []struct {
name string
input string
wantListenPort int
wantTargetAddr string
wantErr bool
}{
// IPv4 test cases
{
name: "valid IPv4 basic",
input: "3001:192.168.1.1:80",
wantListenPort: 3001,
wantTargetAddr: "192.168.1.1:80",
wantErr: false,
},
{
name: "valid IPv4 localhost",
input: "8080:127.0.0.1:3000",
wantListenPort: 8080,
wantTargetAddr: "127.0.0.1:3000",
wantErr: false,
},
{
name: "valid IPv4 same ports",
input: "443:10.0.0.1:443",
wantListenPort: 443,
wantTargetAddr: "10.0.0.1:443",
wantErr: false,
},
// IPv6 test cases
{
name: "valid IPv6 loopback",
input: "3001:[::1]:8080",
wantListenPort: 3001,
wantTargetAddr: "[::1]:8080",
wantErr: false,
},
{
name: "valid IPv6 full address",
input: "80:[fd70:1452:b736:4dd5:caca:7db9:c588:f5b3]:8080",
wantListenPort: 80,
wantTargetAddr: "[fd70:1452:b736:4dd5:caca:7db9:c588:f5b3]:8080",
wantErr: false,
},
{
name: "valid IPv6 link-local",
input: "443:[fe80::1]:443",
wantListenPort: 443,
wantTargetAddr: "[fe80::1]:443",
wantErr: false,
},
{
name: "valid IPv6 all zeros compressed",
input: "8000:[::]:9000",
wantListenPort: 8000,
wantTargetAddr: "[::]:9000",
wantErr: false,
},
{
name: "valid IPv6 mixed notation",
input: "5000:[::ffff:192.168.1.1]:6000",
wantListenPort: 5000,
wantTargetAddr: "[::ffff:192.168.1.1]:6000",
wantErr: false,
},
// Hostname test cases
{
name: "valid hostname",
input: "8080:example.com:80",
wantListenPort: 8080,
wantTargetAddr: "example.com:80",
wantErr: false,
},
{
name: "valid hostname with subdomain",
input: "443:api.example.com:8443",
wantListenPort: 443,
wantTargetAddr: "api.example.com:8443",
wantErr: false,
},
{
name: "valid localhost hostname",
input: "3000:localhost:3000",
wantListenPort: 3000,
wantTargetAddr: "localhost:3000",
wantErr: false,
},
// Error cases
{
name: "invalid - no colons",
input: "invalid",
wantErr: true,
},
{
name: "invalid - empty string",
input: "",
wantErr: true,
},
{
name: "invalid - non-numeric listen port",
input: "abc:192.168.1.1:80",
wantErr: true,
},
{
name: "invalid - missing target port",
input: "3001:192.168.1.1",
wantErr: true,
},
{
name: "invalid - IPv6 without brackets",
input: "3001:fd70:1452:b736:4dd5:caca:7db9:c588:f5b3:80",
wantErr: true,
},
{
name: "invalid - only listen port",
input: "3001:",
wantErr: true,
},
{
name: "invalid - missing host",
input: "3001::80",
wantErr: true,
},
{
name: "invalid - IPv6 unclosed bracket",
input: "3001:[::1:80",
wantErr: true,
},
{
name: "invalid - listen port zero",
input: "0:192.168.1.1:80",
wantErr: true,
},
{
name: "invalid - listen port negative",
input: "-1:192.168.1.1:80",
wantErr: true,
},
{
name: "invalid - listen port out of range",
input: "70000:192.168.1.1:80",
wantErr: true,
},
{
name: "invalid - empty target port",
input: "3001:192.168.1.1:",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
listenPort, targetAddr, err := parseTargetString(tt.input)
if (err != nil) != tt.wantErr {
t.Errorf("parseTargetString(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr)
return
}
if tt.wantErr {
return // Don't check other values if we expected an error
}
if listenPort != tt.wantListenPort {
t.Errorf("parseTargetString(%q) listenPort = %d, want %d", tt.input, listenPort, tt.wantListenPort)
}
if targetAddr != tt.wantTargetAddr {
t.Errorf("parseTargetString(%q) targetAddr = %q, want %q", tt.input, targetAddr, tt.wantTargetAddr)
}
})
}
}
// TestParseTargetStringNetDialCompatibility verifies that the output is compatible with net.Dial
func TestParseTargetStringNetDialCompatibility(t *testing.T) {
tests := []struct {
name string
input string
}{
{"IPv4", "8080:127.0.0.1:80"},
{"IPv6 loopback", "8080:[::1]:80"},
{"IPv6 full", "8080:[2001:db8::1]:80"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, targetAddr, err := parseTargetString(tt.input)
if err != nil {
t.Fatalf("parseTargetString(%q) unexpected error: %v", tt.input, err)
}
// Verify the format is valid for net.Dial by checking it can be split back
// This doesn't actually dial, just validates the format
_, _, err = net.SplitHostPort(targetAddr)
if err != nil {
t.Errorf("parseTargetString(%q) produced invalid net.Dial format %q: %v", tt.input, targetAddr, err)
}
})
}
}

View File

@@ -1,4 +0,0 @@
{
"endpoint": "http://you.fosrl.io",
"provisioningKey": "spk-xt1opb0fkoqb7qb.hi44jciamqcrdaja4lvz3kp52pl3lssamp6asuyx"
}

View File

@@ -1,4 +0,0 @@
{
"endpoint": "http://you.fosrl.io",
"provisioningKey": "spk-xt1opb0fkoqb7qb.hi44jciamqcrdaja4lvz3kp52pl3lssamp6asuyx"
}

View File

@@ -1,7 +1,7 @@
#!/bin/sh #!/bin/bash
# Get Newt - Cross-platform installation script # Get Newt - Cross-platform installation script
# Usage: curl -fsSL https://raw.githubusercontent.com/fosrl/newt/refs/heads/main/get-newt.sh | sh # Usage: curl -fsSL https://raw.githubusercontent.com/fosrl/newt/refs/heads/main/get-newt.sh | bash
set -e set -e
@@ -17,15 +17,15 @@ GITHUB_API_URL="https://api.github.com/repos/${REPO}/releases/latest"
# Function to print colored output # Function to print colored output
print_status() { print_status() {
printf '%b[INFO]%b %s\n' "${GREEN}" "${NC}" "$1" echo -e "${GREEN}[INFO]${NC} $1"
} }
print_warning() { print_warning() {
printf '%b[WARN]%b %s\n' "${YELLOW}" "${NC}" "$1" echo -e "${YELLOW}[WARN]${NC} $1"
} }
print_error() { print_error() {
printf '%b[ERROR]%b %s\n' "${RED}" "${NC}" "$1" echo -e "${RED}[ERROR]${NC} $1"
} }
# Function to get latest version from GitHub API # Function to get latest version from GitHub API
@@ -113,34 +113,16 @@ get_install_dir() {
if [ "$OS" = "windows" ]; then if [ "$OS" = "windows" ]; then
echo "$HOME/bin" echo "$HOME/bin"
else else
# Prefer /usr/local/bin for system-wide installation # Try to use a directory in PATH, fallback to ~/.local/bin
if echo "$PATH" | grep -q "/usr/local/bin"; then
if [ -w "/usr/local/bin" ] 2>/dev/null; then
echo "/usr/local/bin" echo "/usr/local/bin"
fi
}
# Check if we need sudo for installation
needs_sudo() {
local install_dir="$1"
if [ -w "$install_dir" ] 2>/dev/null; then
return 1 # No sudo needed
else else
return 0 # Sudo needed echo "$HOME/.local/bin"
fi
}
# Get the appropriate command prefix (sudo or empty)
get_sudo_cmd() {
local install_dir="$1"
if needs_sudo "$install_dir"; then
if command -v sudo >/dev/null 2>&1; then
echo "sudo"
else
print_error "Cannot write to ${install_dir} and sudo is not available."
print_error "Please run this script as root or install sudo."
exit 1
fi fi
else else
echo "" echo "$HOME/.local/bin"
fi
fi fi
} }
@@ -148,17 +130,14 @@ get_sudo_cmd() {
install_newt() { install_newt() {
local platform="$1" local platform="$1"
local install_dir="$2" local install_dir="$2"
local sudo_cmd="$3"
local binary_name="newt_${platform}" local binary_name="newt_${platform}"
local exe_suffix="" local exe_suffix=""
# Add .exe suffix for Windows # Add .exe suffix for Windows
case "$platform" in if [[ "$platform" == *"windows"* ]]; then
*windows*)
binary_name="${binary_name}.exe" binary_name="${binary_name}.exe"
exe_suffix=".exe" exe_suffix=".exe"
;; fi
esac
local download_url="${BASE_URL}/${binary_name}" local download_url="${BASE_URL}/${binary_name}"
local temp_file="/tmp/newt${exe_suffix}" local temp_file="/tmp/newt${exe_suffix}"
@@ -176,18 +155,14 @@ install_newt() {
exit 1 exit 1
fi fi
# Make executable before moving
chmod +x "$temp_file"
# Create install directory if it doesn't exist # Create install directory if it doesn't exist
if [ -n "$sudo_cmd" ]; then
$sudo_cmd mkdir -p "$install_dir"
print_status "Using sudo to install to ${install_dir}"
$sudo_cmd mv "$temp_file" "$final_path"
else
mkdir -p "$install_dir" mkdir -p "$install_dir"
# Move binary to install directory
mv "$temp_file" "$final_path" mv "$temp_file" "$final_path"
fi
# Make executable (not needed on Windows, but doesn't hurt)
chmod +x "$final_path"
print_status "newt installed to ${final_path}" print_status "newt installed to ${final_path}"
@@ -204,9 +179,9 @@ verify_installation() {
local install_dir="$1" local install_dir="$1"
local exe_suffix="" local exe_suffix=""
case "$PLATFORM" in if [[ "$PLATFORM" == *"windows"* ]]; then
*windows*) exe_suffix=".exe" ;; exe_suffix=".exe"
esac fi
local newt_path="${install_dir}/newt${exe_suffix}" local newt_path="${install_dir}/newt${exe_suffix}"
@@ -240,19 +215,17 @@ main() {
INSTALL_DIR=$(get_install_dir) INSTALL_DIR=$(get_install_dir)
print_status "Install directory: ${INSTALL_DIR}" print_status "Install directory: ${INSTALL_DIR}"
# Check if we need sudo
SUDO_CMD=$(get_sudo_cmd "$INSTALL_DIR")
if [ -n "$SUDO_CMD" ]; then
print_status "Root privileges required for installation to ${INSTALL_DIR}"
fi
# Install newt # Install newt
install_newt "$PLATFORM" "$INSTALL_DIR" "$SUDO_CMD" install_newt "$PLATFORM" "$INSTALL_DIR"
# Verify installation # Verify installation
if verify_installation "$INSTALL_DIR"; then if verify_installation "$INSTALL_DIR"; then
print_status "newt is ready to use!" print_status "newt is ready to use!"
if [[ "$PLATFORM" == *"windows"* ]]; then
print_status "Run 'newt --help' to get started" print_status "Run 'newt --help' to get started"
else
print_status "Run 'newt --help' to get started"
fi
else else
exit 1 exit 1
fi fi

View File

@@ -5,9 +5,7 @@ import (
"crypto/tls" "crypto/tls"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net"
"net/http" "net/http"
"strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
@@ -367,12 +365,11 @@ func (m *Monitor) performHealthCheck(target *Target) {
target.LastCheck = time.Now() target.LastCheck = time.Now()
target.LastError = "" target.LastError = ""
// Build URL (use net.JoinHostPort to properly handle IPv6 addresses with ports) // Build URL
host := target.Config.Hostname url := fmt.Sprintf("%s://%s", target.Config.Scheme, target.Config.Hostname)
if target.Config.Port > 0 { if target.Config.Port > 0 {
host = net.JoinHostPort(target.Config.Hostname, strconv.Itoa(target.Config.Port)) url = fmt.Sprintf("%s:%d", url, target.Config.Port)
} }
url := fmt.Sprintf("%s://%s", target.Config.Scheme, host)
if target.Config.Path != "" { if target.Config.Path != "" {
if !strings.HasPrefix(target.Config.Path, "/") { if !strings.HasPrefix(target.Config.Path, "/") {
url += "/" url += "/"
@@ -524,82 +521,3 @@ func (m *Monitor) DisableTarget(id int) error {
return nil return nil
} }
// GetTargetIDs returns a slice of all current target IDs
func (m *Monitor) GetTargetIDs() []int {
m.mutex.RLock()
defer m.mutex.RUnlock()
ids := make([]int, 0, len(m.targets))
for id := range m.targets {
ids = append(ids, id)
}
return ids
}
// SyncTargets synchronizes the current targets to match the desired set.
// It removes targets not in the desired set and adds targets that are missing.
func (m *Monitor) SyncTargets(desiredConfigs []Config) error {
m.mutex.Lock()
defer m.mutex.Unlock()
logger.Info("Syncing health check targets: %d desired targets", len(desiredConfigs))
// Build a set of desired target IDs
desiredIDs := make(map[int]Config)
for _, config := range desiredConfigs {
desiredIDs[config.ID] = config
}
// Find targets to remove (exist but not in desired set)
var toRemove []int
for id := range m.targets {
if _, exists := desiredIDs[id]; !exists {
toRemove = append(toRemove, id)
}
}
// Remove targets that are not in the desired set
for _, id := range toRemove {
logger.Info("Sync: removing health check target %d", id)
if target, exists := m.targets[id]; exists {
target.cancel()
delete(m.targets, id)
}
}
// Add or update targets from the desired set
var addedCount, updatedCount int
for id, config := range desiredIDs {
if existing, exists := m.targets[id]; exists {
// Target exists - check if config changed and update if needed
// For now, we'll replace it to ensure config is up to date
logger.Debug("Sync: updating health check target %d", id)
existing.cancel()
delete(m.targets, id)
if err := m.addTargetUnsafe(config); err != nil {
logger.Error("Sync: failed to update target %d: %v", id, err)
return fmt.Errorf("failed to update target %d: %v", id, err)
}
updatedCount++
} else {
// Target doesn't exist - add it
logger.Debug("Sync: adding health check target %d", id)
if err := m.addTargetUnsafe(config); err != nil {
logger.Error("Sync: failed to add target %d: %v", id, err)
return fmt.Errorf("failed to add target %d: %v", id, err)
}
addedCount++
}
}
logger.Info("Sync complete: removed %d, added %d, updated %d targets",
len(toRemove), addedCount, updatedCount)
// Notify callback if any changes were made
if (len(toRemove) > 0 || addedCount > 0 || updatedCount > 0) && m.callback != nil {
go m.callback(m.getAllTargetsUnsafe())
}
return nil
}

214
main.go
View File

@@ -10,7 +10,6 @@ import (
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
"net/http/pprof"
"net/netip" "net/netip"
"os" "os"
"os/signal" "os/signal"
@@ -148,7 +147,6 @@ var (
adminAddr string adminAddr string
region string region string
metricsAsyncBytes bool metricsAsyncBytes bool
pprofEnabled bool
blueprintFile string blueprintFile string
noCloud bool noCloud bool
@@ -159,12 +157,6 @@ var (
// Legacy PKCS12 support (deprecated) // Legacy PKCS12 support (deprecated)
tlsPrivateKey string tlsPrivateKey string
// Provisioning key exchanged once for a permanent newt ID + secret
provisioningKey string
// Path to config file (overrides CONFIG_FILE env var and default location)
configFile string
) )
func main() { func main() {
@@ -233,7 +225,6 @@ func runNewtMain(ctx context.Context) {
adminAddrEnv := os.Getenv("NEWT_ADMIN_ADDR") adminAddrEnv := os.Getenv("NEWT_ADMIN_ADDR")
regionEnv := os.Getenv("NEWT_REGION") regionEnv := os.Getenv("NEWT_REGION")
asyncBytesEnv := os.Getenv("NEWT_METRICS_ASYNC_BYTES") asyncBytesEnv := os.Getenv("NEWT_METRICS_ASYNC_BYTES")
pprofEnabledEnv := os.Getenv("NEWT_PPROF_ENABLED")
disableClientsEnv := os.Getenv("DISABLE_CLIENTS") disableClientsEnv := os.Getenv("DISABLE_CLIENTS")
disableClients = disableClientsEnv == "true" disableClients = disableClientsEnv == "true"
@@ -270,8 +261,6 @@ func runNewtMain(ctx context.Context) {
blueprintFile = os.Getenv("BLUEPRINT_FILE") blueprintFile = os.Getenv("BLUEPRINT_FILE")
noCloudEnv := os.Getenv("NO_CLOUD") noCloudEnv := os.Getenv("NO_CLOUD")
noCloud = noCloudEnv == "true" noCloud = noCloudEnv == "true"
provisioningKey = os.Getenv("NEWT_PROVISIONING_KEY")
configFile = os.Getenv("CONFIG_FILE")
if endpoint == "" { if endpoint == "" {
flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server") flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server")
@@ -313,19 +302,13 @@ func runNewtMain(ctx context.Context) {
flag.StringVar(&dockerSocket, "docker-socket", "", "Path or address to Docker socket (typically unix:///var/run/docker.sock)") flag.StringVar(&dockerSocket, "docker-socket", "", "Path or address to Docker socket (typically unix:///var/run/docker.sock)")
} }
if pingIntervalStr == "" { if pingIntervalStr == "" {
flag.StringVar(&pingIntervalStr, "ping-interval", "15s", "Interval for pinging the server (default 15s)") flag.StringVar(&pingIntervalStr, "ping-interval", "3s", "Interval for pinging the server (default 3s)")
} }
if pingTimeoutStr == "" { if pingTimeoutStr == "" {
flag.StringVar(&pingTimeoutStr, "ping-timeout", "7s", " Timeout for each ping (default 7s)") flag.StringVar(&pingTimeoutStr, "ping-timeout", "5s", " Timeout for each ping (default 5s)")
} }
// load the prefer endpoint just as a flag // load the prefer endpoint just as a flag
flag.StringVar(&preferEndpoint, "prefer-endpoint", "", "Prefer this endpoint for the connection (if set, will override the endpoint from the server)") flag.StringVar(&preferEndpoint, "prefer-endpoint", "", "Prefer this endpoint for the connection (if set, will override the endpoint from the server)")
if provisioningKey == "" {
flag.StringVar(&provisioningKey, "provisioning-key", "", "One-time provisioning key used to obtain a newt ID and secret from the server")
}
if configFile == "" {
flag.StringVar(&configFile, "config-file", "", "Path to config file (overrides CONFIG_FILE env var and default location)")
}
// Add new mTLS flags // Add new mTLS flags
if tlsClientCert == "" { if tlsClientCert == "" {
@@ -347,21 +330,21 @@ func runNewtMain(ctx context.Context) {
if pingIntervalStr != "" { if pingIntervalStr != "" {
pingInterval, err = time.ParseDuration(pingIntervalStr) pingInterval, err = time.ParseDuration(pingIntervalStr)
if err != nil { if err != nil {
fmt.Printf("Invalid PING_INTERVAL value: %s, using default 15 seconds\n", pingIntervalStr) fmt.Printf("Invalid PING_INTERVAL value: %s, using default 3 seconds\n", pingIntervalStr)
pingInterval = 15 * time.Second pingInterval = 3 * time.Second
} }
} else { } else {
pingInterval = 15 * time.Second pingInterval = 3 * time.Second
} }
if pingTimeoutStr != "" { if pingTimeoutStr != "" {
pingTimeout, err = time.ParseDuration(pingTimeoutStr) pingTimeout, err = time.ParseDuration(pingTimeoutStr)
if err != nil { if err != nil {
fmt.Printf("Invalid PING_TIMEOUT value: %s, using default 7 seconds\n", pingTimeoutStr) fmt.Printf("Invalid PING_TIMEOUT value: %s, using default 5 seconds\n", pingTimeoutStr)
pingTimeout = 7 * time.Second pingTimeout = 5 * time.Second
} }
} else { } else {
pingTimeout = 7 * time.Second pingTimeout = 5 * time.Second
} }
if dockerEnforceNetworkValidation == "" { if dockerEnforceNetworkValidation == "" {
@@ -407,14 +390,6 @@ func runNewtMain(ctx context.Context) {
metricsAsyncBytes = v metricsAsyncBytes = v
} }
} }
// pprof debug endpoint toggle
if pprofEnabledEnv == "" {
flag.BoolVar(&pprofEnabled, "pprof", false, "Enable pprof debug endpoints on admin server")
} else {
if v, err := strconv.ParseBool(pprofEnabledEnv); err == nil {
pprofEnabled = v
}
}
// Optional region flag (resource attribute) // Optional region flag (resource attribute)
if regionEnv == "" { if regionEnv == "" {
flag.StringVar(&region, "region", "", "Optional region resource attribute (also NEWT_REGION)") flag.StringVar(&region, "region", "", "Optional region resource attribute (also NEWT_REGION)")
@@ -510,14 +485,6 @@ func runNewtMain(ctx context.Context) {
if tel.PrometheusHandler != nil { if tel.PrometheusHandler != nil {
mux.Handle("/metrics", tel.PrometheusHandler) mux.Handle("/metrics", tel.PrometheusHandler)
} }
if pprofEnabled {
mux.HandleFunc("/debug/pprof/", pprof.Index)
mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
mux.HandleFunc("/debug/pprof/profile", pprof.Profile)
mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
mux.HandleFunc("/debug/pprof/trace", pprof.Trace)
logger.Info("pprof debugging enabled on %s/debug/pprof/", tcfg.AdminAddr)
}
admin := &http.Server{ admin := &http.Server{
Addr: tcfg.AdminAddr, Addr: tcfg.AdminAddr,
Handler: otelhttp.NewHandler(mux, "newt-admin"), Handler: otelhttp.NewHandler(mux, "newt-admin"),
@@ -598,19 +565,13 @@ func runNewtMain(ctx context.Context) {
id, // CLI arg takes precedence id, // CLI arg takes precedence
secret, // CLI arg takes precedence secret, // CLI arg takes precedence
endpoint, endpoint,
30*time.Second, pingInterval,
pingTimeout,
opt, opt,
websocket.WithConfigFile(configFile),
) )
if err != nil { if err != nil {
logger.Fatal("Failed to create client: %v", err) logger.Fatal("Failed to create client: %v", err)
} }
// If a provisioning key was supplied via CLI / env and the config file did
// not already carry one, inject it now so provisionIfNeeded() can use it.
if provisioningKey != "" && client.GetConfig().ProvisioningKey == "" {
client.GetConfig().ProvisioningKey = provisioningKey
}
endpoint = client.GetConfig().Endpoint // Update endpoint from config endpoint = client.GetConfig().Endpoint // Update endpoint from config
id = client.GetConfig().ID // Update ID from config id = client.GetConfig().ID // Update ID from config
// Update site labels for metrics with the resolved ID // Update site labels for metrics with the resolved ID
@@ -658,6 +619,8 @@ func runNewtMain(ctx context.Context) {
var wgData WgData var wgData WgData
var dockerEventMonitor *docker.EventMonitor var dockerEventMonitor *docker.EventMonitor
logger.Debug("++++++++++++++++++++++ the port is %d", port)
if !disableClients { if !disableClients {
setupClients(client) setupClients(client)
} }
@@ -996,7 +959,7 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
"publicKey": publicKey.String(), "publicKey": publicKey.String(),
"pingResults": pingResults, "pingResults": pingResults,
"newtVersion": newtVersion, "newtVersion": newtVersion,
}, 2*time.Second) }, 1*time.Second)
return return
} }
@@ -1099,7 +1062,7 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
"publicKey": publicKey.String(), "publicKey": publicKey.String(),
"pingResults": pingResults, "pingResults": pingResults,
"newtVersion": newtVersion, "newtVersion": newtVersion,
}, 2*time.Second) }, 1*time.Second)
logger.Debug("Sent exit node ping results to cloud for selection: pingResults=%+v", pingResults) logger.Debug("Sent exit node ping results to cloud for selection: pingResults=%+v", pingResults)
}) })
@@ -1204,153 +1167,6 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
} }
}) })
// Register handler for syncing targets (TCP, UDP, and health checks)
client.RegisterHandler("newt/sync", func(msg websocket.WSMessage) {
logger.Info("Received sync message")
// if there is no wgData or pm, we can't sync targets
if wgData.TunnelIP == "" || pm == nil {
logger.Info(msgNoTunnelOrProxy)
return
}
// Define the sync data structure
type SyncData struct {
Targets TargetsByType `json:"targets"`
HealthCheckTargets []healthcheck.Config `json:"healthCheckTargets"`
}
var syncData SyncData
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling sync data: %v", err)
return
}
if err := json.Unmarshal(jsonData, &syncData); err != nil {
logger.Error("Error unmarshaling sync data: %v", err)
return
}
logger.Debug("Sync data received: TCP targets=%d, UDP targets=%d, health check targets=%d",
len(syncData.Targets.TCP), len(syncData.Targets.UDP), len(syncData.HealthCheckTargets))
//TODO: TEST AND IMPLEMENT THIS
// // Build sets of desired targets (port -> target string)
// desiredTCP := make(map[int]string)
// for _, t := range syncData.Targets.TCP {
// parts := strings.Split(t, ":")
// if len(parts) != 3 {
// logger.Warn("Invalid TCP target format: %s", t)
// continue
// }
// port := 0
// if _, err := fmt.Sscanf(parts[0], "%d", &port); err != nil {
// logger.Warn("Invalid port in TCP target: %s", parts[0])
// continue
// }
// desiredTCP[port] = parts[1] + ":" + parts[2]
// }
// desiredUDP := make(map[int]string)
// for _, t := range syncData.Targets.UDP {
// parts := strings.Split(t, ":")
// if len(parts) != 3 {
// logger.Warn("Invalid UDP target format: %s", t)
// continue
// }
// port := 0
// if _, err := fmt.Sscanf(parts[0], "%d", &port); err != nil {
// logger.Warn("Invalid port in UDP target: %s", parts[0])
// continue
// }
// desiredUDP[port] = parts[1] + ":" + parts[2]
// }
// // Get current targets from proxy manager
// currentTCP, currentUDP := pm.GetTargets()
// // Sync TCP targets
// // Remove TCP targets not in desired set
// if tcpForIP, ok := currentTCP[wgData.TunnelIP]; ok {
// for port := range tcpForIP {
// if _, exists := desiredTCP[port]; !exists {
// logger.Info("Sync: removing TCP target on port %d", port)
// targetStr := fmt.Sprintf("%d:%s", port, tcpForIP[port])
// updateTargets(pm, "remove", wgData.TunnelIP, "tcp", TargetData{Targets: []string{targetStr}})
// }
// }
// }
// // Add TCP targets that are missing
// for port, target := range desiredTCP {
// needsAdd := true
// if tcpForIP, ok := currentTCP[wgData.TunnelIP]; ok {
// if currentTarget, exists := tcpForIP[port]; exists {
// // Check if target address changed
// if currentTarget == target {
// needsAdd = false
// } else {
// // Target changed, remove old one first
// logger.Info("Sync: updating TCP target on port %d", port)
// targetStr := fmt.Sprintf("%d:%s", port, currentTarget)
// updateTargets(pm, "remove", wgData.TunnelIP, "tcp", TargetData{Targets: []string{targetStr}})
// }
// }
// }
// if needsAdd {
// logger.Info("Sync: adding TCP target on port %d -> %s", port, target)
// targetStr := fmt.Sprintf("%d:%s", port, target)
// updateTargets(pm, "add", wgData.TunnelIP, "tcp", TargetData{Targets: []string{targetStr}})
// }
// }
// // Sync UDP targets
// // Remove UDP targets not in desired set
// if udpForIP, ok := currentUDP[wgData.TunnelIP]; ok {
// for port := range udpForIP {
// if _, exists := desiredUDP[port]; !exists {
// logger.Info("Sync: removing UDP target on port %d", port)
// targetStr := fmt.Sprintf("%d:%s", port, udpForIP[port])
// updateTargets(pm, "remove", wgData.TunnelIP, "udp", TargetData{Targets: []string{targetStr}})
// }
// }
// }
// // Add UDP targets that are missing
// for port, target := range desiredUDP {
// needsAdd := true
// if udpForIP, ok := currentUDP[wgData.TunnelIP]; ok {
// if currentTarget, exists := udpForIP[port]; exists {
// // Check if target address changed
// if currentTarget == target {
// needsAdd = false
// } else {
// // Target changed, remove old one first
// logger.Info("Sync: updating UDP target on port %d", port)
// targetStr := fmt.Sprintf("%d:%s", port, currentTarget)
// updateTargets(pm, "remove", wgData.TunnelIP, "udp", TargetData{Targets: []string{targetStr}})
// }
// }
// }
// if needsAdd {
// logger.Info("Sync: adding UDP target on port %d -> %s", port, target)
// targetStr := fmt.Sprintf("%d:%s", port, target)
// updateTargets(pm, "add", wgData.TunnelIP, "udp", TargetData{Targets: []string{targetStr}})
// }
// }
// // Sync health check targets
// if err := healthMonitor.SyncTargets(syncData.HealthCheckTargets); err != nil {
// logger.Error("Failed to sync health check targets: %v", err)
// } else {
// logger.Info("Successfully synced health check targets")
// }
logger.Info("Sync complete")
})
// Register handler for Docker socket check // Register handler for Docker socket check
client.RegisterHandler("newt/socket/check", func(msg websocket.WSMessage) { client.RegisterHandler("newt/socket/check", func(msg websocket.WSMessage) {
logger.Debug("Received Docker socket check request") logger.Debug("Received Docker socket check request")
@@ -1833,8 +1649,6 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
pm.Stop() pm.Stop()
} }
client.SendMessage("newt/disconnecting", map[string]any{})
if client != nil { if client != nil {
client.Close() client.Close()
} }

View File

@@ -1,514 +0,0 @@
package netstack2
import (
"bytes"
"compress/zlib"
"crypto/rand"
"encoding/base64"
"encoding/hex"
"encoding/json"
"net"
"sort"
"sync"
"time"
"github.com/fosrl/newt/logger"
)
const (
// flushInterval is how often the access logger flushes completed sessions to the server
flushInterval = 60 * time.Second
// maxBufferedSessions is the max number of completed sessions to buffer before forcing a flush
maxBufferedSessions = 100
// sessionGapThreshold is the maximum gap between the end of one connection
// and the start of the next for them to be considered part of the same session.
// If the gap exceeds this, a new consolidated session is created.
sessionGapThreshold = 5 * time.Second
// minConnectionsToConsolidate is the minimum number of connections in a group
// before we bother consolidating. Groups smaller than this are sent as-is.
minConnectionsToConsolidate = 2
)
// SendFunc is a callback that sends compressed access log data to the server.
// The data is a base64-encoded zlib-compressed JSON array of AccessSession objects.
type SendFunc func(data string) error
// AccessSession represents a tracked access session through the proxy
type AccessSession struct {
SessionID string `json:"sessionId"`
ResourceID int `json:"resourceId"`
SourceAddr string `json:"sourceAddr"`
DestAddr string `json:"destAddr"`
Protocol string `json:"protocol"`
StartedAt time.Time `json:"startedAt"`
EndedAt time.Time `json:"endedAt,omitempty"`
BytesTx int64 `json:"bytesTx"`
BytesRx int64 `json:"bytesRx"`
ConnectionCount int `json:"connectionCount,omitempty"` // number of raw connections merged into this session (0 or 1 = single)
}
// udpSessionKey identifies a unique UDP "session" by src -> dst
type udpSessionKey struct {
srcAddr string
dstAddr string
protocol string
}
// consolidationKey groups connections that may be part of the same logical session.
// Source port is intentionally excluded so that many ephemeral-port connections
// from the same source IP to the same destination are grouped together.
type consolidationKey struct {
sourceIP string // IP only, no port
destAddr string // full host:port of the destination
protocol string
resourceID int
}
// AccessLogger tracks access sessions for resources and periodically
// flushes completed sessions to the server via a configurable SendFunc.
type AccessLogger struct {
mu sync.Mutex
sessions map[string]*AccessSession // active sessions: sessionID -> session
udpSessions map[udpSessionKey]*AccessSession // active UDP sessions for dedup
completedSessions []*AccessSession // completed sessions waiting to be flushed
udpTimeout time.Duration
sendFn SendFunc
stopCh chan struct{}
flushDone chan struct{} // closed after the flush goroutine exits
}
// NewAccessLogger creates a new access logger.
// udpTimeout controls how long a UDP session is kept alive without traffic before being ended.
func NewAccessLogger(udpTimeout time.Duration) *AccessLogger {
al := &AccessLogger{
sessions: make(map[string]*AccessSession),
udpSessions: make(map[udpSessionKey]*AccessSession),
completedSessions: make([]*AccessSession, 0),
udpTimeout: udpTimeout,
stopCh: make(chan struct{}),
flushDone: make(chan struct{}),
}
go al.backgroundLoop()
return al
}
// SetSendFunc sets the callback used to send compressed access log batches
// to the server. This can be called after construction once the websocket
// client is available.
func (al *AccessLogger) SetSendFunc(fn SendFunc) {
al.mu.Lock()
defer al.mu.Unlock()
al.sendFn = fn
}
// generateSessionID creates a random session identifier
func generateSessionID() string {
b := make([]byte, 8)
rand.Read(b)
return hex.EncodeToString(b)
}
// StartTCPSession logs the start of a TCP session and returns a session ID.
func (al *AccessLogger) StartTCPSession(resourceID int, srcAddr, dstAddr string) string {
sessionID := generateSessionID()
now := time.Now()
session := &AccessSession{
SessionID: sessionID,
ResourceID: resourceID,
SourceAddr: srcAddr,
DestAddr: dstAddr,
Protocol: "tcp",
StartedAt: now,
}
al.mu.Lock()
al.sessions[sessionID] = session
al.mu.Unlock()
logger.Info("ACCESS START session=%s resource=%d proto=tcp src=%s dst=%s time=%s",
sessionID, resourceID, srcAddr, dstAddr, now.Format(time.RFC3339))
return sessionID
}
// EndTCPSession logs the end of a TCP session and queues it for sending.
func (al *AccessLogger) EndTCPSession(sessionID string) {
now := time.Now()
al.mu.Lock()
session, ok := al.sessions[sessionID]
if ok {
session.EndedAt = now
delete(al.sessions, sessionID)
al.completedSessions = append(al.completedSessions, session)
}
shouldFlush := len(al.completedSessions) >= maxBufferedSessions
al.mu.Unlock()
if ok {
duration := now.Sub(session.StartedAt)
logger.Info("ACCESS END session=%s resource=%d proto=tcp src=%s dst=%s started=%s ended=%s duration=%s",
sessionID, session.ResourceID, session.SourceAddr, session.DestAddr,
session.StartedAt.Format(time.RFC3339), now.Format(time.RFC3339), duration)
}
if shouldFlush {
al.flush()
}
}
// TrackUDPSession starts or returns an existing UDP session. Returns the session ID.
func (al *AccessLogger) TrackUDPSession(resourceID int, srcAddr, dstAddr string) string {
key := udpSessionKey{
srcAddr: srcAddr,
dstAddr: dstAddr,
protocol: "udp",
}
al.mu.Lock()
defer al.mu.Unlock()
if existing, ok := al.udpSessions[key]; ok {
return existing.SessionID
}
sessionID := generateSessionID()
now := time.Now()
session := &AccessSession{
SessionID: sessionID,
ResourceID: resourceID,
SourceAddr: srcAddr,
DestAddr: dstAddr,
Protocol: "udp",
StartedAt: now,
}
al.sessions[sessionID] = session
al.udpSessions[key] = session
logger.Info("ACCESS START session=%s resource=%d proto=udp src=%s dst=%s time=%s",
sessionID, resourceID, srcAddr, dstAddr, now.Format(time.RFC3339))
return sessionID
}
// EndUDPSession ends a UDP session and queues it for sending.
func (al *AccessLogger) EndUDPSession(sessionID string) {
now := time.Now()
al.mu.Lock()
session, ok := al.sessions[sessionID]
if ok {
session.EndedAt = now
delete(al.sessions, sessionID)
key := udpSessionKey{
srcAddr: session.SourceAddr,
dstAddr: session.DestAddr,
protocol: "udp",
}
delete(al.udpSessions, key)
al.completedSessions = append(al.completedSessions, session)
}
shouldFlush := len(al.completedSessions) >= maxBufferedSessions
al.mu.Unlock()
if ok {
duration := now.Sub(session.StartedAt)
logger.Info("ACCESS END session=%s resource=%d proto=udp src=%s dst=%s started=%s ended=%s duration=%s",
sessionID, session.ResourceID, session.SourceAddr, session.DestAddr,
session.StartedAt.Format(time.RFC3339), now.Format(time.RFC3339), duration)
}
if shouldFlush {
al.flush()
}
}
// backgroundLoop handles periodic flushing and stale session reaping.
func (al *AccessLogger) backgroundLoop() {
defer close(al.flushDone)
flushTicker := time.NewTicker(flushInterval)
defer flushTicker.Stop()
reapTicker := time.NewTicker(30 * time.Second)
defer reapTicker.Stop()
for {
select {
case <-al.stopCh:
return
case <-flushTicker.C:
al.flush()
case <-reapTicker.C:
al.reapStaleSessions()
}
}
}
// reapStaleSessions cleans up UDP sessions that were not properly ended.
func (al *AccessLogger) reapStaleSessions() {
al.mu.Lock()
defer al.mu.Unlock()
staleThreshold := time.Now().Add(-5 * time.Minute)
for key, session := range al.udpSessions {
if session.StartedAt.Before(staleThreshold) && session.EndedAt.IsZero() {
now := time.Now()
session.EndedAt = now
duration := now.Sub(session.StartedAt)
logger.Info("ACCESS END (reaped) session=%s resource=%d proto=udp src=%s dst=%s started=%s ended=%s duration=%s",
session.SessionID, session.ResourceID, session.SourceAddr, session.DestAddr,
session.StartedAt.Format(time.RFC3339), now.Format(time.RFC3339), duration)
al.completedSessions = append(al.completedSessions, session)
delete(al.sessions, session.SessionID)
delete(al.udpSessions, key)
}
}
}
// extractIP strips the port from an address string and returns just the IP.
// If the address has no port component it is returned as-is.
func extractIP(addr string) string {
host, _, err := net.SplitHostPort(addr)
if err != nil {
// Might already be a bare IP
return addr
}
return host
}
// consolidateSessions takes a slice of completed sessions and merges bursts of
// short-lived connections from the same source IP to the same destination into
// single higher-level session entries.
//
// The algorithm:
// 1. Group sessions by (sourceIP, destAddr, protocol, resourceID).
// 2. Within each group, sort by StartedAt.
// 3. Walk through the sorted list and merge consecutive sessions whose gap
// (previous EndedAt → next StartedAt) is ≤ sessionGapThreshold.
// 4. For merged sessions the earliest StartedAt and latest EndedAt are kept,
// bytes are summed, and ConnectionCount records how many raw connections
// were folded in. If the merged connections used more than one source port,
// SourceAddr is set to just the IP (port omitted).
// 5. Groups with fewer than minConnectionsToConsolidate members are passed
// through unmodified.
func consolidateSessions(sessions []*AccessSession) []*AccessSession {
if len(sessions) <= 1 {
return sessions
}
// Group sessions by consolidation key
groups := make(map[consolidationKey][]*AccessSession)
for _, s := range sessions {
key := consolidationKey{
sourceIP: extractIP(s.SourceAddr),
destAddr: s.DestAddr,
protocol: s.Protocol,
resourceID: s.ResourceID,
}
groups[key] = append(groups[key], s)
}
result := make([]*AccessSession, 0, len(sessions))
for key, group := range groups {
// Small groups don't need consolidation
if len(group) < minConnectionsToConsolidate {
result = append(result, group...)
continue
}
// Sort the group by start time so we can detect gaps
sort.Slice(group, func(i, j int) bool {
return group[i].StartedAt.Before(group[j].StartedAt)
})
// Walk through and merge runs that are within the gap threshold
var merged []*AccessSession
cur := cloneSession(group[0])
cur.ConnectionCount = 1
sourcePorts := make(map[string]struct{})
sourcePorts[cur.SourceAddr] = struct{}{}
for i := 1; i < len(group); i++ {
s := group[i]
// Determine the gap: from the latest end time we've seen so far to the
// start of the next connection.
gapRef := cur.EndedAt
if gapRef.IsZero() {
gapRef = cur.StartedAt
}
gap := s.StartedAt.Sub(gapRef)
if gap <= sessionGapThreshold {
// Merge into the current consolidated session
cur.ConnectionCount++
cur.BytesTx += s.BytesTx
cur.BytesRx += s.BytesRx
sourcePorts[s.SourceAddr] = struct{}{}
// Extend EndedAt to the latest time
endTime := s.EndedAt
if endTime.IsZero() {
endTime = s.StartedAt
}
if endTime.After(cur.EndedAt) {
cur.EndedAt = endTime
}
} else {
// Gap exceeded — finalize the current session and start a new one
finalizeMergedSourceAddr(cur, key.sourceIP, sourcePorts)
merged = append(merged, cur)
cur = cloneSession(s)
cur.ConnectionCount = 1
sourcePorts = make(map[string]struct{})
sourcePorts[s.SourceAddr] = struct{}{}
}
}
// Finalize the last accumulated session
finalizeMergedSourceAddr(cur, key.sourceIP, sourcePorts)
merged = append(merged, cur)
result = append(result, merged...)
}
return result
}
// cloneSession creates a shallow copy of an AccessSession.
func cloneSession(s *AccessSession) *AccessSession {
cp := *s
return &cp
}
// finalizeMergedSourceAddr sets the SourceAddr on a consolidated session.
// If multiple distinct source addresses (ports) were seen, the port is
// stripped and only the IP is kept so the log isn't misleading.
func finalizeMergedSourceAddr(s *AccessSession, sourceIP string, ports map[string]struct{}) {
if len(ports) > 1 {
// Multiple source ports — just report the IP
s.SourceAddr = sourceIP
}
// Otherwise keep the original SourceAddr which already has ip:port
}
// flush drains the completed sessions buffer, consolidates bursts of
// short-lived connections, compresses with zlib, and sends via the SendFunc.
func (al *AccessLogger) flush() {
al.mu.Lock()
if len(al.completedSessions) == 0 {
al.mu.Unlock()
return
}
batch := al.completedSessions
al.completedSessions = make([]*AccessSession, 0)
sendFn := al.sendFn
al.mu.Unlock()
if sendFn == nil {
logger.Debug("Access logger: no send function configured, discarding %d sessions", len(batch))
return
}
// Consolidate bursts of short-lived connections into higher-level sessions
originalCount := len(batch)
batch = consolidateSessions(batch)
if len(batch) != originalCount {
logger.Info("Access logger: consolidated %d raw connections into %d sessions", originalCount, len(batch))
}
compressed, err := compressSessions(batch)
if err != nil {
logger.Error("Access logger: failed to compress %d sessions: %v", len(batch), err)
return
}
if err := sendFn(compressed); err != nil {
logger.Error("Access logger: failed to send %d sessions: %v", len(batch), err)
// Re-queue the batch so we don't lose data
al.mu.Lock()
al.completedSessions = append(batch, al.completedSessions...)
// Cap re-queued data to prevent unbounded growth if server is unreachable
if len(al.completedSessions) > maxBufferedSessions*5 {
dropped := len(al.completedSessions) - maxBufferedSessions*5
al.completedSessions = al.completedSessions[:maxBufferedSessions*5]
logger.Warn("Access logger: buffer overflow, dropped %d oldest sessions", dropped)
}
al.mu.Unlock()
return
}
logger.Info("Access logger: sent %d sessions to server", len(batch))
}
// compressSessions JSON-encodes the sessions, compresses with zlib, and returns
// a base64-encoded string suitable for embedding in a JSON message.
func compressSessions(sessions []*AccessSession) (string, error) {
jsonData, err := json.Marshal(sessions)
if err != nil {
return "", err
}
var buf bytes.Buffer
w, err := zlib.NewWriterLevel(&buf, zlib.BestCompression)
if err != nil {
return "", err
}
if _, err := w.Write(jsonData); err != nil {
w.Close()
return "", err
}
if err := w.Close(); err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(buf.Bytes()), nil
}
// Close shuts down the background loop, ends all active sessions,
// and performs one final flush to send everything to the server.
func (al *AccessLogger) Close() {
// Signal the background loop to stop
select {
case <-al.stopCh:
// Already closed
return
default:
close(al.stopCh)
}
// Wait for the background loop to exit so we don't race on flush
<-al.flushDone
al.mu.Lock()
now := time.Now()
// End all active sessions and move them to the completed buffer
for _, session := range al.sessions {
if session.EndedAt.IsZero() {
session.EndedAt = now
duration := now.Sub(session.StartedAt)
logger.Info("ACCESS END (shutdown) session=%s resource=%d proto=%s src=%s dst=%s started=%s ended=%s duration=%s",
session.SessionID, session.ResourceID, session.Protocol, session.SourceAddr, session.DestAddr,
session.StartedAt.Format(time.RFC3339), now.Format(time.RFC3339), duration)
al.completedSessions = append(al.completedSessions, session)
}
}
al.sessions = make(map[string]*AccessSession)
al.udpSessions = make(map[udpSessionKey]*AccessSession)
al.mu.Unlock()
// Final flush to send all remaining sessions to the server
al.flush()
}

View File

@@ -1,811 +0,0 @@
package netstack2
import (
"testing"
"time"
)
func TestExtractIP(t *testing.T) {
tests := []struct {
name string
addr string
expected string
}{
{"ipv4 with port", "192.168.1.1:12345", "192.168.1.1"},
{"ipv4 without port", "192.168.1.1", "192.168.1.1"},
{"ipv6 with port", "[::1]:12345", "::1"},
{"ipv6 without port", "::1", "::1"},
{"empty string", "", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := extractIP(tt.addr)
if result != tt.expected {
t.Errorf("extractIP(%q) = %q, want %q", tt.addr, result, tt.expected)
}
})
}
}
func TestConsolidateSessions_Empty(t *testing.T) {
result := consolidateSessions(nil)
if result != nil {
t.Errorf("expected nil, got %v", result)
}
result = consolidateSessions([]*AccessSession{})
if len(result) != 0 {
t.Errorf("expected empty slice, got %d items", len(result))
}
}
func TestConsolidateSessions_SingleSession(t *testing.T) {
now := time.Now()
sessions := []*AccessSession{
{
SessionID: "abc123",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(1 * time.Second),
},
}
result := consolidateSessions(sessions)
if len(result) != 1 {
t.Fatalf("expected 1 session, got %d", len(result))
}
if result[0].SourceAddr != "10.0.0.1:5000" {
t.Errorf("expected source addr preserved, got %q", result[0].SourceAddr)
}
}
func TestConsolidateSessions_MergesBurstFromSameSourceIP(t *testing.T) {
now := time.Now()
sessions := []*AccessSession{
{
SessionID: "s1",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(100 * time.Millisecond),
BytesTx: 100,
BytesRx: 200,
},
{
SessionID: "s2",
ResourceID: 1,
SourceAddr: "10.0.0.1:5001",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(200 * time.Millisecond),
EndedAt: now.Add(300 * time.Millisecond),
BytesTx: 150,
BytesRx: 250,
},
{
SessionID: "s3",
ResourceID: 1,
SourceAddr: "10.0.0.1:5002",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(400 * time.Millisecond),
EndedAt: now.Add(500 * time.Millisecond),
BytesTx: 50,
BytesRx: 75,
},
}
result := consolidateSessions(sessions)
if len(result) != 1 {
t.Fatalf("expected 1 consolidated session, got %d", len(result))
}
s := result[0]
if s.ConnectionCount != 3 {
t.Errorf("expected ConnectionCount=3, got %d", s.ConnectionCount)
}
if s.SourceAddr != "10.0.0.1" {
t.Errorf("expected source addr to be IP only (multiple ports), got %q", s.SourceAddr)
}
if s.DestAddr != "192.168.1.100:443" {
t.Errorf("expected dest addr preserved, got %q", s.DestAddr)
}
if s.StartedAt != now {
t.Errorf("expected StartedAt to be earliest time")
}
if s.EndedAt != now.Add(500*time.Millisecond) {
t.Errorf("expected EndedAt to be latest time")
}
expectedTx := int64(300)
expectedRx := int64(525)
if s.BytesTx != expectedTx {
t.Errorf("expected BytesTx=%d, got %d", expectedTx, s.BytesTx)
}
if s.BytesRx != expectedRx {
t.Errorf("expected BytesRx=%d, got %d", expectedRx, s.BytesRx)
}
}
func TestConsolidateSessions_SameSourcePortPreserved(t *testing.T) {
now := time.Now()
sessions := []*AccessSession{
{
SessionID: "s1",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(100 * time.Millisecond),
},
{
SessionID: "s2",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(200 * time.Millisecond),
EndedAt: now.Add(300 * time.Millisecond),
},
}
result := consolidateSessions(sessions)
if len(result) != 1 {
t.Fatalf("expected 1 session, got %d", len(result))
}
if result[0].SourceAddr != "10.0.0.1:5000" {
t.Errorf("expected source addr with port preserved when all ports are the same, got %q", result[0].SourceAddr)
}
if result[0].ConnectionCount != 2 {
t.Errorf("expected ConnectionCount=2, got %d", result[0].ConnectionCount)
}
}
func TestConsolidateSessions_GapSplitsSessions(t *testing.T) {
now := time.Now()
// First burst
sessions := []*AccessSession{
{
SessionID: "s1",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(100 * time.Millisecond),
},
{
SessionID: "s2",
ResourceID: 1,
SourceAddr: "10.0.0.1:5001",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(200 * time.Millisecond),
EndedAt: now.Add(300 * time.Millisecond),
},
// Big gap here (10 seconds)
{
SessionID: "s3",
ResourceID: 1,
SourceAddr: "10.0.0.1:5002",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(10 * time.Second),
EndedAt: now.Add(10*time.Second + 100*time.Millisecond),
},
{
SessionID: "s4",
ResourceID: 1,
SourceAddr: "10.0.0.1:5003",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(10*time.Second + 200*time.Millisecond),
EndedAt: now.Add(10*time.Second + 300*time.Millisecond),
},
}
result := consolidateSessions(sessions)
if len(result) != 2 {
t.Fatalf("expected 2 consolidated sessions (gap split), got %d", len(result))
}
// Find the sessions by their start time
var first, second *AccessSession
for _, s := range result {
if s.StartedAt.Equal(now) {
first = s
} else {
second = s
}
}
if first == nil || second == nil {
t.Fatal("could not find both consolidated sessions")
}
if first.ConnectionCount != 2 {
t.Errorf("first burst: expected ConnectionCount=2, got %d", first.ConnectionCount)
}
if second.ConnectionCount != 2 {
t.Errorf("second burst: expected ConnectionCount=2, got %d", second.ConnectionCount)
}
}
func TestConsolidateSessions_DifferentDestinationsNotMerged(t *testing.T) {
now := time.Now()
sessions := []*AccessSession{
{
SessionID: "s1",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(100 * time.Millisecond),
},
{
SessionID: "s2",
ResourceID: 1,
SourceAddr: "10.0.0.1:5001",
DestAddr: "192.168.1.100:8080",
Protocol: "tcp",
StartedAt: now.Add(200 * time.Millisecond),
EndedAt: now.Add(300 * time.Millisecond),
},
}
result := consolidateSessions(sessions)
// Each goes to a different dest port so they should not be merged
if len(result) != 2 {
t.Fatalf("expected 2 sessions (different destinations), got %d", len(result))
}
}
func TestConsolidateSessions_DifferentProtocolsNotMerged(t *testing.T) {
now := time.Now()
sessions := []*AccessSession{
{
SessionID: "s1",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(100 * time.Millisecond),
},
{
SessionID: "s2",
ResourceID: 1,
SourceAddr: "10.0.0.1:5001",
DestAddr: "192.168.1.100:443",
Protocol: "udp",
StartedAt: now.Add(200 * time.Millisecond),
EndedAt: now.Add(300 * time.Millisecond),
},
}
result := consolidateSessions(sessions)
if len(result) != 2 {
t.Fatalf("expected 2 sessions (different protocols), got %d", len(result))
}
}
func TestConsolidateSessions_DifferentResourceIDsNotMerged(t *testing.T) {
now := time.Now()
sessions := []*AccessSession{
{
SessionID: "s1",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(100 * time.Millisecond),
},
{
SessionID: "s2",
ResourceID: 2,
SourceAddr: "10.0.0.1:5001",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(200 * time.Millisecond),
EndedAt: now.Add(300 * time.Millisecond),
},
}
result := consolidateSessions(sessions)
if len(result) != 2 {
t.Fatalf("expected 2 sessions (different resource IDs), got %d", len(result))
}
}
func TestConsolidateSessions_DifferentSourceIPsNotMerged(t *testing.T) {
now := time.Now()
sessions := []*AccessSession{
{
SessionID: "s1",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(100 * time.Millisecond),
},
{
SessionID: "s2",
ResourceID: 1,
SourceAddr: "10.0.0.2:5001",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(200 * time.Millisecond),
EndedAt: now.Add(300 * time.Millisecond),
},
}
result := consolidateSessions(sessions)
if len(result) != 2 {
t.Fatalf("expected 2 sessions (different source IPs), got %d", len(result))
}
}
func TestConsolidateSessions_OutOfOrderInput(t *testing.T) {
now := time.Now()
// Provide sessions out of chronological order to verify sorting
sessions := []*AccessSession{
{
SessionID: "s3",
ResourceID: 1,
SourceAddr: "10.0.0.1:5002",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(400 * time.Millisecond),
EndedAt: now.Add(500 * time.Millisecond),
BytesTx: 30,
},
{
SessionID: "s1",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(100 * time.Millisecond),
BytesTx: 10,
},
{
SessionID: "s2",
ResourceID: 1,
SourceAddr: "10.0.0.1:5001",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(200 * time.Millisecond),
EndedAt: now.Add(300 * time.Millisecond),
BytesTx: 20,
},
}
result := consolidateSessions(sessions)
if len(result) != 1 {
t.Fatalf("expected 1 consolidated session, got %d", len(result))
}
s := result[0]
if s.ConnectionCount != 3 {
t.Errorf("expected ConnectionCount=3, got %d", s.ConnectionCount)
}
if s.StartedAt != now {
t.Errorf("expected StartedAt to be earliest time")
}
if s.EndedAt != now.Add(500*time.Millisecond) {
t.Errorf("expected EndedAt to be latest time")
}
if s.BytesTx != 60 {
t.Errorf("expected BytesTx=60, got %d", s.BytesTx)
}
}
func TestConsolidateSessions_ExactlyAtGapThreshold(t *testing.T) {
now := time.Now()
sessions := []*AccessSession{
{
SessionID: "s1",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(100 * time.Millisecond),
},
{
// Starts exactly sessionGapThreshold after s1 ends — should still merge
SessionID: "s2",
ResourceID: 1,
SourceAddr: "10.0.0.1:5001",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(100*time.Millisecond + sessionGapThreshold),
EndedAt: now.Add(100*time.Millisecond + sessionGapThreshold + 50*time.Millisecond),
},
}
result := consolidateSessions(sessions)
if len(result) != 1 {
t.Fatalf("expected 1 session (gap exactly at threshold merges), got %d", len(result))
}
if result[0].ConnectionCount != 2 {
t.Errorf("expected ConnectionCount=2, got %d", result[0].ConnectionCount)
}
}
func TestConsolidateSessions_JustOverGapThreshold(t *testing.T) {
now := time.Now()
sessions := []*AccessSession{
{
SessionID: "s1",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(100 * time.Millisecond),
},
{
// Starts 1ms over the gap threshold after s1 ends — should split
SessionID: "s2",
ResourceID: 1,
SourceAddr: "10.0.0.1:5001",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(100*time.Millisecond + sessionGapThreshold + 1*time.Millisecond),
EndedAt: now.Add(100*time.Millisecond + sessionGapThreshold + 50*time.Millisecond),
},
}
result := consolidateSessions(sessions)
if len(result) != 2 {
t.Fatalf("expected 2 sessions (gap just over threshold splits), got %d", len(result))
}
}
func TestConsolidateSessions_UDPSessions(t *testing.T) {
now := time.Now()
sessions := []*AccessSession{
{
SessionID: "u1",
ResourceID: 5,
SourceAddr: "10.0.0.1:6000",
DestAddr: "192.168.1.100:53",
Protocol: "udp",
StartedAt: now,
EndedAt: now.Add(50 * time.Millisecond),
BytesTx: 64,
BytesRx: 512,
},
{
SessionID: "u2",
ResourceID: 5,
SourceAddr: "10.0.0.1:6001",
DestAddr: "192.168.1.100:53",
Protocol: "udp",
StartedAt: now.Add(100 * time.Millisecond),
EndedAt: now.Add(150 * time.Millisecond),
BytesTx: 64,
BytesRx: 256,
},
{
SessionID: "u3",
ResourceID: 5,
SourceAddr: "10.0.0.1:6002",
DestAddr: "192.168.1.100:53",
Protocol: "udp",
StartedAt: now.Add(200 * time.Millisecond),
EndedAt: now.Add(250 * time.Millisecond),
BytesTx: 64,
BytesRx: 128,
},
}
result := consolidateSessions(sessions)
if len(result) != 1 {
t.Fatalf("expected 1 consolidated UDP session, got %d", len(result))
}
s := result[0]
if s.Protocol != "udp" {
t.Errorf("expected protocol=udp, got %q", s.Protocol)
}
if s.ConnectionCount != 3 {
t.Errorf("expected ConnectionCount=3, got %d", s.ConnectionCount)
}
if s.SourceAddr != "10.0.0.1" {
t.Errorf("expected source addr to be IP only, got %q", s.SourceAddr)
}
if s.BytesTx != 192 {
t.Errorf("expected BytesTx=192, got %d", s.BytesTx)
}
if s.BytesRx != 896 {
t.Errorf("expected BytesRx=896, got %d", s.BytesRx)
}
}
func TestConsolidateSessions_MixedGroupsSomeConsolidatedSomeNot(t *testing.T) {
now := time.Now()
sessions := []*AccessSession{
// Group 1: 3 connections to :443 from same IP — should consolidate
{
SessionID: "s1",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(100 * time.Millisecond),
},
{
SessionID: "s2",
ResourceID: 1,
SourceAddr: "10.0.0.1:5001",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(200 * time.Millisecond),
EndedAt: now.Add(300 * time.Millisecond),
},
{
SessionID: "s3",
ResourceID: 1,
SourceAddr: "10.0.0.1:5002",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(400 * time.Millisecond),
EndedAt: now.Add(500 * time.Millisecond),
},
// Group 2: 1 connection to :8080 from different IP — should pass through
{
SessionID: "s4",
ResourceID: 2,
SourceAddr: "10.0.0.2:6000",
DestAddr: "192.168.1.200:8080",
Protocol: "tcp",
StartedAt: now.Add(1 * time.Second),
EndedAt: now.Add(2 * time.Second),
},
}
result := consolidateSessions(sessions)
if len(result) != 2 {
t.Fatalf("expected 2 sessions total, got %d", len(result))
}
var consolidated, passthrough *AccessSession
for _, s := range result {
if s.ConnectionCount > 1 {
consolidated = s
} else {
passthrough = s
}
}
if consolidated == nil {
t.Fatal("expected a consolidated session")
}
if consolidated.ConnectionCount != 3 {
t.Errorf("consolidated: expected ConnectionCount=3, got %d", consolidated.ConnectionCount)
}
if passthrough == nil {
t.Fatal("expected a passthrough session")
}
if passthrough.SessionID != "s4" {
t.Errorf("passthrough: expected session s4, got %s", passthrough.SessionID)
}
}
func TestConsolidateSessions_OverlappingConnections(t *testing.T) {
now := time.Now()
// Connections that overlap in time (not sequential)
sessions := []*AccessSession{
{
SessionID: "s1",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(5 * time.Second),
BytesTx: 100,
},
{
SessionID: "s2",
ResourceID: 1,
SourceAddr: "10.0.0.1:5001",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(1 * time.Second),
EndedAt: now.Add(3 * time.Second),
BytesTx: 200,
},
{
SessionID: "s3",
ResourceID: 1,
SourceAddr: "10.0.0.1:5002",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(2 * time.Second),
EndedAt: now.Add(6 * time.Second),
BytesTx: 300,
},
}
result := consolidateSessions(sessions)
if len(result) != 1 {
t.Fatalf("expected 1 consolidated session, got %d", len(result))
}
s := result[0]
if s.ConnectionCount != 3 {
t.Errorf("expected ConnectionCount=3, got %d", s.ConnectionCount)
}
if s.StartedAt != now {
t.Error("expected StartedAt to be earliest")
}
if s.EndedAt != now.Add(6*time.Second) {
t.Error("expected EndedAt to be the latest end time")
}
if s.BytesTx != 600 {
t.Errorf("expected BytesTx=600, got %d", s.BytesTx)
}
}
func TestConsolidateSessions_DoesNotMutateOriginals(t *testing.T) {
now := time.Now()
s1 := &AccessSession{
SessionID: "s1",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(100 * time.Millisecond),
BytesTx: 100,
}
s2 := &AccessSession{
SessionID: "s2",
ResourceID: 1,
SourceAddr: "10.0.0.1:5001",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(200 * time.Millisecond),
EndedAt: now.Add(300 * time.Millisecond),
BytesTx: 200,
}
// Save original values
origS1Addr := s1.SourceAddr
origS1Bytes := s1.BytesTx
origS2Addr := s2.SourceAddr
origS2Bytes := s2.BytesTx
_ = consolidateSessions([]*AccessSession{s1, s2})
if s1.SourceAddr != origS1Addr {
t.Errorf("s1.SourceAddr was mutated: %q -> %q", origS1Addr, s1.SourceAddr)
}
if s1.BytesTx != origS1Bytes {
t.Errorf("s1.BytesTx was mutated: %d -> %d", origS1Bytes, s1.BytesTx)
}
if s2.SourceAddr != origS2Addr {
t.Errorf("s2.SourceAddr was mutated: %q -> %q", origS2Addr, s2.SourceAddr)
}
if s2.BytesTx != origS2Bytes {
t.Errorf("s2.BytesTx was mutated: %d -> %d", origS2Bytes, s2.BytesTx)
}
}
func TestConsolidateSessions_ThreeBurstsWithGaps(t *testing.T) {
now := time.Now()
sessions := make([]*AccessSession, 0, 9)
// Burst 1: 3 connections at t=0
for i := 0; i < 3; i++ {
sessions = append(sessions, &AccessSession{
SessionID: generateSessionID(),
ResourceID: 1,
SourceAddr: "10.0.0.1:" + string(rune('A'+i)),
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(time.Duration(i*100) * time.Millisecond),
EndedAt: now.Add(time.Duration(i*100+50) * time.Millisecond),
})
}
// Burst 2: 3 connections at t=20s (well past the 5s gap)
for i := 0; i < 3; i++ {
sessions = append(sessions, &AccessSession{
SessionID: generateSessionID(),
ResourceID: 1,
SourceAddr: "10.0.0.1:" + string(rune('D'+i)),
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(20*time.Second + time.Duration(i*100)*time.Millisecond),
EndedAt: now.Add(20*time.Second + time.Duration(i*100+50)*time.Millisecond),
})
}
// Burst 3: 3 connections at t=40s
for i := 0; i < 3; i++ {
sessions = append(sessions, &AccessSession{
SessionID: generateSessionID(),
ResourceID: 1,
SourceAddr: "10.0.0.1:" + string(rune('G'+i)),
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(40*time.Second + time.Duration(i*100)*time.Millisecond),
EndedAt: now.Add(40*time.Second + time.Duration(i*100+50)*time.Millisecond),
})
}
result := consolidateSessions(sessions)
if len(result) != 3 {
t.Fatalf("expected 3 consolidated sessions (3 bursts), got %d", len(result))
}
for _, s := range result {
if s.ConnectionCount != 3 {
t.Errorf("expected each burst to have ConnectionCount=3, got %d (started=%v)", s.ConnectionCount, s.StartedAt)
}
}
}
func TestFinalizeMergedSourceAddr(t *testing.T) {
s := &AccessSession{SourceAddr: "10.0.0.1:5000"}
ports := map[string]struct{}{"10.0.0.1:5000": {}}
finalizeMergedSourceAddr(s, "10.0.0.1", ports)
if s.SourceAddr != "10.0.0.1:5000" {
t.Errorf("single port: expected addr preserved, got %q", s.SourceAddr)
}
s2 := &AccessSession{SourceAddr: "10.0.0.1:5000"}
ports2 := map[string]struct{}{"10.0.0.1:5000": {}, "10.0.0.1:5001": {}}
finalizeMergedSourceAddr(s2, "10.0.0.1", ports2)
if s2.SourceAddr != "10.0.0.1" {
t.Errorf("multiple ports: expected IP only, got %q", s2.SourceAddr)
}
}
func TestCloneSession(t *testing.T) {
original := &AccessSession{
SessionID: "test",
ResourceID: 42,
SourceAddr: "1.2.3.4:100",
DestAddr: "5.6.7.8:443",
Protocol: "tcp",
BytesTx: 999,
}
clone := cloneSession(original)
if clone == original {
t.Error("clone should be a different pointer")
}
if clone.SessionID != original.SessionID {
t.Error("clone should have same SessionID")
}
// Mutating clone should not affect original
clone.BytesTx = 0
clone.SourceAddr = "changed"
if original.BytesTx != 999 {
t.Error("mutating clone affected original BytesTx")
}
if original.SourceAddr != "1.2.3.4:100" {
t.Error("mutating clone affected original SourceAddr")
}
}

View File

@@ -158,18 +158,6 @@ func (h *TCPHandler) handleTCPConn(netstackConn *gonet.TCPConn, id stack.Transpo
targetAddr := fmt.Sprintf("%s:%d", actualDstIP, dstPort) targetAddr := fmt.Sprintf("%s:%d", actualDstIP, dstPort)
// Look up resource ID and start access session if applicable
var accessSessionID string
if h.proxyHandler != nil {
resourceId := h.proxyHandler.LookupResourceId(srcIP, dstIP, dstPort, uint8(tcp.ProtocolNumber))
if resourceId != 0 {
if al := h.proxyHandler.GetAccessLogger(); al != nil {
srcAddr := fmt.Sprintf("%s:%d", srcIP, srcPort)
accessSessionID = al.StartTCPSession(resourceId, srcAddr, targetAddr)
}
}
}
// Create context with timeout for connection establishment // Create context with timeout for connection establishment
ctx, cancel := context.WithTimeout(context.Background(), tcpConnectTimeout) ctx, cancel := context.WithTimeout(context.Background(), tcpConnectTimeout)
defer cancel() defer cancel()
@@ -179,26 +167,11 @@ func (h *TCPHandler) handleTCPConn(netstackConn *gonet.TCPConn, id stack.Transpo
targetConn, err := d.DialContext(ctx, "tcp", targetAddr) targetConn, err := d.DialContext(ctx, "tcp", targetAddr)
if err != nil { if err != nil {
logger.Info("TCP Forwarder: Failed to connect to %s: %v", targetAddr, err) logger.Info("TCP Forwarder: Failed to connect to %s: %v", targetAddr, err)
// End access session on connection failure
if accessSessionID != "" {
if al := h.proxyHandler.GetAccessLogger(); al != nil {
al.EndTCPSession(accessSessionID)
}
}
// Connection failed, netstack will handle RST // Connection failed, netstack will handle RST
return return
} }
defer targetConn.Close() defer targetConn.Close()
// End access session when connection closes
if accessSessionID != "" {
defer func() {
if al := h.proxyHandler.GetAccessLogger(); al != nil {
al.EndTCPSession(accessSessionID)
}
}()
}
logger.Info("TCP Forwarder: Successfully connected to %s, starting bidirectional copy", targetAddr) logger.Info("TCP Forwarder: Successfully connected to %s, starting bidirectional copy", targetAddr)
// Bidirectional copy between netstack and target // Bidirectional copy between netstack and target
@@ -307,27 +280,6 @@ func (h *UDPHandler) handleUDPConn(netstackConn *gonet.UDPConn, id stack.Transpo
targetAddr := fmt.Sprintf("%s:%d", actualDstIP, dstPort) targetAddr := fmt.Sprintf("%s:%d", actualDstIP, dstPort)
// Look up resource ID and start access session if applicable
var accessSessionID string
if h.proxyHandler != nil {
resourceId := h.proxyHandler.LookupResourceId(srcIP, dstIP, dstPort, uint8(udp.ProtocolNumber))
if resourceId != 0 {
if al := h.proxyHandler.GetAccessLogger(); al != nil {
srcAddr := fmt.Sprintf("%s:%d", srcIP, srcPort)
accessSessionID = al.TrackUDPSession(resourceId, srcAddr, targetAddr)
}
}
}
// End access session when UDP handler returns (timeout or error)
if accessSessionID != "" {
defer func() {
if al := h.proxyHandler.GetAccessLogger(); al != nil {
al.EndUDPSession(accessSessionID)
}
}()
}
// Resolve target address // Resolve target address
remoteUDPAddr, err := net.ResolveUDPAddr("udp", targetAddr) remoteUDPAddr, err := net.ResolveUDPAddr("udp", targetAddr)
if err != nil { if err != nil {

View File

@@ -22,12 +22,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
) )
const (
// udpAccessSessionTimeout is how long a UDP access session stays alive without traffic
// before being considered ended by the access logger
udpAccessSessionTimeout = 120 * time.Second
)
// PortRange represents an allowed range of ports (inclusive) with optional protocol filtering // PortRange represents an allowed range of ports (inclusive) with optional protocol filtering
// Protocol can be "tcp", "udp", or "" (empty string means both protocols) // Protocol can be "tcp", "udp", or "" (empty string means both protocols)
type PortRange struct { type PortRange struct {
@@ -52,24 +46,6 @@ type SubnetRule struct {
DisableIcmp bool // If true, ICMP traffic is blocked for this subnet DisableIcmp bool // If true, ICMP traffic is blocked for this subnet
RewriteTo string // Optional rewrite address for DNAT - can be IP/CIDR or domain name RewriteTo string // Optional rewrite address for DNAT - can be IP/CIDR or domain name
PortRanges []PortRange // empty slice means all ports allowed PortRanges []PortRange // empty slice means all ports allowed
ResourceId int // Optional resource ID from the server for access logging
}
// GetAllRules returns a copy of all subnet rules
func (sl *SubnetLookup) GetAllRules() []SubnetRule {
sl.mu.RLock()
defer sl.mu.RUnlock()
var rules []SubnetRule
for _, destTriePtr := range sl.sourceTrie.All() {
if destTriePtr == nil {
continue
}
for _, rule := range destTriePtr.rules {
rules = append(rules, *rule)
}
}
return rules
} }
// connKey uniquely identifies a connection for NAT tracking // connKey uniquely identifies a connection for NAT tracking
@@ -118,12 +94,10 @@ type ProxyHandler struct {
natTable map[connKey]*natState natTable map[connKey]*natState
reverseNatTable map[reverseConnKey]*natState // Reverse lookup map for O(1) reply packet NAT reverseNatTable map[reverseConnKey]*natState // Reverse lookup map for O(1) reply packet NAT
destRewriteTable map[destKey]netip.Addr // Maps original dest to rewritten dest for handler lookups destRewriteTable map[destKey]netip.Addr // Maps original dest to rewritten dest for handler lookups
resourceTable map[destKey]int // Maps connection key to resource ID for access logging
natMu sync.RWMutex natMu sync.RWMutex
enabled bool enabled bool
icmpReplies chan []byte // Channel for ICMP reply packets to be sent back through the tunnel icmpReplies chan []byte // Channel for ICMP reply packets to be sent back through the tunnel
notifiable channel.Notification // Notification handler for triggering reads notifiable channel.Notification // Notification handler for triggering reads
accessLogger *AccessLogger // Access logger for tracking sessions
} }
// ProxyHandlerOptions configures the proxy handler // ProxyHandlerOptions configures the proxy handler
@@ -146,9 +120,7 @@ func NewProxyHandler(options ProxyHandlerOptions) (*ProxyHandler, error) {
natTable: make(map[connKey]*natState), natTable: make(map[connKey]*natState),
reverseNatTable: make(map[reverseConnKey]*natState), reverseNatTable: make(map[reverseConnKey]*natState),
destRewriteTable: make(map[destKey]netip.Addr), destRewriteTable: make(map[destKey]netip.Addr),
resourceTable: make(map[destKey]int),
icmpReplies: make(chan []byte, 256), // Buffer for ICMP reply packets icmpReplies: make(chan []byte, 256), // Buffer for ICMP reply packets
accessLogger: NewAccessLogger(udpAccessSessionTimeout),
proxyEp: channel.New(1024, uint32(options.MTU), ""), proxyEp: channel.New(1024, uint32(options.MTU), ""),
proxyStack: stack.New(stack.Options{ proxyStack: stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ NetworkProtocols: []stack.NetworkProtocolFactory{
@@ -213,11 +185,11 @@ func NewProxyHandler(options ProxyHandlerOptions) (*ProxyHandler, error) {
// destPrefix: The IP prefix of the destination // destPrefix: The IP prefix of the destination
// rewriteTo: Optional address to rewrite destination to - can be IP/CIDR or domain name // rewriteTo: Optional address to rewrite destination to - can be IP/CIDR or domain name
// If portRanges is nil or empty, all ports are allowed for this subnet // If portRanges is nil or empty, all ports are allowed for this subnet
func (p *ProxyHandler) AddSubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool, resourceId int) { func (p *ProxyHandler) AddSubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool) {
if p == nil || !p.enabled { if p == nil || !p.enabled {
return return
} }
p.subnetLookup.AddSubnet(sourcePrefix, destPrefix, rewriteTo, portRanges, disableIcmp, resourceId) p.subnetLookup.AddSubnet(sourcePrefix, destPrefix, rewriteTo, portRanges, disableIcmp)
} }
// RemoveSubnetRule removes a subnet from the proxy handler // RemoveSubnetRule removes a subnet from the proxy handler
@@ -228,51 +200,6 @@ func (p *ProxyHandler) RemoveSubnetRule(sourcePrefix, destPrefix netip.Prefix) {
p.subnetLookup.RemoveSubnet(sourcePrefix, destPrefix) p.subnetLookup.RemoveSubnet(sourcePrefix, destPrefix)
} }
// GetAllRules returns all subnet rules from the proxy handler
func (p *ProxyHandler) GetAllRules() []SubnetRule {
if p == nil || !p.enabled {
return nil
}
return p.subnetLookup.GetAllRules()
}
// LookupResourceId looks up the resource ID for a connection
// Returns 0 if no resource ID is associated with this connection
func (p *ProxyHandler) LookupResourceId(srcIP, dstIP string, dstPort uint16, proto uint8) int {
if p == nil || !p.enabled {
return 0
}
key := destKey{
srcIP: srcIP,
dstIP: dstIP,
dstPort: dstPort,
proto: proto,
}
p.natMu.RLock()
defer p.natMu.RUnlock()
return p.resourceTable[key]
}
// GetAccessLogger returns the access logger for session tracking
func (p *ProxyHandler) GetAccessLogger() *AccessLogger {
if p == nil {
return nil
}
return p.accessLogger
}
// SetAccessLogSender configures the function used to send compressed access log
// batches to the server. This should be called once the websocket client is available.
func (p *ProxyHandler) SetAccessLogSender(fn SendFunc) {
if p == nil || !p.enabled || p.accessLogger == nil {
return
}
p.accessLogger.SetSendFunc(fn)
}
// LookupDestinationRewrite looks up the rewritten destination for a connection // LookupDestinationRewrite looks up the rewritten destination for a connection
// This is used by TCP/UDP handlers to find the actual target address // This is used by TCP/UDP handlers to find the actual target address
func (p *ProxyHandler) LookupDestinationRewrite(srcIP, dstIP string, dstPort uint16, proto uint8) (netip.Addr, bool) { func (p *ProxyHandler) LookupDestinationRewrite(srcIP, dstIP string, dstPort uint16, proto uint8) (netip.Addr, bool) {
@@ -435,22 +362,8 @@ func (p *ProxyHandler) HandleIncomingPacket(packet []byte) bool {
// Check if the source IP, destination IP, port, and protocol match any subnet rule // Check if the source IP, destination IP, port, and protocol match any subnet rule
matchedRule := p.subnetLookup.Match(srcAddr, dstAddr, dstPort, protocol) matchedRule := p.subnetLookup.Match(srcAddr, dstAddr, dstPort, protocol)
if matchedRule != nil { if matchedRule != nil {
logger.Debug("HandleIncomingPacket: Matched rule for %s -> %s (proto=%d, port=%d, resourceId=%d)", logger.Debug("HandleIncomingPacket: Matched rule for %s -> %s (proto=%d, port=%d)",
srcAddr, dstAddr, protocol, dstPort, matchedRule.ResourceId) srcAddr, dstAddr, protocol, dstPort)
// Store resource ID for connections without DNAT as well
if matchedRule.ResourceId != 0 && matchedRule.RewriteTo == "" {
dKey := destKey{
srcIP: srcAddr.String(),
dstIP: dstAddr.String(),
dstPort: dstPort,
proto: uint8(protocol),
}
p.natMu.Lock()
p.resourceTable[dKey] = matchedRule.ResourceId
p.natMu.Unlock()
}
// Check if we need to perform DNAT // Check if we need to perform DNAT
if matchedRule.RewriteTo != "" { if matchedRule.RewriteTo != "" {
// Create connection tracking key using original destination // Create connection tracking key using original destination
@@ -482,13 +395,6 @@ func (p *ProxyHandler) HandleIncomingPacket(packet []byte) bool {
proto: uint8(protocol), proto: uint8(protocol),
} }
// Store resource ID for access logging if present
if matchedRule.ResourceId != 0 {
p.natMu.Lock()
p.resourceTable[dKey] = matchedRule.ResourceId
p.natMu.Unlock()
}
// Check if we already have a NAT entry for this connection // Check if we already have a NAT entry for this connection
p.natMu.RLock() p.natMu.RLock()
existingEntry, exists := p.natTable[key] existingEntry, exists := p.natTable[key]
@@ -789,11 +695,6 @@ func (p *ProxyHandler) Close() error {
return nil return nil
} }
// Shut down access logger
if p.accessLogger != nil {
p.accessLogger.Close()
}
// Close ICMP replies channel // Close ICMP replies channel
if p.icmpReplies != nil { if p.icmpReplies != nil {
close(p.icmpReplies) close(p.icmpReplies)

View File

@@ -47,7 +47,7 @@ func prefixEqual(a, b netip.Prefix) bool {
// AddSubnet adds a subnet rule with source and destination prefixes and optional port restrictions // AddSubnet adds a subnet rule with source and destination prefixes and optional port restrictions
// If portRanges is nil or empty, all ports are allowed for this subnet // If portRanges is nil or empty, all ports are allowed for this subnet
// rewriteTo can be either an IP/CIDR (e.g., "192.168.1.1/32") or a domain name (e.g., "example.com") // rewriteTo can be either an IP/CIDR (e.g., "192.168.1.1/32") or a domain name (e.g., "example.com")
func (sl *SubnetLookup) AddSubnet(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool, resourceId int) { func (sl *SubnetLookup) AddSubnet(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool) {
sl.mu.Lock() sl.mu.Lock()
defer sl.mu.Unlock() defer sl.mu.Unlock()
@@ -57,7 +57,6 @@ func (sl *SubnetLookup) AddSubnet(sourcePrefix, destPrefix netip.Prefix, rewrite
DisableIcmp: disableIcmp, DisableIcmp: disableIcmp,
RewriteTo: rewriteTo, RewriteTo: rewriteTo,
PortRanges: portRanges, PortRanges: portRanges,
ResourceId: resourceId,
} }
// Canonicalize source prefix to handle host bits correctly // Canonicalize source prefix to handle host bits correctly

View File

@@ -354,10 +354,10 @@ func (net *Net) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) {
// AddProxySubnetRule adds a subnet rule to the proxy handler // AddProxySubnetRule adds a subnet rule to the proxy handler
// If portRanges is nil or empty, all ports are allowed for this subnet // If portRanges is nil or empty, all ports are allowed for this subnet
// rewriteTo can be either an IP/CIDR (e.g., "192.168.1.1/32") or a domain name (e.g., "example.com") // rewriteTo can be either an IP/CIDR (e.g., "192.168.1.1/32") or a domain name (e.g., "example.com")
func (net *Net) AddProxySubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool, resourceId int) { func (net *Net) AddProxySubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool) {
tun := (*netTun)(net) tun := (*netTun)(net)
if tun.proxyHandler != nil { if tun.proxyHandler != nil {
tun.proxyHandler.AddSubnetRule(sourcePrefix, destPrefix, rewriteTo, portRanges, disableIcmp, resourceId) tun.proxyHandler.AddSubnetRule(sourcePrefix, destPrefix, rewriteTo, portRanges, disableIcmp)
} }
} }
@@ -369,15 +369,6 @@ func (net *Net) RemoveProxySubnetRule(sourcePrefix, destPrefix netip.Prefix) {
} }
} }
// GetProxySubnetRules returns all subnet rules from the proxy handler
func (net *Net) GetProxySubnetRules() []SubnetRule {
tun := (*netTun)(net)
if tun.proxyHandler != nil {
return tun.proxyHandler.GetAllRules()
}
return nil
}
// GetProxyHandler returns the proxy handler (for advanced use cases) // GetProxyHandler returns the proxy handler (for advanced use cases)
// Returns nil if proxy is not enabled // Returns nil if proxy is not enabled
func (net *Net) GetProxyHandler() *ProxyHandler { func (net *Net) GetProxyHandler() *ProxyHandler {
@@ -385,15 +376,6 @@ func (net *Net) GetProxyHandler() *ProxyHandler {
return tun.proxyHandler return tun.proxyHandler
} }
// SetAccessLogSender configures the function used to send compressed access log
// batches to the server. This should be called once the websocket client is available.
func (net *Net) SetAccessLogSender(fn SendFunc) {
tun := (*netTun)(net)
if tun.proxyHandler != nil {
tun.proxyHandler.SetAccessLogSender(fn)
}
}
type PingConn struct { type PingConn struct {
laddr PingAddr laddr PingAddr
raddr PingAddr raddr PingAddr

View File

@@ -21,10 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
) )
const ( const errUnsupportedProtoFmt = "unsupported protocol: %s"
errUnsupportedProtoFmt = "unsupported protocol: %s"
maxUDPPacketSize = 65507
)
// Target represents a proxy target with its address and port // Target represents a proxy target with its address and port
type Target struct { type Target struct {
@@ -108,10 +105,14 @@ func classifyProxyError(err error) string {
if errors.Is(err, net.ErrClosed) { if errors.Is(err, net.ErrClosed) {
return "closed" return "closed"
} }
var ne net.Error if ne, ok := err.(net.Error); ok {
if errors.As(err, &ne) && ne.Timeout() { if ne.Timeout() {
return "timeout" return "timeout"
} }
if ne.Temporary() {
return "temporary"
}
}
msg := strings.ToLower(err.Error()) msg := strings.ToLower(err.Error())
switch { switch {
case strings.Contains(msg, "refused"): case strings.Contains(msg, "refused"):
@@ -436,6 +437,14 @@ func (pm *ProxyManager) Stop() error {
pm.udpConns = append(pm.udpConns[:i], pm.udpConns[i+1:]...) pm.udpConns = append(pm.udpConns[:i], pm.udpConns[i+1:]...)
} }
// // Clear the target maps
// for k := range pm.tcpTargets {
// delete(pm.tcpTargets, k)
// }
// for k := range pm.udpTargets {
// delete(pm.udpTargets, k)
// }
// Give active connections a chance to close gracefully // Give active connections a chance to close gracefully
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
@@ -489,7 +498,7 @@ func (pm *ProxyManager) handleTCPProxy(listener net.Listener, targetAddr string)
if !pm.running { if !pm.running {
return return
} }
if errors.Is(err, net.ErrClosed) { if ne, ok := err.(net.Error); ok && !ne.Temporary() {
logger.Info("TCP listener closed, stopping proxy handler for %v", listener.Addr()) logger.Info("TCP listener closed, stopping proxy handler for %v", listener.Addr())
return return
} }
@@ -555,7 +564,7 @@ func (pm *ProxyManager) handleTCPProxy(listener net.Listener, targetAddr string)
} }
func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) { func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
buffer := make([]byte, maxUDPPacketSize) // Max UDP packet size buffer := make([]byte, 65507) // Max UDP packet size
clientConns := make(map[string]*net.UDPConn) clientConns := make(map[string]*net.UDPConn)
var clientsMutex sync.RWMutex var clientsMutex sync.RWMutex
@@ -574,7 +583,7 @@ func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
} }
// Check for connection closed conditions // Check for connection closed conditions
if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) { if err == io.EOF || strings.Contains(err.Error(), "use of closed network connection") {
logger.Info("UDP connection closed, stopping proxy handler") logger.Info("UDP connection closed, stopping proxy handler")
// Clean up existing client connections // Clean up existing client connections
@@ -653,14 +662,10 @@ func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
telemetry.IncProxyConnectionEvent(context.Background(), tunnelID, "udp", telemetry.ProxyConnectionClosed) telemetry.IncProxyConnectionEvent(context.Background(), tunnelID, "udp", telemetry.ProxyConnectionClosed)
}() }()
buffer := make([]byte, maxUDPPacketSize) buffer := make([]byte, 65507)
for { for {
n, _, err := targetConn.ReadFromUDP(buffer) n, _, err := targetConn.ReadFromUDP(buffer)
if err != nil { if err != nil {
// Connection closed is normal during cleanup
if errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) {
return // defer will handle cleanup, result stays "success"
}
logger.Error("Error reading from target: %v", err) logger.Error("Error reading from target: %v", err)
result = "failure" result = "failure"
return // defer will handle cleanup return // defer will handle cleanup
@@ -731,28 +736,3 @@ func (pm *ProxyManager) PrintTargets() {
} }
} }
} }
// GetTargets returns a copy of the current TCP and UDP targets
// Returns map[listenIP]map[port]targetAddress for both TCP and UDP
func (pm *ProxyManager) GetTargets() (tcpTargets map[string]map[int]string, udpTargets map[string]map[int]string) {
pm.mutex.RLock()
defer pm.mutex.RUnlock()
tcpTargets = make(map[string]map[int]string)
for listenIP, targets := range pm.tcpTargets {
tcpTargets[listenIP] = make(map[int]string)
for port, targetAddr := range targets {
tcpTargets[listenIP][port] = targetAddr
}
}
udpTargets = make(map[string]map[int]string)
for listenIP, targets := range pm.udpTargets {
udpTargets[listenIP] = make(map[int]string)
for port, targetAddr := range targets {
udpTargets[listenIP][port] = targetAddr
}
}
return tcpTargets, udpTargets
}

View File

@@ -2,7 +2,6 @@ package websocket
import ( import (
"bytes" "bytes"
"compress/gzip"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/json" "encoding/json"
@@ -38,21 +37,16 @@ type Client struct {
isConnected bool isConnected bool
reconnectMux sync.RWMutex reconnectMux sync.RWMutex
pingInterval time.Duration pingInterval time.Duration
pingTimeout time.Duration
onConnect func() error onConnect func() error
onTokenUpdate func(token string) onTokenUpdate func(token string)
writeMux sync.Mutex writeMux sync.Mutex
clientType string // Type of client (e.g., "newt", "olm") clientType string // Type of client (e.g., "newt", "olm")
configFilePath string // Optional override for the config file path
tlsConfig TLSConfig tlsConfig TLSConfig
metricsCtxMu sync.RWMutex metricsCtxMu sync.RWMutex
metricsCtx context.Context metricsCtx context.Context
configNeedsSave bool // Flag to track if config needs to be saved configNeedsSave bool // Flag to track if config needs to be saved
serverVersion string serverVersion string
configVersion int64 // Latest config version received from server
configVersionMux sync.RWMutex
processingMessage bool // Flag to track if a message is currently being processed
processingMux sync.RWMutex // Protects processingMessage
processingWg sync.WaitGroup // WaitGroup to wait for message processing to complete
} }
type ClientOption func(*Client) type ClientOption func(*Client)
@@ -78,12 +72,6 @@ func WithBaseURL(url string) ClientOption {
} }
// WithTLSConfig sets the TLS configuration for the client // WithTLSConfig sets the TLS configuration for the client
func WithConfigFile(path string) ClientOption {
return func(c *Client) {
c.configFilePath = path
}
}
func WithTLSConfig(config TLSConfig) ClientOption { func WithTLSConfig(config TLSConfig) ClientOption {
return func(c *Client) { return func(c *Client) {
c.tlsConfig = config c.tlsConfig = config
@@ -123,7 +111,7 @@ func (c *Client) MetricsContext() context.Context {
} }
// NewClient creates a new websocket client // NewClient creates a new websocket client
func NewClient(clientType string, ID, secret string, endpoint string, pingInterval time.Duration, opts ...ClientOption) (*Client, error) { func NewClient(clientType string, ID, secret string, endpoint string, pingInterval time.Duration, pingTimeout time.Duration, opts ...ClientOption) (*Client, error) {
config := &Config{ config := &Config{
ID: ID, ID: ID,
Secret: secret, Secret: secret,
@@ -138,6 +126,7 @@ func NewClient(clientType string, ID, secret string, endpoint string, pingInterv
reconnectInterval: 3 * time.Second, reconnectInterval: 3 * time.Second,
isConnected: false, isConnected: false,
pingInterval: pingInterval, pingInterval: pingInterval,
pingTimeout: pingTimeout,
clientType: clientType, clientType: clientType,
} }
@@ -165,20 +154,6 @@ func (c *Client) GetServerVersion() string {
return c.serverVersion return c.serverVersion
} }
// GetConfigVersion returns the latest config version received from server
func (c *Client) GetConfigVersion() int64 {
c.configVersionMux.RLock()
defer c.configVersionMux.RUnlock()
return c.configVersion
}
// setConfigVersion updates the config version
func (c *Client) setConfigVersion(version int64) {
c.configVersionMux.Lock()
defer c.configVersionMux.Unlock()
c.configVersion = version
}
// Connect establishes the WebSocket connection // Connect establishes the WebSocket connection
func (c *Client) Connect() error { func (c *Client) Connect() error {
go c.connectWithRetry() go c.connectWithRetry()
@@ -488,11 +463,6 @@ func (c *Client) connectWithRetry() {
func (c *Client) establishConnection() error { func (c *Client) establishConnection() error {
ctx := context.Background() ctx := context.Background()
// Exchange provisioning key for permanent credentials if needed.
if err := c.provisionIfNeeded(); err != nil {
return fmt.Errorf("failed to provision newt credentials: %w", err)
}
// Get token for authentication // Get token for authentication
token, err := c.getToken() token, err := c.getToken()
if err != nil { if err != nil {
@@ -671,37 +641,24 @@ func (c *Client) setupPKCS12TLS() (*tls.Config, error) {
} }
// pingMonitor sends pings at a short interval and triggers reconnect on failure // pingMonitor sends pings at a short interval and triggers reconnect on failure
func (c *Client) sendPing() { func (c *Client) pingMonitor() {
ticker := time.NewTicker(c.pingInterval)
defer ticker.Stop()
for {
select {
case <-c.done:
return
case <-ticker.C:
if c.conn == nil { if c.conn == nil {
return return
} }
// Skip ping if a message is currently being processed
c.processingMux.RLock()
isProcessing := c.processingMessage
c.processingMux.RUnlock()
if isProcessing {
logger.Debug("Skipping ping, message is being processed")
return
}
c.configVersionMux.RLock()
configVersion := c.configVersion
c.configVersionMux.RUnlock()
pingMsg := WSMessage{
Type: "newt/ping",
Data: map[string]interface{}{},
ConfigVersion: configVersion,
}
c.writeMux.Lock() c.writeMux.Lock()
err := c.conn.WriteJSON(pingMsg) err := c.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(c.pingTimeout))
if err == nil { if err == nil {
telemetry.IncWSMessage(c.metricsContext(), "out", "ping") telemetry.IncWSMessage(c.metricsContext(), "out", "ping")
} }
c.writeMux.Unlock() c.writeMux.Unlock()
if err != nil { if err != nil {
// Check if we're shutting down before logging error and reconnecting // Check if we're shutting down before logging error and reconnecting
select { select {
@@ -716,21 +673,6 @@ func (c *Client) sendPing() {
return return
} }
} }
}
func (c *Client) pingMonitor() {
// Send an immediate ping as soon as we connect
c.sendPing()
ticker := time.NewTicker(c.pingInterval)
defer ticker.Stop()
for {
select {
case <-c.done:
return
case <-ticker.C:
c.sendPing()
} }
} }
} }
@@ -767,14 +709,11 @@ func (c *Client) readPumpWithDisconnectDetection(started time.Time) {
disconnectResult = "success" disconnectResult = "success"
return return
default: default:
msgType, p, err := c.conn.ReadMessage() var msg WSMessage
err := c.conn.ReadJSON(&msg)
if err == nil { if err == nil {
if msgType == websocket.BinaryMessage {
telemetry.IncWSMessage(c.metricsContext(), "in", "binary")
} else {
telemetry.IncWSMessage(c.metricsContext(), "in", "text") telemetry.IncWSMessage(c.metricsContext(), "in", "text")
} }
}
if err != nil { if err != nil {
// Check if we're shutting down before logging error // Check if we're shutting down before logging error
select { select {
@@ -798,47 +737,9 @@ func (c *Client) readPumpWithDisconnectDetection(started time.Time) {
} }
} }
// Update config version from incoming message
var data []byte
if msgType == websocket.BinaryMessage {
gr, err := gzip.NewReader(bytes.NewReader(p))
if err != nil {
logger.Error("WebSocket failed to create gzip reader: %v", err)
continue
}
data, err = io.ReadAll(gr)
gr.Close()
if err != nil {
logger.Error("WebSocket failed to decompress message: %v", err)
continue
}
} else {
data = p
}
var msg WSMessage
if err = json.Unmarshal(data, &msg); err != nil {
logger.Error("WebSocket failed to parse message: %v", err)
continue
}
c.setConfigVersion(msg.ConfigVersion)
c.handlersMux.RLock() c.handlersMux.RLock()
if handler, ok := c.handlers[msg.Type]; ok { if handler, ok := c.handlers[msg.Type]; ok {
// Mark that we're processing a message
c.processingMux.Lock()
c.processingMessage = true
c.processingMux.Unlock()
c.processingWg.Add(1)
handler(msg) handler(msg)
// Mark that we're done processing
c.processingWg.Done()
c.processingMux.Lock()
c.processingMessage = false
c.processingMux.Unlock()
} }
c.handlersMux.RUnlock() c.handlersMux.RUnlock()
} }

View File

@@ -1,28 +1,16 @@
package websocket package websocket
import ( import (
"bytes"
"context"
"crypto/tls"
"encoding/json" "encoding/json"
"fmt"
"io"
"log" "log"
"net/http"
"net/url"
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
"strings"
"time"
"github.com/fosrl/newt/logger" "github.com/fosrl/newt/logger"
) )
func getConfigPath(clientType string, overridePath string) string { func getConfigPath(clientType string) string {
if overridePath != "" {
return overridePath
}
configFile := os.Getenv("CONFIG_FILE") configFile := os.Getenv("CONFIG_FILE")
if configFile == "" { if configFile == "" {
var configDir string var configDir string
@@ -48,7 +36,7 @@ func getConfigPath(clientType string, overridePath string) string {
func (c *Client) loadConfig() error { func (c *Client) loadConfig() error {
originalConfig := *c.config // Store original config to detect changes originalConfig := *c.config // Store original config to detect changes
configPath := getConfigPath(c.clientType, c.configFilePath) configPath := getConfigPath(c.clientType)
if c.config.ID != "" && c.config.Secret != "" && c.config.Endpoint != "" { if c.config.ID != "" && c.config.Secret != "" && c.config.Endpoint != "" {
logger.Debug("Config already provided, skipping loading from file") logger.Debug("Config already provided, skipping loading from file")
@@ -95,10 +83,6 @@ func (c *Client) loadConfig() error {
c.config.Endpoint = config.Endpoint c.config.Endpoint = config.Endpoint
c.baseURL = config.Endpoint c.baseURL = config.Endpoint
} }
// Always load the provisioning key from the file if not already set
if c.config.ProvisioningKey == "" {
c.config.ProvisioningKey = config.ProvisioningKey
}
// Check if CLI args provided values that override file values // Check if CLI args provided values that override file values
if (!fileHadID && originalConfig.ID != "") || if (!fileHadID && originalConfig.ID != "") ||
@@ -121,7 +105,7 @@ func (c *Client) saveConfig() error {
return nil return nil
} }
configPath := getConfigPath(c.clientType, c.configFilePath) configPath := getConfigPath(c.clientType)
data, err := json.MarshalIndent(c.config, "", " ") data, err := json.MarshalIndent(c.config, "", " ")
if err != nil { if err != nil {
return err return err
@@ -134,116 +118,3 @@ func (c *Client) saveConfig() error {
} }
return err return err
} }
// provisionIfNeeded checks whether a provisioning key is present and, if so,
// exchanges it for a newt ID and secret by calling the registration endpoint.
// On success the config is updated in-place and flagged for saving so that
// subsequent runs use the permanent credentials directly.
func (c *Client) provisionIfNeeded() error {
if c.config.ProvisioningKey == "" {
return nil
}
// If we already have both credentials there is nothing to provision.
if c.config.ID != "" && c.config.Secret != "" {
logger.Debug("Credentials already present, skipping provisioning")
return nil
}
logger.Info("Provisioning key found exchanging for newt credentials...")
baseURL, err := url.Parse(c.baseURL)
if err != nil {
return fmt.Errorf("failed to parse base URL for provisioning: %w", err)
}
baseEndpoint := strings.TrimRight(baseURL.String(), "/")
reqBody := map[string]interface{}{
"provisioningKey": c.config.ProvisioningKey,
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return fmt.Errorf("failed to marshal provisioning request: %w", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(
ctx,
"POST",
baseEndpoint+"/api/v1/auth/newt/register",
bytes.NewBuffer(jsonData),
)
if err != nil {
return fmt.Errorf("failed to create provisioning request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-CSRF-Token", "x-csrf-protection")
// Mirror the TLS setup used by getToken so mTLS / self-signed CAs work.
var tlsCfg *tls.Config
if c.tlsConfig.ClientCertFile != "" || c.tlsConfig.ClientKeyFile != "" ||
len(c.tlsConfig.CAFiles) > 0 || c.tlsConfig.PKCS12File != "" {
tlsCfg, err = c.setupTLS()
if err != nil {
return fmt.Errorf("failed to setup TLS for provisioning: %w", err)
}
}
if os.Getenv("SKIP_TLS_VERIFY") == "true" {
if tlsCfg == nil {
tlsCfg = &tls.Config{}
}
tlsCfg.InsecureSkipVerify = true
logger.Debug("TLS certificate verification disabled for provisioning via SKIP_TLS_VERIFY")
}
httpClient := &http.Client{}
if tlsCfg != nil {
httpClient.Transport = &http.Transport{TLSClientConfig: tlsCfg}
}
resp, err := httpClient.Do(req)
if err != nil {
return fmt.Errorf("provisioning request failed: %w", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
logger.Debug("Provisioning response body: %s", string(body))
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return fmt.Errorf("provisioning endpoint returned status %d: %s", resp.StatusCode, string(body))
}
var provResp ProvisioningResponse
if err := json.Unmarshal(body, &provResp); err != nil {
return fmt.Errorf("failed to decode provisioning response: %w", err)
}
if !provResp.Success {
return fmt.Errorf("provisioning failed: %s", provResp.Message)
}
if provResp.Data.NewtID == "" || provResp.Data.Secret == "" {
return fmt.Errorf("provisioning response is missing newt ID or secret")
}
logger.Info("Successfully provisioned newt ID: %s", provResp.Data.NewtID)
// Persist the returned credentials and clear the one-time provisioning key
// so subsequent runs authenticate normally.
c.config.ID = provResp.Data.NewtID
c.config.Secret = provResp.Data.Secret
c.config.ProvisioningKey = ""
c.configNeedsSave = true
// Save immediately so that if the subsequent connection attempt fails the
// provisioning key is already gone from disk and the next retry uses the
// permanent credentials instead of trying to provision again.
if err := c.saveConfig(); err != nil {
logger.Error("Failed to save config after provisioning: %v", err)
}
return nil
}

View File

@@ -5,7 +5,6 @@ type Config struct {
Secret string `json:"secret"` Secret string `json:"secret"`
Endpoint string `json:"endpoint"` Endpoint string `json:"endpoint"`
TlsClientCert string `json:"tlsClientCert"` TlsClientCert string `json:"tlsClientCert"`
ProvisioningKey string `json:"provisioningKey,omitempty"`
} }
type TokenResponse struct { type TokenResponse struct {
@@ -17,17 +16,7 @@ type TokenResponse struct {
Message string `json:"message"` Message string `json:"message"`
} }
type ProvisioningResponse struct {
Data struct {
NewtID string `json:"newtId"`
Secret string `json:"secret"`
} `json:"data"`
Success bool `json:"success"`
Message string `json:"message"`
}
type WSMessage struct { type WSMessage struct {
Type string `json:"type"` Type string `json:"type"`
Data interface{} `json:"data"` Data interface{} `json:"data"`
ConfigVersion int64 `json:"configVersion,omitempty"`
} }