Support prefixes sent from server

This commit is contained in:
Owen
2026-03-02 18:11:20 -08:00
parent 19f143fc6a
commit 039ae07b7b

View File

@@ -37,11 +37,12 @@ type WgConfig struct {
} }
type Target struct { type Target struct {
SourcePrefix string `json:"sourcePrefix"` SourcePrefix string `json:"sourcePrefix"`
DestPrefix string `json:"destPrefix"` SourcePrefixes []string `json:"sourcePrefixes"`
RewriteTo string `json:"rewriteTo,omitempty"` DestPrefix string `json:"destPrefix"`
DisableIcmp bool `json:"disableIcmp,omitempty"` RewriteTo string `json:"rewriteTo,omitempty"`
PortRange []PortRange `json:"portRange,omitempty"` DisableIcmp bool `json:"disableIcmp,omitempty"`
PortRange []PortRange `json:"portRange,omitempty"`
} }
type PortRange struct { type PortRange struct {
@@ -277,7 +278,7 @@ func (s *WireGuardService) StartHolepunch(publicKey string, endpoint string, rel
} }
if relayPort == 0 { if relayPort == 0 {
relayPort = 21820 relayPort = 21820
} }
// Convert websocket.ExitNode to holepunch.ExitNode // Convert websocket.ExitNode to holepunch.ExitNode
@@ -695,6 +696,19 @@ func (s *WireGuardService) ensureWireguardPeers(peers []Peer) error {
return nil return nil
} }
// resolveSourcePrefixes returns the effective list of source prefixes for a target,
// supporting both the legacy single SourcePrefix field and the new SourcePrefixes array.
// If SourcePrefixes is non-empty it takes precedence; otherwise SourcePrefix is used.
func resolveSourcePrefixes(target Target) []string {
if len(target.SourcePrefixes) > 0 {
return target.SourcePrefixes
}
if target.SourcePrefix != "" {
return []string{target.SourcePrefix}
}
return nil
}
func (s *WireGuardService) ensureTargets(targets []Target) error { func (s *WireGuardService) ensureTargets(targets []Target) error {
if s.tnet == nil { if s.tnet == nil {
// Native interface mode - proxy features not available, skip silently // Native interface mode - proxy features not available, skip silently
@@ -703,11 +717,6 @@ func (s *WireGuardService) ensureTargets(targets []Target) error {
} }
for _, target := range targets { for _, target := range targets {
sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix)
if err != nil {
return fmt.Errorf("invalid CIDR %s: %v", target.SourcePrefix, err)
}
destPrefix, err := netip.ParsePrefix(target.DestPrefix) destPrefix, err := netip.ParsePrefix(target.DestPrefix)
if err != nil { if err != nil {
return fmt.Errorf("invalid CIDR %s: %v", target.DestPrefix, err) return fmt.Errorf("invalid CIDR %s: %v", target.DestPrefix, err)
@@ -722,9 +731,14 @@ func (s *WireGuardService) ensureTargets(targets []Target) error {
}) })
} }
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp) for _, sp := range resolveSourcePrefixes(target) {
sourcePrefix, err := netip.ParsePrefix(sp)
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange) if err != nil {
return fmt.Errorf("invalid CIDR %s: %v", sp, err)
}
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange)
}
} }
return nil return nil
@@ -1043,7 +1057,7 @@ func (s *WireGuardService) processPeerBandwidth(publicKey string, rxBytes, txByt
BytesOut: bytesOutMB, BytesOut: bytesOutMB,
} }
} }
return nil return nil
} }
} }
@@ -1094,12 +1108,6 @@ func (s *WireGuardService) handleAddTarget(msg websocket.WSMessage) {
// Process all targets // Process all targets
for _, target := range targets { for _, target := range targets {
sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix)
if err != nil {
logger.Info("Invalid CIDR %s: %v", target.SourcePrefix, err)
continue
}
destPrefix, err := netip.ParsePrefix(target.DestPrefix) destPrefix, err := netip.ParsePrefix(target.DestPrefix)
if err != nil { if err != nil {
logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err) logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err)
@@ -1109,15 +1117,21 @@ func (s *WireGuardService) handleAddTarget(msg websocket.WSMessage) {
var portRanges []netstack2.PortRange var portRanges []netstack2.PortRange
for _, pr := range target.PortRange { for _, pr := range target.PortRange {
portRanges = append(portRanges, netstack2.PortRange{ portRanges = append(portRanges, netstack2.PortRange{
Min: pr.Min, Min: pr.Min,
Max: pr.Max, Max: pr.Max,
Protocol: pr.Protocol, Protocol: pr.Protocol,
}) })
} }
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp) for _, sp := range resolveSourcePrefixes(target) {
sourcePrefix, err := netip.ParsePrefix(sp)
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange) if err != nil {
logger.Info("Invalid CIDR %s: %v", sp, err)
continue
}
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange)
}
} }
} }
@@ -1146,21 +1160,21 @@ func (s *WireGuardService) handleRemoveTarget(msg websocket.WSMessage) {
// Process all targets // Process all targets
for _, target := range targets { for _, target := range targets {
sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix)
if err != nil {
logger.Info("Invalid CIDR %s: %v", target.SourcePrefix, err)
continue
}
destPrefix, err := netip.ParsePrefix(target.DestPrefix) destPrefix, err := netip.ParsePrefix(target.DestPrefix)
if err != nil { if err != nil {
logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err) logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err)
continue continue
} }
s.tnet.RemoveProxySubnetRule(sourcePrefix, destPrefix) for _, sp := range resolveSourcePrefixes(target) {
sourcePrefix, err := netip.ParsePrefix(sp)
logger.Info("Removed target subnet %s with destination %s", target.SourcePrefix, target.DestPrefix) if err != nil {
logger.Info("Invalid CIDR %s: %v", sp, err)
continue
}
s.tnet.RemoveProxySubnetRule(sourcePrefix, destPrefix)
logger.Info("Removed target subnet %s with destination %s", sp, target.DestPrefix)
}
} }
} }
@@ -1194,30 +1208,24 @@ func (s *WireGuardService) handleUpdateTarget(msg websocket.WSMessage) {
// Process all update requests // Process all update requests
for _, target := range requests.OldTargets { for _, target := range requests.OldTargets {
sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix)
if err != nil {
logger.Info("Invalid CIDR %s: %v", target.SourcePrefix, err)
continue
}
destPrefix, err := netip.ParsePrefix(target.DestPrefix) destPrefix, err := netip.ParsePrefix(target.DestPrefix)
if err != nil { if err != nil {
logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err) logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err)
continue continue
} }
s.tnet.RemoveProxySubnetRule(sourcePrefix, destPrefix) for _, sp := range resolveSourcePrefixes(target) {
logger.Info("Removed target subnet %s with destination %s", target.SourcePrefix, target.DestPrefix) sourcePrefix, err := netip.ParsePrefix(sp)
if err != nil {
logger.Info("Invalid CIDR %s: %v", sp, err)
continue
}
s.tnet.RemoveProxySubnetRule(sourcePrefix, destPrefix)
logger.Info("Removed target subnet %s with destination %s", sp, target.DestPrefix)
}
} }
for _, target := range requests.NewTargets { for _, target := range requests.NewTargets {
// Now add the new target
sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix)
if err != nil {
logger.Info("Invalid CIDR %s: %v", target.SourcePrefix, err)
continue
}
destPrefix, err := netip.ParsePrefix(target.DestPrefix) destPrefix, err := netip.ParsePrefix(target.DestPrefix)
if err != nil { if err != nil {
logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err) logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err)
@@ -1227,14 +1235,21 @@ func (s *WireGuardService) handleUpdateTarget(msg websocket.WSMessage) {
var portRanges []netstack2.PortRange var portRanges []netstack2.PortRange
for _, pr := range target.PortRange { for _, pr := range target.PortRange {
portRanges = append(portRanges, netstack2.PortRange{ portRanges = append(portRanges, netstack2.PortRange{
Min: pr.Min, Min: pr.Min,
Max: pr.Max, Max: pr.Max,
Protocol: pr.Protocol, Protocol: pr.Protocol,
}) })
} }
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp) for _, sp := range resolveSourcePrefixes(target) {
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange) sourcePrefix, err := netip.ParsePrefix(sp)
if err != nil {
logger.Info("Invalid CIDR %s: %v", sp, err)
continue
}
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange)
}
} }
} }