Support prefixes sent from server

This commit is contained in:
Owen
2026-03-02 18:11:20 -08:00
parent 19f143fc6a
commit 039ae07b7b

View File

@@ -38,6 +38,7 @@ type WgConfig struct {
type Target struct { type Target struct {
SourcePrefix string `json:"sourcePrefix"` SourcePrefix string `json:"sourcePrefix"`
SourcePrefixes []string `json:"sourcePrefixes"`
DestPrefix string `json:"destPrefix"` DestPrefix string `json:"destPrefix"`
RewriteTo string `json:"rewriteTo,omitempty"` RewriteTo string `json:"rewriteTo,omitempty"`
DisableIcmp bool `json:"disableIcmp,omitempty"` DisableIcmp bool `json:"disableIcmp,omitempty"`
@@ -695,6 +696,19 @@ func (s *WireGuardService) ensureWireguardPeers(peers []Peer) error {
return nil return nil
} }
// resolveSourcePrefixes returns the effective list of source prefixes for a target,
// supporting both the legacy single SourcePrefix field and the new SourcePrefixes array.
// If SourcePrefixes is non-empty it takes precedence; otherwise SourcePrefix is used.
func resolveSourcePrefixes(target Target) []string {
if len(target.SourcePrefixes) > 0 {
return target.SourcePrefixes
}
if target.SourcePrefix != "" {
return []string{target.SourcePrefix}
}
return nil
}
func (s *WireGuardService) ensureTargets(targets []Target) error { func (s *WireGuardService) ensureTargets(targets []Target) error {
if s.tnet == nil { if s.tnet == nil {
// Native interface mode - proxy features not available, skip silently // Native interface mode - proxy features not available, skip silently
@@ -703,11 +717,6 @@ func (s *WireGuardService) ensureTargets(targets []Target) error {
} }
for _, target := range targets { for _, target := range targets {
sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix)
if err != nil {
return fmt.Errorf("invalid CIDR %s: %v", target.SourcePrefix, err)
}
destPrefix, err := netip.ParsePrefix(target.DestPrefix) destPrefix, err := netip.ParsePrefix(target.DestPrefix)
if err != nil { if err != nil {
return fmt.Errorf("invalid CIDR %s: %v", target.DestPrefix, err) return fmt.Errorf("invalid CIDR %s: %v", target.DestPrefix, err)
@@ -722,9 +731,14 @@ func (s *WireGuardService) ensureTargets(targets []Target) error {
}) })
} }
for _, sp := range resolveSourcePrefixes(target) {
sourcePrefix, err := netip.ParsePrefix(sp)
if err != nil {
return fmt.Errorf("invalid CIDR %s: %v", sp, err)
}
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp) s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange)
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange) }
} }
return nil return nil
@@ -1094,12 +1108,6 @@ func (s *WireGuardService) handleAddTarget(msg websocket.WSMessage) {
// Process all targets // Process all targets
for _, target := range targets { for _, target := range targets {
sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix)
if err != nil {
logger.Info("Invalid CIDR %s: %v", target.SourcePrefix, err)
continue
}
destPrefix, err := netip.ParsePrefix(target.DestPrefix) destPrefix, err := netip.ParsePrefix(target.DestPrefix)
if err != nil { if err != nil {
logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err) logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err)
@@ -1115,9 +1123,15 @@ func (s *WireGuardService) handleAddTarget(msg websocket.WSMessage) {
}) })
} }
for _, sp := range resolveSourcePrefixes(target) {
sourcePrefix, err := netip.ParsePrefix(sp)
if err != nil {
logger.Info("Invalid CIDR %s: %v", sp, err)
continue
}
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp) s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange)
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange) }
} }
} }
@@ -1146,21 +1160,21 @@ func (s *WireGuardService) handleRemoveTarget(msg websocket.WSMessage) {
// Process all targets // Process all targets
for _, target := range targets { for _, target := range targets {
sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix)
if err != nil {
logger.Info("Invalid CIDR %s: %v", target.SourcePrefix, err)
continue
}
destPrefix, err := netip.ParsePrefix(target.DestPrefix) destPrefix, err := netip.ParsePrefix(target.DestPrefix)
if err != nil { if err != nil {
logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err) logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err)
continue continue
} }
for _, sp := range resolveSourcePrefixes(target) {
sourcePrefix, err := netip.ParsePrefix(sp)
if err != nil {
logger.Info("Invalid CIDR %s: %v", sp, err)
continue
}
s.tnet.RemoveProxySubnetRule(sourcePrefix, destPrefix) s.tnet.RemoveProxySubnetRule(sourcePrefix, destPrefix)
logger.Info("Removed target subnet %s with destination %s", sp, target.DestPrefix)
logger.Info("Removed target subnet %s with destination %s", target.SourcePrefix, target.DestPrefix) }
} }
} }
@@ -1194,30 +1208,24 @@ func (s *WireGuardService) handleUpdateTarget(msg websocket.WSMessage) {
// Process all update requests // Process all update requests
for _, target := range requests.OldTargets { for _, target := range requests.OldTargets {
sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix)
if err != nil {
logger.Info("Invalid CIDR %s: %v", target.SourcePrefix, err)
continue
}
destPrefix, err := netip.ParsePrefix(target.DestPrefix) destPrefix, err := netip.ParsePrefix(target.DestPrefix)
if err != nil { if err != nil {
logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err) logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err)
continue continue
} }
for _, sp := range resolveSourcePrefixes(target) {
sourcePrefix, err := netip.ParsePrefix(sp)
if err != nil {
logger.Info("Invalid CIDR %s: %v", sp, err)
continue
}
s.tnet.RemoveProxySubnetRule(sourcePrefix, destPrefix) s.tnet.RemoveProxySubnetRule(sourcePrefix, destPrefix)
logger.Info("Removed target subnet %s with destination %s", target.SourcePrefix, target.DestPrefix) logger.Info("Removed target subnet %s with destination %s", sp, target.DestPrefix)
}
} }
for _, target := range requests.NewTargets { for _, target := range requests.NewTargets {
// Now add the new target
sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix)
if err != nil {
logger.Info("Invalid CIDR %s: %v", target.SourcePrefix, err)
continue
}
destPrefix, err := netip.ParsePrefix(target.DestPrefix) destPrefix, err := netip.ParsePrefix(target.DestPrefix)
if err != nil { if err != nil {
logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err) logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err)
@@ -1233,8 +1241,15 @@ func (s *WireGuardService) handleUpdateTarget(msg websocket.WSMessage) {
}) })
} }
for _, sp := range resolveSourcePrefixes(target) {
sourcePrefix, err := netip.ParsePrefix(sp)
if err != nil {
logger.Info("Invalid CIDR %s: %v", sp, err)
continue
}
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp) s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange) logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange)
}
} }
} }