[management, client] Add IPv6 overlay support (#5631)

This commit is contained in:
Viktor Liu
2026-05-07 18:33:37 +09:00
committed by GitHub
parent f23aaa9ae7
commit 205ebcfda2
229 changed files with 10155 additions and 2816 deletions

View File

@@ -18,10 +18,11 @@ import (
)
type selectRoute struct {
NetID route.NetID
Network netip.Prefix
Domains domain.List
Selected bool
NetID route.NetID
Network netip.Prefix
Domains domain.List
Selected bool
extraNetworks []netip.Prefix
}
// ListNetworks returns a list of all available networks.
@@ -50,18 +51,32 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro
routesMap := routeMgr.GetClientRoutesWithNetID()
routeSelector := routeMgr.GetRouteSelector()
v6ExitMerged := route.V6ExitMergeSet(routesMap)
var routes []*selectRoute
for id, rt := range routesMap {
if len(rt) == 0 {
continue
}
route := &selectRoute{
// Skip v6 exit nodes that are merged into their v4 counterpart.
if _, ok := v6ExitMerged[id]; ok {
continue
}
r := &selectRoute{
NetID: id,
Network: rt[0].Network,
Domains: rt[0].Domains,
Selected: routeSelector.IsSelected(id),
}
routes = append(routes, route)
// Merge paired v6 exit node prefix into this entry.
v6ID := route.NetID(string(id) + route.V6ExitSuffix)
if _, ok := v6ExitMerged[v6ID]; ok && len(routesMap[v6ID]) > 0 {
r.extraNetworks = []netip.Prefix{routesMap[v6ID][0].Network}
}
routes = append(routes, r)
}
sort.Slice(routes, func(i, j int) bool {
@@ -82,9 +97,13 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro
resolvedDomains := s.statusRecorder.GetResolvedDomainsStates()
var pbRoutes []*proto.Network
for _, route := range routes {
rangeStr := route.Network.String()
for _, extra := range route.extraNetworks {
rangeStr += ", " + extra.String()
}
pbRoute := &proto.Network{
ID: string(route.NetID),
Range: route.Network.String(),
Range: rangeStr,
Domains: route.Domains.ToSafeStringList(),
ResolvedIPs: map[string]*proto.IPList{},
Selected: route.Selected,
@@ -147,7 +166,9 @@ func (s *Server) SelectNetworks(_ context.Context, req *proto.SelectNetworksRequ
routeSelector.SelectAllRoutes()
} else {
routes := toNetIDs(req.GetNetworkIDs())
netIdRoutes := maps.Keys(routeManager.GetClientRoutesWithNetID())
routesMap := routeManager.GetClientRoutesWithNetID()
routes = route.ExpandV6ExitPairs(routes, routesMap)
netIdRoutes := maps.Keys(routesMap)
if err := routeSelector.SelectRoutes(routes, req.GetAppend(), netIdRoutes); err != nil {
return nil, fmt.Errorf("select routes: %w", err)
}
@@ -197,7 +218,9 @@ func (s *Server) DeselectNetworks(_ context.Context, req *proto.SelectNetworksRe
routeSelector.DeselectAllRoutes()
} else {
routes := toNetIDs(req.GetNetworkIDs())
netIdRoutes := maps.Keys(routeManager.GetClientRoutesWithNetID())
routesMap := routeManager.GetClientRoutesWithNetID()
routes = route.ExpandV6ExitPairs(routes, routesMap)
netIdRoutes := maps.Keys(routesMap)
if err := routeSelector.DeselectRoutes(routes, netIdRoutes); err != nil {
return nil, fmt.Errorf("deselect routes: %w", err)
}

View File

@@ -385,6 +385,7 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
config.DisableNotifications = msg.DisableNotifications
config.LazyConnectionEnabled = msg.LazyConnectionEnabled
config.BlockInbound = msg.BlockInbound
config.DisableIPv6 = msg.DisableIpv6
config.EnableSSHRoot = msg.EnableSSHRoot
config.EnableSSHSFTP = msg.EnableSSHSFTP
config.EnableSSHLocalPortForwarding = msg.EnableSSHLocalPortForwarding
@@ -1483,6 +1484,7 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
disableDNS := cfg.DisableDNS
disableClientRoutes := cfg.DisableClientRoutes
disableServerRoutes := cfg.DisableServerRoutes
disableIPv6 := cfg.DisableIPv6
blockLANAccess := cfg.BlockLANAccess
enableSSHRoot := false
@@ -1533,6 +1535,7 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
DisableDns: disableDNS,
DisableClientRoutes: disableClientRoutes,
DisableServerRoutes: disableServerRoutes,
DisableIpv6: disableIPv6,
BlockLanAccess: blockLANAccess,
EnableSSHRoot: enableSSHRoot,
EnableSSHSFTP: enableSSHSFTP,

View File

@@ -71,6 +71,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
disableNotifications := true
lazyConnectionEnabled := true
blockInbound := true
disableIPv6 := true
mtu := int64(1280)
sshJWTCacheTTL := int32(300)
@@ -95,6 +96,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
DisableNotifications: &disableNotifications,
LazyConnectionEnabled: &lazyConnectionEnabled,
BlockInbound: &blockInbound,
DisableIpv6: &disableIPv6,
NatExternalIPs: []string{"1.2.3.4", "5.6.7.8"},
CleanNATExternalIPs: false,
CustomDNSAddress: []byte("1.1.1.1:53"),
@@ -140,6 +142,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
require.Equal(t, disableNotifications, *cfg.DisableNotifications)
require.Equal(t, lazyConnectionEnabled, cfg.LazyConnectionEnabled)
require.Equal(t, blockInbound, cfg.BlockInbound)
require.Equal(t, disableIPv6, cfg.DisableIPv6)
require.Equal(t, []string{"1.2.3.4", "5.6.7.8"}, cfg.NATExternalIPs)
require.Equal(t, "1.1.1.1:53", cfg.CustomDNSAddress)
// IFaceBlackList contains defaults + extras
@@ -189,6 +192,7 @@ func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) {
"DisableNotifications": true,
"LazyConnectionEnabled": true,
"BlockInbound": true,
"DisableIpv6": true,
"NatExternalIPs": true,
"CustomDNSAddress": true,
"ExtraIFaceBlacklist": true,
@@ -247,6 +251,7 @@ func TestCLIFlags_MappedToSetConfig(t *testing.T) {
"disable-firewall": "DisableFirewall",
"block-lan-access": "BlockLanAccess",
"block-inbound": "BlockInbound",
"disable-ipv6": "DisableIpv6",
"enable-lazy-connection": "LazyConnectionEnabled",
"external-ip-map": "NatExternalIPs",
"dns-resolver-address": "CustomDNSAddress",

View File

@@ -24,14 +24,9 @@ func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) (
return nil, err
}
srcAddr, err := s.parseAddress(req.GetSourceIp(), engine)
srcAddr, dstAddr, err := s.resolveTraceAddresses(req.GetSourceIp(), req.GetDestinationIp(), engine)
if err != nil {
return nil, fmt.Errorf("invalid source IP address: %w", err)
}
dstAddr, err := s.parseAddress(req.GetDestinationIp(), engine)
if err != nil {
return nil, fmt.Errorf("invalid destination IP address: %w", err)
return nil, err
}
protocol, err := s.parseProtocol(req.GetProtocol())
@@ -89,16 +84,73 @@ func (s *Server) getPacketTracer() (packetTracer, *internal.Engine, error) {
return tracer, engine, nil
}
func (s *Server) parseAddress(addr string, engine *internal.Engine) (netip.Addr, error) {
if addr == "self" {
return engine.GetWgAddr(), nil
// resolveTraceAddresses parses src/dst, resolving "self" to the local overlay
// address matching the peer's address family.
func (s *Server) resolveTraceAddresses(src, dst string, engine *internal.Engine) (netip.Addr, netip.Addr, error) {
srcSelf := src == "self"
dstSelf := dst == "self"
if srcSelf && dstSelf {
return netip.Addr{}, netip.Addr{}, fmt.Errorf("both source and destination cannot be 'self'")
}
var srcAddr, dstAddr netip.Addr
var err error
// Parse the non-self address first so we know the family for self resolution.
if !srcSelf {
if srcAddr, err = parseAddr(src); err != nil {
return netip.Addr{}, netip.Addr{}, fmt.Errorf("invalid source IP: %w", err)
}
}
if !dstSelf {
if dstAddr, err = parseAddr(dst); err != nil {
return netip.Addr{}, netip.Addr{}, fmt.Errorf("invalid destination IP: %w", err)
}
}
// Determine the peer address to pick the right self address.
peer := srcAddr
if srcSelf {
peer = dstAddr
}
if srcSelf {
if srcAddr, err = selfAddr(engine, peer); err != nil {
return netip.Addr{}, netip.Addr{}, err
}
}
if dstSelf {
if dstAddr, err = selfAddr(engine, peer); err != nil {
return netip.Addr{}, netip.Addr{}, err
}
}
return srcAddr, dstAddr, nil
}
func selfAddr(engine *internal.Engine, peer netip.Addr) (netip.Addr, error) {
var addr netip.Addr
if peer.Is6() {
addr = engine.GetWgV6Addr()
} else {
addr = engine.GetWgAddr()
}
if !addr.IsValid() {
family := "IPv4"
if peer.Is6() {
family = "IPv6"
}
return netip.Addr{}, fmt.Errorf("no local %s overlay address configured", family)
}
return addr, nil
}
func parseAddr(addr string) (netip.Addr, error) {
a, err := netip.ParseAddr(addr)
if err != nil {
return netip.Addr{}, err
}
return a.Unmap(), nil
}