diff --git a/clients/clients.go b/clients/clients.go index b7065fa..dff5025 100644 --- a/clients/clients.go +++ b/clients/clients.go @@ -37,11 +37,12 @@ type WgConfig struct { } type Target struct { - SourcePrefix string `json:"sourcePrefix"` - DestPrefix string `json:"destPrefix"` - RewriteTo string `json:"rewriteTo,omitempty"` - DisableIcmp bool `json:"disableIcmp,omitempty"` - PortRange []PortRange `json:"portRange,omitempty"` + SourcePrefix string `json:"sourcePrefix"` + SourcePrefixes []string `json:"sourcePrefixes"` + DestPrefix string `json:"destPrefix"` + RewriteTo string `json:"rewriteTo,omitempty"` + DisableIcmp bool `json:"disableIcmp,omitempty"` + PortRange []PortRange `json:"portRange,omitempty"` } type PortRange struct { @@ -277,7 +278,7 @@ func (s *WireGuardService) StartHolepunch(publicKey string, endpoint string, rel } if relayPort == 0 { - relayPort = 21820 + relayPort = 21820 } // Convert websocket.ExitNode to holepunch.ExitNode @@ -695,6 +696,19 @@ func (s *WireGuardService) ensureWireguardPeers(peers []Peer) error { return nil } +// resolveSourcePrefixes returns the effective list of source prefixes for a target, +// supporting both the legacy single SourcePrefix field and the new SourcePrefixes array. +// If SourcePrefixes is non-empty it takes precedence; otherwise SourcePrefix is used. +func resolveSourcePrefixes(target Target) []string { + if len(target.SourcePrefixes) > 0 { + return target.SourcePrefixes + } + if target.SourcePrefix != "" { + return []string{target.SourcePrefix} + } + return nil +} + func (s *WireGuardService) ensureTargets(targets []Target) error { if s.tnet == nil { // Native interface mode - proxy features not available, skip silently @@ -703,11 +717,6 @@ func (s *WireGuardService) ensureTargets(targets []Target) error { } for _, target := range targets { - sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix) - if err != nil { - return fmt.Errorf("invalid CIDR %s: %v", target.SourcePrefix, err) - } - destPrefix, err := netip.ParsePrefix(target.DestPrefix) if err != nil { return fmt.Errorf("invalid CIDR %s: %v", target.DestPrefix, err) @@ -722,9 +731,14 @@ func (s *WireGuardService) ensureTargets(targets []Target) error { }) } - s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp) - - logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange) + for _, sp := range resolveSourcePrefixes(target) { + sourcePrefix, err := netip.ParsePrefix(sp) + if err != nil { + return fmt.Errorf("invalid CIDR %s: %v", sp, err) + } + s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp) + logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange) + } } return nil @@ -1043,7 +1057,7 @@ func (s *WireGuardService) processPeerBandwidth(publicKey string, rxBytes, txByt BytesOut: bytesOutMB, } } - + return nil } } @@ -1094,12 +1108,6 @@ func (s *WireGuardService) handleAddTarget(msg websocket.WSMessage) { // Process all targets for _, target := range targets { - sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix) - if err != nil { - logger.Info("Invalid CIDR %s: %v", target.SourcePrefix, err) - continue - } - destPrefix, err := netip.ParsePrefix(target.DestPrefix) if err != nil { logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err) @@ -1109,15 +1117,21 @@ func (s *WireGuardService) handleAddTarget(msg websocket.WSMessage) { var portRanges []netstack2.PortRange for _, pr := range target.PortRange { portRanges = append(portRanges, netstack2.PortRange{ - Min: pr.Min, - Max: pr.Max, - Protocol: pr.Protocol, + Min: pr.Min, + Max: pr.Max, + Protocol: pr.Protocol, }) } - s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp) - - logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange) + for _, sp := range resolveSourcePrefixes(target) { + sourcePrefix, err := netip.ParsePrefix(sp) + if err != nil { + logger.Info("Invalid CIDR %s: %v", sp, err) + continue + } + s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp) + logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange) + } } } @@ -1146,21 +1160,21 @@ func (s *WireGuardService) handleRemoveTarget(msg websocket.WSMessage) { // Process all targets for _, target := range targets { - sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix) - if err != nil { - logger.Info("Invalid CIDR %s: %v", target.SourcePrefix, err) - continue - } - destPrefix, err := netip.ParsePrefix(target.DestPrefix) if err != nil { logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err) continue } - s.tnet.RemoveProxySubnetRule(sourcePrefix, destPrefix) - - logger.Info("Removed target subnet %s with destination %s", target.SourcePrefix, target.DestPrefix) + for _, sp := range resolveSourcePrefixes(target) { + sourcePrefix, err := netip.ParsePrefix(sp) + if err != nil { + logger.Info("Invalid CIDR %s: %v", sp, err) + continue + } + s.tnet.RemoveProxySubnetRule(sourcePrefix, destPrefix) + logger.Info("Removed target subnet %s with destination %s", sp, target.DestPrefix) + } } } @@ -1194,30 +1208,24 @@ func (s *WireGuardService) handleUpdateTarget(msg websocket.WSMessage) { // Process all update requests for _, target := range requests.OldTargets { - sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix) - if err != nil { - logger.Info("Invalid CIDR %s: %v", target.SourcePrefix, err) - continue - } - destPrefix, err := netip.ParsePrefix(target.DestPrefix) if err != nil { logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err) continue } - s.tnet.RemoveProxySubnetRule(sourcePrefix, destPrefix) - logger.Info("Removed target subnet %s with destination %s", target.SourcePrefix, target.DestPrefix) + for _, sp := range resolveSourcePrefixes(target) { + sourcePrefix, err := netip.ParsePrefix(sp) + if err != nil { + logger.Info("Invalid CIDR %s: %v", sp, err) + continue + } + s.tnet.RemoveProxySubnetRule(sourcePrefix, destPrefix) + logger.Info("Removed target subnet %s with destination %s", sp, target.DestPrefix) + } } for _, target := range requests.NewTargets { - // Now add the new target - sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix) - if err != nil { - logger.Info("Invalid CIDR %s: %v", target.SourcePrefix, err) - continue - } - destPrefix, err := netip.ParsePrefix(target.DestPrefix) if err != nil { logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err) @@ -1227,14 +1235,21 @@ func (s *WireGuardService) handleUpdateTarget(msg websocket.WSMessage) { var portRanges []netstack2.PortRange for _, pr := range target.PortRange { portRanges = append(portRanges, netstack2.PortRange{ - Min: pr.Min, - Max: pr.Max, - Protocol: pr.Protocol, + Min: pr.Min, + Max: pr.Max, + Protocol: pr.Protocol, }) } - s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp) - logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange) + for _, sp := range resolveSourcePrefixes(target) { + sourcePrefix, err := netip.ParsePrefix(sp) + if err != nil { + logger.Info("Invalid CIDR %s: %v", sp, err) + continue + } + s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp) + logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange) + } } }