mirror of
https://github.com/fosrl/newt.git
synced 2026-03-27 13:06:38 +00:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
44ca592a5c | ||
|
|
e1edbcea07 |
@@ -38,7 +38,6 @@ type WgConfig struct {
|
|||||||
|
|
||||||
type Target struct {
|
type Target struct {
|
||||||
SourcePrefix string `json:"sourcePrefix"`
|
SourcePrefix string `json:"sourcePrefix"`
|
||||||
SourcePrefixes []string `json:"sourcePrefixes"`
|
|
||||||
DestPrefix string `json:"destPrefix"`
|
DestPrefix string `json:"destPrefix"`
|
||||||
RewriteTo string `json:"rewriteTo,omitempty"`
|
RewriteTo string `json:"rewriteTo,omitempty"`
|
||||||
DisableIcmp bool `json:"disableIcmp,omitempty"`
|
DisableIcmp bool `json:"disableIcmp,omitempty"`
|
||||||
@@ -113,6 +112,8 @@ func NewWireGuardService(interfaceName string, port uint16, mtu int, host string
|
|||||||
return nil, fmt.Errorf("failed to generate private key: %v", err)
|
return nil, fmt.Errorf("failed to generate private key: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logger.Debug("+++++++++++++++++++++++++++++++= the port is %d", port)
|
||||||
|
|
||||||
if port == 0 {
|
if port == 0 {
|
||||||
// Find an available port
|
// Find an available port
|
||||||
portRandom, err := util.FindAvailableUDPPort(49152, 65535)
|
portRandom, err := util.FindAvailableUDPPort(49152, 65535)
|
||||||
@@ -173,7 +174,6 @@ func NewWireGuardService(interfaceName string, port uint16, mtu int, host string
|
|||||||
wsClient.RegisterHandler("newt/wg/targets/add", service.handleAddTarget)
|
wsClient.RegisterHandler("newt/wg/targets/add", service.handleAddTarget)
|
||||||
wsClient.RegisterHandler("newt/wg/targets/remove", service.handleRemoveTarget)
|
wsClient.RegisterHandler("newt/wg/targets/remove", service.handleRemoveTarget)
|
||||||
wsClient.RegisterHandler("newt/wg/targets/update", service.handleUpdateTarget)
|
wsClient.RegisterHandler("newt/wg/targets/update", service.handleUpdateTarget)
|
||||||
wsClient.RegisterHandler("newt/wg/sync", service.handleSyncConfig)
|
|
||||||
|
|
||||||
return service, nil
|
return service, nil
|
||||||
}
|
}
|
||||||
@@ -494,183 +494,6 @@ func (s *WireGuardService) handleConfig(msg websocket.WSMessage) {
|
|||||||
logger.Info("Client connectivity setup. Ready to accept connections from clients!")
|
logger.Info("Client connectivity setup. Ready to accept connections from clients!")
|
||||||
}
|
}
|
||||||
|
|
||||||
// SyncConfig represents the configuration sent from server for syncing
|
|
||||||
type SyncConfig struct {
|
|
||||||
Targets []Target `json:"targets"`
|
|
||||||
Peers []Peer `json:"peers"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *WireGuardService) handleSyncConfig(msg websocket.WSMessage) {
|
|
||||||
var syncConfig SyncConfig
|
|
||||||
|
|
||||||
logger.Debug("Received sync message: %v", msg)
|
|
||||||
logger.Info("Received sync configuration from remote server")
|
|
||||||
|
|
||||||
jsonData, err := json.Marshal(msg.Data)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Error marshaling sync data: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := json.Unmarshal(jsonData, &syncConfig); err != nil {
|
|
||||||
logger.Error("Error unmarshaling sync data: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sync peers
|
|
||||||
if err := s.syncPeers(syncConfig.Peers); err != nil {
|
|
||||||
logger.Error("Failed to sync peers: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sync targets
|
|
||||||
if err := s.syncTargets(syncConfig.Targets); err != nil {
|
|
||||||
logger.Error("Failed to sync targets: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// syncPeers synchronizes the current peers with the desired state
|
|
||||||
// It removes peers not in the desired list and adds missing ones
|
|
||||||
func (s *WireGuardService) syncPeers(desiredPeers []Peer) error {
|
|
||||||
if s.device == nil {
|
|
||||||
return fmt.Errorf("WireGuard device is not initialized")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get current peers from the device
|
|
||||||
currentConfig, err := s.device.IpcGet()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to get current device config: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse current peer public keys
|
|
||||||
lines := strings.Split(currentConfig, "\n")
|
|
||||||
currentPeerKeys := make(map[string]bool)
|
|
||||||
for _, line := range lines {
|
|
||||||
if strings.HasPrefix(line, "public_key=") {
|
|
||||||
pubKey := strings.TrimPrefix(line, "public_key=")
|
|
||||||
currentPeerKeys[pubKey] = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build a map of desired peers by their public key (normalized)
|
|
||||||
desiredPeerMap := make(map[string]Peer)
|
|
||||||
for _, peer := range desiredPeers {
|
|
||||||
// Normalize the public key for comparison
|
|
||||||
pubKey, err := wgtypes.ParseKey(peer.PublicKey)
|
|
||||||
if err != nil {
|
|
||||||
logger.Warn("Invalid public key in desired peers: %s", peer.PublicKey)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
normalizedKey := util.FixKey(pubKey.String())
|
|
||||||
desiredPeerMap[normalizedKey] = peer
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove peers that are not in the desired list
|
|
||||||
for currentKey := range currentPeerKeys {
|
|
||||||
if _, exists := desiredPeerMap[currentKey]; !exists {
|
|
||||||
// Parse the key back to get the original format for removal
|
|
||||||
removeConfig := fmt.Sprintf("public_key=%s\nremove=true", currentKey)
|
|
||||||
if err := s.device.IpcSet(removeConfig); err != nil {
|
|
||||||
logger.Warn("Failed to remove peer %s during sync: %v", currentKey, err)
|
|
||||||
} else {
|
|
||||||
logger.Info("Removed peer %s during sync", currentKey)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add peers that are missing
|
|
||||||
for normalizedKey, peer := range desiredPeerMap {
|
|
||||||
if _, exists := currentPeerKeys[normalizedKey]; !exists {
|
|
||||||
if err := s.addPeerToDevice(peer); err != nil {
|
|
||||||
logger.Warn("Failed to add peer %s during sync: %v", peer.PublicKey, err)
|
|
||||||
} else {
|
|
||||||
logger.Info("Added peer %s during sync", peer.PublicKey)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// syncTargets synchronizes the current targets with the desired state
|
|
||||||
// It removes targets not in the desired list and adds missing ones
|
|
||||||
func (s *WireGuardService) syncTargets(desiredTargets []Target) error {
|
|
||||||
if s.tnet == nil {
|
|
||||||
// Native interface mode - proxy features not available, skip silently
|
|
||||||
logger.Debug("Skipping target sync - using native interface (no proxy support)")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get current rules from the proxy handler
|
|
||||||
currentRules := s.tnet.GetProxySubnetRules()
|
|
||||||
|
|
||||||
// Build a map of current rules by source+dest prefix
|
|
||||||
type ruleKey struct {
|
|
||||||
sourcePrefix string
|
|
||||||
destPrefix string
|
|
||||||
}
|
|
||||||
currentRuleMap := make(map[ruleKey]bool)
|
|
||||||
for _, rule := range currentRules {
|
|
||||||
key := ruleKey{
|
|
||||||
sourcePrefix: rule.SourcePrefix.String(),
|
|
||||||
destPrefix: rule.DestPrefix.String(),
|
|
||||||
}
|
|
||||||
currentRuleMap[key] = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build a map of desired targets
|
|
||||||
desiredTargetMap := make(map[ruleKey]Target)
|
|
||||||
for _, target := range desiredTargets {
|
|
||||||
key := ruleKey{
|
|
||||||
sourcePrefix: target.SourcePrefix,
|
|
||||||
destPrefix: target.DestPrefix,
|
|
||||||
}
|
|
||||||
desiredTargetMap[key] = target
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove targets that are not in the desired list
|
|
||||||
for _, rule := range currentRules {
|
|
||||||
key := ruleKey{
|
|
||||||
sourcePrefix: rule.SourcePrefix.String(),
|
|
||||||
destPrefix: rule.DestPrefix.String(),
|
|
||||||
}
|
|
||||||
if _, exists := desiredTargetMap[key]; !exists {
|
|
||||||
s.tnet.RemoveProxySubnetRule(rule.SourcePrefix, rule.DestPrefix)
|
|
||||||
logger.Info("Removed target %s -> %s during sync", rule.SourcePrefix.String(), rule.DestPrefix.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add targets that are missing
|
|
||||||
for key, target := range desiredTargetMap {
|
|
||||||
if _, exists := currentRuleMap[key]; !exists {
|
|
||||||
sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix)
|
|
||||||
if err != nil {
|
|
||||||
logger.Warn("Invalid source prefix %s during sync: %v", target.SourcePrefix, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
destPrefix, err := netip.ParsePrefix(target.DestPrefix)
|
|
||||||
if err != nil {
|
|
||||||
logger.Warn("Invalid dest prefix %s during sync: %v", target.DestPrefix, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
var portRanges []netstack2.PortRange
|
|
||||||
for _, pr := range target.PortRange {
|
|
||||||
portRanges = append(portRanges, netstack2.PortRange{
|
|
||||||
Min: pr.Min,
|
|
||||||
Max: pr.Max,
|
|
||||||
Protocol: pr.Protocol,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
|
|
||||||
logger.Info("Added target %s -> %s during sync", target.SourcePrefix, target.DestPrefix)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
|
func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
|
|
||||||
@@ -874,19 +697,6 @@ func (s *WireGuardService) ensureWireguardPeers(peers []Peer) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// resolveSourcePrefixes returns the effective list of source prefixes for a target,
|
|
||||||
// supporting both the legacy single SourcePrefix field and the new SourcePrefixes array.
|
|
||||||
// If SourcePrefixes is non-empty it takes precedence; otherwise SourcePrefix is used.
|
|
||||||
func resolveSourcePrefixes(target Target) []string {
|
|
||||||
if len(target.SourcePrefixes) > 0 {
|
|
||||||
return target.SourcePrefixes
|
|
||||||
}
|
|
||||||
if target.SourcePrefix != "" {
|
|
||||||
return []string{target.SourcePrefix}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *WireGuardService) ensureTargets(targets []Target) error {
|
func (s *WireGuardService) ensureTargets(targets []Target) error {
|
||||||
if s.tnet == nil {
|
if s.tnet == nil {
|
||||||
// Native interface mode - proxy features not available, skip silently
|
// Native interface mode - proxy features not available, skip silently
|
||||||
@@ -895,6 +705,11 @@ func (s *WireGuardService) ensureTargets(targets []Target) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, target := range targets {
|
for _, target := range targets {
|
||||||
|
sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid CIDR %s: %v", target.SourcePrefix, err)
|
||||||
|
}
|
||||||
|
|
||||||
destPrefix, err := netip.ParsePrefix(target.DestPrefix)
|
destPrefix, err := netip.ParsePrefix(target.DestPrefix)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("invalid CIDR %s: %v", target.DestPrefix, err)
|
return fmt.Errorf("invalid CIDR %s: %v", target.DestPrefix, err)
|
||||||
@@ -909,14 +724,9 @@ func (s *WireGuardService) ensureTargets(targets []Target) error {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, sp := range resolveSourcePrefixes(target) {
|
|
||||||
sourcePrefix, err := netip.ParsePrefix(sp)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("invalid CIDR %s: %v", sp, err)
|
|
||||||
}
|
|
||||||
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
|
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
|
||||||
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange)
|
|
||||||
}
|
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v disableIcmp: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange, target.DisableIcmp)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -1286,6 +1096,12 @@ func (s *WireGuardService) handleAddTarget(msg websocket.WSMessage) {
|
|||||||
|
|
||||||
// Process all targets
|
// Process all targets
|
||||||
for _, target := range targets {
|
for _, target := range targets {
|
||||||
|
sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix)
|
||||||
|
if err != nil {
|
||||||
|
logger.Info("Invalid CIDR %s: %v", target.SourcePrefix, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
destPrefix, err := netip.ParsePrefix(target.DestPrefix)
|
destPrefix, err := netip.ParsePrefix(target.DestPrefix)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err)
|
logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err)
|
||||||
@@ -1301,15 +1117,9 @@ func (s *WireGuardService) handleAddTarget(msg websocket.WSMessage) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, sp := range resolveSourcePrefixes(target) {
|
|
||||||
sourcePrefix, err := netip.ParsePrefix(sp)
|
|
||||||
if err != nil {
|
|
||||||
logger.Info("Invalid CIDR %s: %v", sp, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
|
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
|
||||||
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange)
|
|
||||||
}
|
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v disableIcmp: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange, target.DisableIcmp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1338,21 +1148,21 @@ func (s *WireGuardService) handleRemoveTarget(msg websocket.WSMessage) {
|
|||||||
|
|
||||||
// Process all targets
|
// Process all targets
|
||||||
for _, target := range targets {
|
for _, target := range targets {
|
||||||
|
sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix)
|
||||||
|
if err != nil {
|
||||||
|
logger.Info("Invalid CIDR %s: %v", target.SourcePrefix, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
destPrefix, err := netip.ParsePrefix(target.DestPrefix)
|
destPrefix, err := netip.ParsePrefix(target.DestPrefix)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err)
|
logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, sp := range resolveSourcePrefixes(target) {
|
|
||||||
sourcePrefix, err := netip.ParsePrefix(sp)
|
|
||||||
if err != nil {
|
|
||||||
logger.Info("Invalid CIDR %s: %v", sp, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
s.tnet.RemoveProxySubnetRule(sourcePrefix, destPrefix)
|
s.tnet.RemoveProxySubnetRule(sourcePrefix, destPrefix)
|
||||||
logger.Info("Removed target subnet %s with destination %s", sp, target.DestPrefix)
|
|
||||||
}
|
logger.Info("Removed target subnet %s with destination %s", target.SourcePrefix, target.DestPrefix)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1386,24 +1196,30 @@ func (s *WireGuardService) handleUpdateTarget(msg websocket.WSMessage) {
|
|||||||
|
|
||||||
// Process all update requests
|
// Process all update requests
|
||||||
for _, target := range requests.OldTargets {
|
for _, target := range requests.OldTargets {
|
||||||
|
sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix)
|
||||||
|
if err != nil {
|
||||||
|
logger.Info("Invalid CIDR %s: %v", target.SourcePrefix, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
destPrefix, err := netip.ParsePrefix(target.DestPrefix)
|
destPrefix, err := netip.ParsePrefix(target.DestPrefix)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err)
|
logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, sp := range resolveSourcePrefixes(target) {
|
|
||||||
sourcePrefix, err := netip.ParsePrefix(sp)
|
|
||||||
if err != nil {
|
|
||||||
logger.Info("Invalid CIDR %s: %v", sp, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
s.tnet.RemoveProxySubnetRule(sourcePrefix, destPrefix)
|
s.tnet.RemoveProxySubnetRule(sourcePrefix, destPrefix)
|
||||||
logger.Info("Removed target subnet %s with destination %s", sp, target.DestPrefix)
|
logger.Info("Removed target subnet %s with destination %s", target.SourcePrefix, target.DestPrefix)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, target := range requests.NewTargets {
|
for _, target := range requests.NewTargets {
|
||||||
|
// Now add the new target
|
||||||
|
sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix)
|
||||||
|
if err != nil {
|
||||||
|
logger.Info("Invalid CIDR %s: %v", target.SourcePrefix, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
destPrefix, err := netip.ParsePrefix(target.DestPrefix)
|
destPrefix, err := netip.ParsePrefix(target.DestPrefix)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err)
|
logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err)
|
||||||
@@ -1419,15 +1235,8 @@ func (s *WireGuardService) handleUpdateTarget(msg websocket.WSMessage) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, sp := range resolveSourcePrefixes(target) {
|
|
||||||
sourcePrefix, err := netip.ParsePrefix(sp)
|
|
||||||
if err != nil {
|
|
||||||
logger.Info("Invalid CIDR %s: %v", sp, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
|
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
|
||||||
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange)
|
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v disableIcmp: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange, target.DisableIcmp)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
66
common.go
66
common.go
@@ -5,7 +5,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -364,62 +363,27 @@ func parseTargetData(data interface{}) (TargetData, error) {
|
|||||||
return targetData, nil
|
return targetData, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseTargetString parses a target string in the format "listenPort:host:targetPort"
|
|
||||||
// It properly handles IPv6 addresses which must be in brackets: "listenPort:[ipv6]:targetPort"
|
|
||||||
// Examples:
|
|
||||||
// - IPv4: "3001:192.168.1.1:80"
|
|
||||||
// - IPv6: "3001:[::1]:8080" or "3001:[fd70:1452:b736:4dd5:caca:7db9:c588:f5b3]:80"
|
|
||||||
//
|
|
||||||
// Returns listenPort, targetAddress (in host:port format suitable for net.Dial), and error
|
|
||||||
func parseTargetString(target string) (int, string, error) {
|
|
||||||
// Find the first colon to extract the listen port
|
|
||||||
firstColon := strings.Index(target, ":")
|
|
||||||
if firstColon == -1 {
|
|
||||||
return 0, "", fmt.Errorf("invalid target format, no colon found: %s", target)
|
|
||||||
}
|
|
||||||
|
|
||||||
listenPortStr := target[:firstColon]
|
|
||||||
var listenPort int
|
|
||||||
_, err := fmt.Sscanf(listenPortStr, "%d", &listenPort)
|
|
||||||
if err != nil {
|
|
||||||
return 0, "", fmt.Errorf("invalid listen port: %s", listenPortStr)
|
|
||||||
}
|
|
||||||
if listenPort <= 0 || listenPort > 65535 {
|
|
||||||
return 0, "", fmt.Errorf("listen port out of range: %d", listenPort)
|
|
||||||
}
|
|
||||||
|
|
||||||
// The remainder is host:targetPort - use net.SplitHostPort which handles IPv6 brackets
|
|
||||||
remainder := target[firstColon+1:]
|
|
||||||
host, targetPort, err := net.SplitHostPort(remainder)
|
|
||||||
if err != nil {
|
|
||||||
return 0, "", fmt.Errorf("invalid host:port format '%s': %w", remainder, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reject empty host or target port
|
|
||||||
if host == "" {
|
|
||||||
return 0, "", fmt.Errorf("empty host in target: %s", target)
|
|
||||||
}
|
|
||||||
if targetPort == "" {
|
|
||||||
return 0, "", fmt.Errorf("empty target port in target: %s", target)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reconstruct the target address using JoinHostPort (handles IPv6 properly)
|
|
||||||
targetAddr := net.JoinHostPort(host, targetPort)
|
|
||||||
|
|
||||||
return listenPort, targetAddr, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error {
|
func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error {
|
||||||
for _, t := range targetData.Targets {
|
for _, t := range targetData.Targets {
|
||||||
// Parse the target string, handling both IPv4 and IPv6 addresses
|
// Split the first number off of the target with : separator and use as the port
|
||||||
port, target, err := parseTargetString(t)
|
parts := strings.Split(t, ":")
|
||||||
|
if len(parts) != 3 {
|
||||||
|
logger.Info("Invalid target format: %s", t)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the port as an int
|
||||||
|
port := 0
|
||||||
|
_, err := fmt.Sscanf(parts[0], "%d", &port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Info("Invalid target format: %s (%v)", t, err)
|
logger.Info("Invalid port: %s", parts[0])
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
switch action {
|
switch action {
|
||||||
case "add":
|
case "add":
|
||||||
|
target := parts[1] + ":" + parts[2]
|
||||||
|
|
||||||
// Call updown script if provided
|
// Call updown script if provided
|
||||||
processedTarget := target
|
processedTarget := target
|
||||||
if updownScript != "" {
|
if updownScript != "" {
|
||||||
@@ -446,6 +410,8 @@ func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto
|
|||||||
case "remove":
|
case "remove":
|
||||||
logger.Info("Removing target with port %d", port)
|
logger.Info("Removing target with port %d", port)
|
||||||
|
|
||||||
|
target := parts[1] + ":" + parts[2]
|
||||||
|
|
||||||
// Call updown script if provided
|
// Call updown script if provided
|
||||||
if updownScript != "" {
|
if updownScript != "" {
|
||||||
_, err := executeUpdownScript(action, proto, target)
|
_, err := executeUpdownScript(action, proto, target)
|
||||||
@@ -454,7 +420,7 @@ func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = pm.RemoveTarget(proto, tunnelIP, port)
|
err := pm.RemoveTarget(proto, tunnelIP, port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Failed to remove target: %v", err)
|
logger.Error("Failed to remove target: %v", err)
|
||||||
return err
|
return err
|
||||||
|
|||||||
212
common_test.go
212
common_test.go
@@ -1,212 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestParseTargetString(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
input string
|
|
||||||
wantListenPort int
|
|
||||||
wantTargetAddr string
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
// IPv4 test cases
|
|
||||||
{
|
|
||||||
name: "valid IPv4 basic",
|
|
||||||
input: "3001:192.168.1.1:80",
|
|
||||||
wantListenPort: 3001,
|
|
||||||
wantTargetAddr: "192.168.1.1:80",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "valid IPv4 localhost",
|
|
||||||
input: "8080:127.0.0.1:3000",
|
|
||||||
wantListenPort: 8080,
|
|
||||||
wantTargetAddr: "127.0.0.1:3000",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "valid IPv4 same ports",
|
|
||||||
input: "443:10.0.0.1:443",
|
|
||||||
wantListenPort: 443,
|
|
||||||
wantTargetAddr: "10.0.0.1:443",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
|
|
||||||
// IPv6 test cases
|
|
||||||
{
|
|
||||||
name: "valid IPv6 loopback",
|
|
||||||
input: "3001:[::1]:8080",
|
|
||||||
wantListenPort: 3001,
|
|
||||||
wantTargetAddr: "[::1]:8080",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "valid IPv6 full address",
|
|
||||||
input: "80:[fd70:1452:b736:4dd5:caca:7db9:c588:f5b3]:8080",
|
|
||||||
wantListenPort: 80,
|
|
||||||
wantTargetAddr: "[fd70:1452:b736:4dd5:caca:7db9:c588:f5b3]:8080",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "valid IPv6 link-local",
|
|
||||||
input: "443:[fe80::1]:443",
|
|
||||||
wantListenPort: 443,
|
|
||||||
wantTargetAddr: "[fe80::1]:443",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "valid IPv6 all zeros compressed",
|
|
||||||
input: "8000:[::]:9000",
|
|
||||||
wantListenPort: 8000,
|
|
||||||
wantTargetAddr: "[::]:9000",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "valid IPv6 mixed notation",
|
|
||||||
input: "5000:[::ffff:192.168.1.1]:6000",
|
|
||||||
wantListenPort: 5000,
|
|
||||||
wantTargetAddr: "[::ffff:192.168.1.1]:6000",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
|
|
||||||
// Hostname test cases
|
|
||||||
{
|
|
||||||
name: "valid hostname",
|
|
||||||
input: "8080:example.com:80",
|
|
||||||
wantListenPort: 8080,
|
|
||||||
wantTargetAddr: "example.com:80",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "valid hostname with subdomain",
|
|
||||||
input: "443:api.example.com:8443",
|
|
||||||
wantListenPort: 443,
|
|
||||||
wantTargetAddr: "api.example.com:8443",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "valid localhost hostname",
|
|
||||||
input: "3000:localhost:3000",
|
|
||||||
wantListenPort: 3000,
|
|
||||||
wantTargetAddr: "localhost:3000",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
|
|
||||||
// Error cases
|
|
||||||
{
|
|
||||||
name: "invalid - no colons",
|
|
||||||
input: "invalid",
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid - empty string",
|
|
||||||
input: "",
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid - non-numeric listen port",
|
|
||||||
input: "abc:192.168.1.1:80",
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid - missing target port",
|
|
||||||
input: "3001:192.168.1.1",
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid - IPv6 without brackets",
|
|
||||||
input: "3001:fd70:1452:b736:4dd5:caca:7db9:c588:f5b3:80",
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid - only listen port",
|
|
||||||
input: "3001:",
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid - missing host",
|
|
||||||
input: "3001::80",
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid - IPv6 unclosed bracket",
|
|
||||||
input: "3001:[::1:80",
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid - listen port zero",
|
|
||||||
input: "0:192.168.1.1:80",
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid - listen port negative",
|
|
||||||
input: "-1:192.168.1.1:80",
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid - listen port out of range",
|
|
||||||
input: "70000:192.168.1.1:80",
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid - empty target port",
|
|
||||||
input: "3001:192.168.1.1:",
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
listenPort, targetAddr, err := parseTargetString(tt.input)
|
|
||||||
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("parseTargetString(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if tt.wantErr {
|
|
||||||
return // Don't check other values if we expected an error
|
|
||||||
}
|
|
||||||
|
|
||||||
if listenPort != tt.wantListenPort {
|
|
||||||
t.Errorf("parseTargetString(%q) listenPort = %d, want %d", tt.input, listenPort, tt.wantListenPort)
|
|
||||||
}
|
|
||||||
|
|
||||||
if targetAddr != tt.wantTargetAddr {
|
|
||||||
t.Errorf("parseTargetString(%q) targetAddr = %q, want %q", tt.input, targetAddr, tt.wantTargetAddr)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestParseTargetStringNetDialCompatibility verifies that the output is compatible with net.Dial
|
|
||||||
func TestParseTargetStringNetDialCompatibility(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
input string
|
|
||||||
}{
|
|
||||||
{"IPv4", "8080:127.0.0.1:80"},
|
|
||||||
{"IPv6 loopback", "8080:[::1]:80"},
|
|
||||||
{"IPv6 full", "8080:[2001:db8::1]:80"},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
_, targetAddr, err := parseTargetString(tt.input)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("parseTargetString(%q) unexpected error: %v", tt.input, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify the format is valid for net.Dial by checking it can be split back
|
|
||||||
// This doesn't actually dial, just validates the format
|
|
||||||
_, _, err = net.SplitHostPort(targetAddr)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("parseTargetString(%q) produced invalid net.Dial format %q: %v", tt.input, targetAddr, err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
79
get-newt.sh
79
get-newt.sh
@@ -1,7 +1,7 @@
|
|||||||
#!/bin/sh
|
#!/bin/bash
|
||||||
|
|
||||||
# Get Newt - Cross-platform installation script
|
# Get Newt - Cross-platform installation script
|
||||||
# Usage: curl -fsSL https://raw.githubusercontent.com/fosrl/newt/refs/heads/main/get-newt.sh | sh
|
# Usage: curl -fsSL https://raw.githubusercontent.com/fosrl/newt/refs/heads/main/get-newt.sh | bash
|
||||||
|
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
@@ -17,15 +17,15 @@ GITHUB_API_URL="https://api.github.com/repos/${REPO}/releases/latest"
|
|||||||
|
|
||||||
# Function to print colored output
|
# Function to print colored output
|
||||||
print_status() {
|
print_status() {
|
||||||
printf '%b[INFO]%b %s\n' "${GREEN}" "${NC}" "$1"
|
echo -e "${GREEN}[INFO]${NC} $1"
|
||||||
}
|
}
|
||||||
|
|
||||||
print_warning() {
|
print_warning() {
|
||||||
printf '%b[WARN]%b %s\n' "${YELLOW}" "${NC}" "$1"
|
echo -e "${YELLOW}[WARN]${NC} $1"
|
||||||
}
|
}
|
||||||
|
|
||||||
print_error() {
|
print_error() {
|
||||||
printf '%b[ERROR]%b %s\n' "${RED}" "${NC}" "$1"
|
echo -e "${RED}[ERROR]${NC} $1"
|
||||||
}
|
}
|
||||||
|
|
||||||
# Function to get latest version from GitHub API
|
# Function to get latest version from GitHub API
|
||||||
@@ -113,34 +113,16 @@ get_install_dir() {
|
|||||||
if [ "$OS" = "windows" ]; then
|
if [ "$OS" = "windows" ]; then
|
||||||
echo "$HOME/bin"
|
echo "$HOME/bin"
|
||||||
else
|
else
|
||||||
# Prefer /usr/local/bin for system-wide installation
|
# Try to use a directory in PATH, fallback to ~/.local/bin
|
||||||
|
if echo "$PATH" | grep -q "/usr/local/bin"; then
|
||||||
|
if [ -w "/usr/local/bin" ] 2>/dev/null; then
|
||||||
echo "/usr/local/bin"
|
echo "/usr/local/bin"
|
||||||
fi
|
|
||||||
}
|
|
||||||
|
|
||||||
# Check if we need sudo for installation
|
|
||||||
needs_sudo() {
|
|
||||||
local install_dir="$1"
|
|
||||||
if [ -w "$install_dir" ] 2>/dev/null; then
|
|
||||||
return 1 # No sudo needed
|
|
||||||
else
|
else
|
||||||
return 0 # Sudo needed
|
echo "$HOME/.local/bin"
|
||||||
fi
|
|
||||||
}
|
|
||||||
|
|
||||||
# Get the appropriate command prefix (sudo or empty)
|
|
||||||
get_sudo_cmd() {
|
|
||||||
local install_dir="$1"
|
|
||||||
if needs_sudo "$install_dir"; then
|
|
||||||
if command -v sudo >/dev/null 2>&1; then
|
|
||||||
echo "sudo"
|
|
||||||
else
|
|
||||||
print_error "Cannot write to ${install_dir} and sudo is not available."
|
|
||||||
print_error "Please run this script as root or install sudo."
|
|
||||||
exit 1
|
|
||||||
fi
|
fi
|
||||||
else
|
else
|
||||||
echo ""
|
echo "$HOME/.local/bin"
|
||||||
|
fi
|
||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -148,17 +130,14 @@ get_sudo_cmd() {
|
|||||||
install_newt() {
|
install_newt() {
|
||||||
local platform="$1"
|
local platform="$1"
|
||||||
local install_dir="$2"
|
local install_dir="$2"
|
||||||
local sudo_cmd="$3"
|
|
||||||
local binary_name="newt_${platform}"
|
local binary_name="newt_${platform}"
|
||||||
local exe_suffix=""
|
local exe_suffix=""
|
||||||
|
|
||||||
# Add .exe suffix for Windows
|
# Add .exe suffix for Windows
|
||||||
case "$platform" in
|
if [[ "$platform" == *"windows"* ]]; then
|
||||||
*windows*)
|
|
||||||
binary_name="${binary_name}.exe"
|
binary_name="${binary_name}.exe"
|
||||||
exe_suffix=".exe"
|
exe_suffix=".exe"
|
||||||
;;
|
fi
|
||||||
esac
|
|
||||||
|
|
||||||
local download_url="${BASE_URL}/${binary_name}"
|
local download_url="${BASE_URL}/${binary_name}"
|
||||||
local temp_file="/tmp/newt${exe_suffix}"
|
local temp_file="/tmp/newt${exe_suffix}"
|
||||||
@@ -176,18 +155,14 @@ install_newt() {
|
|||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Make executable before moving
|
|
||||||
chmod +x "$temp_file"
|
|
||||||
|
|
||||||
# Create install directory if it doesn't exist
|
# Create install directory if it doesn't exist
|
||||||
if [ -n "$sudo_cmd" ]; then
|
|
||||||
$sudo_cmd mkdir -p "$install_dir"
|
|
||||||
print_status "Using sudo to install to ${install_dir}"
|
|
||||||
$sudo_cmd mv "$temp_file" "$final_path"
|
|
||||||
else
|
|
||||||
mkdir -p "$install_dir"
|
mkdir -p "$install_dir"
|
||||||
|
|
||||||
|
# Move binary to install directory
|
||||||
mv "$temp_file" "$final_path"
|
mv "$temp_file" "$final_path"
|
||||||
fi
|
|
||||||
|
# Make executable (not needed on Windows, but doesn't hurt)
|
||||||
|
chmod +x "$final_path"
|
||||||
|
|
||||||
print_status "newt installed to ${final_path}"
|
print_status "newt installed to ${final_path}"
|
||||||
|
|
||||||
@@ -204,9 +179,9 @@ verify_installation() {
|
|||||||
local install_dir="$1"
|
local install_dir="$1"
|
||||||
local exe_suffix=""
|
local exe_suffix=""
|
||||||
|
|
||||||
case "$PLATFORM" in
|
if [[ "$PLATFORM" == *"windows"* ]]; then
|
||||||
*windows*) exe_suffix=".exe" ;;
|
exe_suffix=".exe"
|
||||||
esac
|
fi
|
||||||
|
|
||||||
local newt_path="${install_dir}/newt${exe_suffix}"
|
local newt_path="${install_dir}/newt${exe_suffix}"
|
||||||
|
|
||||||
@@ -240,19 +215,17 @@ main() {
|
|||||||
INSTALL_DIR=$(get_install_dir)
|
INSTALL_DIR=$(get_install_dir)
|
||||||
print_status "Install directory: ${INSTALL_DIR}"
|
print_status "Install directory: ${INSTALL_DIR}"
|
||||||
|
|
||||||
# Check if we need sudo
|
|
||||||
SUDO_CMD=$(get_sudo_cmd "$INSTALL_DIR")
|
|
||||||
if [ -n "$SUDO_CMD" ]; then
|
|
||||||
print_status "Root privileges required for installation to ${INSTALL_DIR}"
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Install newt
|
# Install newt
|
||||||
install_newt "$PLATFORM" "$INSTALL_DIR" "$SUDO_CMD"
|
install_newt "$PLATFORM" "$INSTALL_DIR"
|
||||||
|
|
||||||
# Verify installation
|
# Verify installation
|
||||||
if verify_installation "$INSTALL_DIR"; then
|
if verify_installation "$INSTALL_DIR"; then
|
||||||
print_status "newt is ready to use!"
|
print_status "newt is ready to use!"
|
||||||
|
if [[ "$PLATFORM" == *"windows"* ]]; then
|
||||||
print_status "Run 'newt --help' to get started"
|
print_status "Run 'newt --help' to get started"
|
||||||
|
else
|
||||||
|
print_status "Run 'newt --help' to get started"
|
||||||
|
fi
|
||||||
else
|
else
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|||||||
@@ -5,9 +5,7 @@ import (
|
|||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@@ -367,12 +365,11 @@ func (m *Monitor) performHealthCheck(target *Target) {
|
|||||||
target.LastCheck = time.Now()
|
target.LastCheck = time.Now()
|
||||||
target.LastError = ""
|
target.LastError = ""
|
||||||
|
|
||||||
// Build URL (use net.JoinHostPort to properly handle IPv6 addresses with ports)
|
// Build URL
|
||||||
host := target.Config.Hostname
|
url := fmt.Sprintf("%s://%s", target.Config.Scheme, target.Config.Hostname)
|
||||||
if target.Config.Port > 0 {
|
if target.Config.Port > 0 {
|
||||||
host = net.JoinHostPort(target.Config.Hostname, strconv.Itoa(target.Config.Port))
|
url = fmt.Sprintf("%s:%d", url, target.Config.Port)
|
||||||
}
|
}
|
||||||
url := fmt.Sprintf("%s://%s", target.Config.Scheme, host)
|
|
||||||
if target.Config.Path != "" {
|
if target.Config.Path != "" {
|
||||||
if !strings.HasPrefix(target.Config.Path, "/") {
|
if !strings.HasPrefix(target.Config.Path, "/") {
|
||||||
url += "/"
|
url += "/"
|
||||||
@@ -524,82 +521,3 @@ func (m *Monitor) DisableTarget(id int) error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetTargetIDs returns a slice of all current target IDs
|
|
||||||
func (m *Monitor) GetTargetIDs() []int {
|
|
||||||
m.mutex.RLock()
|
|
||||||
defer m.mutex.RUnlock()
|
|
||||||
|
|
||||||
ids := make([]int, 0, len(m.targets))
|
|
||||||
for id := range m.targets {
|
|
||||||
ids = append(ids, id)
|
|
||||||
}
|
|
||||||
return ids
|
|
||||||
}
|
|
||||||
|
|
||||||
// SyncTargets synchronizes the current targets to match the desired set.
|
|
||||||
// It removes targets not in the desired set and adds targets that are missing.
|
|
||||||
func (m *Monitor) SyncTargets(desiredConfigs []Config) error {
|
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
|
|
||||||
logger.Info("Syncing health check targets: %d desired targets", len(desiredConfigs))
|
|
||||||
|
|
||||||
// Build a set of desired target IDs
|
|
||||||
desiredIDs := make(map[int]Config)
|
|
||||||
for _, config := range desiredConfigs {
|
|
||||||
desiredIDs[config.ID] = config
|
|
||||||
}
|
|
||||||
|
|
||||||
// Find targets to remove (exist but not in desired set)
|
|
||||||
var toRemove []int
|
|
||||||
for id := range m.targets {
|
|
||||||
if _, exists := desiredIDs[id]; !exists {
|
|
||||||
toRemove = append(toRemove, id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove targets that are not in the desired set
|
|
||||||
for _, id := range toRemove {
|
|
||||||
logger.Info("Sync: removing health check target %d", id)
|
|
||||||
if target, exists := m.targets[id]; exists {
|
|
||||||
target.cancel()
|
|
||||||
delete(m.targets, id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add or update targets from the desired set
|
|
||||||
var addedCount, updatedCount int
|
|
||||||
for id, config := range desiredIDs {
|
|
||||||
if existing, exists := m.targets[id]; exists {
|
|
||||||
// Target exists - check if config changed and update if needed
|
|
||||||
// For now, we'll replace it to ensure config is up to date
|
|
||||||
logger.Debug("Sync: updating health check target %d", id)
|
|
||||||
existing.cancel()
|
|
||||||
delete(m.targets, id)
|
|
||||||
if err := m.addTargetUnsafe(config); err != nil {
|
|
||||||
logger.Error("Sync: failed to update target %d: %v", id, err)
|
|
||||||
return fmt.Errorf("failed to update target %d: %v", id, err)
|
|
||||||
}
|
|
||||||
updatedCount++
|
|
||||||
} else {
|
|
||||||
// Target doesn't exist - add it
|
|
||||||
logger.Debug("Sync: adding health check target %d", id)
|
|
||||||
if err := m.addTargetUnsafe(config); err != nil {
|
|
||||||
logger.Error("Sync: failed to add target %d: %v", id, err)
|
|
||||||
return fmt.Errorf("failed to add target %d: %v", id, err)
|
|
||||||
}
|
|
||||||
addedCount++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Info("Sync complete: removed %d, added %d, updated %d targets",
|
|
||||||
len(toRemove), addedCount, updatedCount)
|
|
||||||
|
|
||||||
// Notify callback if any changes were made
|
|
||||||
if (len(toRemove) > 0 || addedCount > 0 || updatedCount > 0) && m.callback != nil {
|
|
||||||
go m.callback(m.getAllTargetsUnsafe())
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|||||||
193
main.go
193
main.go
@@ -10,7 +10,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/pprof"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
@@ -148,7 +147,6 @@ var (
|
|||||||
adminAddr string
|
adminAddr string
|
||||||
region string
|
region string
|
||||||
metricsAsyncBytes bool
|
metricsAsyncBytes bool
|
||||||
pprofEnabled bool
|
|
||||||
blueprintFile string
|
blueprintFile string
|
||||||
noCloud bool
|
noCloud bool
|
||||||
|
|
||||||
@@ -227,7 +225,6 @@ func runNewtMain(ctx context.Context) {
|
|||||||
adminAddrEnv := os.Getenv("NEWT_ADMIN_ADDR")
|
adminAddrEnv := os.Getenv("NEWT_ADMIN_ADDR")
|
||||||
regionEnv := os.Getenv("NEWT_REGION")
|
regionEnv := os.Getenv("NEWT_REGION")
|
||||||
asyncBytesEnv := os.Getenv("NEWT_METRICS_ASYNC_BYTES")
|
asyncBytesEnv := os.Getenv("NEWT_METRICS_ASYNC_BYTES")
|
||||||
pprofEnabledEnv := os.Getenv("NEWT_PPROF_ENABLED")
|
|
||||||
|
|
||||||
disableClientsEnv := os.Getenv("DISABLE_CLIENTS")
|
disableClientsEnv := os.Getenv("DISABLE_CLIENTS")
|
||||||
disableClients = disableClientsEnv == "true"
|
disableClients = disableClientsEnv == "true"
|
||||||
@@ -305,10 +302,10 @@ func runNewtMain(ctx context.Context) {
|
|||||||
flag.StringVar(&dockerSocket, "docker-socket", "", "Path or address to Docker socket (typically unix:///var/run/docker.sock)")
|
flag.StringVar(&dockerSocket, "docker-socket", "", "Path or address to Docker socket (typically unix:///var/run/docker.sock)")
|
||||||
}
|
}
|
||||||
if pingIntervalStr == "" {
|
if pingIntervalStr == "" {
|
||||||
flag.StringVar(&pingIntervalStr, "ping-interval", "15s", "Interval for pinging the server (default 15s)")
|
flag.StringVar(&pingIntervalStr, "ping-interval", "3s", "Interval for pinging the server (default 3s)")
|
||||||
}
|
}
|
||||||
if pingTimeoutStr == "" {
|
if pingTimeoutStr == "" {
|
||||||
flag.StringVar(&pingTimeoutStr, "ping-timeout", "7s", " Timeout for each ping (default 7s)")
|
flag.StringVar(&pingTimeoutStr, "ping-timeout", "5s", " Timeout for each ping (default 5s)")
|
||||||
}
|
}
|
||||||
// load the prefer endpoint just as a flag
|
// load the prefer endpoint just as a flag
|
||||||
flag.StringVar(&preferEndpoint, "prefer-endpoint", "", "Prefer this endpoint for the connection (if set, will override the endpoint from the server)")
|
flag.StringVar(&preferEndpoint, "prefer-endpoint", "", "Prefer this endpoint for the connection (if set, will override the endpoint from the server)")
|
||||||
@@ -333,21 +330,21 @@ func runNewtMain(ctx context.Context) {
|
|||||||
if pingIntervalStr != "" {
|
if pingIntervalStr != "" {
|
||||||
pingInterval, err = time.ParseDuration(pingIntervalStr)
|
pingInterval, err = time.ParseDuration(pingIntervalStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("Invalid PING_INTERVAL value: %s, using default 15 seconds\n", pingIntervalStr)
|
fmt.Printf("Invalid PING_INTERVAL value: %s, using default 3 seconds\n", pingIntervalStr)
|
||||||
pingInterval = 15 * time.Second
|
pingInterval = 3 * time.Second
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
pingInterval = 15 * time.Second
|
pingInterval = 3 * time.Second
|
||||||
}
|
}
|
||||||
|
|
||||||
if pingTimeoutStr != "" {
|
if pingTimeoutStr != "" {
|
||||||
pingTimeout, err = time.ParseDuration(pingTimeoutStr)
|
pingTimeout, err = time.ParseDuration(pingTimeoutStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("Invalid PING_TIMEOUT value: %s, using default 7 seconds\n", pingTimeoutStr)
|
fmt.Printf("Invalid PING_TIMEOUT value: %s, using default 5 seconds\n", pingTimeoutStr)
|
||||||
pingTimeout = 7 * time.Second
|
pingTimeout = 5 * time.Second
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
pingTimeout = 7 * time.Second
|
pingTimeout = 5 * time.Second
|
||||||
}
|
}
|
||||||
|
|
||||||
if dockerEnforceNetworkValidation == "" {
|
if dockerEnforceNetworkValidation == "" {
|
||||||
@@ -393,14 +390,6 @@ func runNewtMain(ctx context.Context) {
|
|||||||
metricsAsyncBytes = v
|
metricsAsyncBytes = v
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// pprof debug endpoint toggle
|
|
||||||
if pprofEnabledEnv == "" {
|
|
||||||
flag.BoolVar(&pprofEnabled, "pprof", false, "Enable pprof debug endpoints on admin server")
|
|
||||||
} else {
|
|
||||||
if v, err := strconv.ParseBool(pprofEnabledEnv); err == nil {
|
|
||||||
pprofEnabled = v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Optional region flag (resource attribute)
|
// Optional region flag (resource attribute)
|
||||||
if regionEnv == "" {
|
if regionEnv == "" {
|
||||||
flag.StringVar(®ion, "region", "", "Optional region resource attribute (also NEWT_REGION)")
|
flag.StringVar(®ion, "region", "", "Optional region resource attribute (also NEWT_REGION)")
|
||||||
@@ -496,14 +485,6 @@ func runNewtMain(ctx context.Context) {
|
|||||||
if tel.PrometheusHandler != nil {
|
if tel.PrometheusHandler != nil {
|
||||||
mux.Handle("/metrics", tel.PrometheusHandler)
|
mux.Handle("/metrics", tel.PrometheusHandler)
|
||||||
}
|
}
|
||||||
if pprofEnabled {
|
|
||||||
mux.HandleFunc("/debug/pprof/", pprof.Index)
|
|
||||||
mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
|
|
||||||
mux.HandleFunc("/debug/pprof/profile", pprof.Profile)
|
|
||||||
mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
|
|
||||||
mux.HandleFunc("/debug/pprof/trace", pprof.Trace)
|
|
||||||
logger.Info("pprof debugging enabled on %s/debug/pprof/", tcfg.AdminAddr)
|
|
||||||
}
|
|
||||||
admin := &http.Server{
|
admin := &http.Server{
|
||||||
Addr: tcfg.AdminAddr,
|
Addr: tcfg.AdminAddr,
|
||||||
Handler: otelhttp.NewHandler(mux, "newt-admin"),
|
Handler: otelhttp.NewHandler(mux, "newt-admin"),
|
||||||
@@ -584,7 +565,8 @@ func runNewtMain(ctx context.Context) {
|
|||||||
id, // CLI arg takes precedence
|
id, // CLI arg takes precedence
|
||||||
secret, // CLI arg takes precedence
|
secret, // CLI arg takes precedence
|
||||||
endpoint,
|
endpoint,
|
||||||
30*time.Second,
|
pingInterval,
|
||||||
|
pingTimeout,
|
||||||
opt,
|
opt,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -637,6 +619,8 @@ func runNewtMain(ctx context.Context) {
|
|||||||
var wgData WgData
|
var wgData WgData
|
||||||
var dockerEventMonitor *docker.EventMonitor
|
var dockerEventMonitor *docker.EventMonitor
|
||||||
|
|
||||||
|
logger.Debug("++++++++++++++++++++++ the port is %d", port)
|
||||||
|
|
||||||
if !disableClients {
|
if !disableClients {
|
||||||
setupClients(client)
|
setupClients(client)
|
||||||
}
|
}
|
||||||
@@ -975,7 +959,7 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
|
|||||||
"publicKey": publicKey.String(),
|
"publicKey": publicKey.String(),
|
||||||
"pingResults": pingResults,
|
"pingResults": pingResults,
|
||||||
"newtVersion": newtVersion,
|
"newtVersion": newtVersion,
|
||||||
}, 2*time.Second)
|
}, 1*time.Second)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1078,7 +1062,7 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
|
|||||||
"publicKey": publicKey.String(),
|
"publicKey": publicKey.String(),
|
||||||
"pingResults": pingResults,
|
"pingResults": pingResults,
|
||||||
"newtVersion": newtVersion,
|
"newtVersion": newtVersion,
|
||||||
}, 2*time.Second)
|
}, 1*time.Second)
|
||||||
|
|
||||||
logger.Debug("Sent exit node ping results to cloud for selection: pingResults=%+v", pingResults)
|
logger.Debug("Sent exit node ping results to cloud for selection: pingResults=%+v", pingResults)
|
||||||
})
|
})
|
||||||
@@ -1183,153 +1167,6 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
// Register handler for syncing targets (TCP, UDP, and health checks)
|
|
||||||
client.RegisterHandler("newt/sync", func(msg websocket.WSMessage) {
|
|
||||||
logger.Info("Received sync message")
|
|
||||||
|
|
||||||
// if there is no wgData or pm, we can't sync targets
|
|
||||||
if wgData.TunnelIP == "" || pm == nil {
|
|
||||||
logger.Info(msgNoTunnelOrProxy)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Define the sync data structure
|
|
||||||
type SyncData struct {
|
|
||||||
Targets TargetsByType `json:"targets"`
|
|
||||||
HealthCheckTargets []healthcheck.Config `json:"healthCheckTargets"`
|
|
||||||
}
|
|
||||||
|
|
||||||
var syncData SyncData
|
|
||||||
jsonData, err := json.Marshal(msg.Data)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Error marshaling sync data: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := json.Unmarshal(jsonData, &syncData); err != nil {
|
|
||||||
logger.Error("Error unmarshaling sync data: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Debug("Sync data received: TCP targets=%d, UDP targets=%d, health check targets=%d",
|
|
||||||
len(syncData.Targets.TCP), len(syncData.Targets.UDP), len(syncData.HealthCheckTargets))
|
|
||||||
|
|
||||||
//TODO: TEST AND IMPLEMENT THIS
|
|
||||||
|
|
||||||
// // Build sets of desired targets (port -> target string)
|
|
||||||
// desiredTCP := make(map[int]string)
|
|
||||||
// for _, t := range syncData.Targets.TCP {
|
|
||||||
// parts := strings.Split(t, ":")
|
|
||||||
// if len(parts) != 3 {
|
|
||||||
// logger.Warn("Invalid TCP target format: %s", t)
|
|
||||||
// continue
|
|
||||||
// }
|
|
||||||
// port := 0
|
|
||||||
// if _, err := fmt.Sscanf(parts[0], "%d", &port); err != nil {
|
|
||||||
// logger.Warn("Invalid port in TCP target: %s", parts[0])
|
|
||||||
// continue
|
|
||||||
// }
|
|
||||||
// desiredTCP[port] = parts[1] + ":" + parts[2]
|
|
||||||
// }
|
|
||||||
|
|
||||||
// desiredUDP := make(map[int]string)
|
|
||||||
// for _, t := range syncData.Targets.UDP {
|
|
||||||
// parts := strings.Split(t, ":")
|
|
||||||
// if len(parts) != 3 {
|
|
||||||
// logger.Warn("Invalid UDP target format: %s", t)
|
|
||||||
// continue
|
|
||||||
// }
|
|
||||||
// port := 0
|
|
||||||
// if _, err := fmt.Sscanf(parts[0], "%d", &port); err != nil {
|
|
||||||
// logger.Warn("Invalid port in UDP target: %s", parts[0])
|
|
||||||
// continue
|
|
||||||
// }
|
|
||||||
// desiredUDP[port] = parts[1] + ":" + parts[2]
|
|
||||||
// }
|
|
||||||
|
|
||||||
// // Get current targets from proxy manager
|
|
||||||
// currentTCP, currentUDP := pm.GetTargets()
|
|
||||||
|
|
||||||
// // Sync TCP targets
|
|
||||||
// // Remove TCP targets not in desired set
|
|
||||||
// if tcpForIP, ok := currentTCP[wgData.TunnelIP]; ok {
|
|
||||||
// for port := range tcpForIP {
|
|
||||||
// if _, exists := desiredTCP[port]; !exists {
|
|
||||||
// logger.Info("Sync: removing TCP target on port %d", port)
|
|
||||||
// targetStr := fmt.Sprintf("%d:%s", port, tcpForIP[port])
|
|
||||||
// updateTargets(pm, "remove", wgData.TunnelIP, "tcp", TargetData{Targets: []string{targetStr}})
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// // Add TCP targets that are missing
|
|
||||||
// for port, target := range desiredTCP {
|
|
||||||
// needsAdd := true
|
|
||||||
// if tcpForIP, ok := currentTCP[wgData.TunnelIP]; ok {
|
|
||||||
// if currentTarget, exists := tcpForIP[port]; exists {
|
|
||||||
// // Check if target address changed
|
|
||||||
// if currentTarget == target {
|
|
||||||
// needsAdd = false
|
|
||||||
// } else {
|
|
||||||
// // Target changed, remove old one first
|
|
||||||
// logger.Info("Sync: updating TCP target on port %d", port)
|
|
||||||
// targetStr := fmt.Sprintf("%d:%s", port, currentTarget)
|
|
||||||
// updateTargets(pm, "remove", wgData.TunnelIP, "tcp", TargetData{Targets: []string{targetStr}})
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// if needsAdd {
|
|
||||||
// logger.Info("Sync: adding TCP target on port %d -> %s", port, target)
|
|
||||||
// targetStr := fmt.Sprintf("%d:%s", port, target)
|
|
||||||
// updateTargets(pm, "add", wgData.TunnelIP, "tcp", TargetData{Targets: []string{targetStr}})
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// // Sync UDP targets
|
|
||||||
// // Remove UDP targets not in desired set
|
|
||||||
// if udpForIP, ok := currentUDP[wgData.TunnelIP]; ok {
|
|
||||||
// for port := range udpForIP {
|
|
||||||
// if _, exists := desiredUDP[port]; !exists {
|
|
||||||
// logger.Info("Sync: removing UDP target on port %d", port)
|
|
||||||
// targetStr := fmt.Sprintf("%d:%s", port, udpForIP[port])
|
|
||||||
// updateTargets(pm, "remove", wgData.TunnelIP, "udp", TargetData{Targets: []string{targetStr}})
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// // Add UDP targets that are missing
|
|
||||||
// for port, target := range desiredUDP {
|
|
||||||
// needsAdd := true
|
|
||||||
// if udpForIP, ok := currentUDP[wgData.TunnelIP]; ok {
|
|
||||||
// if currentTarget, exists := udpForIP[port]; exists {
|
|
||||||
// // Check if target address changed
|
|
||||||
// if currentTarget == target {
|
|
||||||
// needsAdd = false
|
|
||||||
// } else {
|
|
||||||
// // Target changed, remove old one first
|
|
||||||
// logger.Info("Sync: updating UDP target on port %d", port)
|
|
||||||
// targetStr := fmt.Sprintf("%d:%s", port, currentTarget)
|
|
||||||
// updateTargets(pm, "remove", wgData.TunnelIP, "udp", TargetData{Targets: []string{targetStr}})
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// if needsAdd {
|
|
||||||
// logger.Info("Sync: adding UDP target on port %d -> %s", port, target)
|
|
||||||
// targetStr := fmt.Sprintf("%d:%s", port, target)
|
|
||||||
// updateTargets(pm, "add", wgData.TunnelIP, "udp", TargetData{Targets: []string{targetStr}})
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// // Sync health check targets
|
|
||||||
// if err := healthMonitor.SyncTargets(syncData.HealthCheckTargets); err != nil {
|
|
||||||
// logger.Error("Failed to sync health check targets: %v", err)
|
|
||||||
// } else {
|
|
||||||
// logger.Info("Successfully synced health check targets")
|
|
||||||
// }
|
|
||||||
|
|
||||||
logger.Info("Sync complete")
|
|
||||||
})
|
|
||||||
|
|
||||||
// Register handler for Docker socket check
|
// Register handler for Docker socket check
|
||||||
client.RegisterHandler("newt/socket/check", func(msg websocket.WSMessage) {
|
client.RegisterHandler("newt/socket/check", func(msg websocket.WSMessage) {
|
||||||
logger.Debug("Received Docker socket check request")
|
logger.Debug("Received Docker socket check request")
|
||||||
@@ -1812,8 +1649,6 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
|
|||||||
pm.Stop()
|
pm.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
client.SendMessage("newt/disconnecting", map[string]any{})
|
|
||||||
|
|
||||||
if client != nil {
|
if client != nil {
|
||||||
client.Close()
|
client.Close()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -48,23 +48,6 @@ type SubnetRule struct {
|
|||||||
PortRanges []PortRange // empty slice means all ports allowed
|
PortRanges []PortRange // empty slice means all ports allowed
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAllRules returns a copy of all subnet rules
|
|
||||||
func (sl *SubnetLookup) GetAllRules() []SubnetRule {
|
|
||||||
sl.mu.RLock()
|
|
||||||
defer sl.mu.RUnlock()
|
|
||||||
|
|
||||||
var rules []SubnetRule
|
|
||||||
for _, destTriePtr := range sl.sourceTrie.All() {
|
|
||||||
if destTriePtr == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
for _, rule := range destTriePtr.rules {
|
|
||||||
rules = append(rules, *rule)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return rules
|
|
||||||
}
|
|
||||||
|
|
||||||
// connKey uniquely identifies a connection for NAT tracking
|
// connKey uniquely identifies a connection for NAT tracking
|
||||||
type connKey struct {
|
type connKey struct {
|
||||||
srcIP string
|
srcIP string
|
||||||
@@ -217,14 +200,6 @@ func (p *ProxyHandler) RemoveSubnetRule(sourcePrefix, destPrefix netip.Prefix) {
|
|||||||
p.subnetLookup.RemoveSubnet(sourcePrefix, destPrefix)
|
p.subnetLookup.RemoveSubnet(sourcePrefix, destPrefix)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAllRules returns all subnet rules from the proxy handler
|
|
||||||
func (p *ProxyHandler) GetAllRules() []SubnetRule {
|
|
||||||
if p == nil || !p.enabled {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return p.subnetLookup.GetAllRules()
|
|
||||||
}
|
|
||||||
|
|
||||||
// LookupDestinationRewrite looks up the rewritten destination for a connection
|
// LookupDestinationRewrite looks up the rewritten destination for a connection
|
||||||
// This is used by TCP/UDP handlers to find the actual target address
|
// This is used by TCP/UDP handlers to find the actual target address
|
||||||
func (p *ProxyHandler) LookupDestinationRewrite(srcIP, dstIP string, dstPort uint16, proto uint8) (netip.Addr, bool) {
|
func (p *ProxyHandler) LookupDestinationRewrite(srcIP, dstIP string, dstPort uint16, proto uint8) (netip.Addr, bool) {
|
||||||
|
|||||||
@@ -369,15 +369,6 @@ func (net *Net) RemoveProxySubnetRule(sourcePrefix, destPrefix netip.Prefix) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetProxySubnetRules returns all subnet rules from the proxy handler
|
|
||||||
func (net *Net) GetProxySubnetRules() []SubnetRule {
|
|
||||||
tun := (*netTun)(net)
|
|
||||||
if tun.proxyHandler != nil {
|
|
||||||
return tun.proxyHandler.GetAllRules()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetProxyHandler returns the proxy handler (for advanced use cases)
|
// GetProxyHandler returns the proxy handler (for advanced use cases)
|
||||||
// Returns nil if proxy is not enabled
|
// Returns nil if proxy is not enabled
|
||||||
func (net *Net) GetProxyHandler() *ProxyHandler {
|
func (net *Net) GetProxyHandler() *ProxyHandler {
|
||||||
|
|||||||
@@ -21,10 +21,7 @@ import (
|
|||||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const errUnsupportedProtoFmt = "unsupported protocol: %s"
|
||||||
errUnsupportedProtoFmt = "unsupported protocol: %s"
|
|
||||||
maxUDPPacketSize = 65507
|
|
||||||
)
|
|
||||||
|
|
||||||
// Target represents a proxy target with its address and port
|
// Target represents a proxy target with its address and port
|
||||||
type Target struct {
|
type Target struct {
|
||||||
@@ -108,10 +105,14 @@ func classifyProxyError(err error) string {
|
|||||||
if errors.Is(err, net.ErrClosed) {
|
if errors.Is(err, net.ErrClosed) {
|
||||||
return "closed"
|
return "closed"
|
||||||
}
|
}
|
||||||
var ne net.Error
|
if ne, ok := err.(net.Error); ok {
|
||||||
if errors.As(err, &ne) && ne.Timeout() {
|
if ne.Timeout() {
|
||||||
return "timeout"
|
return "timeout"
|
||||||
}
|
}
|
||||||
|
if ne.Temporary() {
|
||||||
|
return "temporary"
|
||||||
|
}
|
||||||
|
}
|
||||||
msg := strings.ToLower(err.Error())
|
msg := strings.ToLower(err.Error())
|
||||||
switch {
|
switch {
|
||||||
case strings.Contains(msg, "refused"):
|
case strings.Contains(msg, "refused"):
|
||||||
@@ -436,6 +437,14 @@ func (pm *ProxyManager) Stop() error {
|
|||||||
pm.udpConns = append(pm.udpConns[:i], pm.udpConns[i+1:]...)
|
pm.udpConns = append(pm.udpConns[:i], pm.udpConns[i+1:]...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// // Clear the target maps
|
||||||
|
// for k := range pm.tcpTargets {
|
||||||
|
// delete(pm.tcpTargets, k)
|
||||||
|
// }
|
||||||
|
// for k := range pm.udpTargets {
|
||||||
|
// delete(pm.udpTargets, k)
|
||||||
|
// }
|
||||||
|
|
||||||
// Give active connections a chance to close gracefully
|
// Give active connections a chance to close gracefully
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
@@ -489,7 +498,7 @@ func (pm *ProxyManager) handleTCPProxy(listener net.Listener, targetAddr string)
|
|||||||
if !pm.running {
|
if !pm.running {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if errors.Is(err, net.ErrClosed) {
|
if ne, ok := err.(net.Error); ok && !ne.Temporary() {
|
||||||
logger.Info("TCP listener closed, stopping proxy handler for %v", listener.Addr())
|
logger.Info("TCP listener closed, stopping proxy handler for %v", listener.Addr())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -555,7 +564,7 @@ func (pm *ProxyManager) handleTCPProxy(listener net.Listener, targetAddr string)
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
|
func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
|
||||||
buffer := make([]byte, maxUDPPacketSize) // Max UDP packet size
|
buffer := make([]byte, 65507) // Max UDP packet size
|
||||||
clientConns := make(map[string]*net.UDPConn)
|
clientConns := make(map[string]*net.UDPConn)
|
||||||
var clientsMutex sync.RWMutex
|
var clientsMutex sync.RWMutex
|
||||||
|
|
||||||
@@ -574,7 +583,7 @@ func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check for connection closed conditions
|
// Check for connection closed conditions
|
||||||
if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) {
|
if err == io.EOF || strings.Contains(err.Error(), "use of closed network connection") {
|
||||||
logger.Info("UDP connection closed, stopping proxy handler")
|
logger.Info("UDP connection closed, stopping proxy handler")
|
||||||
|
|
||||||
// Clean up existing client connections
|
// Clean up existing client connections
|
||||||
@@ -653,14 +662,10 @@ func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
|
|||||||
telemetry.IncProxyConnectionEvent(context.Background(), tunnelID, "udp", telemetry.ProxyConnectionClosed)
|
telemetry.IncProxyConnectionEvent(context.Background(), tunnelID, "udp", telemetry.ProxyConnectionClosed)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
buffer := make([]byte, maxUDPPacketSize)
|
buffer := make([]byte, 65507)
|
||||||
for {
|
for {
|
||||||
n, _, err := targetConn.ReadFromUDP(buffer)
|
n, _, err := targetConn.ReadFromUDP(buffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Connection closed is normal during cleanup
|
|
||||||
if errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) {
|
|
||||||
return // defer will handle cleanup, result stays "success"
|
|
||||||
}
|
|
||||||
logger.Error("Error reading from target: %v", err)
|
logger.Error("Error reading from target: %v", err)
|
||||||
result = "failure"
|
result = "failure"
|
||||||
return // defer will handle cleanup
|
return // defer will handle cleanup
|
||||||
@@ -731,28 +736,3 @@ func (pm *ProxyManager) PrintTargets() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetTargets returns a copy of the current TCP and UDP targets
|
|
||||||
// Returns map[listenIP]map[port]targetAddress for both TCP and UDP
|
|
||||||
func (pm *ProxyManager) GetTargets() (tcpTargets map[string]map[int]string, udpTargets map[string]map[int]string) {
|
|
||||||
pm.mutex.RLock()
|
|
||||||
defer pm.mutex.RUnlock()
|
|
||||||
|
|
||||||
tcpTargets = make(map[string]map[int]string)
|
|
||||||
for listenIP, targets := range pm.tcpTargets {
|
|
||||||
tcpTargets[listenIP] = make(map[int]string)
|
|
||||||
for port, targetAddr := range targets {
|
|
||||||
tcpTargets[listenIP][port] = targetAddr
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
udpTargets = make(map[string]map[int]string)
|
|
||||||
for listenIP, targets := range pm.udpTargets {
|
|
||||||
udpTargets[listenIP] = make(map[int]string)
|
|
||||||
for port, targetAddr := range targets {
|
|
||||||
udpTargets[listenIP][port] = targetAddr
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return tcpTargets, udpTargets
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package websocket
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"compress/gzip"
|
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@@ -38,6 +37,7 @@ type Client struct {
|
|||||||
isConnected bool
|
isConnected bool
|
||||||
reconnectMux sync.RWMutex
|
reconnectMux sync.RWMutex
|
||||||
pingInterval time.Duration
|
pingInterval time.Duration
|
||||||
|
pingTimeout time.Duration
|
||||||
onConnect func() error
|
onConnect func() error
|
||||||
onTokenUpdate func(token string)
|
onTokenUpdate func(token string)
|
||||||
writeMux sync.Mutex
|
writeMux sync.Mutex
|
||||||
@@ -47,11 +47,6 @@ type Client struct {
|
|||||||
metricsCtx context.Context
|
metricsCtx context.Context
|
||||||
configNeedsSave bool // Flag to track if config needs to be saved
|
configNeedsSave bool // Flag to track if config needs to be saved
|
||||||
serverVersion string
|
serverVersion string
|
||||||
configVersion int64 // Latest config version received from server
|
|
||||||
configVersionMux sync.RWMutex
|
|
||||||
processingMessage bool // Flag to track if a message is currently being processed
|
|
||||||
processingMux sync.RWMutex // Protects processingMessage
|
|
||||||
processingWg sync.WaitGroup // WaitGroup to wait for message processing to complete
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type ClientOption func(*Client)
|
type ClientOption func(*Client)
|
||||||
@@ -116,7 +111,7 @@ func (c *Client) MetricsContext() context.Context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewClient creates a new websocket client
|
// NewClient creates a new websocket client
|
||||||
func NewClient(clientType string, ID, secret string, endpoint string, pingInterval time.Duration, opts ...ClientOption) (*Client, error) {
|
func NewClient(clientType string, ID, secret string, endpoint string, pingInterval time.Duration, pingTimeout time.Duration, opts ...ClientOption) (*Client, error) {
|
||||||
config := &Config{
|
config := &Config{
|
||||||
ID: ID,
|
ID: ID,
|
||||||
Secret: secret,
|
Secret: secret,
|
||||||
@@ -131,6 +126,7 @@ func NewClient(clientType string, ID, secret string, endpoint string, pingInterv
|
|||||||
reconnectInterval: 3 * time.Second,
|
reconnectInterval: 3 * time.Second,
|
||||||
isConnected: false,
|
isConnected: false,
|
||||||
pingInterval: pingInterval,
|
pingInterval: pingInterval,
|
||||||
|
pingTimeout: pingTimeout,
|
||||||
clientType: clientType,
|
clientType: clientType,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -158,20 +154,6 @@ func (c *Client) GetServerVersion() string {
|
|||||||
return c.serverVersion
|
return c.serverVersion
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetConfigVersion returns the latest config version received from server
|
|
||||||
func (c *Client) GetConfigVersion() int64 {
|
|
||||||
c.configVersionMux.RLock()
|
|
||||||
defer c.configVersionMux.RUnlock()
|
|
||||||
return c.configVersion
|
|
||||||
}
|
|
||||||
|
|
||||||
// setConfigVersion updates the config version
|
|
||||||
func (c *Client) setConfigVersion(version int64) {
|
|
||||||
c.configVersionMux.Lock()
|
|
||||||
defer c.configVersionMux.Unlock()
|
|
||||||
c.configVersion = version
|
|
||||||
}
|
|
||||||
|
|
||||||
// Connect establishes the WebSocket connection
|
// Connect establishes the WebSocket connection
|
||||||
func (c *Client) Connect() error {
|
func (c *Client) Connect() error {
|
||||||
go c.connectWithRetry()
|
go c.connectWithRetry()
|
||||||
@@ -659,37 +641,24 @@ func (c *Client) setupPKCS12TLS() (*tls.Config, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// pingMonitor sends pings at a short interval and triggers reconnect on failure
|
// pingMonitor sends pings at a short interval and triggers reconnect on failure
|
||||||
func (c *Client) sendPing() {
|
func (c *Client) pingMonitor() {
|
||||||
|
ticker := time.NewTicker(c.pingInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-c.done:
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
if c.conn == nil {
|
if c.conn == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Skip ping if a message is currently being processed
|
|
||||||
c.processingMux.RLock()
|
|
||||||
isProcessing := c.processingMessage
|
|
||||||
c.processingMux.RUnlock()
|
|
||||||
if isProcessing {
|
|
||||||
logger.Debug("Skipping ping, message is being processed")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.configVersionMux.RLock()
|
|
||||||
configVersion := c.configVersion
|
|
||||||
c.configVersionMux.RUnlock()
|
|
||||||
|
|
||||||
pingMsg := WSMessage{
|
|
||||||
Type: "newt/ping",
|
|
||||||
Data: map[string]interface{}{},
|
|
||||||
ConfigVersion: configVersion,
|
|
||||||
}
|
|
||||||
|
|
||||||
c.writeMux.Lock()
|
c.writeMux.Lock()
|
||||||
err := c.conn.WriteJSON(pingMsg)
|
err := c.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(c.pingTimeout))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
telemetry.IncWSMessage(c.metricsContext(), "out", "ping")
|
telemetry.IncWSMessage(c.metricsContext(), "out", "ping")
|
||||||
}
|
}
|
||||||
c.writeMux.Unlock()
|
c.writeMux.Unlock()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Check if we're shutting down before logging error and reconnecting
|
// Check if we're shutting down before logging error and reconnecting
|
||||||
select {
|
select {
|
||||||
@@ -704,21 +673,6 @@ func (c *Client) sendPing() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) pingMonitor() {
|
|
||||||
// Send an immediate ping as soon as we connect
|
|
||||||
c.sendPing()
|
|
||||||
|
|
||||||
ticker := time.NewTicker(c.pingInterval)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-c.done:
|
|
||||||
return
|
|
||||||
case <-ticker.C:
|
|
||||||
c.sendPing()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -755,14 +709,11 @@ func (c *Client) readPumpWithDisconnectDetection(started time.Time) {
|
|||||||
disconnectResult = "success"
|
disconnectResult = "success"
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
msgType, p, err := c.conn.ReadMessage()
|
var msg WSMessage
|
||||||
|
err := c.conn.ReadJSON(&msg)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
if msgType == websocket.BinaryMessage {
|
|
||||||
telemetry.IncWSMessage(c.metricsContext(), "in", "binary")
|
|
||||||
} else {
|
|
||||||
telemetry.IncWSMessage(c.metricsContext(), "in", "text")
|
telemetry.IncWSMessage(c.metricsContext(), "in", "text")
|
||||||
}
|
}
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Check if we're shutting down before logging error
|
// Check if we're shutting down before logging error
|
||||||
select {
|
select {
|
||||||
@@ -786,47 +737,9 @@ func (c *Client) readPumpWithDisconnectDetection(started time.Time) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update config version from incoming message
|
|
||||||
var data []byte
|
|
||||||
if msgType == websocket.BinaryMessage {
|
|
||||||
gr, err := gzip.NewReader(bytes.NewReader(p))
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("WebSocket failed to create gzip reader: %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
data, err = io.ReadAll(gr)
|
|
||||||
gr.Close()
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("WebSocket failed to decompress message: %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
data = p
|
|
||||||
}
|
|
||||||
|
|
||||||
var msg WSMessage
|
|
||||||
if err = json.Unmarshal(data, &msg); err != nil {
|
|
||||||
logger.Error("WebSocket failed to parse message: %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
c.setConfigVersion(msg.ConfigVersion)
|
|
||||||
|
|
||||||
c.handlersMux.RLock()
|
c.handlersMux.RLock()
|
||||||
if handler, ok := c.handlers[msg.Type]; ok {
|
if handler, ok := c.handlers[msg.Type]; ok {
|
||||||
// Mark that we're processing a message
|
|
||||||
c.processingMux.Lock()
|
|
||||||
c.processingMessage = true
|
|
||||||
c.processingMux.Unlock()
|
|
||||||
c.processingWg.Add(1)
|
|
||||||
|
|
||||||
handler(msg)
|
handler(msg)
|
||||||
|
|
||||||
// Mark that we're done processing
|
|
||||||
c.processingWg.Done()
|
|
||||||
c.processingMux.Lock()
|
|
||||||
c.processingMessage = false
|
|
||||||
c.processingMux.Unlock()
|
|
||||||
}
|
}
|
||||||
c.handlersMux.RUnlock()
|
c.handlersMux.RUnlock()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,5 +19,4 @@ type TokenResponse struct {
|
|||||||
type WSMessage struct {
|
type WSMessage struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Data interface{} `json:"data"`
|
Data interface{} `json:"data"`
|
||||||
ConfigVersion int64 `json:"configVersion,omitempty"`
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user