From 1c4e5e71d7027815b67b1539705a86a4cf217b7a Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 9 Apr 2026 16:56:08 +0800 Subject: [PATCH] [client] Add IPv6 support to ACL manager, USP filter, and forwarder (#5688) --- client/android/client.go | 87 +++-- client/android/route_command.go | 7 +- client/anonymize/anonymize.go | 11 +- client/anonymize/anonymize_test.go | 14 +- client/cmd/ssh.go | 4 +- client/cmd/ssh_test.go | 4 +- client/firewall/iptables/acl_linux.go | 31 +- client/firewall/iptables/manager_linux.go | 236 ++++++++++-- client/firewall/iptables/router_linux.go | 68 +++- client/firewall/iptables/rule.go | 1 + client/firewall/iptables/state_linux.go | 30 ++ client/firewall/nftables/acl_linux.go | 44 +-- client/firewall/nftables/addr_family_linux.go | 81 ++++ client/firewall/nftables/manager_linux.go | 287 +++++++++++++-- .../firewall/nftables/manager_linux_test.go | 124 +++++++ client/firewall/nftables/router_linux.go | 165 +++++---- client/firewall/nftables/router_linux_test.go | 189 +++++++++- .../uspfilter/allow_netbird_windows.go | 53 ++- client/firewall/uspfilter/conntrack/common.go | 7 +- client/firewall/uspfilter/conntrack/icmp.go | 83 +++-- client/firewall/uspfilter/filter.go | 257 ++++++++++--- .../firewall/uspfilter/filter_bench_test.go | 9 +- .../firewall/uspfilter/filter_filter_test.go | 345 ++++++++++++++++-- client/firewall/uspfilter/filter_test.go | 75 +++- .../firewall/uspfilter/forwarder/endpoint.go | 25 +- .../firewall/uspfilter/forwarder/forwarder.go | 228 ++++++++++-- client/firewall/uspfilter/forwarder/icmp.go | 218 +++++++++-- client/firewall/uspfilter/forwarder/tcp.go | 18 +- client/firewall/uspfilter/forwarder/udp.go | 17 +- client/firewall/uspfilter/localip.go | 135 ++----- .../firewall/uspfilter/localip_bench_test.go | 72 ++++ client/firewall/uspfilter/localip_test.go | 124 ++----- client/firewall/uspfilter/nat.go | 159 ++++++-- client/firewall/uspfilter/nat_bench_test.go | 22 +- client/firewall/uspfilter/nat_test.go | 11 +- client/firewall/uspfilter/tracer.go | 102 ++++-- client/iface/configurer/usp.go | 2 +- client/iface/device/adapter.go | 2 +- client/iface/device/device_android.go | 2 +- client/iface/wgproxy/bind/proxy.go | 22 +- client/internal/acl/manager.go | 58 +-- client/internal/debug/debug_test.go | 49 ++- client/internal/dns/service_listener.go | 10 +- client/internal/dns/upstream.go | 7 + client/internal/dns/upstream_android.go | 2 +- client/internal/dns/upstream_general.go | 2 +- client/internal/dns/upstream_ios.go | 57 ++- client/internal/dnsfwd/manager.go | 1 + client/internal/ebpf/ebpf/dns_fwd_linux.go | 15 +- client/internal/ebpf/manager/manager.go | 4 +- client/internal/engine.go | 49 ++- client/internal/engine_ssh.go | 28 +- client/internal/engine_test.go | 16 +- .../lazyconn/activity/listener_bind.go | 23 +- client/internal/peer/status.go | 6 +- client/internal/routemanager/client/client.go | 5 +- .../routemanager/dnsinterceptor/handler.go | 2 +- client/internal/routemanager/dynamic/route.go | 4 +- .../routemanager/dynamic/route_ios.go | 46 ++- client/internal/routemanager/fakeip/fakeip.go | 144 +++++--- .../routemanager/fakeip/fakeip_test.go | 169 ++++++--- .../routemanager/ipfwdstate/ipfwdstate.go | 6 +- client/internal/routemanager/manager.go | 20 +- client/internal/routemanager/server/server.go | 3 +- .../routemanager/systemops/systemops.go | 7 +- .../systemops/systemops_generic.go | 70 ++-- .../routemanager/systemops/systemops_linux.go | 23 +- client/ios/NetBirdSDK/client.go | 74 ++-- client/server/network.go | 42 ++- client/ui/network.go | 2 +- client/wasm/cmd/main.go | 115 ++++-- client/wasm/internal/ssh/client.go | 15 +- proxy/cmd/proxy/cmd/debug.go | 21 +- proxy/internal/debug/client.go | 22 +- proxy/internal/debug/handler.go | 18 +- route/route.go | 61 ++++ route/route_test.go | 108 ++++++ shared/relay/client/dialer/quic/quic.go | 2 +- 78 files changed, 3606 insertions(+), 1071 deletions(-) create mode 100644 client/firewall/nftables/addr_family_linux.go create mode 100644 client/firewall/uspfilter/localip_bench_test.go create mode 100644 route/route_test.go diff --git a/client/android/client.go b/client/android/client.go index 70ebc0011..a8766afd2 100644 --- a/client/android/client.go +++ b/client/android/client.go @@ -238,43 +238,84 @@ func (c *Client) Networks() *NetworkArray { return nil } + routesMap := routeManager.GetClientRoutesWithNetID() + v6Merged := route.V6ExitMergeSet(routesMap) + resolvedDomains := c.recorder.GetResolvedDomainsStates() + networkArray := &NetworkArray{ items: make([]Network, 0), } - resolvedDomains := c.recorder.GetResolvedDomainsStates() - - for id, routes := range routeManager.GetClientRoutesWithNetID() { + for id, routes := range routesMap { if len(routes) == 0 { continue } - - r := routes[0] - domains := c.getNetworkDomainsFromRoute(r, resolvedDomains) - netStr := r.Network.String() - - if r.IsDynamic() { - netStr = r.Domains.SafeString() - } - - routePeer, err := c.recorder.GetPeer(routes[0].Peer) - if err != nil { - log.Errorf("could not get peer info for %s: %v", routes[0].Peer, err) + if _, skip := v6Merged[id]; skip { continue } - network := Network{ - Name: string(id), - Network: netStr, - Peer: routePeer.FQDN, - Status: routePeer.ConnStatus.String(), - IsSelected: routeSelector.IsSelected(id), - Domains: domains, + + network := c.buildNetwork(id, routes, routeSelector.IsSelected(id), resolvedDomains, v6Merged) + if network == nil { + continue } - networkArray.Add(network) + networkArray.Add(*network) } return networkArray } +func (c *Client) buildNetwork(id route.NetID, routes []*route.Route, selected bool, resolvedDomains map[domain.Domain]peer.ResolvedDomainInfo, v6Merged map[route.NetID]struct{}) *Network { + r := routes[0] + netStr := r.Network.String() + if r.IsDynamic() { + netStr = r.Domains.SafeString() + } + + routePeer, err := c.findBestRoutePeer(routes) + if err != nil { + log.Errorf("could not get peer info for route %s: %v", id, err) + return nil + } + + network := &Network{ + Name: string(id), + Network: netStr, + Peer: routePeer.FQDN, + Status: routePeer.ConnStatus.String(), + IsSelected: selected, + Domains: c.getNetworkDomainsFromRoute(r, resolvedDomains), + } + + if route.IsV4DefaultRoute(r.Network) && route.HasV6ExitPair(id, v6Merged) { + network.Network = "0.0.0.0/0, ::/0" + } + + return network +} + +// findBestRoutePeer returns the peer actively routing traffic for the given +// HA route group. Falls back to the first connected peer, then the first peer. +func (c *Client) findBestRoutePeer(routes []*route.Route) (peer.State, error) { + netStr := routes[0].Network.String() + + fullStatus := c.recorder.GetFullStatus() + for _, p := range fullStatus.Peers { + if _, ok := p.GetRoutes()[netStr]; ok { + return p, nil + } + } + + for _, r := range routes { + p, err := c.recorder.GetPeer(r.Peer) + if err != nil { + continue + } + if p.ConnStatus == peer.StatusConnected { + return p, nil + } + } + return c.recorder.GetPeer(routes[0].Peer) +} + // OnUpdatedHostDNS update the DNS servers addresses for root zones func (c *Client) OnUpdatedHostDNS(list *DNSList) error { dnsServer, err := dns.GetServerDns() diff --git a/client/android/route_command.go b/client/android/route_command.go index b47d5ca6c..5e7357335 100644 --- a/client/android/route_command.go +++ b/client/android/route_command.go @@ -18,9 +18,12 @@ func executeRouteToggle(id string, manager routemanager.Manager, netID := route.NetID(id) routes := []route.NetID{netID} - log.Debugf("%s with id: %s", operationName, id) + routesMap := manager.GetClientRoutesWithNetID() + routes = route.ExpandV6ExitPairs(routes, routesMap) - if err := routeOperation(routes, maps.Keys(manager.GetClientRoutesWithNetID())); err != nil { + log.Debugf("%s with ids: %v", operationName, routes) + + if err := routeOperation(routes, maps.Keys(routesMap)); err != nil { log.Debugf("error when %s: %s", operationName, err) return fmt.Errorf("error %s: %w", operationName, err) } diff --git a/client/anonymize/anonymize.go b/client/anonymize/anonymize.go index 89e653300..b7b6a20dd 100644 --- a/client/anonymize/anonymize.go +++ b/client/anonymize/anonymize.go @@ -9,6 +9,7 @@ import ( "net/url" "regexp" "slices" + "strconv" "strings" ) @@ -26,8 +27,9 @@ type Anonymizer struct { } func DefaultAddresses() (netip.Addr, netip.Addr) { - // 198.51.100.0, 100:: - return netip.AddrFrom4([4]byte{198, 51, 100, 0}), netip.AddrFrom16([16]byte{0x01}) + // 198.51.100.0 (RFC 5737 TEST-NET-2), 2001:db8:ffff:: (RFC 3849 documentation, last /48) + // The old start 100:: (discard, RFC 6666) is now used for fake IPs on Android. + return netip.AddrFrom4([4]byte{198, 51, 100, 0}), netip.MustParseAddr("2001:db8:ffff::") } func NewAnonymizer(startIPv4, startIPv6 netip.Addr) *Anonymizer { @@ -96,6 +98,11 @@ func (a *Anonymizer) isInAnonymizedRange(ip netip.Addr) bool { } func (a *Anonymizer) AnonymizeIPString(ip string) string { + // Handle CIDR notation (e.g. "2001:db8::/32") + if prefix, err := netip.ParsePrefix(ip); err == nil { + return a.AnonymizeIP(prefix.Addr()).String() + "/" + strconv.Itoa(prefix.Bits()) + } + addr, err := netip.ParseAddr(ip) if err != nil { return ip diff --git a/client/anonymize/anonymize_test.go b/client/anonymize/anonymize_test.go index ff2e48869..45e205834 100644 --- a/client/anonymize/anonymize_test.go +++ b/client/anonymize/anonymize_test.go @@ -13,7 +13,7 @@ import ( func TestAnonymizeIP(t *testing.T) { startIPv4 := netip.MustParseAddr("198.51.100.0") - startIPv6 := netip.MustParseAddr("100::") + startIPv6 := netip.MustParseAddr("2001:db8:ffff::") anonymizer := anonymize.NewAnonymizer(startIPv4, startIPv6) tests := []struct { @@ -26,9 +26,9 @@ func TestAnonymizeIP(t *testing.T) { {"Second Public IPv4", "4.3.2.1", "198.51.100.1"}, {"Repeated IPv4", "1.2.3.4", "198.51.100.0"}, {"Private IPv4", "192.168.1.1", "192.168.1.1"}, - {"First Public IPv6", "2607:f8b0:4005:805::200e", "100::"}, - {"Second Public IPv6", "a::b", "100::1"}, - {"Repeated IPv6", "2607:f8b0:4005:805::200e", "100::"}, + {"First Public IPv6", "2607:f8b0:4005:805::200e", "2001:db8:ffff::"}, + {"Second Public IPv6", "a::b", "2001:db8:ffff::1"}, + {"Repeated IPv6", "2607:f8b0:4005:805::200e", "2001:db8:ffff::"}, {"Private IPv6", "fe80::1", "fe80::1"}, {"In Range IPv4", "198.51.100.2", "198.51.100.2"}, } @@ -274,17 +274,17 @@ func TestAnonymizeString_IPAddresses(t *testing.T) { { name: "IPv6 Address", input: "Access attempted from 2001:db8::ff00:42", - expect: "Access attempted from 100::", + expect: "Access attempted from 2001:db8:ffff::", }, { name: "IPv6 Address with Port", input: "Access attempted from [2001:db8::ff00:42]:8080", - expect: "Access attempted from [100::]:8080", + expect: "Access attempted from [2001:db8:ffff::]:8080", }, { name: "Both IPv4 and IPv6", input: "IPv4: 142.108.0.1 and IPv6: 2001:db8::ff00:43", - expect: "IPv4: 198.51.100.1 and IPv6: 100::1", + expect: "IPv4: 198.51.100.1 and IPv6: 2001:db8:ffff::1", }, } diff --git a/client/cmd/ssh.go b/client/cmd/ssh.go index 0acf0b133..de5150b1f 100644 --- a/client/cmd/ssh.go +++ b/client/cmd/ssh.go @@ -787,10 +787,10 @@ func isUnixSocket(path string) bool { return strings.HasPrefix(path, "/") || strings.HasPrefix(path, "./") } -// normalizeLocalHost converts "*" to "0.0.0.0" for binding to all interfaces. +// normalizeLocalHost converts "*" to "" for binding to all interfaces (dual-stack). func normalizeLocalHost(host string) string { if host == "*" { - return "0.0.0.0" + return "" } return host } diff --git a/client/cmd/ssh_test.go b/client/cmd/ssh_test.go index 43291fa87..16ffadb90 100644 --- a/client/cmd/ssh_test.go +++ b/client/cmd/ssh_test.go @@ -527,10 +527,10 @@ func TestParsePortForward(t *testing.T) { { name: "wildcard bind all interfaces", spec: "*:8080:localhost:80", - expectedLocal: "0.0.0.0:8080", + expectedLocal: ":8080", expectedRemote: "localhost:80", expectError: false, - description: "Wildcard * should bind to all interfaces (0.0.0.0)", + description: "Wildcard * should bind to all interfaces (dual-stack)", }, { name: "wildcard for port only", diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go index d83798f09..4740c4127 100644 --- a/client/firewall/iptables/acl_linux.go +++ b/client/firewall/iptables/acl_linux.go @@ -36,6 +36,7 @@ type aclManager struct { entries aclEntries optionalEntries map[string][]entry ipsetStore *ipsetStore + v6 bool stateManager *statemanager.Manager } @@ -47,6 +48,7 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*acl entries: make(map[string][][]string), optionalEntries: make(map[string][]entry), ipsetStore: newIpsetStore(), + v6: iptablesClient.Proto() == iptables.ProtocolIPv6, }, nil } @@ -81,7 +83,11 @@ func (m *aclManager) AddPeerFiltering( chain := chainNameInputRules ipsetName = transformIPsetName(ipsetName, sPort, dPort, action) - specs := filterRuleSpecs(ip, string(protocol), sPort, dPort, action, ipsetName) + if m.v6 && ipsetName != "" { + ipsetName += "-v6" + } + proto := protoForFamily(protocol, m.v6) + specs := filterRuleSpecs(ip, proto, sPort, dPort, action, ipsetName) mangleSpecs := slices.Clone(specs) mangleSpecs = append(mangleSpecs, @@ -105,6 +111,7 @@ func (m *aclManager) AddPeerFiltering( ip: ip.String(), chain: chain, specs: specs, + v6: m.v6, }}, nil } @@ -157,6 +164,7 @@ func (m *aclManager) AddPeerFiltering( ipsetName: ipsetName, ip: ip.String(), chain: chain, + v6: m.v6, } m.updateState() @@ -376,8 +384,13 @@ func (m *aclManager) updateState() { currentState.Lock() defer currentState.Unlock() - currentState.ACLEntries = m.entries - currentState.ACLIPsetStore = m.ipsetStore + if m.v6 { + currentState.ACLEntries6 = m.entries + currentState.ACLIPsetStore6 = m.ipsetStore + } else { + currentState.ACLEntries = m.entries + currentState.ACLIPsetStore = m.ipsetStore + } if err := m.stateManager.UpdateState(currentState); err != nil { log.Errorf("failed to update state: %v", err) @@ -385,6 +398,15 @@ func (m *aclManager) updateState() { } // filterRuleSpecs returns the specs of a filtering rule +// protoForFamily translates ICMP to ICMPv6 for ip6tables. +// ip6tables requires "ipv6-icmp" (or "icmpv6") instead of "icmp". +func protoForFamily(protocol firewall.Protocol, v6 bool) string { + if v6 && protocol == firewall.ProtocolICMP { + return "ipv6-icmp" + } + return string(protocol) +} + func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) { // don't use IP matching if IP is 0.0.0.0 matchByIP := !ip.IsUnspecified() @@ -437,6 +459,9 @@ func (m *aclManager) createIPSet(name string) error { opts := ipset.CreateOptions{ Replace: true, } + if m.v6 { + opts.Family = ipset.FamilyIPV6 + } if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil { return fmt.Errorf("create ipset %s: %w", name, err) diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index 2fc6f8ec8..c278924f2 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -17,6 +17,10 @@ import ( "github.com/netbirdio/netbird/client/internal/statemanager" ) +type resetter interface { + Reset() error +} + // Manager of iptables firewall type Manager struct { mutex sync.Mutex @@ -27,6 +31,11 @@ type Manager struct { aclMgr *aclManager router *router rawSupported bool + + // IPv6 counterparts, nil when no v6 overlay + ipv6Client *iptables.IPTables + aclMgr6 *aclManager + router6 *router } // iFaceMapper defines subset methods of interface required for manager @@ -58,9 +67,43 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) { return nil, fmt.Errorf("create acl manager: %w", err) } + if wgIface.Address().HasIPv6() { + if err := m.createIPv6Components(wgIface, mtu); err != nil { + return nil, fmt.Errorf("create IPv6 firewall: %w", err) + } + } + return m, nil } +func (m *Manager) createIPv6Components(wgIface iFaceMapper, mtu uint16) error { + ip6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6) + if err != nil { + return fmt.Errorf("init ip6tables: %w", err) + } + m.ipv6Client = ip6Client + + m.router6, err = newRouter(ip6Client, wgIface, mtu) + if err != nil { + return fmt.Errorf("create v6 router: %w", err) + } + + // Share the same IP forwarding state with the v4 router, since + // EnableIPForwarding controls both v4 and v6 sysctls. + m.router6.ipFwdState = m.router.ipFwdState + + m.aclMgr6, err = newAclManager(ip6Client, wgIface) + if err != nil { + return fmt.Errorf("create v6 acl manager: %w", err) + } + + return nil +} + +func (m *Manager) hasIPv6() bool { + return m.ipv6Client != nil +} + func (m *Manager) Init(stateManager *statemanager.Manager) error { state := &ShutdownState{ InterfaceState: &InterfaceState{ @@ -75,13 +118,8 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { log.Errorf("failed to update state: %v", err) } - if err := m.router.init(stateManager); err != nil { - return fmt.Errorf("router init: %w", err) - } - - if err := m.aclMgr.init(stateManager); err != nil { - // TODO: cleanup router - return fmt.Errorf("acl manager init: %w", err) + if err := m.initChains(stateManager); err != nil { + return err } if err := m.initNoTrackChain(); err != nil { @@ -98,6 +136,41 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { return nil } +// initChains initializes router and ACL chains for both address families, +// rolling back on failure. +func (m *Manager) initChains(stateManager *statemanager.Manager) error { + type initStep struct { + name string + init func(*statemanager.Manager) error + mgr resetter + } + + steps := []initStep{ + {"router", m.router.init, m.router}, + {"acl manager", m.aclMgr.init, m.aclMgr}, + } + if m.hasIPv6() { + steps = append(steps, + initStep{"v6 router", m.router6.init, m.router6}, + initStep{"v6 acl manager", m.aclMgr6.init, m.aclMgr6}, + ) + } + + var initialized []initStep + for _, s := range steps { + if err := s.init(stateManager); err != nil { + for i := len(initialized) - 1; i >= 0; i-- { + if rerr := initialized[i].mgr.Reset(); rerr != nil { + log.Warnf("rollback %s: %v", initialized[i].name, rerr) + } + } + return fmt.Errorf("%s init: %w", s.name, err) + } + initialized = append(initialized, s) + } + return nil +} + // AddPeerFiltering adds a rule to the firewall // // Comment will be ignored because some system this feature is not supported @@ -113,7 +186,13 @@ func (m *Manager) AddPeerFiltering( m.mutex.Lock() defer m.mutex.Unlock() - return m.aclMgr.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName) + if ip.To4() != nil { + return m.aclMgr.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName) + } + if !m.hasIPv6() { + return nil, fmt.Errorf("IPv6 not initialized, cannot add rule for %s", ip) + } + return m.aclMgr6.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName) } func (m *Manager) AddRouteFiltering( @@ -127,25 +206,48 @@ func (m *Manager) AddRouteFiltering( m.mutex.Lock() defer m.mutex.Unlock() - if destination.IsPrefix() && !destination.Prefix.Addr().Is4() { - return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String()) + if isIPv6RouteRule(sources, destination) { + if !m.hasIPv6() { + return nil, fmt.Errorf("IPv6 not initialized, cannot add route rule") + } + return m.router6.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action) } return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action) } +func isIPv6RouteRule(sources []netip.Prefix, destination firewall.Network) bool { + if destination.IsPrefix() { + return destination.Prefix.Addr().Is6() + } + return len(sources) > 0 && sources[0].Addr().Is6() +} + // DeletePeerRule from the firewall by rule definition func (m *Manager) DeletePeerRule(rule firewall.Rule) error { m.mutex.Lock() defer m.mutex.Unlock() + if m.hasIPv6() && isIPv6IptRule(rule) { + return m.aclMgr6.DeletePeerRule(rule) + } return m.aclMgr.DeletePeerRule(rule) } +func isIPv6IptRule(rule firewall.Rule) bool { + r, ok := rule.(*Rule) + return ok && r.v6 +} + +// DeleteRouteRule deletes a routing rule. +// Route rules are keyed by content hash. Check v4 first, try v6 if not found. func (m *Manager) DeleteRouteRule(rule firewall.Rule) error { m.mutex.Lock() defer m.mutex.Unlock() + if m.hasIPv6() && !m.router.hasRule(rule.ID()) { + return m.router6.DeleteRouteRule(rule) + } return m.router.DeleteRouteRule(rule) } @@ -161,18 +263,63 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.AddNatRule(pair) + if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() { + if !m.hasIPv6() { + return fmt.Errorf("IPv6 not initialized, cannot add NAT rule") + } + return m.router6.AddNatRule(pair) + } + + if err := m.router.AddNatRule(pair); err != nil { + return err + } + + // Dynamic routes need NAT in both tables + if m.hasIPv6() && pair.Destination.IsSet() { + v6Pair := pair + v6Pair.Source = firewall.Network{Prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0)} + if err := m.router6.AddNatRule(v6Pair); err != nil { + return fmt.Errorf("add v6 NAT rule: %w", err) + } + } + + return nil } func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.RemoveNatRule(pair) + if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() { + if !m.hasIPv6() { + return nil + } + return m.router6.RemoveNatRule(pair) + } + + if err := m.router.RemoveNatRule(pair); err != nil { + return err + } + + if m.hasIPv6() && pair.Destination.IsSet() { + v6Pair := pair + v6Pair.Source = firewall.Network{Prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0)} + if err := m.router6.RemoveNatRule(v6Pair); err != nil { + return fmt.Errorf("remove v6 NAT rule: %w", err) + } + } + + return nil } func (m *Manager) SetLegacyManagement(isLegacy bool) error { - return firewall.SetLegacyManagement(m.router, isLegacy) + if err := firewall.SetLegacyManagement(m.router, isLegacy); err != nil { + return err + } + if m.hasIPv6() { + return firewall.SetLegacyManagement(m.router6, isLegacy) + } + return nil } // Reset firewall to the default state @@ -186,6 +333,15 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error { merr = multierror.Append(merr, fmt.Errorf("cleanup notrack chain: %w", err)) } + if m.hasIPv6() { + if err := m.aclMgr6.Reset(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("reset v6 acl manager: %w", err)) + } + if err := m.router6.Reset(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("reset v6 router: %w", err)) + } + } + if err := m.aclMgr.Reset(); err != nil { merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err)) } @@ -209,19 +365,16 @@ func (m *Manager) AllowNetbird() error { return nil } - _, err := m.AddPeerFiltering( - nil, - net.IP{0, 0, 0, 0}, - firewall.ProtocolALL, - nil, - nil, - firewall.ActionAccept, - "", - ) - if err != nil { - return fmt.Errorf("allow netbird interface traffic: %w", err) + var merr *multierror.Error + if _, err := m.aclMgr.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolALL, nil, nil, firewall.ActionAccept, ""); err != nil { + merr = multierror.Append(merr, fmt.Errorf("allow netbird interface traffic: %w", err)) } - return nil + if m.hasIPv6() { + if _, err := m.aclMgr6.AddPeerFiltering(nil, net.IPv6zero, firewall.ProtocolALL, nil, nil, firewall.ActionAccept, ""); err != nil { + merr = multierror.Append(merr, fmt.Errorf("allow v6 netbird interface traffic: %w", err)) + } + } + return nberrors.FormatErrorOrNil(merr) } // Flush doesn't need to be implemented for this manager @@ -251,6 +404,9 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) m.mutex.Lock() defer m.mutex.Unlock() + if m.hasIPv6() && rule.TranslatedAddress.Is6() { + return m.router6.AddDNATRule(rule) + } return m.router.AddDNATRule(rule) } @@ -259,6 +415,9 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error { m.mutex.Lock() defer m.mutex.Unlock() + if m.hasIPv6() && !m.router.hasRule(rule.ID()+dnatSuffix) { + return m.router6.DeleteDNATRule(rule) + } return m.router.DeleteDNATRule(rule) } @@ -267,7 +426,26 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.UpdateSet(set, prefixes) + var v4Prefixes, v6Prefixes []netip.Prefix + for _, p := range prefixes { + if p.Addr().Is6() { + v6Prefixes = append(v6Prefixes, p) + } else { + v4Prefixes = append(v4Prefixes, p) + } + } + + if err := m.router.UpdateSet(set, v4Prefixes); err != nil { + return err + } + + if m.hasIPv6() && len(v6Prefixes) > 0 { + if err := m.router6.UpdateSet(set, v6Prefixes); err != nil { + return fmt.Errorf("update v6 set: %w", err) + } + } + + return nil } // AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. @@ -275,6 +453,9 @@ func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protoco m.mutex.Lock() defer m.mutex.Unlock() + if m.hasIPv6() && localAddr.Is6() { + return m.router6.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort) + } return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort) } @@ -283,6 +464,9 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot m.mutex.Lock() defer m.mutex.Unlock() + if m.hasIPv6() && localAddr.Is6() { + return m.router6.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort) + } return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort) } diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index a7c4f67dd..61921f7f9 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -54,8 +54,10 @@ const ( snatSuffix = "_snat" fwdSuffix = "_fwd" - // ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation - ipTCPHeaderMinSize = 40 + // ipv4TCPHeaderSize is the minimum IPv4 (20) + TCP (20) header size for MSS calculation. + ipv4TCPHeaderSize = 40 + // ipv6TCPHeaderSize is the minimum IPv6 (40) + TCP (20) header size for MSS calculation. + ipv6TCPHeaderSize = 60 ) type ruleInfo struct { @@ -86,6 +88,7 @@ type router struct { wgIface iFaceMapper legacyManagement bool mtu uint16 + v6 bool stateManager *statemanager.Manager ipFwdState *ipfwdstate.IPForwardingState @@ -97,6 +100,7 @@ func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint1 rules: make(map[string][]string), wgIface: wgIface, mtu: mtu, + v6: iptablesClient.Proto() == iptables.ProtocolIPv6, ipFwdState: ipfwdstate.NewIPForwardingState(), } @@ -186,6 +190,11 @@ func (r *router) AddRouteFiltering( return ruleKey, nil } +func (r *router) hasRule(id string) bool { + _, ok := r.rules[id] + return ok +} + func (r *router) DeleteRouteRule(rule firewall.Rule) error { ruleKey := rule.ID() @@ -434,6 +443,12 @@ func (r *router) createContainers() error { {chainRTRDR, tableNat}, {chainRTMSSCLAMP, tableMangle}, } { + // Fallback: clear chains that survived an unclean shutdown. + if ok, _ := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain); ok { + if err := r.iptablesClient.ClearAndDeleteChain(chainInfo.table, chainInfo.chain); err != nil { + log.Warnf("clear stale chain %s in %s: %v", chainInfo.chain, chainInfo.table, err) + } + } if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil { return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err) } @@ -540,9 +555,12 @@ func (r *router) addPostroutingRules() error { } // addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic. -// TODO: Add IPv6 support func (r *router) addMSSClampingRules() error { - mss := r.mtu - ipTCPHeaderMinSize + overhead := uint16(ipv4TCPHeaderSize) + if r.v6 { + overhead = ipv6TCPHeaderSize + } + mss := r.mtu - overhead // Add jump rule from FORWARD chain in mangle table to our custom chain jumpRule := []string{ @@ -727,8 +745,13 @@ func (r *router) updateState() { currentState.Lock() defer currentState.Unlock() - currentState.RouteRules = r.rules - currentState.RouteIPsetCounter = r.ipsetCounter + if r.v6 { + currentState.RouteRules6 = r.rules + currentState.RouteIPsetCounter6 = r.ipsetCounter + } else { + currentState.RouteRules = r.rules + currentState.RouteIPsetCounter = r.ipsetCounter + } if err := r.stateManager.UpdateState(currentState); err != nil { log.Errorf("failed to update state: %v", err) @@ -856,7 +879,7 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error { } if fwdRule, exists := r.rules[ruleKey+fwdSuffix]; exists { - if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, fwdRule...); err != nil { + if err := r.iptablesClient.Delete(tableFilter, chainRTFWDOUT, fwdRule...); err != nil { merr = multierror.Append(merr, fmt.Errorf("delete forward rule: %w", err)) } delete(r.rules, ruleKey+fwdSuffix) @@ -883,7 +906,7 @@ func (r *router) genRouteRuleSpec(params routeFilteringRuleParams, sources []net rule = append(rule, destExp...) if params.Proto != firewall.ProtocolALL { - rule = append(rule, "-p", strings.ToLower(string(params.Proto))) + rule = append(rule, "-p", strings.ToLower(protoForFamily(params.Proto, r.v6))) rule = append(rule, applyPort("--sport", params.SPort)...) rule = append(rule, applyPort("--dport", params.DPort)...) } @@ -900,11 +923,12 @@ func (r *router) applyNetwork(flag string, network firewall.Network, prefixes [] } if network.IsSet() { - if _, err := r.ipsetCounter.Increment(network.Set.HashedName(), prefixes); err != nil { + name := r.ipsetName(network.Set.HashedName()) + if _, err := r.ipsetCounter.Increment(name, prefixes); err != nil { return nil, fmt.Errorf("create or get ipset: %w", err) } - return []string{"-m", "set", matchSet, network.Set.HashedName(), direction}, nil + return []string{"-m", "set", matchSet, name, direction}, nil } if network.IsPrefix() { return []string{flag, network.Prefix.String()}, nil @@ -915,19 +939,15 @@ func (r *router) applyNetwork(flag string, network firewall.Network, prefixes [] } func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { + name := r.ipsetName(set.HashedName()) var merr *multierror.Error for _, prefix := range prefixes { - // TODO: Implement IPv6 support - if prefix.Addr().Is6() { - log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix) - continue - } - if err := r.addPrefixToIPSet(set.HashedName(), prefix); err != nil { + if err := r.addPrefixToIPSet(name, prefix); err != nil { merr = multierror.Append(merr, fmt.Errorf("add prefix to ipset: %w", err)) } } if merr == nil { - log.Debugf("updated set %s with prefixes %v", set.HashedName(), prefixes) + log.Debugf("updated set %s with prefixes %v", name, prefixes) } return nberrors.FormatErrorOrNil(merr) @@ -943,7 +963,7 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol dnatRule := []string{ "-i", r.wgIface.Name(), - "-p", strings.ToLower(string(protocol)), + "-p", strings.ToLower(protoForFamily(protocol, r.v6)), "--dport", strconv.Itoa(int(sourcePort)), "-d", localAddr.String(), "-m", "addrtype", "--dst-type", "LOCAL", @@ -1076,10 +1096,22 @@ func applyPort(flag string, port *firewall.Port) []string { return []string{flag, strconv.Itoa(int(port.Values[0]))} } +// ipsetName returns the ipset name, suffixed with "-v6" for the v6 router +// to avoid collisions since ipsets are global in the kernel. +func (r *router) ipsetName(name string) string { + if r.v6 { + return name + "-v6" + } + return name +} + func (r *router) createIPSet(name string) error { opts := ipset.CreateOptions{ Replace: true, } + if r.v6 { + opts.Family = ipset.FamilyIPV6 + } if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil { return fmt.Errorf("create ipset %s: %w", name, err) diff --git a/client/firewall/iptables/rule.go b/client/firewall/iptables/rule.go index aa4d2d079..4f4eab167 100644 --- a/client/firewall/iptables/rule.go +++ b/client/firewall/iptables/rule.go @@ -9,6 +9,7 @@ type Rule struct { mangleSpecs []string ip string chain string + v6 bool } // GetRuleID returns the rule id diff --git a/client/firewall/iptables/state_linux.go b/client/firewall/iptables/state_linux.go index c88774c1f..6b2e99e31 100644 --- a/client/firewall/iptables/state_linux.go +++ b/client/firewall/iptables/state_linux.go @@ -4,6 +4,8 @@ import ( "fmt" "sync" + log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -37,6 +39,12 @@ type ShutdownState struct { ACLEntries aclEntries `json:"acl_entries,omitempty"` ACLIPsetStore *ipsetStore `json:"acl_ipset_store,omitempty"` + + // IPv6 counterparts + RouteRules6 routeRules `json:"route_rules_v6,omitempty"` + RouteIPsetCounter6 *ipsetCounter `json:"route_ipset_counter_v6,omitempty"` + ACLEntries6 aclEntries `json:"acl_entries_v6,omitempty"` + ACLIPsetStore6 *ipsetStore `json:"acl_ipset_store_v6,omitempty"` } func (s *ShutdownState) Name() string { @@ -67,6 +75,28 @@ func (s *ShutdownState) Cleanup() error { ipt.aclMgr.ipsetStore = s.ACLIPsetStore } + // Clean up v6 state even if the current run has no IPv6. + // The previous run may have left ip6tables rules behind. + if !ipt.hasIPv6() { + if err := ipt.createIPv6Components(s.InterfaceState, mtu); err != nil { + log.Warnf("failed to create v6 components for cleanup: %v", err) + } + } + if ipt.hasIPv6() { + if s.RouteRules6 != nil { + ipt.router6.rules = s.RouteRules6 + } + if s.RouteIPsetCounter6 != nil { + ipt.router6.ipsetCounter.LoadData(s.RouteIPsetCounter6) + } + if s.ACLEntries6 != nil { + ipt.aclMgr6.entries = s.ACLEntries6 + } + if s.ACLIPsetStore6 != nil { + ipt.aclMgr6.ipsetStore = s.ACLIPsetStore6 + } + } + if err := ipt.Close(nil); err != nil { return fmt.Errorf("reset iptables manager: %w", err) } diff --git a/client/firewall/nftables/acl_linux.go b/client/firewall/nftables/acl_linux.go index a9d066e2f..9d2ea7264 100644 --- a/client/firewall/nftables/acl_linux.go +++ b/client/firewall/nftables/acl_linux.go @@ -33,15 +33,12 @@ const ( const flushError = "flush: %w" -var ( - anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} -) - type AclManager struct { rConn *nftables.Conn sConn *nftables.Conn wgIface iFaceMapper routingFwChainName string + af addrFamily workTable *nftables.Table chainInputRules *nftables.Chain @@ -67,6 +64,7 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainNam wgIface: wgIface, workTable: table, routingFwChainName: routingFwChainName, + af: familyForAddr(table.Family == nftables.TableFamilyIPv4), ipsetStore: newIpsetStore(), rules: make(map[string]*Rule), @@ -145,7 +143,7 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error { } if _, ok := ips[r.ip.String()]; ok { - err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: r.ip.To4()}}) + err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: ipToBytes(r.ip, m.af)}}) if err != nil { log.Errorf("delete elements for set %q: %v", r.nftSet.Name, err) } @@ -254,11 +252,11 @@ func (m *AclManager) addIOFiltering( expressions = append(expressions, &expr.Payload{ DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, - Offset: uint32(9), + Offset: m.af.protoOffset, Len: uint32(1), }) - protoData, err := protoToInt(proto) + protoData, err := m.af.protoNum(proto) if err != nil { return nil, fmt.Errorf("convert protocol to number: %v", err) } @@ -270,19 +268,16 @@ func (m *AclManager) addIOFiltering( }) } - rawIP := ip.To4() + rawIP := ipToBytes(ip, m.af) // check if rawIP contains zeroed IPv4 0.0.0.0 value // in that case not add IP match expression into the rule definition - if !bytes.HasPrefix(anyIP, rawIP) { - // source address position - addrOffset := uint32(12) - + if slices.ContainsFunc(rawIP, func(v byte) bool { return v != 0 }) { expressions = append(expressions, &expr.Payload{ DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, - Offset: addrOffset, - Len: 4, + Offset: m.af.srcAddrOffset, + Len: m.af.addrLen, }, ) // add individual IP for match if no ipset defined @@ -587,7 +582,7 @@ func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr func (m *AclManager) addIpToSet(ipsetName string, ip net.IP) (*nftables.Set, error) { ipset, err := m.rConn.GetSetByName(m.workTable, ipsetName) - rawIP := ip.To4() + rawIP := ipToBytes(ip, m.af) if err != nil { if ipset, err = m.createSet(m.workTable, ipsetName); err != nil { return nil, fmt.Errorf("get set name: %v", err) @@ -619,7 +614,7 @@ func (m *AclManager) createSet(table *nftables.Table, name string) (*nftables.Se Name: name, Table: table, Dynamic: true, - KeyType: nftables.TypeIPAddr, + KeyType: m.af.setKeyType, } if err := m.rConn.AddSet(ipset, nil); err != nil { @@ -707,15 +702,12 @@ func ifname(n string) []byte { return b } -func protoToInt(protocol firewall.Protocol) (uint8, error) { - switch protocol { - case firewall.ProtocolTCP: - return unix.IPPROTO_TCP, nil - case firewall.ProtocolUDP: - return unix.IPPROTO_UDP, nil - case firewall.ProtocolICMP: - return unix.IPPROTO_ICMP, nil - } - return 0, fmt.Errorf("unsupported protocol: %s", protocol) +// ipToBytes converts net.IP to the correct byte length for the address family. +func ipToBytes(ip net.IP, af addrFamily) []byte { + if af.addrLen == 4 { + return ip.To4() + } + return ip.To16() } + diff --git a/client/firewall/nftables/addr_family_linux.go b/client/firewall/nftables/addr_family_linux.go new file mode 100644 index 000000000..0c90d704a --- /dev/null +++ b/client/firewall/nftables/addr_family_linux.go @@ -0,0 +1,81 @@ +package nftables + +import ( + "fmt" + "net" + + "github.com/google/nftables" + "golang.org/x/sys/unix" + + firewall "github.com/netbirdio/netbird/client/firewall/manager" +) + +var ( + // afIPv4 defines IPv4 header layout and nftables types. + afIPv4 = addrFamily{ + protoOffset: 9, + srcAddrOffset: 12, + dstAddrOffset: 16, + addrLen: net.IPv4len, + totalBits: 8 * net.IPv4len, + setKeyType: nftables.TypeIPAddr, + tableFamily: nftables.TableFamilyIPv4, + icmpProto: unix.IPPROTO_ICMP, + } + // afIPv6 defines IPv6 header layout and nftables types. + afIPv6 = addrFamily{ + protoOffset: 6, + srcAddrOffset: 8, + dstAddrOffset: 24, + addrLen: net.IPv6len, + totalBits: 8 * net.IPv6len, + setKeyType: nftables.TypeIP6Addr, + tableFamily: nftables.TableFamilyIPv6, + icmpProto: unix.IPPROTO_ICMPV6, + } +) + +// addrFamily holds protocol-specific constants for nftables expression building. +type addrFamily struct { + // protoOffset is the IP header offset for the protocol/next-header field (9 for v4, 6 for v6) + protoOffset uint32 + // srcAddrOffset is the IP header offset for the source address (12 for v4, 8 for v6) + srcAddrOffset uint32 + // dstAddrOffset is the IP header offset for the destination address (16 for v4, 24 for v6) + dstAddrOffset uint32 + // addrLen is the byte length of addresses (4 for v4, 16 for v6) + addrLen uint32 + // totalBits is the address size in bits (32 for v4, 128 for v6) + totalBits int + // setKeyType is the nftables set data type for addresses + setKeyType nftables.SetDatatype + // tableFamily is the nftables table family + tableFamily nftables.TableFamily + // icmpProto is the ICMP protocol number for this family (1 for v4, 58 for v6) + icmpProto uint8 +} + +// familyForAddr returns the address family for the given IP. +func familyForAddr(is4 bool) addrFamily { + if is4 { + return afIPv4 + } + return afIPv6 +} + +// protoNum converts a firewall protocol to the IP protocol number, +// using the correct ICMP variant for the address family. +func (af addrFamily) protoNum(protocol firewall.Protocol) (uint8, error) { + switch protocol { + case firewall.ProtocolTCP: + return unix.IPPROTO_TCP, nil + case firewall.ProtocolUDP: + return unix.IPPROTO_UDP, nil + case firewall.ProtocolICMP: + return af.icmpProto, nil + case firewall.ProtocolALL: + return 0, nil + default: + return 0, fmt.Errorf("unsupported protocol: %s", protocol) + } +} diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index beb5b70a7..c3c1c1a65 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -11,9 +11,11 @@ import ( "github.com/google/nftables" "github.com/google/nftables/binaryutil" "github.com/google/nftables/expr" + "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" "golang.org/x/sys/unix" + nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/statemanager" @@ -49,8 +51,13 @@ type Manager struct { rConn *nftables.Conn wgIface iFaceMapper - router *router - aclManager *AclManager + router *router + aclManager *AclManager + + // IPv6 counterparts, nil when no v6 overlay + router6 *router + aclManager6 *AclManager + notrackOutputChain *nftables.Chain notrackPreroutingChain *nftables.Chain } @@ -62,7 +69,8 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) { wgIface: wgIface, } - workTable := &nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4} + tableName := getTableName() + workTable := &nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4} var err error m.router, err = newRouter(workTable, wgIface, mtu) @@ -75,11 +83,70 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) { return nil, fmt.Errorf("create acl manager: %w", err) } + if wgIface.Address().HasIPv6() { + if err := m.createIPv6Components(tableName, wgIface, mtu); err != nil { + return nil, fmt.Errorf("create IPv6 firewall: %w", err) + } + } + return m, nil } +func (m *Manager) createIPv6Components(tableName string, wgIface iFaceMapper, mtu uint16) error { + workTable6 := &nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv6} + + var err error + m.router6, err = newRouter(workTable6, wgIface, mtu) + if err != nil { + return fmt.Errorf("create v6 router: %w", err) + } + + // Share the same IP forwarding state with the v4 router, since + // EnableIPForwarding controls both v4 and v6 sysctls. + m.router6.ipFwdState = m.router.ipFwdState + + m.aclManager6, err = newAclManager(workTable6, wgIface, chainNameRoutingFw) + if err != nil { + return fmt.Errorf("create v6 acl manager: %w", err) + } + + return nil +} + +// hasIPv6 reports whether the manager has IPv6 components initialized. +func (m *Manager) hasIPv6() bool { + return m.router6 != nil +} + +func (m *Manager) initIPv6() error { + workTable6, err := m.createWorkTableFamily(nftables.TableFamilyIPv6) + if err != nil { + return fmt.Errorf("create v6 work table: %w", err) + } + + if err := m.router6.init(workTable6); err != nil { + return fmt.Errorf("v6 router init: %w", err) + } + + if err := m.aclManager6.init(workTable6); err != nil { + return fmt.Errorf("v6 acl manager init: %w", err) + } + + return nil +} + // Init nftables firewall manager func (m *Manager) Init(stateManager *statemanager.Manager) error { + if err := m.initFirewall(); err != nil { + return err + } + + m.persistState(stateManager) + + return nil +} + +func (m *Manager) initFirewall() error { workTable, err := m.createWorkTable() if err != nil { return fmt.Errorf("create work table: %w", err) @@ -90,20 +157,32 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { } if err := m.aclManager.init(workTable); err != nil { - // TODO: cleanup router + m.rollbackInit() return fmt.Errorf("acl manager init: %w", err) } + if m.hasIPv6() { + if err := m.initIPv6(); err != nil { + // Peer has a v6 address: v6 firewall MUST work or we risk fail-open. + m.rollbackInit() + return fmt.Errorf("init IPv6 firewall (required because peer has IPv6 address): %w", err) + } + } + if err := m.initNoTrackChains(workTable); err != nil { log.Warnf("raw priority chains not available, notrack rules will be disabled: %v", err) } + return nil +} + +// persistState saves the current interface state for potential recreation on restart. +// Unlike iptables, which requires tracking individual rules, nftables maintains +// a known state (our netbird table plus a few static rules). This allows for easy +// cleanup using Close() without needing to store specific rules. +func (m *Manager) persistState(stateManager *statemanager.Manager) { stateManager.RegisterState(&ShutdownState{}) - // We only need to record minimal interface state for potential recreation. - // Unlike iptables, which requires tracking individual rules, nftables maintains - // a known state (our netbird table plus a few static rules). This allows for easy - // cleanup using Close() without needing to store specific rules. if err := stateManager.UpdateState(&ShutdownState{ InterfaceState: &InterfaceState{ NameStr: m.wgIface.Name(), @@ -115,14 +194,29 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { log.Errorf("failed to update state: %v", err) } - // persist early go func() { if err := stateManager.PersistState(context.Background()); err != nil { log.Errorf("failed to persist state: %v", err) } }() +} - return nil +// rollbackInit performs best-effort cleanup of already-initialized state when Init fails partway through. +func (m *Manager) rollbackInit() { + if err := m.router.Reset(); err != nil { + log.Warnf("rollback router: %v", err) + } + if m.hasIPv6() { + if err := m.router6.Reset(); err != nil { + log.Warnf("rollback v6 router: %v", err) + } + } + if err := m.cleanupNetbirdTables(); err != nil { + log.Warnf("cleanup tables: %v", err) + } + if err := m.rConn.Flush(); err != nil { + log.Warnf("flush: %v", err) + } } // AddPeerFiltering rule to the firewall @@ -141,12 +235,14 @@ func (m *Manager) AddPeerFiltering( m.mutex.Lock() defer m.mutex.Unlock() - rawIP := ip.To4() - if rawIP == nil { - return nil, fmt.Errorf("unsupported IP version: %s", ip.String()) + if ip.To4() != nil { + return m.aclManager.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName) } - return m.aclManager.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName) + if !m.hasIPv6() { + return nil, fmt.Errorf("IPv6 not initialized, cannot add rule for %s", ip) + } + return m.aclManager6.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName) } func (m *Manager) AddRouteFiltering( @@ -160,8 +256,11 @@ func (m *Manager) AddRouteFiltering( m.mutex.Lock() defer m.mutex.Unlock() - if destination.IsPrefix() && !destination.Prefix.Addr().Is4() { - return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String()) + if isIPv6RouteRule(sources, destination) { + if !m.hasIPv6() { + return nil, fmt.Errorf("IPv6 not initialized, cannot add route rule") + } + return m.router6.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action) } return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action) @@ -172,14 +271,38 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error { m.mutex.Lock() defer m.mutex.Unlock() + if m.hasIPv6() && isIPv6Rule(rule) { + return m.aclManager6.DeletePeerRule(rule) + } return m.aclManager.DeletePeerRule(rule) } -// DeleteRouteRule deletes a routing rule +func isIPv6Rule(rule firewall.Rule) bool { + r, ok := rule.(*Rule) + return ok && r.nftRule != nil && r.nftRule.Table != nil && r.nftRule.Table.Family == nftables.TableFamilyIPv6 +} + +// isIPv6RouteRule determines whether a route rule belongs to the v6 table. +// For static routes, the destination prefix determines the family. For dynamic +// routes (DomainSet), the sources determine the family since management +// duplicates dynamic rules per family. +func isIPv6RouteRule(sources []netip.Prefix, destination firewall.Network) bool { + if destination.IsPrefix() { + return destination.Prefix.Addr().Is6() + } + return len(sources) > 0 && sources[0].Addr().Is6() +} + +// DeleteRouteRule deletes a routing rule. +// Route rules are keyed by content hash, so the rule exists in exactly one +// router. We check v4 first; if the key isn't there, try v6. func (m *Manager) DeleteRouteRule(rule firewall.Rule) error { m.mutex.Lock() defer m.mutex.Unlock() + if m.hasIPv6() && !m.router.hasRule(rule.ID()) { + return m.router6.DeleteRouteRule(rule) + } return m.router.DeleteRouteRule(rule) } @@ -195,17 +318,63 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.AddNatRule(pair) + if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() { + if !m.hasIPv6() { + return fmt.Errorf("IPv6 not initialized, cannot add NAT rule") + } + return m.router6.AddNatRule(pair) + } + + if err := m.router.AddNatRule(pair); err != nil { + return err + } + + // Dynamic routes (DomainSet) need NAT in both tables since resolved IPs + // can be either v4 or v6. + if m.hasIPv6() && pair.Destination.IsSet() { + v6Pair := pair + v6Pair.Source = firewall.Network{Prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0)} + if err := m.router6.AddNatRule(v6Pair); err != nil { + return fmt.Errorf("add v6 NAT rule: %w", err) + } + } + + return nil } func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.RemoveNatRule(pair) + if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() { + if !m.hasIPv6() { + return nil + } + return m.router6.RemoveNatRule(pair) + } + + if err := m.router.RemoveNatRule(pair); err != nil { + return err + } + + if m.hasIPv6() && pair.Destination.IsSet() { + v6Pair := pair + v6Pair.Source = firewall.Network{Prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0)} + if err := m.router6.RemoveNatRule(v6Pair); err != nil { + return fmt.Errorf("remove v6 NAT rule: %w", err) + } + } + + return nil } -// AllowNetbird allows netbird interface traffic +// AllowNetbird allows netbird interface traffic. +// TODO: In USP mode this only adds ACCEPT to the netbird table's own chains, +// which doesn't override DROP rules in external tables (e.g. firewalld). +// Should add passthrough rules to external chains (like the native mode router's +// addExternalChainsRules does) for both the netbird table family and inet tables. +// The netbird table itself is fine (routing chains already exist there), but +// non-netbird tables with INPUT/FORWARD hooks can still DROP our WG traffic. func (m *Manager) AllowNetbird() error { if !m.wgIface.IsUserspaceBind() { return nil @@ -217,6 +386,11 @@ func (m *Manager) AllowNetbird() error { if err := m.aclManager.createDefaultAllowRules(); err != nil { return fmt.Errorf("create default allow rules: %w", err) } + if m.hasIPv6() { + if err := m.aclManager6.createDefaultAllowRules(); err != nil { + return fmt.Errorf("create v6 default allow rules: %w", err) + } + } if err := m.rConn.Flush(); err != nil { return fmt.Errorf("flush allow input netbird rules: %w", err) } @@ -226,7 +400,13 @@ func (m *Manager) AllowNetbird() error { // SetLegacyManagement sets the route manager to use legacy management func (m *Manager) SetLegacyManagement(isLegacy bool) error { - return firewall.SetLegacyManagement(m.router, isLegacy) + if err := firewall.SetLegacyManagement(m.router, isLegacy); err != nil { + return err + } + if m.hasIPv6() { + return firewall.SetLegacyManagement(m.router6, isLegacy) + } + return nil } // Close closes the firewall manager @@ -234,23 +414,31 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error { m.mutex.Lock() defer m.mutex.Unlock() + var merr *multierror.Error + if err := m.router.Reset(); err != nil { - return fmt.Errorf("reset router: %v", err) + merr = multierror.Append(merr, fmt.Errorf("reset router: %v", err)) + } + + if m.hasIPv6() { + if err := m.router6.Reset(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("reset v6 router: %v", err)) + } } if err := m.cleanupNetbirdTables(); err != nil { - return fmt.Errorf("cleanup netbird tables: %v", err) + merr = multierror.Append(merr, fmt.Errorf("cleanup netbird tables: %v", err)) } if err := m.rConn.Flush(); err != nil { - return fmt.Errorf(flushError, err) + merr = multierror.Append(merr, fmt.Errorf(flushError, err)) } if err := stateManager.DeleteState(&ShutdownState{}); err != nil { - return fmt.Errorf("delete state: %v", err) + merr = multierror.Append(merr, fmt.Errorf("delete state: %v", err)) } - return nil + return nberrors.FormatErrorOrNil(merr) } func (m *Manager) cleanupNetbirdTables() error { @@ -299,6 +487,12 @@ func (m *Manager) Flush() error { return err } + if m.hasIPv6() { + if err := m.aclManager6.Flush(); err != nil { + return fmt.Errorf("flush v6 acl: %w", err) + } + } + if err := m.refreshNoTrackChains(); err != nil { log.Errorf("failed to refresh notrack chains: %v", err) } @@ -311,6 +505,9 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) m.mutex.Lock() defer m.mutex.Unlock() + if m.hasIPv6() && rule.TranslatedAddress.Is6() { + return m.router6.AddDNATRule(rule) + } return m.router.AddDNATRule(rule) } @@ -319,6 +516,9 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error { m.mutex.Lock() defer m.mutex.Unlock() + if m.hasIPv6() && !m.router.hasDNATRule(rule.ID()) { + return m.router6.DeleteDNATRule(rule) + } return m.router.DeleteDNATRule(rule) } @@ -327,7 +527,26 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.UpdateSet(set, prefixes) + var v4Prefixes, v6Prefixes []netip.Prefix + for _, p := range prefixes { + if p.Addr().Is6() { + v6Prefixes = append(v6Prefixes, p) + } else { + v4Prefixes = append(v4Prefixes, p) + } + } + + if err := m.router.UpdateSet(set, v4Prefixes); err != nil { + return err + } + + if m.hasIPv6() && len(v6Prefixes) > 0 { + if err := m.router6.UpdateSet(set, v6Prefixes); err != nil { + return fmt.Errorf("update v6 set: %w", err) + } + } + + return nil } // AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. @@ -335,6 +554,9 @@ func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protoco m.mutex.Lock() defer m.mutex.Unlock() + if m.hasIPv6() && localAddr.Is6() { + return m.router6.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort) + } return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort) } @@ -343,6 +565,9 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot m.mutex.Lock() defer m.mutex.Unlock() + if m.hasIPv6() && localAddr.Is6() { + return m.router6.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort) + } return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort) } @@ -533,7 +758,11 @@ func (m *Manager) refreshNoTrackChains() error { } func (m *Manager) createWorkTable() (*nftables.Table, error) { - tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4) + return m.createWorkTableFamily(nftables.TableFamilyIPv4) +} + +func (m *Manager) createWorkTableFamily(family nftables.TableFamily) (*nftables.Table, error) { + tables, err := m.rConn.ListTablesOfFamily(family) if err != nil { return nil, fmt.Errorf("list of tables: %w", err) } @@ -545,7 +774,7 @@ func (m *Manager) createWorkTable() (*nftables.Table, error) { } } - table := m.rConn.AddTable(&nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4}) + table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: family}) err = m.rConn.Flush() return table, err } diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go index 75b1e2b6c..d925f3ef3 100644 --- a/client/firewall/nftables/manager_linux_test.go +++ b/client/firewall/nftables/manager_linux_test.go @@ -385,10 +385,134 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) { err = manager.AddNatRule(pair) require.NoError(t, err, "failed to add NAT rule") + dnatRule, err := manager.AddDNATRule(fw.ForwardRule{ + Protocol: fw.ProtocolTCP, + DestinationPort: fw.Port{Values: []uint16{8080}}, + TranslatedAddress: netip.MustParseAddr("100.96.0.2"), + TranslatedPort: fw.Port{Values: []uint16{80}}, + }) + require.NoError(t, err, "failed to add DNAT rule") + + t.Cleanup(func() { + require.NoError(t, manager.DeleteDNATRule(dnatRule), "failed to delete DNAT rule") + }) + stdout, stderr = runIptablesSave(t) verifyIptablesOutput(t, stdout, stderr) } +func TestNftablesManagerIPv6CompatibilityWithIp6tables(t *testing.T) { + if check() != NFTABLES { + t.Skip("nftables not supported on this system") + } + + for _, bin := range []string{"ip6tables", "ip6tables-save", "iptables-save"} { + if _, err := exec.LookPath(bin); err != nil { + t.Skipf("%s not available on this system: %v", bin, err) + } + } + + // Seed ip6 tables in the nft backend. Docker may not create them. + seedIp6tables(t) + + ifaceMockV6 := &iFaceMock{ + NameFunc: func() string { return "wt-test" }, + AddressFunc: func() wgaddr.Address { + return wgaddr.Address{ + IP: netip.MustParseAddr("100.96.0.1"), + Network: netip.MustParsePrefix("100.96.0.0/16"), + IPv6: netip.MustParseAddr("fd00::1"), + IPv6Net: netip.MustParsePrefix("fd00::/64"), + } + }, + } + + manager, err := Create(ifaceMockV6, iface.DefaultMTU) + require.NoError(t, err, "create manager") + require.NoError(t, manager.Init(nil)) + + t.Cleanup(func() { + require.NoError(t, manager.Close(nil), "close manager") + + stdout, stderr := runIp6tablesSave(t) + verifyIp6tablesOutput(t, stdout, stderr) + }) + + ip := netip.MustParseAddr("fd00::2") + _, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "") + require.NoError(t, err, "add v6 peer filtering rule") + + _, err = manager.AddRouteFiltering( + nil, + []netip.Prefix{netip.MustParsePrefix("fd00:1::/64")}, + fw.Network{Prefix: netip.MustParsePrefix("2001:db8::/48")}, + fw.ProtocolTCP, + nil, + &fw.Port{Values: []uint16{443}}, + fw.ActionAccept, + ) + require.NoError(t, err, "add v6 route filtering rule") + + err = manager.AddNatRule(fw.RouterPair{ + Source: fw.Network{Prefix: netip.MustParsePrefix("fd00::/64")}, + Destination: fw.Network{Prefix: netip.MustParsePrefix("2001:db8::/48")}, + Masquerade: true, + }) + require.NoError(t, err, "add v6 NAT rule") + + dnatRule, err := manager.AddDNATRule(fw.ForwardRule{ + Protocol: fw.ProtocolTCP, + DestinationPort: fw.Port{Values: []uint16{8080}}, + TranslatedAddress: netip.MustParseAddr("fd00::2"), + TranslatedPort: fw.Port{Values: []uint16{80}}, + }) + require.NoError(t, err, "add v6 DNAT rule") + + t.Cleanup(func() { + require.NoError(t, manager.DeleteDNATRule(dnatRule), "delete v6 DNAT rule") + }) + + stdout, stderr := runIptablesSave(t) + verifyIptablesOutput(t, stdout, stderr) + + stdout, stderr = runIp6tablesSave(t) + verifyIp6tablesOutput(t, stdout, stderr) +} + +func seedIp6tables(t *testing.T) { + t.Helper() + for _, tc := range []struct{ table, chain string }{ + {"filter", "FORWARD"}, + {"nat", "POSTROUTING"}, + {"mangle", "FORWARD"}, + } { + add := exec.Command("ip6tables", "-t", tc.table, "-A", tc.chain, "-j", "ACCEPT") + require.NoError(t, add.Run(), "seed ip6tables -t %s", tc.table) + del := exec.Command("ip6tables", "-t", tc.table, "-D", tc.chain, "-j", "ACCEPT") + require.NoError(t, del.Run(), "unseed ip6tables -t %s", tc.table) + } +} + +func runIp6tablesSave(t *testing.T) (string, string) { + t.Helper() + var stdout, stderr bytes.Buffer + cmd := exec.Command("ip6tables-save") + cmd.Stdout = &stdout + cmd.Stderr = &stderr + require.NoError(t, cmd.Run(), "ip6tables-save failed") + return stdout.String(), stderr.String() +} + +func verifyIp6tablesOutput(t *testing.T, stdout, stderr string) { + t.Helper() + require.NotContains(t, stdout, "Table `nat' is incompatible", + "ip6tables-save: nat table incompatible. Full output: %s", stdout) + require.NotContains(t, stdout, "Table `mangle' is incompatible", + "ip6tables-save: mangle table incompatible. Full output: %s", stdout) + require.NotContains(t, stdout, "Table `filter' is incompatible", + "ip6tables-save: filter table incompatible. Full output: %s", stdout) +} + func TestNftablesManagerCompatibilityWithIptablesFor6kPrefixes(t *testing.T) { if check() != NFTABLES { t.Skip("nftables not supported on this system") diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index 904daf7cb..02f8288fe 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -47,8 +47,10 @@ const ( dnatSuffix = "_dnat" snatSuffix = "_snat" - // ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation - ipTCPHeaderMinSize = 40 + // ipv4TCPHeaderSize is the minimum IPv4 (20) + TCP (20) header size for MSS calculation. + ipv4TCPHeaderSize = 40 + // ipv6TCPHeaderSize is the minimum IPv6 (40) + TCP (20) header size for MSS calculation. + ipv6TCPHeaderSize = 60 // maxPrefixesSet 1638 prefixes start to fail, taking some margin maxPrefixesSet = 1500 @@ -73,6 +75,7 @@ type router struct { rules map[string]*nftables.Rule ipsetCounter *refcounter.Counter[string, setInput, *nftables.Set] + af addrFamily wgIface iFaceMapper ipFwdState *ipfwdstate.IPForwardingState legacyManagement bool @@ -85,6 +88,7 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*rou workTable: workTable, chains: make(map[string]*nftables.Chain), rules: make(map[string]*nftables.Rule), + af: familyForAddr(workTable.Family == nftables.TableFamilyIPv4), wgIface: wgIface, ipFwdState: ipfwdstate.NewIPForwardingState(), mtu: mtu, @@ -143,7 +147,7 @@ func (r *router) Reset() error { func (r *router) removeNatPreroutingRules() error { table := &nftables.Table{ Name: tableNat, - Family: nftables.TableFamilyIPv4, + Family: r.af.tableFamily, } chain := &nftables.Chain{ Name: chainNameNatPrerouting, @@ -176,7 +180,7 @@ func (r *router) removeNatPreroutingRules() error { } func (r *router) loadFilterTable() (*nftables.Table, error) { - tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4) + tables, err := r.conn.ListTablesOfFamily(r.af.tableFamily) if err != nil { return nil, fmt.Errorf("list tables: %w", err) } @@ -408,7 +412,7 @@ func (r *router) AddRouteFiltering( // Handle protocol if proto != firewall.ProtocolALL { - protoNum, err := protoToInt(proto) + protoNum, err := r.af.protoNum(proto) if err != nil { return nil, fmt.Errorf("convert protocol to number: %w", err) } @@ -468,7 +472,24 @@ func (r *router) getIpSet(set firewall.Set, prefixes []netip.Prefix, isSource bo return nil, fmt.Errorf("create or get ipset: %w", err) } - return getIpSetExprs(ref, isSource) + return r.getIpSetExprs(ref, isSource) +} + +func (r *router) iptablesProto() iptables.Protocol { + if r.af.tableFamily == nftables.TableFamilyIPv6 { + return iptables.ProtocolIPv6 + } + return iptables.ProtocolIPv4 +} + +func (r *router) hasRule(id string) bool { + _, ok := r.rules[id] + return ok +} + +func (r *router) hasDNATRule(id string) bool { + _, ok := r.rules[id+dnatSuffix] + return ok } func (r *router) DeleteRouteRule(rule firewall.Rule) error { @@ -517,10 +538,10 @@ func (r *router) createIpSet(setName string, input setInput) (*nftables.Set, err Table: r.workTable, // required for prefixes Interval: true, - KeyType: nftables.TypeIPAddr, + KeyType: r.af.setKeyType, } - elements := convertPrefixesToSet(prefixes) + elements := r.convertPrefixesToSet(prefixes) nElements := len(elements) maxElements := maxPrefixesSet * 2 @@ -553,23 +574,17 @@ func (r *router) createIpSet(setName string, input setInput) (*nftables.Set, err return nfset, nil } -func convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement { +func (r *router) convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement { var elements []nftables.SetElement for _, prefix := range prefixes { - // TODO: Implement IPv6 support - if prefix.Addr().Is6() { - log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix) - continue - } - // nftables needs half-open intervals [firstIP, lastIP) for prefixes // e.g. 10.0.0.0/24 becomes [10.0.0.0, 10.0.1.0), 10.1.1.1/32 becomes [10.1.1.1, 10.1.1.2) etc firstIP := prefix.Addr() lastIP := calculateLastIP(prefix).Next() elements = append(elements, - // the nft tool also adds a line like this, see https://github.com/google/nftables/issues/247 - // nftables.SetElement{Key: []byte{0, 0, 0, 0}, IntervalEnd: true}, + // the nft tool also adds a zero-address IntervalEnd element, see https://github.com/google/nftables/issues/247 + // nftables.SetElement{Key: make([]byte, r.af.addrLen), IntervalEnd: true}, nftables.SetElement{Key: firstIP.AsSlice()}, nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true}, ) @@ -579,10 +594,20 @@ func convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement { // calculateLastIP determines the last IP in a given prefix. func calculateLastIP(prefix netip.Prefix) netip.Addr { - hostMask := ^uint32(0) >> prefix.Masked().Bits() - lastIP := uint32FromNetipAddr(prefix.Addr()) | hostMask + masked := prefix.Masked() + if masked.Addr().Is4() { + hostMask := ^uint32(0) >> masked.Bits() + lastIP := uint32FromNetipAddr(masked.Addr()) | hostMask + return netip.AddrFrom4(uint32ToBytes(lastIP)) + } - return netip.AddrFrom4(uint32ToBytes(lastIP)) + // IPv6: set host bits to all 1s + b := masked.Addr().As16() + bits := masked.Bits() + for i := bits; i < 128; i++ { + b[i/8] |= 1 << (7 - i%8) + } + return netip.AddrFrom16(b) } // Utility function to convert netip.Addr to uint32. @@ -834,9 +859,12 @@ func (r *router) addPostroutingRules() { } // addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic. -// TODO: Add IPv6 support func (r *router) addMSSClampingRules() error { - mss := r.mtu - ipTCPHeaderMinSize + overhead := uint16(ipv4TCPHeaderSize) + if r.af.tableFamily == nftables.TableFamilyIPv6 { + overhead = ipv6TCPHeaderSize + } + mss := r.mtu - overhead exprsOut := []expr.Any{ &expr.Meta{ @@ -1043,17 +1071,22 @@ func (r *router) acceptFilterTableRules() error { log.Debugf("Used %s to add accept forward and input rules", fw) }() - // Try iptables first and fallback to nftables if iptables is not available - ipt, err := iptables.New() + // Try iptables first and fallback to nftables if iptables is not available. + // Use the correct protocol (iptables vs ip6tables) for the address family. + ipt, err := iptables.NewWithProtocol(r.iptablesProto()) if err != nil { - // iptables is not available but the filter table exists log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err) fw = "nftables" return r.acceptFilterRulesNftables(r.filterTable) } - return r.acceptFilterRulesIptables(ipt) + if err := r.acceptFilterRulesIptables(ipt); err != nil { + log.Warnf("iptables failed (table may be incompatible), falling back to nftables: %v", err) + fw = "nftables" + return r.acceptFilterRulesNftables(r.filterTable) + } + return nil } func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error { @@ -1222,13 +1255,17 @@ func (r *router) removeFilterTableRules() error { return nil } - ipt, err := iptables.New() + ipt, err := iptables.NewWithProtocol(r.iptablesProto()) if err != nil { log.Debugf("iptables not available, using nftables to remove filter rules: %v", err) return r.removeAcceptRulesFromTable(r.filterTable) } - return r.removeAcceptFilterRulesIptables(ipt) + if err := r.removeAcceptFilterRulesIptables(ipt); err != nil { + log.Debugf("iptables removal failed (table may be incompatible), falling back to nftables: %v", err) + return r.removeAcceptRulesFromTable(r.filterTable) + } + return nil } func (r *router) removeAcceptRulesFromTable(table *nftables.Table) error { @@ -1295,7 +1332,7 @@ func (r *router) removeExternalChainsRules() error { func (r *router) findExternalChains() []*nftables.Chain { var chains []*nftables.Chain - families := []nftables.TableFamily{nftables.TableFamilyIPv4, nftables.TableFamilyINet} + families := []nftables.TableFamily{r.af.tableFamily, nftables.TableFamilyINet} for _, family := range families { allChains, err := r.conn.ListChainsOfTableFamily(family) @@ -1319,8 +1356,8 @@ func (r *router) isExternalChain(chain *nftables.Chain) bool { return false } - // Skip all iptables-managed tables in the ip family - if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) { + // Skip iptables/ip6tables-managed tables (adding nft-native rules breaks iptables-save compat) + if (chain.Table.Family == nftables.TableFamilyIPv4 || chain.Table.Family == nftables.TableFamilyIPv6) && isIptablesTable(chain.Table.Name) { return false } @@ -1461,7 +1498,7 @@ func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) { return rule, nil } - protoNum, err := protoToInt(rule.Protocol) + protoNum, err := r.af.protoNum(rule.Protocol) if err != nil { return nil, fmt.Errorf("convert protocol to number: %w", err) } @@ -1524,7 +1561,7 @@ func (r *router) addDnatRedirect(rule firewall.ForwardRule, protoNum uint8, rule dnatExprs = append(dnatExprs, &expr.NAT{ Type: expr.NATTypeDestNAT, - Family: uint32(nftables.TableFamilyIPv4), + Family: uint32(r.af.tableFamily), RegAddrMin: 1, RegProtoMin: regProtoMin, RegProtoMax: regProtoMax, @@ -1620,7 +1657,7 @@ func (r *router) addXTablesRedirect(dnatExprs []expr.Any, ruleKey string, rule f dnatRule := &nftables.Rule{ Table: &nftables.Table{ Name: tableNat, - Family: nftables.TableFamilyIPv4, + Family: r.af.tableFamily, }, Chain: &nftables.Chain{ Name: chainNameNatPrerouting, @@ -1655,8 +1692,8 @@ func (r *router) addDnatMasq(rule firewall.ForwardRule, protoNum uint8, ruleKey &expr.Payload{ DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, - Offset: 16, - Len: 4, + Offset: r.af.dstAddrOffset, + Len: r.af.addrLen, }, &expr.Cmp{ Op: expr.CmpOpEq, @@ -1734,7 +1771,7 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { return fmt.Errorf("get set %s: %w", set.HashedName(), err) } - elements := convertPrefixesToSet(prefixes) + elements := r.convertPrefixesToSet(prefixes) if err := r.conn.SetAddElements(nfset, elements); err != nil { return fmt.Errorf("add elements to set %s: %w", set.HashedName(), err) } @@ -1756,7 +1793,7 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol return nil } - protoNum, err := protoToInt(protocol) + protoNum, err := r.af.protoNum(protocol) if err != nil { return fmt.Errorf("convert protocol to number: %w", err) } @@ -1787,7 +1824,11 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol }, } - exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...) + bits := 32 + if localAddr.Is6() { + bits = 128 + } + exprs = append(exprs, r.applyPrefix(netip.PrefixFrom(localAddr, bits), false)...) exprs = append(exprs, &expr.Immediate{ @@ -1800,7 +1841,7 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol }, &expr.NAT{ Type: expr.NATTypeDestNAT, - Family: uint32(nftables.TableFamilyIPv4), + Family: uint32(r.af.tableFamily), RegAddrMin: 1, RegProtoMin: 2, RegProtoMax: 0, @@ -1887,7 +1928,7 @@ func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, return err } - protoNum, err := protoToInt(protocol) + protoNum, err := r.af.protoNum(protocol) if err != nil { return fmt.Errorf("convert protocol to number: %w", err) } @@ -1912,7 +1953,11 @@ func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, }, } - exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...) + bits := 32 + if localAddr.Is6() { + bits = 128 + } + exprs = append(exprs, r.applyPrefix(netip.PrefixFrom(localAddr, bits), false)...) exprs = append(exprs, &expr.Immediate{ @@ -1925,7 +1970,7 @@ func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, }, &expr.NAT{ Type: expr.NATTypeDestNAT, - Family: uint32(nftables.TableFamilyIPv4), + Family: uint32(r.af.tableFamily), RegAddrMin: 1, RegProtoMin: 2, }, @@ -1993,45 +2038,44 @@ func (r *router) applyNetwork( } if network.IsPrefix() { - return applyPrefix(network.Prefix, isSource), nil + return r.applyPrefix(network.Prefix, isSource), nil } return nil, nil } // applyPrefix generates nftables expressions for a CIDR prefix -func applyPrefix(prefix netip.Prefix, isSource bool) []expr.Any { - // dst offset - offset := uint32(16) +func (r *router) applyPrefix(prefix netip.Prefix, isSource bool) []expr.Any { + // dst offset by default + offset := r.af.dstAddrOffset if isSource { // src offset - offset = 12 + offset = r.af.srcAddrOffset } ones := prefix.Bits() - // 0.0.0.0/0 doesn't need extra expressions + // unspecified address (/0) doesn't need extra expressions if ones == 0 { return nil } - mask := net.CIDRMask(ones, 32) + mask := net.CIDRMask(ones, r.af.totalBits) + xor := make([]byte, r.af.addrLen) return []expr.Any{ &expr.Payload{ DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: offset, - Len: 4, + Len: r.af.addrLen, }, - // netmask &expr.Bitwise{ DestRegister: 1, SourceRegister: 1, - Len: 4, + Len: r.af.addrLen, Mask: mask, - Xor: []byte{0, 0, 0, 0}, + Xor: xor, }, - // net address &expr.Cmp{ Op: expr.CmpOpEq, Register: 1, @@ -2114,13 +2158,12 @@ func getCtNewExprs() []expr.Any { } } -func getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any, error) { - - // dst offset - offset := uint32(16) +func (r *router) getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any, error) { + // dst offset by default + offset := r.af.dstAddrOffset if isSource { // src offset - offset = 12 + offset = r.af.srcAddrOffset } return []expr.Any{ @@ -2128,7 +2171,7 @@ func getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: offset, - Len: 4, + Len: r.af.addrLen, }, &expr.Lookup{ SourceRegister: 1, diff --git a/client/firewall/nftables/router_linux_test.go b/client/firewall/nftables/router_linux_test.go index f0e34d211..c5d6729d9 100644 --- a/client/firewall/nftables/router_linux_test.go +++ b/client/firewall/nftables/router_linux_test.go @@ -90,8 +90,9 @@ func TestNftablesManager_AddNatRule(t *testing.T) { } // Build CIDR matching expressions - sourceExp := applyPrefix(testCase.InputPair.Source.Prefix, true) - destExp := applyPrefix(testCase.InputPair.Destination.Prefix, false) + testRouter := &router{af: afIPv4} + sourceExp := testRouter.applyPrefix(testCase.InputPair.Source.Prefix, true) + destExp := testRouter.applyPrefix(testCase.InputPair.Destination.Prefix, false) // Combine all expressions in the correct order // nolint:gocritic @@ -508,6 +509,136 @@ func TestNftablesCreateIpSet(t *testing.T) { } } +func TestNftablesCreateIpSet_IPv6(t *testing.T) { + if check() != NFTABLES { + t.Skip("nftables not supported on this system") + } + + workTable, err := createWorkTableIPv6() + require.NoError(t, err, "Failed to create v6 work table") + defer deleteWorkTableIPv6() + + r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU) + require.NoError(t, err, "Failed to create router") + require.NoError(t, r.init(workTable)) + defer func() { + require.NoError(t, r.Reset(), "Failed to reset router") + }() + + tests := []struct { + name string + sources []netip.Prefix + expected []netip.Prefix + }{ + { + name: "Single IPv6", + sources: []netip.Prefix{netip.MustParsePrefix("2001:db8::1/128")}, + }, + { + name: "Multiple IPv6 Subnets", + sources: []netip.Prefix{ + netip.MustParsePrefix("fd00::/64"), + netip.MustParsePrefix("2001:db8::/48"), + netip.MustParsePrefix("fe80::/10"), + }, + }, + { + name: "Overlapping IPv6", + sources: []netip.Prefix{ + netip.MustParsePrefix("fd00::/48"), + netip.MustParsePrefix("fd00::/64"), + netip.MustParsePrefix("fd00::1/128"), + }, + expected: []netip.Prefix{ + netip.MustParsePrefix("fd00::/48"), + }, + }, + { + name: "Mixed prefix lengths", + sources: []netip.Prefix{ + netip.MustParsePrefix("2001:db8:1::/48"), + netip.MustParsePrefix("2001:db8:2::1/128"), + netip.MustParsePrefix("fd00:abcd::/32"), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + setName := firewall.NewPrefixSet(tt.sources).HashedName() + set, err := r.createIpSet(setName, setInput{prefixes: tt.sources}) + require.NoError(t, err, "Failed to create IPv6 set") + require.NotNil(t, set) + + assert.Equal(t, setName, set.Name) + assert.True(t, set.Interval) + assert.Equal(t, nftables.TypeIP6Addr, set.KeyType) + + fetchedSet, err := r.conn.GetSetByName(r.workTable, setName) + require.NoError(t, err, "Failed to fetch created set") + + elements, err := r.conn.GetSetElements(fetchedSet) + require.NoError(t, err, "Failed to get set elements") + + uniquePrefixes := make(map[string]bool) + for _, elem := range elements { + if !elem.IntervalEnd && len(elem.Key) == 16 { + ip := netip.AddrFrom16([16]byte(elem.Key)) + uniquePrefixes[ip.String()] = true + } + } + + expectedCount := len(tt.expected) + if expectedCount == 0 { + expectedCount = len(tt.sources) + } + assert.Equal(t, expectedCount, len(uniquePrefixes), "unique prefix count mismatch") + + r.conn.DelSet(set) + require.NoError(t, r.conn.Flush()) + }) + } +} + +func createWorkTableIPv6() (*nftables.Table, error) { + sConn, err := nftables.New(nftables.AsLasting()) + if err != nil { + return nil, err + } + + tables, err := sConn.ListTablesOfFamily(nftables.TableFamilyIPv6) + if err != nil { + return nil, err + } + for _, t := range tables { + if t.Name == tableNameNetbird { + sConn.DelTable(t) + } + } + + table := sConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv6}) + err = sConn.Flush() + return table, err +} + +func deleteWorkTableIPv6() { + sConn, err := nftables.New(nftables.AsLasting()) + if err != nil { + return + } + + tables, err := sConn.ListTablesOfFamily(nftables.TableFamilyIPv6) + if err != nil { + return + } + for _, t := range tables { + if t.Name == tableNameNetbird { + sConn.DelTable(t) + _ = sConn.Flush() + } + } +} + func verifyRule(t *testing.T, rule *nftables.Rule, sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, expectSet bool) { t.Helper() @@ -627,7 +758,7 @@ func containsPort(exprs []expr.Any, port *firewall.Port, isSource bool) bool { func containsProtocol(exprs []expr.Any, proto firewall.Protocol) bool { var metaFound, cmpFound bool - expectedProto, _ := protoToInt(proto) + expectedProto, _ := afIPv4.protoNum(proto) for _, e := range exprs { switch ex := e.(type) { case *expr.Meta: @@ -854,3 +985,55 @@ func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) { } assert.Equal(t, 1, found, "NAT rule should exist in kernel") } + +func TestCalculateLastIP(t *testing.T) { + tests := []struct { + prefix string + want string + }{ + {"10.0.0.0/24", "10.0.0.255"}, + {"10.0.0.0/32", "10.0.0.0"}, + {"0.0.0.0/0", "255.255.255.255"}, + {"192.168.1.0/28", "192.168.1.15"}, + {"fd00::/64", "fd00::ffff:ffff:ffff:ffff"}, + {"fd00::/128", "fd00::"}, + {"2001:db8::/48", "2001:db8:0:ffff:ffff:ffff:ffff:ffff"}, + {"::/0", "ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"}, + } + for _, tt := range tests { + t.Run(tt.prefix, func(t *testing.T) { + prefix := netip.MustParsePrefix(tt.prefix) + got := calculateLastIP(prefix) + assert.Equal(t, tt.want, got.String()) + }) + } +} + +func TestConvertPrefixesToSet_IPv6(t *testing.T) { + r := &router{af: afIPv6} + prefixes := []netip.Prefix{ + netip.MustParsePrefix("fd00::/64"), + netip.MustParsePrefix("2001:db8::1/128"), + } + + elements := r.convertPrefixesToSet(prefixes) + + // Each prefix produces 2 elements (start + end) + require.Len(t, elements, 4) + + // fd00::/64 start + assert.Equal(t, netip.MustParseAddr("fd00::").As16(), [16]byte(elements[0].Key)) + assert.False(t, elements[0].IntervalEnd) + + // fd00::/64 end (fd00:0:0:1::, one past the last) + assert.Equal(t, netip.MustParseAddr("fd00:0:0:1::").As16(), [16]byte(elements[1].Key)) + assert.True(t, elements[1].IntervalEnd) + + // 2001:db8::1/128 start + assert.Equal(t, netip.MustParseAddr("2001:db8::1").As16(), [16]byte(elements[2].Key)) + assert.False(t, elements[2].IntervalEnd) + + // 2001:db8::1/128 end (2001:db8::2) + assert.Equal(t, netip.MustParseAddr("2001:db8::2").As16(), [16]byte(elements[3].Key)) + assert.True(t, elements[3].IntervalEnd) +} diff --git a/client/firewall/uspfilter/allow_netbird_windows.go b/client/firewall/uspfilter/allow_netbird_windows.go index 6aef2ecfd..10a2b9116 100644 --- a/client/firewall/uspfilter/allow_netbird_windows.go +++ b/client/firewall/uspfilter/allow_netbird_windows.go @@ -5,8 +5,10 @@ import ( "os/exec" "syscall" + "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" + nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/internal/statemanager" ) @@ -29,15 +31,20 @@ func (m *Manager) Close(*statemanager.Manager) error { return nil } - if !isFirewallRuleActive(firewallRuleName) { - return nil + var merr *multierror.Error + if isFirewallRuleActive(firewallRuleName) { + if err := manageFirewallRule(firewallRuleName, deleteRule); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove windows firewall rule: %w", err)) + } } - if err := manageFirewallRule(firewallRuleName, deleteRule); err != nil { - return fmt.Errorf("couldn't remove windows firewall: %w", err) + if isFirewallRuleActive(firewallRuleName + "-v6") { + if err := manageFirewallRule(firewallRuleName+"-v6", deleteRule); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove windows v6 firewall rule: %w", err)) + } } - return nil + return nberrors.FormatErrorOrNil(merr) } // AllowNetbird allows netbird interface traffic @@ -46,17 +53,33 @@ func (m *Manager) AllowNetbird() error { return nil } - if isFirewallRuleActive(firewallRuleName) { - return nil + if !isFirewallRuleActive(firewallRuleName) { + if err := manageFirewallRule(firewallRuleName, + addRule, + "dir=in", + "enable=yes", + "action=allow", + "profile=any", + "localip="+m.wgIface.Address().IP.String(), + ); err != nil { + return err + } } - return manageFirewallRule(firewallRuleName, - addRule, - "dir=in", - "enable=yes", - "action=allow", - "profile=any", - "localip="+m.wgIface.Address().IP.String(), - ) + + if v6 := m.wgIface.Address().IPv6; v6.IsValid() && !isFirewallRuleActive(firewallRuleName+"-v6") { + if err := manageFirewallRule(firewallRuleName+"-v6", + addRule, + "dir=in", + "enable=yes", + "action=allow", + "profile=any", + "localip="+v6.String(), + ); err != nil { + return err + } + } + + return nil } func manageFirewallRule(ruleName string, action action, extraArgs ...string) error { diff --git a/client/firewall/uspfilter/conntrack/common.go b/client/firewall/uspfilter/conntrack/common.go index 7be0dd78f..88e90317c 100644 --- a/client/firewall/uspfilter/conntrack/common.go +++ b/client/firewall/uspfilter/conntrack/common.go @@ -1,8 +1,9 @@ package conntrack import ( - "fmt" + "net" "net/netip" + "strconv" "sync/atomic" "time" @@ -64,5 +65,7 @@ type ConnKey struct { } func (c ConnKey) String() string { - return fmt.Sprintf("%s:%d → %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort) + return net.JoinHostPort(c.SrcIP.Unmap().String(), strconv.Itoa(int(c.SrcPort))) + + " → " + + net.JoinHostPort(c.DstIP.Unmap().String(), strconv.Itoa(int(c.DstPort))) } diff --git a/client/firewall/uspfilter/conntrack/icmp.go b/client/firewall/uspfilter/conntrack/icmp.go index 50b663642..85b6f13be 100644 --- a/client/firewall/uspfilter/conntrack/icmp.go +++ b/client/firewall/uspfilter/conntrack/icmp.go @@ -21,9 +21,10 @@ const ( // ICMPCleanupInterval is how often we check for stale ICMP connections ICMPCleanupInterval = 15 * time.Second - // MaxICMPPayloadLength is the maximum length of ICMP payload we consider for original packet info, - // which includes the IP header (20 bytes) and transport header (8 bytes) - MaxICMPPayloadLength = 28 + // MaxICMPPayloadLength is the maximum length of ICMP payload we consider for original packet info. + // IPv4: 20-byte header + 8-byte transport = 28 bytes. + // IPv6: 40-byte header + 8-byte transport = 48 bytes. + MaxICMPPayloadLength = 48 ) // ICMPConnKey uniquely identifies an ICMP connection @@ -74,32 +75,64 @@ func (info ICMPInfo) String() string { return info.TypeCode.String() } -// isErrorMessage returns true if this ICMP type carries original packet info +// isErrorMessage returns true if this ICMP type carries original packet info. +// Covers both ICMPv4 and ICMPv6 error types. Without a family field we match +// both sets; type 3 overlaps (v4 DestUnreachable / v6 TimeExceeded) so it's +// kept as a literal. func (info ICMPInfo) isErrorMessage() bool { typ := info.TypeCode.Type() - return typ == 3 || // Destination Unreachable - typ == 5 || // Redirect - typ == 11 || // Time Exceeded - typ == 12 // Parameter Problem + // ICMPv4 error types + if typ == layers.ICMPv4TypeDestinationUnreachable || + typ == layers.ICMPv4TypeRedirect || + typ == layers.ICMPv4TypeTimeExceeded || + typ == layers.ICMPv4TypeParameterProblem { + return true + } + // ICMPv6 error types (type 3 already matched above as v4 DestUnreachable) + if typ == layers.ICMPv6TypeDestinationUnreachable || + typ == layers.ICMPv6TypePacketTooBig || + typ == layers.ICMPv6TypeParameterProblem { + return true + } + return false } // parseOriginalPacket extracts info about the original packet from ICMP payload func (info ICMPInfo) parseOriginalPacket() string { - if info.PayloadLen < MaxICMPPayloadLength { + if info.PayloadLen == 0 { return "" } - // TODO: handle IPv6 - if version := (info.PayloadData[0] >> 4) & 0xF; version != 4 { + version := (info.PayloadData[0] >> 4) & 0xF + + var protocol uint8 + var srcIP, dstIP net.IP + var transportData []byte + + switch version { + case 4: + // 20-byte IPv4 header + 8-byte transport minimum + if info.PayloadLen < 28 { + return "" + } + protocol = info.PayloadData[9] + srcIP = net.IP(info.PayloadData[12:16]) + dstIP = net.IP(info.PayloadData[16:20]) + transportData = info.PayloadData[20:] + case 6: + // 40-byte IPv6 header + 8-byte transport minimum + if info.PayloadLen < 48 { + return "" + } + // Next Header field in IPv6 header + protocol = info.PayloadData[6] + srcIP = net.IP(info.PayloadData[8:24]) + dstIP = net.IP(info.PayloadData[24:40]) + transportData = info.PayloadData[40:] + default: return "" } - protocol := info.PayloadData[9] - srcIP := net.IP(info.PayloadData[12:16]) - dstIP := net.IP(info.PayloadData[16:20]) - - transportData := info.PayloadData[20:] - switch nftypes.Protocol(protocol) { case nftypes.TCP: srcPort := uint16(transportData[0])<<8 | uint16(transportData[1]) @@ -247,9 +280,10 @@ func (t *ICMPTracker) track( t.sendEvent(nftypes.TypeStart, conn, ruleId) } -// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request +// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request. +// Accepts both ICMPv4 (type 0) and ICMPv6 (type 129) echo replies. func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, icmpType uint8, size int) bool { - if icmpType != uint8(layers.ICMPv4TypeEchoReply) { + if icmpType != uint8(layers.ICMPv4TypeEchoReply) && icmpType != uint8(layers.ICMPv6TypeEchoReply) { return false } @@ -301,6 +335,13 @@ func (t *ICMPTracker) cleanup() { } } +func icmpProtocolForAddr(ip netip.Addr) nftypes.Protocol { + if ip.Is6() { + return nftypes.ICMPv6 + } + return nftypes.ICMP +} + // Close stops the cleanup routine and releases resources func (t *ICMPTracker) Close() { t.tickerCancel() @@ -316,7 +357,7 @@ func (t *ICMPTracker) sendEvent(typ nftypes.Type, conn *ICMPConnTrack, ruleID [] Type: typ, RuleID: ruleID, Direction: conn.Direction, - Protocol: nftypes.ICMP, // TODO: adjust for IPv6/icmpv6 + Protocol: icmpProtocolForAddr(conn.SourceIP), SourceIP: conn.SourceIP, DestIP: conn.DestIP, ICMPType: conn.ICMPType, @@ -334,7 +375,7 @@ func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, srcIP netip.Ad Type: nftypes.TypeStart, RuleID: ruleID, Direction: direction, - Protocol: nftypes.ICMP, + Protocol: icmpProtocolForAddr(srcIP), SourceIP: srcIP, DestIP: dstIP, ICMPType: typ, diff --git a/client/firewall/uspfilter/filter.go b/client/firewall/uspfilter/filter.go index cb9e1bb0a..75a02ac6f 100644 --- a/client/firewall/uspfilter/filter.go +++ b/client/firewall/uspfilter/filter.go @@ -35,8 +35,10 @@ import ( const ( layerTypeAll = 255 - // ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation - ipTCPHeaderMinSize = 40 + // ipv4TCPHeaderMinSize represents minimum IPv4 (20) + TCP (20) header size for MSS calculation + ipv4TCPHeaderMinSize = 40 + // ipv6TCPHeaderMinSize represents minimum IPv6 (40) + TCP (20) header size for MSS calculation + ipv6TCPHeaderMinSize = 60 ) // serviceKey represents a protocol/port combination for netstack service registry @@ -137,9 +139,10 @@ type Manager struct { netstackServices map[serviceKey]struct{} netstackServiceMutex sync.RWMutex - mtu uint16 - mssClampValue uint16 - mssClampEnabled bool + mtu uint16 + mssClampValueIPv4 uint16 + mssClampValueIPv6 uint16 + mssClampEnabled bool // Only one hook per protocol is supported. Outbound direction only. udpHookOut atomic.Pointer[packetHook] @@ -163,11 +166,28 @@ type decoder struct { icmp4 layers.ICMPv4 icmp6 layers.ICMPv6 decoded []gopacket.LayerType - parser *gopacket.DecodingLayerParser + parser4 *gopacket.DecodingLayerParser + parser6 *gopacket.DecodingLayerParser dnatOrigPort uint16 } +// decodePacket decodes packet data using the appropriate parser based on IP version. +func (d *decoder) decodePacket(data []byte) error { + if len(data) == 0 { + return errors.New("empty packet") + } + version := data[0] >> 4 + switch version { + case 4: + return d.parser4.DecodeLayers(data, &d.decoded) + case 6: + return d.parser6.DecodeLayers(data, &d.decoded) + default: + return fmt.Errorf("unknown IP version %d", version) + } +} + // Create userspace firewall manager constructor func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) { return create(iface, nil, disableServerRoutes, flowLogger, mtu) @@ -225,11 +245,17 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe d := &decoder{ decoded: []gopacket.LayerType{}, } - d.parser = gopacket.NewDecodingLayerParser( + d.parser4 = gopacket.NewDecodingLayerParser( layers.LayerTypeIPv4, &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, ) - d.parser.IgnoreUnsupported = true + d.parser4.IgnoreUnsupported = true + + d.parser6 = gopacket.NewDecodingLayerParser( + layers.LayerTypeIPv6, + &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, + ) + d.parser6.IgnoreUnsupported = true return d }, }, @@ -255,7 +281,8 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe if !disableMSSClamping { m.mssClampEnabled = true - m.mssClampValue = mtu - ipTCPHeaderMinSize + m.mssClampValueIPv4 = mtu - ipv4TCPHeaderMinSize + m.mssClampValueIPv6 = mtu - ipv6TCPHeaderMinSize } if err := m.localipmanager.UpdateLocalIPs(iface); err != nil { return nil, fmt.Errorf("update local IPs: %w", err) @@ -282,9 +309,14 @@ func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) (firewall.Rule, e wgPrefix := iface.Address().Network log.Debugf("blocking invalid routed traffic for %s", wgPrefix) + sources := []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)} + if v6 := iface.Address().IPv6Net; v6.IsValid() { + sources = append(sources, netip.PrefixFrom(netip.IPv6Unspecified(), 0)) + } + rule, err := m.addRouteFiltering( nil, - []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}, + sources, firewall.Network{Prefix: wgPrefix}, firewall.ProtocolALL, nil, @@ -292,7 +324,22 @@ func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) (firewall.Rule, e firewall.ActionDrop, ) if err != nil { - return nil, fmt.Errorf("block wg nte : %w", err) + return nil, fmt.Errorf("block wg v4 net: %w", err) + } + + if v6Net := iface.Address().IPv6Net; v6Net.IsValid() { + log.Debugf("blocking invalid routed traffic for %s", v6Net) + if _, err := m.addRouteFiltering( + nil, + sources, + firewall.Network{Prefix: v6Net}, + firewall.ProtocolALL, + nil, + nil, + firewall.ActionDrop, + ); err != nil { + return nil, fmt.Errorf("block wg v6 net: %w", err) + } } // TODO: Block networks that we're a client of @@ -509,7 +556,7 @@ func (m *Manager) addRouteFiltering( mgmtId: id, sources: sources, dstSet: destination.Set, - protoLayer: protoToLayer(proto, layers.LayerTypeIPv4), + protoLayer: protoToLayer(proto, ipLayerFromPrefix(destination.Prefix)), srcPort: sPort, dstPort: dPort, action: action, @@ -663,11 +710,7 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { } destinations := matches[0].destinations - for _, prefix := range prefixes { - if prefix.Addr().Is4() { - destinations = append(destinations, prefix) - } - } + destinations = append(destinations, prefixes...) slices.SortFunc(destinations, func(a, b netip.Prefix) int { cmp := a.Addr().Compare(b.Addr()) @@ -706,7 +749,7 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool { d := m.decoders.Get().(*decoder) defer m.decoders.Put(d) - if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + if err := d.decodePacket(packetData); err != nil { return false } @@ -790,12 +833,28 @@ func (m *Manager) clampTCPMSS(packetData []byte, d *decoder) bool { return false } + var mssClampValue uint16 + var ipHeaderSize int + switch d.decoded[0] { + case layers.LayerTypeIPv4: + mssClampValue = m.mssClampValueIPv4 + ipHeaderSize = int(d.ip4.IHL) * 4 + if ipHeaderSize < 20 { + return false + } + case layers.LayerTypeIPv6: + mssClampValue = m.mssClampValueIPv6 + ipHeaderSize = 40 + default: + return false + } + mssOptionIndex := -1 var currentMSS uint16 for i, opt := range d.tcp.Options { if opt.OptionType == layers.TCPOptionKindMSS && len(opt.OptionData) == 2 { currentMSS = binary.BigEndian.Uint16(opt.OptionData) - if currentMSS > m.mssClampValue { + if currentMSS > mssClampValue { mssOptionIndex = i break } @@ -806,20 +865,15 @@ func (m *Manager) clampTCPMSS(packetData []byte, d *decoder) bool { return false } - ipHeaderSize := int(d.ip4.IHL) * 4 - if ipHeaderSize < 20 { + if !m.updateMSSOption(packetData, d, mssOptionIndex, mssClampValue, ipHeaderSize) { return false } - if !m.updateMSSOption(packetData, d, mssOptionIndex, ipHeaderSize) { - return false - } - - m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, m.mssClampValue) + m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, mssClampValue) return true } -func (m *Manager) updateMSSOption(packetData []byte, d *decoder, mssOptionIndex, ipHeaderSize int) bool { +func (m *Manager) updateMSSOption(packetData []byte, d *decoder, mssOptionIndex int, mssClampValue uint16, ipHeaderSize int) bool { tcpHeaderStart := ipHeaderSize tcpOptionsStart := tcpHeaderStart + 20 @@ -834,7 +888,7 @@ func (m *Manager) updateMSSOption(packetData []byte, d *decoder, mssOptionIndex, } mssValueOffset := optOffset + 2 - binary.BigEndian.PutUint16(packetData[mssValueOffset:mssValueOffset+2], m.mssClampValue) + binary.BigEndian.PutUint16(packetData[mssValueOffset:mssValueOffset+2], mssClampValue) m.recalculateTCPChecksum(packetData, d, tcpHeaderStart) return true @@ -844,18 +898,32 @@ func (m *Manager) recalculateTCPChecksum(packetData []byte, d *decoder, tcpHeade tcpLayer := packetData[tcpHeaderStart:] tcpLength := len(packetData) - tcpHeaderStart + // Zero out existing checksum tcpLayer[16] = 0 tcpLayer[17] = 0 + // Build pseudo-header checksum based on IP version var pseudoSum uint32 - pseudoSum += uint32(d.ip4.SrcIP[0])<<8 | uint32(d.ip4.SrcIP[1]) - pseudoSum += uint32(d.ip4.SrcIP[2])<<8 | uint32(d.ip4.SrcIP[3]) - pseudoSum += uint32(d.ip4.DstIP[0])<<8 | uint32(d.ip4.DstIP[1]) - pseudoSum += uint32(d.ip4.DstIP[2])<<8 | uint32(d.ip4.DstIP[3]) - pseudoSum += uint32(d.ip4.Protocol) - pseudoSum += uint32(tcpLength) + switch d.decoded[0] { + case layers.LayerTypeIPv4: + pseudoSum += uint32(d.ip4.SrcIP[0])<<8 | uint32(d.ip4.SrcIP[1]) + pseudoSum += uint32(d.ip4.SrcIP[2])<<8 | uint32(d.ip4.SrcIP[3]) + pseudoSum += uint32(d.ip4.DstIP[0])<<8 | uint32(d.ip4.DstIP[1]) + pseudoSum += uint32(d.ip4.DstIP[2])<<8 | uint32(d.ip4.DstIP[3]) + pseudoSum += uint32(d.ip4.Protocol) + pseudoSum += uint32(tcpLength) + case layers.LayerTypeIPv6: + for i := 0; i < 16; i += 2 { + pseudoSum += uint32(d.ip6.SrcIP[i])<<8 | uint32(d.ip6.SrcIP[i+1]) + } + for i := 0; i < 16; i += 2 { + pseudoSum += uint32(d.ip6.DstIP[i])<<8 | uint32(d.ip6.DstIP[i+1]) + } + pseudoSum += uint32(tcpLength) + pseudoSum += uint32(layers.IPProtocolTCP) + } - var sum = pseudoSum + sum := pseudoSum for i := 0; i < tcpLength-1; i += 2 { sum += uint32(tcpLayer[i])<<8 | uint32(tcpLayer[i+1]) } @@ -893,6 +961,9 @@ func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, packetData } case layers.LayerTypeICMPv4: m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size) + case layers.LayerTypeICMPv6: + id, tc := icmpv6EchoFields(d) + m.icmpTracker.TrackOutbound(srcIP, dstIP, id, tc, d.icmp6.Payload, size) } } @@ -906,6 +977,9 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size, d.dnatOrigPort) case layers.LayerTypeICMPv4: m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size) + case layers.LayerTypeICMPv6: + id, tc := icmpv6EchoFields(d) + m.icmpTracker.TrackInbound(srcIP, dstIP, id, tc, ruleID, d.icmp6.Payload, size) } d.dnatOrigPort = 0 @@ -948,15 +1022,19 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool { // TODO: pass fragments of routed packets to forwarder if fragment { - m.logger.Trace4("packet is a fragment: src=%v dst=%v id=%v flags=%v", - srcIP, dstIP, d.ip4.Id, d.ip4.Flags) + if d.decoded[0] == layers.LayerTypeIPv4 { + m.logger.Trace4("packet is a fragment: src=%v dst=%v id=%v flags=%v", + srcIP, dstIP, d.ip4.Id, d.ip4.Flags) + } else { + m.logger.Trace2("packet is an IPv6 fragment: src=%v dst=%v", srcIP, dstIP) + } return false } // TODO: optimize port DNAT by caching matched rules in conntrack if translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP); translated { // Re-decode after port DNAT translation to update port information - if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + if err := d.decodePacket(packetData); err != nil { m.logger.Error1("failed to re-decode packet after port DNAT: %v", err) return true } @@ -965,7 +1043,7 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool { if translated := m.translateInboundReverse(packetData, d); translated { // Re-decode after translation to get original addresses - if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + if err := d.decodePacket(packetData); err != nil { m.logger.Error1("failed to re-decode packet after reverse DNAT: %v", err) return true } @@ -1097,6 +1175,48 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe return true } +// icmpv6EchoFields extracts the echo identifier from an ICMPv6 packet and maps +// the ICMPv6 type code to an ICMPv4TypeCode so the ICMP conntrack can handle +// both families uniformly. The echo ID is in the first two payload bytes. +func icmpv6EchoFields(d *decoder) (id uint16, tc layers.ICMPv4TypeCode) { + if len(d.icmp6.Payload) >= 2 { + id = uint16(d.icmp6.Payload[0])<<8 | uint16(d.icmp6.Payload[1]) + } + // Map ICMPv6 echo types to ICMPv4 equivalents for unified tracking. + switch d.icmp6.TypeCode.Type() { + case layers.ICMPv6TypeEchoRequest: + tc = layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0) + case layers.ICMPv6TypeEchoReply: + tc = layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoReply, 0) + default: + tc = layers.CreateICMPv4TypeCode(d.icmp6.TypeCode.Type(), d.icmp6.TypeCode.Code()) + } + return id, tc +} + +// protoLayerMatches checks if a packet's protocol layer matches a rule's expected +// protocol layer. ICMPv4 and ICMPv6 are treated as equivalent when matching +// ICMP rules since management sends a single ICMP rule for both families. +func protoLayerMatches(ruleLayer, packetLayer gopacket.LayerType) bool { + if ruleLayer == packetLayer { + return true + } + if ruleLayer == layers.LayerTypeICMPv4 && packetLayer == layers.LayerTypeICMPv6 { + return true + } + if ruleLayer == layers.LayerTypeICMPv6 && packetLayer == layers.LayerTypeICMPv4 { + return true + } + return false +} + +func ipLayerFromPrefix(p netip.Prefix) gopacket.LayerType { + if p.Addr().Is6() { + return layers.LayerTypeIPv6 + } + return layers.LayerTypeIPv4 +} + func protoToLayer(proto firewall.Protocol, ipLayer gopacket.LayerType) gopacket.LayerType { switch proto { case firewall.ProtocolTCP: @@ -1120,8 +1240,10 @@ func getProtocolFromPacket(d *decoder) nftypes.Protocol { return nftypes.TCP case layers.LayerTypeUDP: return nftypes.UDP - case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6: + case layers.LayerTypeICMPv4: return nftypes.ICMP + case layers.LayerTypeICMPv6: + return nftypes.ICMPv6 default: return nftypes.ProtocolUnknown } @@ -1142,7 +1264,7 @@ func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) { // It returns true, false if the packet is valid and not a fragment. // It returns true, true if the packet is a fragment and valid. func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) { - if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + if err := d.decodePacket(packetData); err != nil { m.logger.Trace1("couldn't decode packet, err: %s", err) return false, false } @@ -1155,10 +1277,18 @@ func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) { } // Fragments are also valid - if l == 1 && d.decoded[0] == layers.LayerTypeIPv4 { - ip4 := d.ip4 - if ip4.Flags&layers.IPv4MoreFragments != 0 || ip4.FragOffset != 0 { - return true, true + if l == 1 { + switch d.decoded[0] { + case layers.LayerTypeIPv4: + if d.ip4.Flags&layers.IPv4MoreFragments != 0 || d.ip4.FragOffset != 0 { + return true, true + } + case layers.LayerTypeIPv6: + // IPv6 uses Fragment extension header (NextHeader=44). If gopacket + // only decoded the IPv6 layer, the transport is in a fragment. + if d.ip6.NextHeader == layers.IPProtocolIPv6Fragment { + return true, true + } } } @@ -1196,21 +1326,34 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr, size, ) - // TODO: ICMPv6 + case layers.LayerTypeICMPv6: + id, _ := icmpv6EchoFields(d) + return m.icmpTracker.IsValidInbound( + srcIP, + dstIP, + id, + d.icmp6.TypeCode.Type(), + size, + ) } return false } -// isSpecialICMP returns true if the packet is a special ICMP packet that should be allowed +// isSpecialICMP returns true if the packet is a special ICMP error packet that should be allowed. func (m *Manager) isSpecialICMP(d *decoder) bool { - if d.decoded[1] != layers.LayerTypeICMPv4 { - return false + switch d.decoded[1] { + case layers.LayerTypeICMPv4: + icmpType := d.icmp4.TypeCode.Type() + return icmpType == layers.ICMPv4TypeDestinationUnreachable || + icmpType == layers.ICMPv4TypeTimeExceeded + case layers.LayerTypeICMPv6: + icmpType := d.icmp6.TypeCode.Type() + return icmpType == layers.ICMPv6TypeDestinationUnreachable || + icmpType == layers.ICMPv6TypePacketTooBig || + icmpType == layers.ICMPv6TypeTimeExceeded } - - icmpType := d.icmp4.TypeCode.Type() - return icmpType == layers.ICMPv4TypeDestinationUnreachable || - icmpType == layers.ICMPv4TypeTimeExceeded + return false } func (m *Manager) peerACLsBlock(srcIP netip.Addr, d *decoder, packetData []byte) ([]byte, bool) { @@ -1267,7 +1410,7 @@ func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d return rule.mgmtId, rule.drop, true } - if payloadLayer != rule.protoLayer { + if !protoLayerMatches(rule.protoLayer, payloadLayer) { continue } @@ -1302,8 +1445,7 @@ func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, protoLayer gopacket.Lay } func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) bool { - // TODO: handle ipv6 vs ipv4 icmp rules - if rule.protoLayer != layerTypeAll && rule.protoLayer != protoLayer { + if rule.protoLayer != layerTypeAll && !protoLayerMatches(rule.protoLayer, protoLayer) { return false } @@ -1473,7 +1615,8 @@ func (m *Manager) shouldForward(d *decoder, dstIP netip.Addr) bool { } // traffic to our other local interfaces (not NetBird IP) - always forward - if dstIP != m.wgIface.Address().IP { + addr := m.wgIface.Address() + if dstIP != addr.IP && (!addr.IPv6.IsValid() || dstIP != addr.IPv6) { return true } diff --git a/client/firewall/uspfilter/filter_bench_test.go b/client/firewall/uspfilter/filter_bench_test.go index 10ff62ed3..4dccb0f65 100644 --- a/client/firewall/uspfilter/filter_bench_test.go +++ b/client/firewall/uspfilter/filter_bench_test.go @@ -1023,7 +1023,8 @@ func BenchmarkMSSClamping(b *testing.B) { }() manager.mssClampEnabled = true - manager.mssClampValue = 1240 + manager.mssClampValueIPv4 = 1240 + manager.mssClampValueIPv6 = 1220 srcIP := net.ParseIP("100.64.0.2") dstIP := net.ParseIP("8.8.8.8") @@ -1088,7 +1089,8 @@ func BenchmarkMSSClampingOverhead(b *testing.B) { manager.mssClampEnabled = sc.enabled if sc.enabled { - manager.mssClampValue = 1240 + manager.mssClampValueIPv4 = 1240 + manager.mssClampValueIPv6 = 1220 } srcIP := net.ParseIP("100.64.0.2") @@ -1141,7 +1143,8 @@ func BenchmarkMSSClampingMemory(b *testing.B) { }() manager.mssClampEnabled = true - manager.mssClampValue = 1240 + manager.mssClampValueIPv4 = 1240 + manager.mssClampValueIPv6 = 1220 srcIP := net.ParseIP("100.64.0.2") dstIP := net.ParseIP("8.8.8.8") diff --git a/client/firewall/uspfilter/filter_filter_test.go b/client/firewall/uspfilter/filter_filter_test.go index a8efbac1c..a64c83138 100644 --- a/client/firewall/uspfilter/filter_filter_test.go +++ b/client/firewall/uspfilter/filter_filter_test.go @@ -539,53 +539,236 @@ func TestPeerACLFiltering(t *testing.T) { } } +func TestPeerACLFilteringIPv6(t *testing.T) { + localIP := netip.MustParseAddr("100.10.0.100") + localIPv6 := netip.MustParseAddr("fd00::100") + wgNet := netip.MustParsePrefix("100.10.0.0/16") + wgNetV6 := netip.MustParsePrefix("fd00::/64") + + ifaceMock := &IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + AddressFunc: func() wgaddr.Address { + return wgaddr.Address{ + IP: localIP, + Network: wgNet, + IPv6: localIPv6, + IPv6Net: wgNetV6, + } + }, + } + + manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, manager.Close(nil)) }) + + err = manager.UpdateLocalIPs() + require.NoError(t, err) + + testCases := []struct { + name string + srcIP string + dstIP string + proto fw.Protocol + srcPort uint16 + dstPort uint16 + ruleIP string + ruleProto fw.Protocol + ruleDstPort *fw.Port + ruleAction fw.Action + shouldBeBlocked bool + }{ + { + name: "IPv6: allow TCP from peer", + srcIP: "fd00::1", + dstIP: "fd00::100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 443, + ruleIP: "fd00::1", + ruleProto: fw.ProtocolTCP, + ruleDstPort: &fw.Port{Values: []uint16{443}}, + ruleAction: fw.ActionAccept, + shouldBeBlocked: false, + }, + { + name: "IPv6: allow UDP from peer", + srcIP: "fd00::1", + dstIP: "fd00::100", + proto: fw.ProtocolUDP, + srcPort: 12345, + dstPort: 53, + ruleIP: "fd00::1", + ruleProto: fw.ProtocolUDP, + ruleDstPort: &fw.Port{Values: []uint16{53}}, + ruleAction: fw.ActionAccept, + shouldBeBlocked: false, + }, + { + name: "IPv6: allow ICMPv6 from peer", + srcIP: "fd00::1", + dstIP: "fd00::100", + proto: fw.ProtocolICMP, + ruleIP: "fd00::1", + ruleProto: fw.ProtocolICMP, + ruleAction: fw.ActionAccept, + shouldBeBlocked: false, + }, + { + name: "IPv6: block TCP without rule", + srcIP: "fd00::2", + dstIP: "fd00::100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 443, + ruleIP: "fd00::1", + ruleProto: fw.ProtocolTCP, + ruleDstPort: &fw.Port{Values: []uint16{443}}, + ruleAction: fw.ActionAccept, + shouldBeBlocked: true, + }, + { + name: "IPv6: drop rule", + srcIP: "fd00::1", + dstIP: "fd00::100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 22, + ruleIP: "fd00::1", + ruleProto: fw.ProtocolTCP, + ruleDstPort: &fw.Port{Values: []uint16{22}}, + ruleAction: fw.ActionDrop, + shouldBeBlocked: true, + }, + { + name: "IPv6: allow all protocols", + srcIP: "fd00::1", + dstIP: "fd00::100", + proto: fw.ProtocolUDP, + srcPort: 12345, + dstPort: 9999, + ruleIP: "fd00::1", + ruleProto: fw.ProtocolALL, + ruleAction: fw.ActionAccept, + shouldBeBlocked: false, + }, + { + name: "IPv6: v4 wildcard ICMP rule matches ICMPv6 via protoLayerMatches", + srcIP: "fd00::1", + dstIP: "fd00::100", + proto: fw.ProtocolICMP, + ruleIP: "0.0.0.0", + ruleProto: fw.ProtocolICMP, + ruleAction: fw.ActionAccept, + shouldBeBlocked: false, + }, + } + + t.Run("IPv6 implicit DROP (no rules)", func(t *testing.T) { + packet := createTestPacket(t, "fd00::1", "fd00::100", fw.ProtocolTCP, 12345, 443) + isDropped := manager.FilterInbound(packet, 0) + require.True(t, isDropped, "IPv6 packet should be dropped when no rules exist") + }) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if tc.ruleAction == fw.ActionDrop { + rules, err := manager.AddPeerFiltering(nil, net.ParseIP(tc.ruleIP), fw.ProtocolALL, nil, nil, fw.ActionAccept, "") + require.NoError(t, err) + t.Cleanup(func() { + for _, rule := range rules { + require.NoError(t, manager.DeletePeerRule(rule)) + } + }) + } + + rules, err := manager.AddPeerFiltering(nil, net.ParseIP(tc.ruleIP), tc.ruleProto, nil, tc.ruleDstPort, tc.ruleAction, "") + require.NoError(t, err) + require.NotEmpty(t, rules) + t.Cleanup(func() { + for _, rule := range rules { + require.NoError(t, manager.DeletePeerRule(rule)) + } + }) + + packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort) + isDropped := manager.FilterInbound(packet, 0) + require.Equal(t, tc.shouldBeBlocked, isDropped, "packet filter result mismatch") + }) + } +} + func createTestPacket(t *testing.T, srcIP, dstIP string, proto fw.Protocol, srcPort, dstPort uint16) []byte { t.Helper() + src := net.ParseIP(srcIP) + dst := net.ParseIP(dstIP) + buf := gopacket.NewSerializeBuffer() opts := gopacket.SerializeOptions{ ComputeChecksums: true, FixLengths: true, } - ipLayer := &layers.IPv4{ - Version: 4, - TTL: 64, - SrcIP: net.ParseIP(srcIP), - DstIP: net.ParseIP(dstIP), - } + // Detect address family + isV6 := src.To4() == nil var err error - switch proto { - case fw.ProtocolTCP: - ipLayer.Protocol = layers.IPProtocolTCP - tcp := &layers.TCP{ - SrcPort: layers.TCPPort(srcPort), - DstPort: layers.TCPPort(dstPort), - } - err = tcp.SetNetworkLayerForChecksum(ipLayer) - require.NoError(t, err) - err = gopacket.SerializeLayers(buf, opts, ipLayer, tcp) - case fw.ProtocolUDP: - ipLayer.Protocol = layers.IPProtocolUDP - udp := &layers.UDP{ - SrcPort: layers.UDPPort(srcPort), - DstPort: layers.UDPPort(dstPort), + if isV6 { + ip6 := &layers.IPv6{ + Version: 6, + HopLimit: 64, + SrcIP: src, + DstIP: dst, } - err = udp.SetNetworkLayerForChecksum(ipLayer) - require.NoError(t, err) - err = gopacket.SerializeLayers(buf, opts, ipLayer, udp) - case fw.ProtocolICMP: - ipLayer.Protocol = layers.IPProtocolICMPv4 - icmp := &layers.ICMPv4{ - TypeCode: layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0), + switch proto { + case fw.ProtocolTCP: + ip6.NextHeader = layers.IPProtocolTCP + tcp := &layers.TCP{SrcPort: layers.TCPPort(srcPort), DstPort: layers.TCPPort(dstPort)} + _ = tcp.SetNetworkLayerForChecksum(ip6) + err = gopacket.SerializeLayers(buf, opts, ip6, tcp) + case fw.ProtocolUDP: + ip6.NextHeader = layers.IPProtocolUDP + udp := &layers.UDP{SrcPort: layers.UDPPort(srcPort), DstPort: layers.UDPPort(dstPort)} + _ = udp.SetNetworkLayerForChecksum(ip6) + err = gopacket.SerializeLayers(buf, opts, ip6, udp) + case fw.ProtocolICMP: + ip6.NextHeader = layers.IPProtocolICMPv6 + icmp := &layers.ICMPv6{ + TypeCode: layers.CreateICMPv6TypeCode(layers.ICMPv6TypeEchoRequest, 0), + } + _ = icmp.SetNetworkLayerForChecksum(ip6) + err = gopacket.SerializeLayers(buf, opts, ip6, icmp) + default: + err = gopacket.SerializeLayers(buf, opts, ip6) + } + } else { + ip4 := &layers.IPv4{ + Version: 4, + TTL: 64, + SrcIP: src, + DstIP: dst, } - err = gopacket.SerializeLayers(buf, opts, ipLayer, icmp) - default: - err = gopacket.SerializeLayers(buf, opts, ipLayer) + switch proto { + case fw.ProtocolTCP: + ip4.Protocol = layers.IPProtocolTCP + tcp := &layers.TCP{SrcPort: layers.TCPPort(srcPort), DstPort: layers.TCPPort(dstPort)} + _ = tcp.SetNetworkLayerForChecksum(ip4) + err = gopacket.SerializeLayers(buf, opts, ip4, tcp) + case fw.ProtocolUDP: + ip4.Protocol = layers.IPProtocolUDP + udp := &layers.UDP{SrcPort: layers.UDPPort(srcPort), DstPort: layers.UDPPort(dstPort)} + _ = udp.SetNetworkLayerForChecksum(ip4) + err = gopacket.SerializeLayers(buf, opts, ip4, udp) + case fw.ProtocolICMP: + ip4.Protocol = layers.IPProtocolICMPv4 + icmp := &layers.ICMPv4{TypeCode: layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0)} + err = gopacket.SerializeLayers(buf, opts, ip4, icmp) + default: + err = gopacket.SerializeLayers(buf, opts, ip4) + } } require.NoError(t, err) @@ -1498,3 +1681,103 @@ func TestRouteACLSet(t *testing.T) { _, isAllowed = manager.routeACLsPass(srcIP, dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80) require.True(t, isAllowed, "After set update, traffic to the added network should be allowed") } + +// TestRouteACLFilteringIPv6 tests IPv6 route ACL matching directly via routeACLsPass. +// Note: full FilterInbound for routed IPv6 traffic drops at the forwarder stage (IPv4-only) +// but the ACL decision itself is correct. +func TestRouteACLFilteringIPv6(t *testing.T) { + manager := setupRoutedManager(t, "10.10.0.100/16") + + v6Dst := netip.MustParsePrefix("fd00:dead:beef::/48") + _, err := manager.AddRouteFiltering( + nil, + []netip.Prefix{netip.MustParsePrefix("fd00::/16")}, + fw.Network{Prefix: v6Dst}, + fw.ProtocolTCP, + nil, + &fw.Port{Values: []uint16{80}}, + fw.ActionAccept, + ) + require.NoError(t, err) + + _, err = manager.AddRouteFiltering( + nil, + []netip.Prefix{netip.MustParsePrefix("fd00::/16")}, + fw.Network{Prefix: netip.MustParsePrefix("fd00:dead:beef:1::/64")}, + fw.ProtocolALL, + nil, + nil, + fw.ActionDrop, + ) + require.NoError(t, err) + + tests := []struct { + name string + srcIP netip.Addr + dstIP netip.Addr + proto gopacket.LayerType + srcPort uint16 + dstPort uint16 + allowed bool + }{ + { + name: "IPv6 TCP to allowed dest", + srcIP: netip.MustParseAddr("fd00::1"), + dstIP: netip.MustParseAddr("fd00:dead:beef::80"), + proto: layers.LayerTypeTCP, + srcPort: 12345, + dstPort: 80, + allowed: true, + }, + { + name: "IPv6 TCP wrong port", + srcIP: netip.MustParseAddr("fd00::1"), + dstIP: netip.MustParseAddr("fd00:dead:beef::80"), + proto: layers.LayerTypeTCP, + srcPort: 12345, + dstPort: 443, + allowed: false, + }, + { + name: "IPv6 UDP not matched by TCP rule", + srcIP: netip.MustParseAddr("fd00::1"), + dstIP: netip.MustParseAddr("fd00:dead:beef::80"), + proto: layers.LayerTypeUDP, + srcPort: 12345, + dstPort: 80, + allowed: false, + }, + { + name: "IPv6 ICMPv6 matches ICMP rule via protoLayerMatches", + srcIP: netip.MustParseAddr("fd00::1"), + dstIP: netip.MustParseAddr("fd00:dead:beef::80"), + proto: layers.LayerTypeICMPv6, + allowed: false, + }, + { + name: "IPv6 to denied subnet", + srcIP: netip.MustParseAddr("fd00::1"), + dstIP: netip.MustParseAddr("fd00:dead:beef:1::1"), + proto: layers.LayerTypeTCP, + srcPort: 12345, + dstPort: 80, + allowed: false, + }, + { + name: "IPv6 source outside allowed range", + srcIP: netip.MustParseAddr("fe80::1"), + dstIP: netip.MustParseAddr("fd00:dead:beef::80"), + proto: layers.LayerTypeTCP, + srcPort: 12345, + dstPort: 80, + allowed: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + _, pass := manager.routeACLsPass(tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort) + require.Equal(t, tc.allowed, pass, "route ACL result mismatch") + }) + } +} diff --git a/client/firewall/uspfilter/filter_test.go b/client/firewall/uspfilter/filter_test.go index 5f0f9f860..01e5f97c1 100644 --- a/client/firewall/uspfilter/filter_test.go +++ b/client/firewall/uspfilter/filter_test.go @@ -527,11 +527,16 @@ func TestProcessOutgoingHooks(t *testing.T) { d := &decoder{ decoded: []gopacket.LayerType{}, } - d.parser = gopacket.NewDecodingLayerParser( + d.parser4 = gopacket.NewDecodingLayerParser( layers.LayerTypeIPv4, &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, ) - d.parser.IgnoreUnsupported = true + d.parser4.IgnoreUnsupported = true + d.parser6 = gopacket.NewDecodingLayerParser( + layers.LayerTypeIPv6, + &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, + ) + d.parser6.IgnoreUnsupported = true return d }, } @@ -630,11 +635,16 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) { d := &decoder{ decoded: []gopacket.LayerType{}, } - d.parser = gopacket.NewDecodingLayerParser( + d.parser4 = gopacket.NewDecodingLayerParser( layers.LayerTypeIPv4, &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, ) - d.parser.IgnoreUnsupported = true + d.parser4.IgnoreUnsupported = true + d.parser6 = gopacket.NewDecodingLayerParser( + layers.LayerTypeIPv6, + &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, + ) + d.parser6.IgnoreUnsupported = true return d }, } @@ -1040,8 +1050,8 @@ func TestMSSClamping(t *testing.T) { }() require.True(t, manager.mssClampEnabled, "MSS clamping should be enabled by default") - expectedMSSValue := uint16(1280 - ipTCPHeaderMinSize) - require.Equal(t, expectedMSSValue, manager.mssClampValue, "MSS clamp value should be MTU - 40") + require.Equal(t, uint16(1280-ipv4TCPHeaderMinSize), manager.mssClampValueIPv4, "IPv4 MSS clamp value should be MTU - 40") + require.Equal(t, uint16(1280-ipv6TCPHeaderMinSize), manager.mssClampValueIPv6, "IPv6 MSS clamp value should be MTU - 60") err = manager.UpdateLocalIPs() require.NoError(t, err) @@ -1059,7 +1069,7 @@ func TestMSSClamping(t *testing.T) { require.Len(t, d.tcp.Options, 1, "Should have MSS option") require.Equal(t, uint8(layers.TCPOptionKindMSS), uint8(d.tcp.Options[0].OptionType)) actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData) - require.Equal(t, expectedMSSValue, actualMSS, "MSS should be clamped to MTU - 40") + require.Equal(t, manager.mssClampValueIPv4, actualMSS, "MSS should be clamped to MTU - 40") }) t.Run("SYN packet with low MSS unchanged", func(t *testing.T) { @@ -1083,7 +1093,7 @@ func TestMSSClamping(t *testing.T) { d := parsePacket(t, packet) require.Len(t, d.tcp.Options, 1, "Should have MSS option") actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData) - require.Equal(t, expectedMSSValue, actualMSS, "MSS in SYN-ACK should be clamped") + require.Equal(t, manager.mssClampValueIPv4, actualMSS, "MSS in SYN-ACK should be clamped") }) t.Run("Non-SYN packet unchanged", func(t *testing.T) { @@ -1255,13 +1265,18 @@ func TestShouldForward(t *testing.T) { d := &decoder{ decoded: []gopacket.LayerType{}, } - d.parser = gopacket.NewDecodingLayerParser( + d.parser4 = gopacket.NewDecodingLayerParser( layers.LayerTypeIPv4, &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, ) - d.parser.IgnoreUnsupported = true + d.parser4.IgnoreUnsupported = true + d.parser6 = gopacket.NewDecodingLayerParser( + layers.LayerTypeIPv6, + &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, + ) + d.parser6.IgnoreUnsupported = true - err = d.parser.DecodeLayers(buf.Bytes(), &d.decoded) + err = d.decodePacket(buf.Bytes()) require.NoError(t, err) return d @@ -1321,6 +1336,44 @@ func TestShouldForward(t *testing.T) { }, } + // Add IPv6 to the interface and test dual-stack cases + wgIPv6 := netip.MustParseAddr("fd00::1") + otherIPv6 := netip.MustParseAddr("fd00::2") + ifaceMock.AddressFunc = func() wgaddr.Address { + return wgaddr.Address{ + IP: wgIP, + Network: netip.PrefixFrom(wgIP, 24), + IPv6: wgIPv6, + IPv6Net: netip.PrefixFrom(wgIPv6, 64), + } + } + + // Re-create manager to pick up the new address with IPv6 + require.NoError(t, manager.Close(nil)) + manager, err = Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU) + require.NoError(t, err) + + v6Cases := []struct { + name string + dstIP netip.Addr + expected bool + description string + }{ + {"v6 traffic to other address", otherIPv6, true, "should forward v6 traffic not destined to our v6 address"}, + {"v6 traffic to our v6 IP", wgIPv6, false, "should not forward traffic destined to our v6 address"}, + {"v4 traffic to other with v6 configured", otherIP, true, "should forward v4 traffic when v6 configured"}, + {"v4 traffic to our v4 IP with v6 configured", wgIP, false, "should not forward traffic to our v4 address"}, + } + for _, tt := range v6Cases { + t.Run(tt.name, func(t *testing.T) { + manager.localForwarding = true + manager.netstack = false + decoder := createTCPDecoder(8080) + result := manager.shouldForward(decoder, tt.dstIP) + require.Equal(t, tt.expected, result, tt.description) + }) + } + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // Configure manager diff --git a/client/firewall/uspfilter/forwarder/endpoint.go b/client/firewall/uspfilter/forwarder/endpoint.go index 692a24140..bec6fb3e5 100644 --- a/client/firewall/uspfilter/forwarder/endpoint.go +++ b/client/firewall/uspfilter/forwarder/endpoint.go @@ -1,7 +1,8 @@ package forwarder import ( - "fmt" + "net" + "strconv" "sync/atomic" wgdevice "golang.zx2c4.com/wireguard/device" @@ -47,17 +48,23 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress { func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) { var written int for _, pkt := range pkts.AsSlice() { - netHeader := header.IPv4(pkt.NetworkHeader().View().AsSlice()) - data := stack.PayloadSince(pkt.NetworkHeader()) if data == nil { continue } - // Send the packet through WireGuard - address := netHeader.DestinationAddress() - err := e.device.CreateOutboundPacket(data.AsSlice(), address.AsSlice()) - if err != nil { + raw := pkt.NetworkHeader().View().AsSlice() + if len(raw) == 0 { + continue + } + var address tcpip.Address + if raw[0]>>4 == 6 { + address = header.IPv6(raw).DestinationAddress() + } else { + address = header.IPv4(raw).DestinationAddress() + } + + if err := e.device.CreateOutboundPacket(data.AsSlice(), address.AsSlice()); err != nil { e.logger.Error1("CreateOutboundPacket: %v", err) continue } @@ -103,5 +110,7 @@ type epID stack.TransportEndpointID func (i epID) String() string { // src and remote is swapped - return fmt.Sprintf("%s:%d → %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort) + return net.JoinHostPort(i.RemoteAddress.String(), strconv.Itoa(int(i.RemotePort))) + + " → " + + net.JoinHostPort(i.LocalAddress.String(), strconv.Itoa(int(i.LocalPort))) } diff --git a/client/firewall/uspfilter/forwarder/forwarder.go b/client/firewall/uspfilter/forwarder/forwarder.go index d17c3cd5c..85c5bbc03 100644 --- a/client/firewall/uspfilter/forwarder/forwarder.go +++ b/client/firewall/uspfilter/forwarder/forwarder.go @@ -14,6 +14,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" @@ -36,25 +37,31 @@ type Forwarder struct { logger *nblog.Logger flowLogger nftypes.FlowLogger // ruleIdMap is used to store the rule ID for a given connection - ruleIdMap sync.Map - stack *stack.Stack - endpoint *endpoint - udpForwarder *udpForwarder - ctx context.Context - cancel context.CancelFunc - ip tcpip.Address - netstack bool - hasRawICMPAccess bool - pingSemaphore chan struct{} + ruleIdMap sync.Map + stack *stack.Stack + endpoint *endpoint + udpForwarder *udpForwarder + ctx context.Context + cancel context.CancelFunc + ip tcpip.Address + ipv6 tcpip.Address + netstack bool + hasRawICMPAccess bool + hasRawICMPv6Access bool + pingSemaphore chan struct{} } func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + NetworkProtocols: []stack.NetworkProtocolFactory{ + ipv4.NewProtocol, + ipv6.NewProtocol, + }, TransportProtocols: []stack.TransportProtocolFactory{ tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4, + icmp.NewProtocol6, }, HandleLocal: false, }) @@ -73,7 +80,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow protoAddr := tcpip.ProtocolAddress{ Protocol: ipv4.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()), + Address: tcpip.AddrFrom4(iface.Address().IP.As4()), PrefixLen: iface.Address().Network.Bits(), }, } @@ -82,6 +89,19 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow return nil, fmt.Errorf("failed to add protocol address: %s", err) } + if v6 := iface.Address().IPv6; v6.IsValid() { + v6Addr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.AddrFrom16(v6.As16()), + PrefixLen: iface.Address().IPv6Net.Bits(), + }, + } + if err := s.AddProtocolAddress(nicID, v6Addr, stack.AddressProperties{}); err != nil { + return nil, fmt.Errorf("add IPv6 protocol address: %s", err) + } + } + defaultSubnet, err := tcpip.NewSubnet( tcpip.AddrFrom4([4]byte{0, 0, 0, 0}), tcpip.MaskFromBytes([]byte{0, 0, 0, 0}), @@ -90,6 +110,14 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow return nil, fmt.Errorf("creating default subnet: %w", err) } + defaultSubnetV6, err := tcpip.NewSubnet( + tcpip.AddrFrom16([16]byte{}), + tcpip.MaskFromBytes(make([]byte, 16)), + ) + if err != nil { + return nil, fmt.Errorf("creating default v6 subnet: %w", err) + } + if err := s.SetPromiscuousMode(nicID, true); err != nil { return nil, fmt.Errorf("set promiscuous mode: %s", err) } @@ -98,10 +126,8 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow } s.SetRouteTable([]tcpip.Route{ - { - Destination: defaultSubnet, - NIC: nicID, - }, + {Destination: defaultSubnet, NIC: nicID}, + {Destination: defaultSubnetV6, NIC: nicID}, }) ctx, cancel := context.WithCancel(context.Background()) @@ -114,7 +140,8 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow ctx: ctx, cancel: cancel, netstack: netstack, - ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()), + ip: tcpip.AddrFrom4(iface.Address().IP.As4()), + ipv6: addrFromNetipAddr(iface.Address().IPv6), pingSemaphore: make(chan struct{}, 3), } @@ -131,7 +158,10 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow udpForwarder := udp.NewForwarder(s, f.handleUDP) s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket) - s.SetTransportProtocolHandler(icmp.ProtocolNumber4, f.handleICMP) + // ICMP is handled directly in InjectIncomingPacket, bypassing gVisor's + // network layer. This avoids duplicate echo replies (v4) and the v6 + // auto-reply bug where gVisor responds at the network layer before + // our transport handler fires. f.checkICMPCapability() @@ -140,8 +170,30 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow } func (f *Forwarder) InjectIncomingPacket(payload []byte) error { - if len(payload) < header.IPv4MinimumSize { - return fmt.Errorf("packet too small: %d bytes", len(payload)) + if len(payload) == 0 { + return fmt.Errorf("empty packet") + } + + var protoNum tcpip.NetworkProtocolNumber + switch payload[0] >> 4 { + case 4: + if len(payload) < header.IPv4MinimumSize { + return fmt.Errorf("IPv4 packet too small: %d bytes", len(payload)) + } + if f.handleICMPDirect(payload) { + return nil + } + protoNum = ipv4.ProtocolNumber + case 6: + if len(payload) < header.IPv6MinimumSize { + return fmt.Errorf("IPv6 packet too small: %d bytes", len(payload)) + } + if f.handleICMPDirect(payload) { + return nil + } + protoNum = ipv6.ProtocolNumber + default: + return fmt.Errorf("unknown IP version: %d", payload[0]>>4) } pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -150,11 +202,95 @@ func (f *Forwarder) InjectIncomingPacket(payload []byte) error { defer pkt.DecRef() if f.endpoint.dispatcher != nil { - f.endpoint.dispatcher.DeliverNetworkPacket(ipv4.ProtocolNumber, pkt) + f.endpoint.dispatcher.DeliverNetworkPacket(protoNum, pkt) } return nil } +// handleICMPDirect intercepts ICMP packets from raw IP payloads before they +// enter gVisor. It synthesizes the TransportEndpointID and PacketBuffer that +// the existing handlers expect, then dispatches to handleICMP/handleICMPv6. +// This bypasses gVisor's network layer which causes duplicate v4 echo replies +// and auto-replies to all v6 echo requests in promiscuous mode. +// +// Unlike gVisor's network layer, this does not validate ICMP checksums or +// reassemble IP fragments. Fragmented ICMP packets fall through to gVisor. +func parseICMPv4(payload []byte) (ipHdrLen int, src, dst tcpip.Address, ok bool) { + ip := header.IPv4(payload) + if ip.Protocol() != uint8(header.ICMPv4ProtocolNumber) { + return 0, src, dst, false + } + if ip.FragmentOffset() != 0 || ip.Flags()&header.IPv4FlagMoreFragments != 0 { + return 0, src, dst, false + } + ipHdrLen = int(ip.HeaderLength()) + if len(payload)-ipHdrLen < header.ICMPv4MinimumSize { + return 0, src, dst, false + } + return ipHdrLen, ip.SourceAddress(), ip.DestinationAddress(), true +} + +func parseICMPv6(payload []byte) (ipHdrLen int, src, dst tcpip.Address, ok bool) { + ip := header.IPv6(payload) + if ip.NextHeader() != uint8(header.ICMPv6ProtocolNumber) { + return 0, src, dst, false + } + ipHdrLen = header.IPv6MinimumSize + if len(payload)-ipHdrLen < header.ICMPv6MinimumSize { + return 0, src, dst, false + } + return ipHdrLen, ip.SourceAddress(), ip.DestinationAddress(), true +} + +func (f *Forwarder) handleICMPDirect(payload []byte) bool { + var ( + ipHdrLen int + srcAddr tcpip.Address + dstAddr tcpip.Address + ok bool + ) + switch payload[0] >> 4 { + case 4: + ipHdrLen, srcAddr, dstAddr, ok = parseICMPv4(payload) + case 6: + ipHdrLen, srcAddr, dstAddr, ok = parseICMPv6(payload) + } + if !ok { + return false + } + + // Let gVisor handle ICMP destined for our own addresses natively. + // Its network-layer auto-reply is correct and efficient for local traffic. + if f.ip.Equal(dstAddr) || f.ipv6.Equal(dstAddr) { + return false + } + + id := stack.TransportEndpointID{ + LocalAddress: dstAddr, + RemoteAddress: srcAddr, + } + + // Build a PacketBuffer with headers consumed the same way gVisor would. + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: buffer.MakeWithData(payload), + }) + defer pkt.DecRef() + + if _, ok := pkt.NetworkHeader().Consume(ipHdrLen); !ok { + return false + } + + icmpPayload := payload[ipHdrLen:] + if _, ok := pkt.TransportHeader().Consume(len(icmpPayload)); !ok { + return false + } + + if payload[0]>>4 == 6 { + return f.handleICMPv6(id, pkt) + } + return f.handleICMP(id, pkt) +} + // Stop gracefully shuts down the forwarder func (f *Forwarder) Stop() { f.cancel() @@ -167,11 +303,14 @@ func (f *Forwarder) Stop() { f.stack.Wait() } -func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP { +func (f *Forwarder) determineDialAddr(addr tcpip.Address) netip.Addr { if f.netstack && f.ip.Equal(addr) { - return net.IPv4(127, 0, 0, 1) + return netip.AddrFrom4([4]byte{127, 0, 0, 1}) } - return addr.AsSlice() + if f.netstack && f.ipv6.Equal(addr) { + return netip.IPv6Loopback() + } + return addrToNetipAddr(addr) } func (f *Forwarder) RegisterRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, ruleID []byte) { @@ -205,23 +344,50 @@ func buildKey(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) conntrack.ConnKe } } +// addrFromNetipAddr converts a netip.Addr to a gvisor tcpip.Address without allocating. +func addrFromNetipAddr(addr netip.Addr) tcpip.Address { + if !addr.IsValid() { + return tcpip.Address{} + } + if addr.Is4() { + return tcpip.AddrFrom4(addr.As4()) + } + return tcpip.AddrFrom16(addr.As16()) +} + +// addrToNetipAddr converts a gvisor tcpip.Address to netip.Addr without allocating. +func addrToNetipAddr(addr tcpip.Address) netip.Addr { + switch addr.Len() { + case 4: + return netip.AddrFrom4(addr.As4()) + case 16: + return netip.AddrFrom16(addr.As16()) + default: + return netip.Addr{} + } +} + // checkICMPCapability tests whether we have raw ICMP socket access at startup. func (f *Forwarder) checkICMPCapability() { + f.hasRawICMPAccess = probeRawICMP("ip4:icmp", "0.0.0.0", f.logger) + f.hasRawICMPv6Access = probeRawICMP("ip6:ipv6-icmp", "::", f.logger) +} + +func probeRawICMP(network, addr string, logger *nblog.Logger) bool { ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() lc := net.ListenConfig{} - conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0") + conn, err := lc.ListenPacket(ctx, network, addr) if err != nil { - f.hasRawICMPAccess = false - f.logger.Debug("forwarder: No raw ICMP socket access, will use ping binary fallback") - return + logger.Debug1("forwarder: no raw %s socket access, will use ping binary fallback", network) + return false } if err := conn.Close(); err != nil { - f.logger.Debug1("forwarder: Failed to close ICMP capability test socket: %v", err) + logger.Debug2("forwarder: failed to close %s capability test socket: %v", network, err) } - f.hasRawICMPAccess = true - f.logger.Debug("forwarder: Raw ICMP socket access available") + logger.Debug1("forwarder: raw %s socket access available", network) + return true } diff --git a/client/firewall/uspfilter/forwarder/icmp.go b/client/firewall/uspfilter/forwarder/icmp.go index cb3db325d..4dde2c50c 100644 --- a/client/firewall/uspfilter/forwarder/icmp.go +++ b/client/firewall/uspfilter/forwarder/icmp.go @@ -35,7 +35,7 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt *stack.PacketBu } icmpData := stack.PayloadSince(pkt.TransportHeader()).AsSlice() - conn, err := f.forwardICMPPacket(id, icmpData, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), 100*time.Millisecond) + conn, err := f.forwardICMPPacket(id, icmpData, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), false, 100*time.Millisecond) if err != nil { f.logger.Error2("forwarder: Failed to forward ICMP packet for %v: %v", epID(id), err) return true @@ -58,7 +58,7 @@ func (f *Forwarder) handleICMPEcho(flowID uuid.UUID, id stack.TransportEndpointI defer func() { <-f.pingSemaphore }() if f.hasRawICMPAccess { - f.handleICMPViaSocket(flowID, id, icmpType, icmpCode, icmpData, rxBytes) + f.handleICMPViaSocket(flowID, id, icmpType, icmpCode, icmpData, rxBytes, false) } else { f.handleICMPViaPing(flowID, id, icmpType, icmpCode, icmpData, rxBytes) } @@ -72,18 +72,23 @@ func (f *Forwarder) handleICMPEcho(flowID uuid.UUID, id stack.TransportEndpointI // forwardICMPPacket creates a raw ICMP socket and sends the packet, returning the connection. // The caller is responsible for closing the returned connection. -func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []byte, icmpType, icmpCode uint8, timeout time.Duration) (net.PacketConn, error) { +func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []byte, icmpType, icmpCode uint8, v6 bool, timeout time.Duration) (net.PacketConn, error) { ctx, cancel := context.WithTimeout(f.ctx, timeout) defer cancel() + network, listenAddr := "ip4:icmp", "0.0.0.0" + if v6 { + network, listenAddr = "ip6:ipv6-icmp", "::" + } + lc := net.ListenConfig{} - conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0") + conn, err := lc.ListenPacket(ctx, network, listenAddr) if err != nil { return nil, fmt.Errorf("create ICMP socket: %w", err) } dstIP := f.determineDialAddr(id.LocalAddress) - dst := &net.IPAddr{IP: dstIP} + dst := &net.IPAddr{IP: dstIP.AsSlice()} if _, err = conn.WriteTo(payload, dst); err != nil { if closeErr := conn.Close(); closeErr != nil { @@ -98,11 +103,11 @@ func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []by return conn, nil } -// handleICMPViaSocket handles ICMP echo requests using raw sockets. -func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int) { +// handleICMPViaSocket handles ICMP echo requests using raw sockets for both v4 and v6. +func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int, v6 bool) { sendTime := time.Now() - conn, err := f.forwardICMPPacket(id, icmpData, icmpType, icmpCode, 5*time.Second) + conn, err := f.forwardICMPPacket(id, icmpData, icmpType, icmpCode, v6, 5*time.Second) if err != nil { f.logger.Error2("forwarder: Failed to send ICMP packet for %v: %v", epID(id), err) return @@ -113,16 +118,20 @@ func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndp } }() - txBytes := f.handleEchoResponse(conn, id) + txBytes := f.handleEchoResponse(conn, id, v6) rtt := time.Since(sendTime).Round(10 * time.Microsecond) - f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, raw socket)", - epID(id), icmpType, icmpCode, rtt) + proto := "ICMP" + if v6 { + proto = "ICMPv6" + } + f.logger.Trace5("forwarder: Forwarded %s echo reply %v type %v code %v (rtt=%v, raw socket)", + proto, epID(id), icmpType, icmpCode, rtt) f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes)) } -func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEndpointID) int { +func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEndpointID, v6 bool) int { if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { f.logger.Error1("forwarder: Failed to set read deadline for ICMP response: %v", err) return 0 @@ -137,6 +146,19 @@ func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEn return 0 } + if v6 { + // Recompute checksum: the raw socket response has a checksum computed + // over the real endpoint addresses, but we inject with overlay addresses. + icmpHdr := header.ICMPv6(response[:n]) + icmpHdr.SetChecksum(0) + icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmpHdr, + Src: id.LocalAddress, + Dst: id.RemoteAddress, + })) + return f.injectICMPv6Reply(id, response[:n]) + } + return f.injectICMPReply(id, response[:n]) } @@ -150,19 +172,23 @@ func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.T txPackets = 1 } - srcIp := netip.AddrFrom4(id.RemoteAddress.As4()) - dstIp := netip.AddrFrom4(id.LocalAddress.As4()) + srcIp := addrToNetipAddr(id.RemoteAddress) + dstIp := addrToNetipAddr(id.LocalAddress) + + proto := nftypes.ICMP + if srcIp.Is6() { + proto = nftypes.ICMPv6 + } fields := nftypes.EventFields{ FlowID: flowID, Type: typ, Direction: nftypes.Ingress, - Protocol: nftypes.ICMP, - // TODO: handle ipv6 - SourceIP: srcIp, - DestIP: dstIp, - ICMPType: icmpType, - ICMPCode: icmpCode, + Protocol: proto, + SourceIP: srcIp, + DestIP: dstIp, + ICMPType: icmpType, + ICMPCode: icmpCode, RxBytes: rxBytes, TxBytes: txBytes, @@ -209,26 +235,164 @@ func (f *Forwarder) handleICMPViaPing(flowID uuid.UUID, id stack.TransportEndpoi f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes)) } +// handleICMPv6 handles ICMPv6 packets from the network stack. +func (f *Forwarder) handleICMPv6(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { + icmpHdr := header.ICMPv6(pkt.TransportHeader().View().AsSlice()) + + flowID := uuid.New() + f.sendICMPEvent(nftypes.TypeStart, flowID, id, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), 0, 0) + + if icmpHdr.Type() == header.ICMPv6EchoRequest { + return f.handleICMPv6Echo(flowID, id, pkt, uint8(icmpHdr.Type()), uint8(icmpHdr.Code())) + } + + // For non-echo types (Destination Unreachable, Packet Too Big, etc), forward without waiting + if !f.hasRawICMPv6Access { + f.logger.Debug2("forwarder: Cannot handle ICMPv6 type %v without raw socket access for %v", icmpHdr.Type(), epID(id)) + return false + } + + icmpData := stack.PayloadSince(pkt.TransportHeader()).AsSlice() + conn, err := f.forwardICMPPacket(id, icmpData, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), true, 100*time.Millisecond) + if err != nil { + f.logger.Error2("forwarder: Failed to forward ICMPv6 packet for %v: %v", epID(id), err) + return true + } + if err := conn.Close(); err != nil { + f.logger.Debug1("forwarder: Failed to close ICMPv6 socket: %v", err) + } + + return true +} + +// handleICMPv6Echo handles ICMPv6 echo requests via raw socket or ping binary fallback. +func (f *Forwarder) handleICMPv6Echo(flowID uuid.UUID, id stack.TransportEndpointID, pkt *stack.PacketBuffer, icmpType, icmpCode uint8) bool { + select { + case f.pingSemaphore <- struct{}{}: + icmpData := stack.PayloadSince(pkt.TransportHeader()).ToSlice() + rxBytes := pkt.Size() + + go func() { + defer func() { <-f.pingSemaphore }() + + if f.hasRawICMPv6Access { + f.handleICMPViaSocket(flowID, id, icmpType, icmpCode, icmpData, rxBytes, true) + } else { + f.handleICMPv6ViaPing(flowID, id, icmpType, icmpCode, icmpData, rxBytes) + } + }() + default: + f.logger.Debug3("forwarder: ICMPv6 rate limit exceeded for %v type %v code %v", epID(id), icmpType, icmpCode) + } + return true +} + +// handleICMPv6ViaPing uses the system ping6 binary for ICMPv6 echo. +func (f *Forwarder) handleICMPv6ViaPing(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int) { + ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second) + defer cancel() + + dstIP := f.determineDialAddr(id.LocalAddress) + cmd := buildPingCommand(ctx, dstIP, 5*time.Second) + + pingStart := time.Now() + if err := cmd.Run(); err != nil { + f.logger.Warn4("forwarder: Ping6 failed for %v type %v code %v: %v", epID(id), icmpType, icmpCode, err) + return + } + rtt := time.Since(pingStart).Round(10 * time.Microsecond) + + f.logger.Trace3("forwarder: Forwarded ICMPv6 echo request %v type %v code %v", + epID(id), icmpType, icmpCode) + + txBytes := f.synthesizeICMPv6EchoReply(id, icmpData) + + f.logger.Trace4("forwarder: Forwarded ICMPv6 echo reply %v type %v code %v (rtt=%v, ping binary)", + epID(id), icmpType, icmpCode, rtt) + + f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes)) +} + +// synthesizeICMPv6EchoReply creates an ICMPv6 echo reply and injects it back. +func (f *Forwarder) synthesizeICMPv6EchoReply(id stack.TransportEndpointID, icmpData []byte) int { + replyICMP := make([]byte, len(icmpData)) + copy(replyICMP, icmpData) + + replyHdr := header.ICMPv6(replyICMP) + replyHdr.SetType(header.ICMPv6EchoReply) + replyHdr.SetChecksum(0) + // ICMPv6Checksum computes the pseudo-header internally from Src/Dst. + // Header contains the full ICMP message, so PayloadCsum/PayloadLen are zero. + replyHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: replyHdr, + Src: id.LocalAddress, + Dst: id.RemoteAddress, + })) + + return f.injectICMPv6Reply(id, replyICMP) +} + +// injectICMPv6Reply wraps an ICMPv6 payload in an IPv6 header and sends to the peer. +func (f *Forwarder) injectICMPv6Reply(id stack.TransportEndpointID, icmpPayload []byte) int { + ipHdr := make([]byte, header.IPv6MinimumSize) + ip := header.IPv6(ipHdr) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(len(icmpPayload)), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: 64, + SrcAddr: id.LocalAddress, + DstAddr: id.RemoteAddress, + }) + + fullPacket := make([]byte, 0, len(ipHdr)+len(icmpPayload)) + fullPacket = append(fullPacket, ipHdr...) + fullPacket = append(fullPacket, icmpPayload...) + + if err := f.endpoint.device.CreateOutboundPacket(fullPacket, id.RemoteAddress.AsSlice()); err != nil { + f.logger.Error1("forwarder: Failed to send ICMPv6 reply to peer: %v", err) + return 0 + } + + return len(fullPacket) +} + +const ( + pingBin = "ping" + ping6Bin = "ping6" +) + // buildPingCommand creates a platform-specific ping command. -func buildPingCommand(ctx context.Context, target net.IP, timeout time.Duration) *exec.Cmd { +// Most platforms auto-detect IPv6 from raw addresses. macOS/iOS/OpenBSD require ping6. +func buildPingCommand(ctx context.Context, target netip.Addr, timeout time.Duration) *exec.Cmd { timeoutSec := int(timeout.Seconds()) if timeoutSec < 1 { timeoutSec = 1 } + isV6 := target.Is6() + timeoutStr := fmt.Sprintf("%d", timeoutSec) + switch runtime.GOOS { case "linux", "android": - return exec.CommandContext(ctx, "ping", "-c", "1", "-W", fmt.Sprintf("%d", timeoutSec), "-q", target.String()) + return exec.CommandContext(ctx, pingBin, "-c", "1", "-W", timeoutStr, "-q", target.String()) case "darwin", "ios": - return exec.CommandContext(ctx, "ping", "-c", "1", "-t", fmt.Sprintf("%d", timeoutSec), "-q", target.String()) + bin := pingBin + if isV6 { + bin = ping6Bin + } + return exec.CommandContext(ctx, bin, "-c", "1", "-t", timeoutStr, "-q", target.String()) case "freebsd": - return exec.CommandContext(ctx, "ping", "-c", "1", "-t", fmt.Sprintf("%d", timeoutSec), target.String()) + return exec.CommandContext(ctx, pingBin, "-c", "1", "-t", timeoutStr, target.String()) case "openbsd", "netbsd": - return exec.CommandContext(ctx, "ping", "-c", "1", "-w", fmt.Sprintf("%d", timeoutSec), target.String()) + bin := pingBin + if isV6 { + bin = ping6Bin + } + return exec.CommandContext(ctx, bin, "-c", "1", "-w", timeoutStr, target.String()) case "windows": - return exec.CommandContext(ctx, "ping", "-n", "1", "-w", fmt.Sprintf("%d", timeoutSec*1000), target.String()) + return exec.CommandContext(ctx, pingBin, "-n", "1", "-w", fmt.Sprintf("%d", timeoutSec*1000), target.String()) default: - return exec.CommandContext(ctx, "ping", "-c", "1", target.String()) + return exec.CommandContext(ctx, pingBin, "-c", "1", target.String()) } } diff --git a/client/firewall/uspfilter/forwarder/tcp.go b/client/firewall/uspfilter/forwarder/tcp.go index aef420061..8844463f5 100644 --- a/client/firewall/uspfilter/forwarder/tcp.go +++ b/client/firewall/uspfilter/forwarder/tcp.go @@ -2,10 +2,9 @@ package forwarder import ( "context" - "fmt" "io" "net" - "net/netip" + "strconv" "sync" "github.com/google/uuid" @@ -33,7 +32,7 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) { } }() - dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort) + dialAddr := net.JoinHostPort(f.determineDialAddr(id.LocalAddress).String(), strconv.Itoa(int(id.LocalPort))) outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr) if err != nil { @@ -133,15 +132,14 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn } func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) { - srcIp := netip.AddrFrom4(id.RemoteAddress.As4()) - dstIp := netip.AddrFrom4(id.LocalAddress.As4()) + srcIp := addrToNetipAddr(id.RemoteAddress) + dstIp := addrToNetipAddr(id.LocalAddress) fields := nftypes.EventFields{ - FlowID: flowID, - Type: typ, - Direction: nftypes.Ingress, - Protocol: nftypes.TCP, - // TODO: handle ipv6 + FlowID: flowID, + Type: typ, + Direction: nftypes.Ingress, + Protocol: nftypes.TCP, SourceIP: srcIp, DestIP: dstIp, SourcePort: id.RemotePort, diff --git a/client/firewall/uspfilter/forwarder/udp.go b/client/firewall/uspfilter/forwarder/udp.go index f175e275b..c92fa1f32 100644 --- a/client/firewall/uspfilter/forwarder/udp.go +++ b/client/firewall/uspfilter/forwarder/udp.go @@ -6,7 +6,7 @@ import ( "fmt" "io" "net" - "net/netip" + "strconv" "sync" "sync/atomic" "time" @@ -158,7 +158,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool { } }() - dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort) + dstAddr := net.JoinHostPort(f.determineDialAddr(id.LocalAddress).String(), strconv.Itoa(int(id.LocalPort))) outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr) if err != nil { f.logger.Debug2("forwarder: UDP dial error for %v: %v", epID(id), err) @@ -276,15 +276,14 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack // sendUDPEvent stores flow events for UDP connections func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) { - srcIp := netip.AddrFrom4(id.RemoteAddress.As4()) - dstIp := netip.AddrFrom4(id.LocalAddress.As4()) + srcIp := addrToNetipAddr(id.RemoteAddress) + dstIp := addrToNetipAddr(id.LocalAddress) fields := nftypes.EventFields{ - FlowID: flowID, - Type: typ, - Direction: nftypes.Ingress, - Protocol: nftypes.UDP, - // TODO: handle ipv6 + FlowID: flowID, + Type: typ, + Direction: nftypes.Ingress, + Protocol: nftypes.UDP, SourceIP: srcIp, DestIP: dstIp, SourcePort: id.RemotePort, diff --git a/client/firewall/uspfilter/localip.go b/client/firewall/uspfilter/localip.go index f63fe3e45..b35be56c6 100644 --- a/client/firewall/uspfilter/localip.go +++ b/client/firewall/uspfilter/localip.go @@ -4,89 +4,32 @@ import ( "fmt" "net" "net/netip" - "sync" + "sync/atomic" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/firewall/uspfilter/common" ) -type localIPManager struct { - mu sync.RWMutex - - // fixed-size high array for upper byte of a IPv4 address - ipv4Bitmap [256]*ipv4LowBitmap +// localIPSnapshot is an immutable snapshot of local IP addresses, swapped +// atomically so reads are lock-free. +type localIPSnapshot struct { + ips map[netip.Addr]struct{} } -// ipv4LowBitmap is a map for the low 16 bits of a IPv4 address -type ipv4LowBitmap struct { - bitmap [8192]uint32 +type localIPManager struct { + snapshot atomic.Pointer[localIPSnapshot] } func newLocalIPManager() *localIPManager { - return &localIPManager{} + m := &localIPManager{} + m.snapshot.Store(&localIPSnapshot{ + ips: make(map[netip.Addr]struct{}), + }) + return m } -func (m *localIPManager) setBitmapBit(ip net.IP) { - ipv4 := ip.To4() - if ipv4 == nil { - return - } - high := uint16(ipv4[0]) - low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3]) - - index := low / 32 - bit := low % 32 - - if m.ipv4Bitmap[high] == nil { - m.ipv4Bitmap[high] = &ipv4LowBitmap{} - } - - m.ipv4Bitmap[high].bitmap[index] |= 1 << bit -} - -func (m *localIPManager) setBitInBitmap(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) { - if !ip.Is4() { - return - } - ipv4 := ip.AsSlice() - - high := uint16(ipv4[0]) - low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3]) - - if bitmap[high] == nil { - bitmap[high] = &ipv4LowBitmap{} - } - - index := low / 32 - bit := low % 32 - bitmap[high].bitmap[index] |= 1 << bit - - if _, exists := ipv4Set[ip]; !exists { - ipv4Set[ip] = struct{}{} - *ipv4Addresses = append(*ipv4Addresses, ip) - } -} - -func (m *localIPManager) checkBitmapBit(ip []byte) bool { - high := uint16(ip[0]) - low := (uint16(ip[1]) << 8) | (uint16(ip[2]) << 4) | uint16(ip[3]) - - if m.ipv4Bitmap[high] == nil { - return false - } - - index := low / 32 - bit := low % 32 - return (m.ipv4Bitmap[high].bitmap[index] & (1 << bit)) != 0 -} - -func (m *localIPManager) processIP(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) error { - m.setBitInBitmap(ip, bitmap, ipv4Set, ipv4Addresses) - return nil -} - -func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) { +func processInterface(iface net.Interface, ips map[netip.Addr]struct{}, addresses *[]netip.Addr) { addrs, err := iface.Addrs() if err != nil { log.Debugf("get addresses for interface %s failed: %v", iface.Name, err) @@ -104,18 +47,19 @@ func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv continue } - addr, ok := netip.AddrFromSlice(ip) + parsed, ok := netip.AddrFromSlice(ip) if !ok { log.Warnf("invalid IP address %s in interface %s", ip.String(), iface.Name) continue } - if err := m.processIP(addr.Unmap(), bitmap, ipv4Set, ipv4Addresses); err != nil { - log.Debugf("process IP failed: %v", err) - } + parsed = parsed.Unmap() + ips[parsed] = struct{}{} + *addresses = append(*addresses, parsed) } } +// UpdateLocalIPs rebuilds the local IP snapshot and swaps it in atomically. func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) { defer func() { if r := recover(); r != nil { @@ -123,20 +67,20 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) { } }() - var newIPv4Bitmap [256]*ipv4LowBitmap - ipv4Set := make(map[netip.Addr]struct{}) - var ipv4Addresses []netip.Addr + ips := make(map[netip.Addr]struct{}) + var addresses []netip.Addr - // 127.0.0.0/8 - newIPv4Bitmap[127] = &ipv4LowBitmap{} - for i := 0; i < 8192; i++ { - // #nosec G602 -- bitmap is defined as [8192]uint32, loop range is correct - newIPv4Bitmap[127].bitmap[i] = 0xFFFFFFFF - } + // loopback + ips[netip.AddrFrom4([4]byte{127, 0, 0, 1})] = struct{}{} + ips[netip.IPv6Loopback()] = struct{}{} if iface != nil { - if err := m.processIP(iface.Address().IP, &newIPv4Bitmap, ipv4Set, &ipv4Addresses); err != nil { - return err + ip := iface.Address().IP + ips[ip] = struct{}{} + addresses = append(addresses, ip) + if v6 := iface.Address().IPv6; v6.IsValid() { + ips[v6] = struct{}{} + addresses = append(addresses, v6) } } @@ -147,25 +91,24 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) { // TODO: filter out down interfaces (net.FlagUp). Also handle the reverse // case where an interface comes up between refreshes. for _, intf := range interfaces { - m.processInterface(intf, &newIPv4Bitmap, ipv4Set, &ipv4Addresses) + processInterface(intf, ips, &addresses) } } - m.mu.Lock() - m.ipv4Bitmap = newIPv4Bitmap - m.mu.Unlock() + m.snapshot.Store(&localIPSnapshot{ips: ips}) - log.Debugf("Local IPv4 addresses: %v", ipv4Addresses) + log.Debugf("Local IP addresses: %v", addresses) return nil } +// IsLocalIP checks if the given IP is a local address. Lock-free on the read path. func (m *localIPManager) IsLocalIP(ip netip.Addr) bool { - if !ip.Is4() { - return false + s := m.snapshot.Load() + + if ip.Is4() && ip.As4()[0] == 127 { + return true } - m.mu.RLock() - defer m.mu.RUnlock() - - return m.checkBitmapBit(ip.AsSlice()) + _, found := s.ips[ip] + return found } diff --git a/client/firewall/uspfilter/localip_bench_test.go b/client/firewall/uspfilter/localip_bench_test.go new file mode 100644 index 000000000..14e12bd08 --- /dev/null +++ b/client/firewall/uspfilter/localip_bench_test.go @@ -0,0 +1,72 @@ +package uspfilter + +import ( + "net/netip" + "testing" + + "github.com/netbirdio/netbird/client/iface/wgaddr" +) + +func setupManager(b *testing.B) *localIPManager { + b.Helper() + m := newLocalIPManager() + mock := &IFaceMock{ + AddressFunc: func() wgaddr.Address { + return wgaddr.Address{ + IP: netip.MustParseAddr("100.64.0.1"), + Network: netip.MustParsePrefix("100.64.0.0/16"), + IPv6: netip.MustParseAddr("fd00::1"), + IPv6Net: netip.MustParsePrefix("fd00::/64"), + } + }, + } + if err := m.UpdateLocalIPs(mock); err != nil { + b.Fatalf("UpdateLocalIPs: %v", err) + } + return m +} + +func BenchmarkIsLocalIP_v4_hit(b *testing.B) { + m := setupManager(b) + ip := netip.MustParseAddr("100.64.0.1") + b.ResetTimer() + for i := 0; i < b.N; i++ { + m.IsLocalIP(ip) + } +} + +func BenchmarkIsLocalIP_v4_miss(b *testing.B) { + m := setupManager(b) + ip := netip.MustParseAddr("8.8.8.8") + b.ResetTimer() + for i := 0; i < b.N; i++ { + m.IsLocalIP(ip) + } +} + +func BenchmarkIsLocalIP_v6_hit(b *testing.B) { + m := setupManager(b) + ip := netip.MustParseAddr("fd00::1") + b.ResetTimer() + for i := 0; i < b.N; i++ { + m.IsLocalIP(ip) + } +} + +func BenchmarkIsLocalIP_v6_miss(b *testing.B) { + m := setupManager(b) + ip := netip.MustParseAddr("2001:db8::1") + b.ResetTimer() + for i := 0; i < b.N; i++ { + m.IsLocalIP(ip) + } +} + +func BenchmarkIsLocalIP_loopback(b *testing.B) { + m := setupManager(b) + ip := netip.MustParseAddr("127.0.0.1") + b.ResetTimer() + for i := 0; i < b.N; i++ { + m.IsLocalIP(ip) + } +} diff --git a/client/firewall/uspfilter/localip_test.go b/client/firewall/uspfilter/localip_test.go index 6653947fa..0dc524c41 100644 --- a/client/firewall/uspfilter/localip_test.go +++ b/client/firewall/uspfilter/localip_test.go @@ -72,14 +72,45 @@ func TestLocalIPManager(t *testing.T) { expected: false, }, { - name: "IPv6 address", + name: "IPv6 address matches", setupAddr: wgaddr.Address{ - IP: netip.MustParseAddr("fe80::1"), + IP: netip.MustParseAddr("100.64.0.1"), + Network: netip.MustParsePrefix("100.64.0.0/16"), + IPv6: netip.MustParseAddr("fd00::1"), + IPv6Net: netip.MustParsePrefix("fd00::/64"), + }, + testIP: netip.MustParseAddr("fd00::1"), + expected: true, + }, + { + name: "IPv6 address does not match", + setupAddr: wgaddr.Address{ + IP: netip.MustParseAddr("100.64.0.1"), + Network: netip.MustParsePrefix("100.64.0.0/16"), + IPv6: netip.MustParseAddr("fd00::1"), + IPv6Net: netip.MustParsePrefix("fd00::/64"), + }, + testIP: netip.MustParseAddr("fd00::99"), + expected: false, + }, + { + name: "No aliasing between similar IPs", + setupAddr: wgaddr.Address{ + IP: netip.MustParseAddr("192.168.1.1"), Network: netip.MustParsePrefix("192.168.1.0/24"), }, - testIP: netip.MustParseAddr("fe80::1"), + testIP: netip.MustParseAddr("192.168.0.17"), expected: false, }, + { + name: "IPv6 loopback", + setupAddr: wgaddr.Address{ + IP: netip.MustParseAddr("100.64.0.1"), + Network: netip.MustParsePrefix("100.64.0.0/16"), + }, + testIP: netip.MustParseAddr("::1"), + expected: true, + }, } for _, tt := range tests { @@ -171,90 +202,3 @@ func TestLocalIPManager_AllInterfaces(t *testing.T) { }) } } - -// MapImplementation is a version using map[string]struct{} -type MapImplementation struct { - localIPs map[string]struct{} -} - -func BenchmarkIPChecks(b *testing.B) { - interfaces := make([]net.IP, 16) - for i := range interfaces { - interfaces[i] = net.IPv4(10, 0, byte(i>>8), byte(i)) - } - - // Setup bitmap - bitmapManager := newLocalIPManager() - for _, ip := range interfaces[:8] { // Add half of IPs - bitmapManager.setBitmapBit(ip) - } - - // Setup map version - mapManager := &MapImplementation{ - localIPs: make(map[string]struct{}), - } - for _, ip := range interfaces[:8] { - mapManager.localIPs[ip.String()] = struct{}{} - } - - b.Run("Bitmap_Hit", func(b *testing.B) { - ip := interfaces[4] - b.ResetTimer() - for i := 0; i < b.N; i++ { - bitmapManager.checkBitmapBit(ip) - } - }) - - b.Run("Bitmap_Miss", func(b *testing.B) { - ip := interfaces[12] - b.ResetTimer() - for i := 0; i < b.N; i++ { - bitmapManager.checkBitmapBit(ip) - } - }) - - b.Run("Map_Hit", func(b *testing.B) { - ip := interfaces[4] - b.ResetTimer() - for i := 0; i < b.N; i++ { - // nolint:gosimple - _ = mapManager.localIPs[ip.String()] - } - }) - - b.Run("Map_Miss", func(b *testing.B) { - ip := interfaces[12] - b.ResetTimer() - for i := 0; i < b.N; i++ { - // nolint:gosimple - _ = mapManager.localIPs[ip.String()] - } - }) -} - -func BenchmarkWGPosition(b *testing.B) { - wgIP := net.ParseIP("10.10.0.1") - - // Create two managers - one checks WG IP first, other checks it last - b.Run("WG_First", func(b *testing.B) { - bm := newLocalIPManager() - bm.setBitmapBit(wgIP) - b.ResetTimer() - for i := 0; i < b.N; i++ { - bm.checkBitmapBit(wgIP) - } - }) - - b.Run("WG_Last", func(b *testing.B) { - bm := newLocalIPManager() - // Fill with other IPs first - for i := 0; i < 15; i++ { - bm.setBitmapBit(net.IPv4(10, 0, byte(i>>8), byte(i))) - } - bm.setBitmapBit(wgIP) // Add WG IP last - b.ResetTimer() - for i := 0; i < b.N; i++ { - bm.checkBitmapBit(wgIP) - } - }) -} diff --git a/client/firewall/uspfilter/nat.go b/client/firewall/uspfilter/nat.go index 8ed32eb5e..87ef4d4a0 100644 --- a/client/firewall/uspfilter/nat.go +++ b/client/firewall/uspfilter/nat.go @@ -13,8 +13,6 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" ) -var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT") - var ( errInvalidIPHeaderLength = errors.New("invalid IP header length") ) @@ -25,10 +23,33 @@ const ( destinationPortOffset = 2 // IP address offsets in IPv4 header - sourceIPOffset = 12 - destinationIPOffset = 16 + ipv4SrcOffset = 12 + ipv4DstOffset = 16 + + // IP address offsets in IPv6 header + ipv6SrcOffset = 8 + ipv6DstOffset = 24 + + // IPv6 fixed header length + ipv6HeaderLen = 40 ) +// ipHeaderLen returns the IP header length based on the decoded layer type. +func ipHeaderLen(d *decoder) (int, error) { + switch d.decoded[0] { + case layers.LayerTypeIPv4: + n := int(d.ip4.IHL) * 4 + if n < 20 { + return 0, errInvalidIPHeaderLength + } + return n, nil + case layers.LayerTypeIPv6: + return ipv6HeaderLen, nil + default: + return 0, fmt.Errorf("unknown IP layer: %v", d.decoded[0]) + } +} + // ipv4Checksum calculates IPv4 header checksum. func ipv4Checksum(header []byte) uint16 { if len(header) < 20 { @@ -234,14 +255,13 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool { return false } - dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]}) - + _, dstIP := extractPacketIPs(packetData, d) translatedIP, exists := m.getDNATTranslation(dstIP) if !exists { return false } - if err := m.rewritePacketIP(packetData, d, translatedIP, destinationIPOffset); err != nil { + if err := m.rewritePacketIP(packetData, d, translatedIP, false); err != nil { m.logger.Error1("failed to rewrite packet destination: %v", err) return false } @@ -256,14 +276,13 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool { return false } - srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]}) - + srcIP, _ := extractPacketIPs(packetData, d) originalIP, exists := m.findReverseDNATMapping(srcIP) if !exists { return false } - if err := m.rewritePacketIP(packetData, d, originalIP, sourceIPOffset); err != nil { + if err := m.rewritePacketIP(packetData, d, originalIP, true); err != nil { m.logger.Error1("failed to rewrite packet source: %v", err) return false } @@ -272,38 +291,96 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool { return true } -// rewritePacketIP replaces an IP address (source or destination) in the packet and updates checksums. -func (m *Manager) rewritePacketIP(packetData []byte, d *decoder, newIP netip.Addr, ipOffset int) error { +// extractPacketIPs extracts src and dst IP addresses directly from raw packet bytes. +func extractPacketIPs(packetData []byte, d *decoder) (src, dst netip.Addr) { + switch d.decoded[0] { + case layers.LayerTypeIPv4: + src = netip.AddrFrom4([4]byte{packetData[ipv4SrcOffset], packetData[ipv4SrcOffset+1], packetData[ipv4SrcOffset+2], packetData[ipv4SrcOffset+3]}) + dst = netip.AddrFrom4([4]byte{packetData[ipv4DstOffset], packetData[ipv4DstOffset+1], packetData[ipv4DstOffset+2], packetData[ipv4DstOffset+3]}) + case layers.LayerTypeIPv6: + src = netip.AddrFrom16([16]byte(packetData[ipv6SrcOffset : ipv6SrcOffset+16])) + dst = netip.AddrFrom16([16]byte(packetData[ipv6DstOffset : ipv6DstOffset+16])) + } + return src, dst +} + +// rewritePacketIP replaces a source (isSource=true) or destination IP address in the packet and updates checksums. +func (m *Manager) rewritePacketIP(packetData []byte, d *decoder, newIP netip.Addr, isSource bool) error { + hdrLen, err := ipHeaderLen(d) + if err != nil { + return err + } + + switch d.decoded[0] { + case layers.LayerTypeIPv4: + return m.rewriteIPv4(packetData, d, newIP, hdrLen, isSource) + case layers.LayerTypeIPv6: + return m.rewriteIPv6(packetData, d, newIP, hdrLen, isSource) + default: + return fmt.Errorf("unknown IP layer: %v", d.decoded[0]) + } +} + +func (m *Manager) rewriteIPv4(packetData []byte, d *decoder, newIP netip.Addr, hdrLen int, isSource bool) error { if !newIP.Is4() { - return ErrIPv4Only + return fmt.Errorf("cannot write IPv6 address into IPv4 packet") + } + + offset := ipv4DstOffset + if isSource { + offset = ipv4SrcOffset } var oldIP [4]byte - copy(oldIP[:], packetData[ipOffset:ipOffset+4]) + copy(oldIP[:], packetData[offset:offset+4]) newIPBytes := newIP.As4() + copy(packetData[offset:offset+4], newIPBytes[:]) - copy(packetData[ipOffset:ipOffset+4], newIPBytes[:]) - - ipHeaderLen := int(d.ip4.IHL) * 4 - if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { - return errInvalidIPHeaderLength - } - + // Recalculate IPv4 header checksum binary.BigEndian.PutUint16(packetData[10:12], 0) - ipChecksum := ipv4Checksum(packetData[:ipHeaderLen]) - binary.BigEndian.PutUint16(packetData[10:12], ipChecksum) + binary.BigEndian.PutUint16(packetData[10:12], ipv4Checksum(packetData[:hdrLen])) + // Update transport checksums incrementally if len(d.decoded) > 1 { switch d.decoded[1] { case layers.LayerTypeTCP: - m.updateTCPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:]) + m.updateTCPChecksum(packetData, hdrLen, oldIP[:], newIPBytes[:]) case layers.LayerTypeUDP: - m.updateUDPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:]) + m.updateUDPChecksum(packetData, hdrLen, oldIP[:], newIPBytes[:]) case layers.LayerTypeICMPv4: - m.updateICMPChecksum(packetData, ipHeaderLen) + m.updateICMPChecksum(packetData, hdrLen) } } + return nil +} +func (m *Manager) rewriteIPv6(packetData []byte, d *decoder, newIP netip.Addr, hdrLen int, isSource bool) error { + if !newIP.Is6() { + return fmt.Errorf("cannot write IPv4 address into IPv6 packet") + } + + offset := ipv6DstOffset + if isSource { + offset = ipv6SrcOffset + } + + var oldIP [16]byte + copy(oldIP[:], packetData[offset:offset+16]) + newIPBytes := newIP.As16() + copy(packetData[offset:offset+16], newIPBytes[:]) + + // IPv6 has no header checksum, only update transport checksums + if len(d.decoded) > 1 { + switch d.decoded[1] { + case layers.LayerTypeTCP: + m.updateTCPChecksum(packetData, hdrLen, oldIP[:], newIPBytes[:]) + case layers.LayerTypeUDP: + m.updateUDPChecksum(packetData, hdrLen, oldIP[:], newIPBytes[:]) + case layers.LayerTypeICMPv6: + // ICMPv6 checksum includes pseudo-header with addresses, use incremental update + m.updateICMPv6Checksum(packetData, hdrLen, oldIP[:], newIPBytes[:]) + } + } return nil } @@ -351,6 +428,20 @@ func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) { binary.BigEndian.PutUint16(icmpData[2:4], checksum) } +// updateICMPv6Checksum updates ICMPv6 checksum after address change. +// ICMPv6 uses a pseudo-header (like TCP/UDP), so incremental update applies. +func (m *Manager) updateICMPv6Checksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) { + icmpStart := ipHeaderLen + if len(packetData) < icmpStart+4 { + return + } + + checksumOffset := icmpStart + 2 + oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2]) + newChecksum := incrementalUpdate(oldChecksum, oldIP, newIP) + binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum) +} + // incrementalUpdate performs incremental checksum update per RFC 1624. func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 { sum := uint32(^oldChecksum) @@ -532,12 +623,12 @@ func (m *Manager) applyPortRule(packetData []byte, d *decoder, srcIP, dstIP neti // rewriteTCPPort rewrites a TCP port (source or destination) and updates checksum. func (m *Manager) rewriteTCPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error { - ipHeaderLen := int(d.ip4.IHL) * 4 - if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { - return errInvalidIPHeaderLength + hdrLen, err := ipHeaderLen(d) + if err != nil { + return err } - tcpStart := ipHeaderLen + tcpStart := hdrLen if len(packetData) < tcpStart+4 { return fmt.Errorf("packet too short for TCP header") } @@ -563,12 +654,12 @@ func (m *Manager) rewriteTCPPort(packetData []byte, d *decoder, newPort uint16, // rewriteUDPPort rewrites a UDP port (source or destination) and updates checksum. func (m *Manager) rewriteUDPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error { - ipHeaderLen := int(d.ip4.IHL) * 4 - if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { - return errInvalidIPHeaderLength + hdrLen, err := ipHeaderLen(d) + if err != nil { + return err } - udpStart := ipHeaderLen + udpStart := hdrLen if len(packetData) < udpStart+8 { return fmt.Errorf("packet too short for UDP header") } diff --git a/client/firewall/uspfilter/nat_bench_test.go b/client/firewall/uspfilter/nat_bench_test.go index d2599e577..1e15c8c0c 100644 --- a/client/firewall/uspfilter/nat_bench_test.go +++ b/client/firewall/uspfilter/nat_bench_test.go @@ -342,12 +342,17 @@ func BenchmarkDNATMemoryAllocations(b *testing.B) { // Parse the packet fresh each time to get a clean decoder d := &decoder{decoded: []gopacket.LayerType{}} - d.parser = gopacket.NewDecodingLayerParser( + d.parser4 = gopacket.NewDecodingLayerParser( layers.LayerTypeIPv4, &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, ) - d.parser.IgnoreUnsupported = true - err = d.parser.DecodeLayers(testPacket, &d.decoded) + d.parser4.IgnoreUnsupported = true + d.parser6 = gopacket.NewDecodingLayerParser( + layers.LayerTypeIPv6, + &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, + ) + d.parser6.IgnoreUnsupported = true + err = d.decodePacket(testPacket) assert.NoError(b, err) manager.translateOutboundDNAT(testPacket, d) @@ -371,12 +376,17 @@ func BenchmarkDirectIPExtraction(b *testing.B) { b.Run("decoder_extraction", func(b *testing.B) { // Create decoder once for comparison d := &decoder{decoded: []gopacket.LayerType{}} - d.parser = gopacket.NewDecodingLayerParser( + d.parser4 = gopacket.NewDecodingLayerParser( layers.LayerTypeIPv4, &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, ) - d.parser.IgnoreUnsupported = true - err := d.parser.DecodeLayers(packet, &d.decoded) + d.parser4.IgnoreUnsupported = true + d.parser6 = gopacket.NewDecodingLayerParser( + layers.LayerTypeIPv6, + &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, + ) + d.parser6.IgnoreUnsupported = true + err := d.decodePacket(packet) assert.NoError(b, err) for i := 0; i < b.N; i++ { diff --git a/client/firewall/uspfilter/nat_test.go b/client/firewall/uspfilter/nat_test.go index 50743d006..4598c3901 100644 --- a/client/firewall/uspfilter/nat_test.go +++ b/client/firewall/uspfilter/nat_test.go @@ -86,13 +86,18 @@ func parsePacket(t testing.TB, packetData []byte) *decoder { d := &decoder{ decoded: []gopacket.LayerType{}, } - d.parser = gopacket.NewDecodingLayerParser( + d.parser4 = gopacket.NewDecodingLayerParser( layers.LayerTypeIPv4, &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, ) - d.parser.IgnoreUnsupported = true + d.parser4.IgnoreUnsupported = true + d.parser6 = gopacket.NewDecodingLayerParser( + layers.LayerTypeIPv6, + &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, + ) + d.parser6.IgnoreUnsupported = true - err := d.parser.DecodeLayers(packetData, &d.decoded) + err := d.decodePacket(packetData) require.NoError(t, err) return d } diff --git a/client/firewall/uspfilter/tracer.go b/client/firewall/uspfilter/tracer.go index 69c2519bf..3b066c160 100644 --- a/client/firewall/uspfilter/tracer.go +++ b/client/firewall/uspfilter/tracer.go @@ -112,10 +112,13 @@ func (t *PacketTrace) AddResultWithForwarder(stage PacketStage, message string, } func (p *PacketBuilder) Build() ([]byte, error) { - ip := p.buildIPLayer() - pktLayers := []gopacket.SerializableLayer{ip} + ipLayer, err := p.buildIPLayer() + if err != nil { + return nil, err + } + pktLayers := []gopacket.SerializableLayer{ipLayer} - transportLayer, err := p.buildTransportLayer(ip) + transportLayer, err := p.buildTransportLayer(ipLayer) if err != nil { return nil, err } @@ -129,30 +132,43 @@ func (p *PacketBuilder) Build() ([]byte, error) { return serializePacket(pktLayers) } -func (p *PacketBuilder) buildIPLayer() *layers.IPv4 { +func (p *PacketBuilder) buildIPLayer() (gopacket.SerializableLayer, error) { + if p.SrcIP.Is4() != p.DstIP.Is4() { + return nil, fmt.Errorf("mixed address families: src=%s dst=%s", p.SrcIP, p.DstIP) + } + proto := getIPProtocolNumber(p.Protocol, p.SrcIP.Is6()) + if p.SrcIP.Is6() { + return &layers.IPv6{ + Version: 6, + HopLimit: 64, + NextHeader: proto, + SrcIP: p.SrcIP.AsSlice(), + DstIP: p.DstIP.AsSlice(), + }, nil + } return &layers.IPv4{ Version: 4, TTL: 64, - Protocol: layers.IPProtocol(getIPProtocolNumber(p.Protocol)), + Protocol: proto, SrcIP: p.SrcIP.AsSlice(), DstIP: p.DstIP.AsSlice(), - } + }, nil } -func (p *PacketBuilder) buildTransportLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) { +func (p *PacketBuilder) buildTransportLayer(ipLayer gopacket.SerializableLayer) ([]gopacket.SerializableLayer, error) { switch p.Protocol { case "tcp": - return p.buildTCPLayer(ip) + return p.buildTCPLayer(ipLayer) case "udp": - return p.buildUDPLayer(ip) + return p.buildUDPLayer(ipLayer) case "icmp": - return p.buildICMPLayer() + return p.buildICMPLayer(ipLayer) default: return nil, fmt.Errorf("unsupported protocol: %s", p.Protocol) } } -func (p *PacketBuilder) buildTCPLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) { +func (p *PacketBuilder) buildTCPLayer(ipLayer gopacket.SerializableLayer) ([]gopacket.SerializableLayer, error) { tcp := &layers.TCP{ SrcPort: layers.TCPPort(p.SrcPort), DstPort: layers.TCPPort(p.DstPort), @@ -164,24 +180,44 @@ func (p *PacketBuilder) buildTCPLayer(ip *layers.IPv4) ([]gopacket.SerializableL PSH: p.TCPState != nil && p.TCPState.PSH, URG: p.TCPState != nil && p.TCPState.URG, } - if err := tcp.SetNetworkLayerForChecksum(ip); err != nil { - return nil, fmt.Errorf("set network layer for TCP checksum: %w", err) + if nl, ok := ipLayer.(gopacket.NetworkLayer); ok { + if err := tcp.SetNetworkLayerForChecksum(nl); err != nil { + return nil, fmt.Errorf("set network layer for TCP checksum: %w", err) + } } return []gopacket.SerializableLayer{tcp}, nil } -func (p *PacketBuilder) buildUDPLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) { +func (p *PacketBuilder) buildUDPLayer(ipLayer gopacket.SerializableLayer) ([]gopacket.SerializableLayer, error) { udp := &layers.UDP{ SrcPort: layers.UDPPort(p.SrcPort), DstPort: layers.UDPPort(p.DstPort), } - if err := udp.SetNetworkLayerForChecksum(ip); err != nil { - return nil, fmt.Errorf("set network layer for UDP checksum: %w", err) + if nl, ok := ipLayer.(gopacket.NetworkLayer); ok { + if err := udp.SetNetworkLayerForChecksum(nl); err != nil { + return nil, fmt.Errorf("set network layer for UDP checksum: %w", err) + } } return []gopacket.SerializableLayer{udp}, nil } -func (p *PacketBuilder) buildICMPLayer() ([]gopacket.SerializableLayer, error) { +func (p *PacketBuilder) buildICMPLayer(ipLayer gopacket.SerializableLayer) ([]gopacket.SerializableLayer, error) { + if p.SrcIP.Is6() || p.DstIP.Is6() { + icmp := &layers.ICMPv6{ + TypeCode: layers.CreateICMPv6TypeCode(p.ICMPType, p.ICMPCode), + } + if nl, ok := ipLayer.(gopacket.NetworkLayer); ok { + _ = icmp.SetNetworkLayerForChecksum(nl) + } + if p.ICMPType == layers.ICMPv6TypeEchoRequest || p.ICMPType == layers.ICMPv6TypeEchoReply { + echo := &layers.ICMPv6Echo{ + Identifier: 1, + SeqNumber: 1, + } + return []gopacket.SerializableLayer{icmp, echo}, nil + } + return []gopacket.SerializableLayer{icmp}, nil + } icmp := &layers.ICMPv4{ TypeCode: layers.CreateICMPv4TypeCode(p.ICMPType, p.ICMPCode), } @@ -204,14 +240,17 @@ func serializePacket(layers []gopacket.SerializableLayer) ([]byte, error) { return buf.Bytes(), nil } -func getIPProtocolNumber(protocol fw.Protocol) int { +func getIPProtocolNumber(protocol fw.Protocol, isV6 bool) layers.IPProtocol { switch protocol { case fw.ProtocolTCP: - return int(layers.IPProtocolTCP) + return layers.IPProtocolTCP case fw.ProtocolUDP: - return int(layers.IPProtocolUDP) + return layers.IPProtocolUDP case fw.ProtocolICMP: - return int(layers.IPProtocolICMPv4) + if isV6 { + return layers.IPProtocolICMPv6 + } + return layers.IPProtocolICMPv4 default: return 0 } @@ -234,7 +273,7 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa trace := &PacketTrace{Direction: direction} // Initial packet decoding - if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + if err := d.decodePacket(packetData); err != nil { trace.AddResult(StageReceived, fmt.Sprintf("Failed to decode packet: %v", err), false) return trace } @@ -256,6 +295,8 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa trace.DestinationPort = uint16(d.udp.DstPort) case layers.LayerTypeICMPv4: trace.Protocol = "ICMP" + case layers.LayerTypeICMPv6: + trace.Protocol = "ICMPv6" } trace.AddResult(StageReceived, fmt.Sprintf("Received %s packet: %s:%d -> %s:%d", @@ -319,6 +360,13 @@ func (m *Manager) buildConntrackStateMessage(d *decoder) string { flags&conntrack.TCPFin != 0) case layers.LayerTypeICMPv4: msg += fmt.Sprintf(" (ICMP ID=%d, Seq=%d)", d.icmp4.Id, d.icmp4.Seq) + case layers.LayerTypeICMPv6: + var id, seq uint16 + if len(d.icmp6.Payload) >= 4 { + id = uint16(d.icmp6.Payload[0])<<8 | uint16(d.icmp6.Payload[1]) + seq = uint16(d.icmp6.Payload[2])<<8 | uint16(d.icmp6.Payload[3]) + } + msg += fmt.Sprintf(" (ICMPv6 ID=%d, Seq=%d)", id, seq) } return msg } @@ -415,7 +463,7 @@ func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTr d := m.decoders.Get().(*decoder) defer m.decoders.Put(d) - if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + if err := d.decodePacket(packetData); err != nil { trace.AddResult(StageCompleted, "Packet dropped - decode error", false) return trace } @@ -434,7 +482,7 @@ func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTr func (m *Manager) handleInboundDNAT(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP *netip.Addr) bool { portDNATApplied := m.traceInboundPortDNAT(trace, packetData, d) if portDNATApplied { - if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + if err := d.decodePacket(packetData); err != nil { trace.AddResult(StageInboundPortDNAT, "Failed to re-decode after port DNAT", false) return true } @@ -444,7 +492,7 @@ func (m *Manager) handleInboundDNAT(trace *PacketTrace, packetData []byte, d *de nat1to1Applied := m.traceInbound1to1NAT(trace, packetData, d) if nat1to1Applied { - if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + if err := d.decodePacket(packetData); err != nil { trace.AddResult(StageInbound1to1NAT, "Failed to re-decode after 1:1 NAT", false) return true } @@ -509,7 +557,7 @@ func (m *Manager) traceInbound1to1NAT(trace *PacketTrace, packetData []byte, d * return false } - srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]}) + srcIP, _ := extractPacketIPs(packetData, d) translated := m.translateInboundReverse(packetData, d) if translated { @@ -539,7 +587,7 @@ func (m *Manager) traceOutbound1to1NAT(trace *PacketTrace, packetData []byte, d return false } - dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]}) + _, dstIP := extractPacketIPs(packetData, d) translated := m.translateOutboundDNAT(packetData, d) if translated { diff --git a/client/iface/configurer/usp.go b/client/iface/configurer/usp.go index e3a96590c..9b070aab8 100644 --- a/client/iface/configurer/usp.go +++ b/client/iface/configurer/usp.go @@ -119,7 +119,7 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, if err != nil { return fmt.Errorf("failed to parse endpoint address: %w", err) } - addrPort := netip.AddrPortFrom(addr, uint16(endpoint.Port)) + addrPort := netip.AddrPortFrom(addr.Unmap(), uint16(endpoint.Port)) c.activityRecorder.UpsertAddress(peerKey, addrPort) } return nil diff --git a/client/iface/device/adapter.go b/client/iface/device/adapter.go index 6ebc05390..e3caaf930 100644 --- a/client/iface/device/adapter.go +++ b/client/iface/device/adapter.go @@ -2,7 +2,7 @@ package device // TunAdapter is an interface for create tun device from external service type TunAdapter interface { - ConfigureInterface(address string, mtu int, dns string, searchDomains string, routes string) (int, error) + ConfigureInterface(address string, addressV6 string, mtu int, dns string, searchDomains string, routes string) (int, error) UpdateAddr(address string) error ProtectSocket(fd int32) bool } diff --git a/client/iface/device/device_android.go b/client/iface/device/device_android.go index 198343fbd..cbe88c10c 100644 --- a/client/iface/device/device_android.go +++ b/client/iface/device/device_android.go @@ -63,7 +63,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string searchDomainsToString = "" } - fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), int(t.mtu), dns, searchDomainsToString, routesString) + fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.address.IPv6String(), int(t.mtu), dns, searchDomainsToString, routesString) if err != nil { log.Errorf("failed to create Android interface: %s", err) return nil, err diff --git a/client/iface/wgproxy/bind/proxy.go b/client/iface/wgproxy/bind/proxy.go index 9ac3ea6df..5bf670e07 100644 --- a/client/iface/wgproxy/bind/proxy.go +++ b/client/iface/wgproxy/bind/proxy.go @@ -6,7 +6,7 @@ import ( "fmt" "net" "net/netip" - "strings" + "sync" log "github.com/sirupsen/logrus" @@ -196,18 +196,22 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) { } } -// fakeAddress returns a fake address that is used to as an identifier for the peer. -// The fake address is in the format of 127.1.x.x where x.x is the last two octets of the peer address. +// fakeAddress returns a fake address that is used as an identifier for the peer. +// The fake address is in the format of 127.1.x.x where x.x is derived from the +// last two bytes of the peer address (works for both IPv4 and IPv6). func fakeAddress(peerAddress *net.UDPAddr) (*netip.AddrPort, error) { - octets := strings.Split(peerAddress.IP.String(), ".") - if len(octets) != 4 { - return nil, fmt.Errorf("invalid IP format") + if peerAddress == nil { + return nil, fmt.Errorf("nil peer address") } - fakeIP, err := netip.ParseAddr(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3])) - if err != nil { - return nil, fmt.Errorf("parse new IP: %w", err) + addr, ok := netip.AddrFromSlice(peerAddress.IP) + if !ok { + return nil, fmt.Errorf("invalid IP format") } + addr = addr.Unmap() + + raw := addr.As16() + fakeIP := netip.AddrFrom4([4]byte{127, 1, raw[14], raw[15]}) netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port)) return &netipAddr, nil diff --git a/client/internal/acl/manager.go b/client/internal/acl/manager.go index 54a97e38f..c54a3e897 100644 --- a/client/internal/acl/manager.go +++ b/client/internal/acl/manager.go @@ -5,7 +5,6 @@ import ( "encoding/hex" "errors" "fmt" - "net" "net/netip" "strconv" "sync" @@ -19,6 +18,7 @@ import ( "github.com/netbirdio/netbird/client/internal/acl/id" "github.com/netbirdio/netbird/shared/management/domain" mgmProto "github.com/netbirdio/netbird/shared/management/proto" + "github.com/netbirdio/netbird/shared/netiputil" ) var ErrSourceRangesEmpty = errors.New("sources range is empty") @@ -105,6 +105,10 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) { newRulePairs := make(map[id.RuleID][]firewall.Rule) ipsetByRuleSelectors := make(map[string]string) + // TODO: deny rules should be fatal: if a deny rule fails to apply, we must + // roll back all allow rules to avoid a fail-open where allowed traffic bypasses + // the missing deny. Currently we accumulate errors and continue. + var merr *multierror.Error for _, r := range rules { // if this rule is member of rule selection with more than DefaultIPsCountForSet // it's IP address can be used in the ipset for firewall manager which supports it @@ -117,9 +121,8 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) { } pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName) if err != nil { - log.Errorf("failed to apply firewall rule: %+v, %v", r, err) - d.rollBack(newRulePairs) - break + merr = multierror.Append(merr, fmt.Errorf("apply firewall rule: %w", err)) + continue } if len(rulePair) > 0 { d.peerRulesPairs[pairID] = rulePair @@ -127,6 +130,10 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) { } } + if merr != nil { + log.Errorf("failed to apply %d peer ACL rule(s): %v", merr.Len(), nberrors.FormatErrorOrNil(merr)) + } + for pairID, rules := range d.peerRulesPairs { if _, ok := newRulePairs[pairID]; !ok { for _, rule := range rules { @@ -216,10 +223,9 @@ func (d *DefaultManager) protoRuleToFirewallRule( r *mgmProto.FirewallRule, ipsetName string, ) (id.RuleID, []firewall.Rule, error) { - //nolint:staticcheck // PeerIP used for backward compatibility with old management - ip := net.ParseIP(r.PeerIP) - if ip == nil { - return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule") + ip, err := extractRuleIP(r) + if err != nil { + return "", nil, err } protocol, err := convertToFirewallProtocol(r.Protocol) @@ -290,13 +296,13 @@ func portInfoEmpty(portInfo *mgmProto.PortInfo) bool { func (d *DefaultManager) addInRules( id []byte, - ip net.IP, + ip netip.Addr, protocol firewall.Protocol, port *firewall.Port, action firewall.Action, ipsetName string, ) ([]firewall.Rule, error) { - rule, err := d.firewall.AddPeerFiltering(id, ip, protocol, nil, port, action, ipsetName) + rule, err := d.firewall.AddPeerFiltering(id, ip.AsSlice(), protocol, nil, port, action, ipsetName) if err != nil { return nil, fmt.Errorf("add firewall rule: %w", err) } @@ -306,7 +312,7 @@ func (d *DefaultManager) addInRules( func (d *DefaultManager) addOutRules( id []byte, - ip net.IP, + ip netip.Addr, protocol firewall.Protocol, port *firewall.Port, action firewall.Action, @@ -316,7 +322,7 @@ func (d *DefaultManager) addOutRules( return nil, nil } - rule, err := d.firewall.AddPeerFiltering(id, ip, protocol, port, nil, action, ipsetName) + rule, err := d.firewall.AddPeerFiltering(id, ip.AsSlice(), protocol, port, nil, action, ipsetName) if err != nil { return nil, fmt.Errorf("add firewall rule: %w", err) } @@ -324,9 +330,9 @@ func (d *DefaultManager) addOutRules( return rule, nil } -// getPeerRuleID() returns unique ID for the rule based on its parameters. +// getPeerRuleID returns unique ID for the rule based on its parameters. func (d *DefaultManager) getPeerRuleID( - ip net.IP, + ip netip.Addr, proto firewall.Protocol, direction int, port *firewall.Port, @@ -345,15 +351,25 @@ func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) st return fmt.Sprintf("%v:%v:%v:%s:%v", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port, rule.PortInfo) } -func (d *DefaultManager) rollBack(newRulePairs map[id.RuleID][]firewall.Rule) { - log.Debugf("rollback ACL to previous state") - for _, rules := range newRulePairs { - for _, rule := range rules { - if err := d.firewall.DeletePeerRule(rule); err != nil { - log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.ID(), err) - } + +// extractRuleIP extracts the peer IP from a firewall rule. +// If sourcePrefixes is populated (new management), decode the first entry and use its address. +// Otherwise fall back to the deprecated PeerIP string field (old management). +func extractRuleIP(r *mgmProto.FirewallRule) (netip.Addr, error) { + if len(r.SourcePrefixes) > 0 { + addr, err := netiputil.DecodeAddr(r.SourcePrefixes[0]) + if err != nil { + return netip.Addr{}, fmt.Errorf("decode source prefix: %w", err) } + return addr.Unmap(), nil } + + //nolint:staticcheck // PeerIP used for backward compatibility with old management + addr, err := netip.ParseAddr(r.PeerIP) + if err != nil { + return netip.Addr{}, fmt.Errorf("invalid IP address, skipping firewall rule") + } + return addr.Unmap(), nil } func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewall.Protocol, error) { diff --git a/client/internal/debug/debug_test.go b/client/internal/debug/debug_test.go index 59837c328..e242b8b1b 100644 --- a/client/internal/debug/debug_test.go +++ b/client/internal/debug/debug_test.go @@ -430,8 +430,6 @@ func isInCGNATRange(ip net.IP) bool { } func TestAnonymizeFirewallRules(t *testing.T) { - // TODO: Add ipv6 - // Example iptables-save output iptablesSave := `# Generated by iptables-save v1.8.7 on Thu Dec 19 10:00:00 2024 *filter @@ -467,17 +465,31 @@ Chain FORWARD (policy ACCEPT 0 packets, 0 bytes) Chain OUTPUT (policy ACCEPT 0 packets, 0 bytes) pkts bytes target prot opt in out source destination` - // Example nftables output + // Example ip6tables-save output + ip6tablesSave := `# Generated by ip6tables-save v1.8.7 on Thu Dec 19 10:00:00 2024 +*filter +:INPUT ACCEPT [0:0] +:FORWARD ACCEPT [0:0] +:OUTPUT ACCEPT [0:0] +-A INPUT -s fd00:1234::1/128 -j ACCEPT +-A INPUT -s 2607:f8b0:4005::1/128 -j DROP +-A FORWARD -s 2001:db8::/32 -d 2607:f8b0:4005::200e/128 -j ACCEPT +COMMIT` + + // Example nftables output with IPv6 nftablesRules := `table inet filter { chain input { type filter hook input priority filter; policy accept; ip saddr 192.168.1.1 accept ip saddr 44.192.140.1 drop + ip6 saddr 2607:f8b0:4005::1 drop + ip6 saddr fd00:1234::1 accept } chain forward { type filter hook forward priority filter; policy accept; ip saddr 10.0.0.0/8 drop ip saddr 44.192.140.0/24 ip daddr 52.84.12.34/24 accept + ip6 saddr 2001:db8::/32 ip6 daddr 2607:f8b0:4005::200e/128 accept } }` @@ -540,4 +552,35 @@ Chain OUTPUT (policy ACCEPT 0 packets, 0 bytes) assert.Contains(t, anonNftables, "table inet filter {") assert.Contains(t, anonNftables, "chain input {") assert.Contains(t, anonNftables, "type filter hook input priority filter; policy accept;") + + // IPv6 public addresses in nftables should be anonymized + assert.NotContains(t, anonNftables, "2607:f8b0:4005::1") + assert.NotContains(t, anonNftables, "2607:f8b0:4005::200e") + assert.NotContains(t, anonNftables, "2001:db8::") + assert.Contains(t, anonNftables, "2001:db8:ffff::") // Default anonymous v6 range + + // ULA addresses in nftables should remain unchanged (private) + assert.Contains(t, anonNftables, "fd00:1234::1") + + // IPv6 nftables structure preserved + assert.Contains(t, anonNftables, "ip6 saddr") + assert.Contains(t, anonNftables, "ip6 daddr") + + // Test ip6tables-save anonymization + anonIp6tablesSave := anonymizer.AnonymizeString(ip6tablesSave) + + // ULA (private) IPv6 should remain unchanged + assert.Contains(t, anonIp6tablesSave, "fd00:1234::1/128") + + // Public IPv6 addresses should be anonymized + assert.NotContains(t, anonIp6tablesSave, "2607:f8b0:4005::1") + assert.NotContains(t, anonIp6tablesSave, "2607:f8b0:4005::200e") + assert.NotContains(t, anonIp6tablesSave, "2001:db8::") + assert.Contains(t, anonIp6tablesSave, "2001:db8:ffff::") // Default anonymous v6 range + + // Structure should be preserved + assert.Contains(t, anonIp6tablesSave, "*filter") + assert.Contains(t, anonIp6tablesSave, "COMMIT") + assert.Contains(t, anonIp6tablesSave, "-j DROP") + assert.Contains(t, anonIp6tablesSave, "-j ACCEPT") } diff --git a/client/internal/dns/service_listener.go b/client/internal/dns/service_listener.go index 4e09f1b7f..551555ad4 100644 --- a/client/internal/dns/service_listener.go +++ b/client/internal/dns/service_listener.go @@ -189,10 +189,10 @@ func (s *serviceViaListener) RuntimeIP() netip.Addr { } -// evalListenAddress figure out the listen address for the DNS server -// first check the 53 port availability on WG interface or lo, if not success -// pick a random port on WG interface for eBPF, if not success -// check the 5053 port availability on WG interface or lo without eBPF usage, +// evalListenAddress figures out the listen address for the DNS server. +// IPv4-only: all peers have a v4 overlay address, and DNS config points to v4. +// First checks port 53 on WG interface or lo, then tries eBPF on a random port, +// then falls back to port 5053. func (s *serviceViaListener) evalListenAddress() (netip.Addr, uint16, error) { if s.customAddr != nil { return s.customAddr.Addr(), s.customAddr.Port(), nil @@ -278,7 +278,7 @@ func (s *serviceViaListener) tryToUseeBPF() (ebpfMgr.Manager, uint16, bool) { } ebpfSrv := ebpf.GetEbpfManagerInstance() - err = ebpfSrv.LoadDNSFwd(s.wgInterface.Address().IP.String(), int(port)) + err = ebpfSrv.LoadDNSFwd(s.wgInterface.Address().IP, int(port)) if err != nil { log.Warnf("failed to load DNS forwarder eBPF program, error: %s", err) return nil, 0, false diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index 746b73ca7..a26536f6e 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -21,6 +21,7 @@ import ( "golang.zx2c4.com/wireguard/tun/netstack" "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/dns/resutil" "github.com/netbirdio/netbird/client/internal/dns/types" "github.com/netbirdio/netbird/client/internal/peer" @@ -29,6 +30,12 @@ import ( var currentMTU uint16 = iface.DefaultMTU +// privateClientIface is the subset of the WireGuard interface needed by GetClientPrivate. +type privateClientIface interface { + Name() string + Address() wgaddr.Address +} + func SetCurrentMTU(mtu uint16) { currentMTU = mtu } diff --git a/client/internal/dns/upstream_android.go b/client/internal/dns/upstream_android.go index ee1ca42fe..988adb7d2 100644 --- a/client/internal/dns/upstream_android.go +++ b/client/internal/dns/upstream_android.go @@ -86,7 +86,7 @@ func (u *upstreamResolver) isLocalResolver(upstream string) bool { return false } -func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) { +func GetClientPrivate(_ privateClientIface, _ netip.Addr, dialTimeout time.Duration) (*dns.Client, error) { return &dns.Client{ Timeout: dialTimeout, Net: "udp", diff --git a/client/internal/dns/upstream_general.go b/client/internal/dns/upstream_general.go index 1143b6c51..910c3779e 100644 --- a/client/internal/dns/upstream_general.go +++ b/client/internal/dns/upstream_general.go @@ -52,7 +52,7 @@ func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns return ExchangeWithFallback(ctx, client, r, upstream) } -func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) { +func GetClientPrivate(_ privateClientIface, _ netip.Addr, dialTimeout time.Duration) (*dns.Client, error) { return &dns.Client{ Timeout: dialTimeout, Net: "udp", diff --git a/client/internal/dns/upstream_ios.go b/client/internal/dns/upstream_ios.go index 26b19dac3..0e04742a0 100644 --- a/client/internal/dns/upstream_ios.go +++ b/client/internal/dns/upstream_ios.go @@ -19,11 +19,7 @@ import ( type upstreamResolverIOS struct { *upstreamResolverBase - lIP netip.Addr - lNet netip.Prefix - lIPv6 netip.Addr - lNetV6 netip.Prefix - interfaceName string + wgIface WGIface } func newUpstreamResolver( @@ -37,11 +33,7 @@ func newUpstreamResolver( ios := &upstreamResolverIOS{ upstreamResolverBase: upstreamResolverBase, - lIP: wgIface.Address().IP, - lNet: wgIface.Address().Network, - lIPv6: wgIface.Address().IPv6, - lNetV6: wgIface.Address().IPv6Net, - interfaceName: wgIface.Name(), + wgIface: wgIface, } ios.upstreamClient = ios @@ -69,24 +61,15 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r * } else { upstreamIP = upstreamIP.Unmap() } - needsPrivate := u.lNet.Contains(upstreamIP) || - u.lNetV6.Contains(upstreamIP) || + addr := u.wgIface.Address() + needsPrivate := addr.Network.Contains(upstreamIP) || + addr.IPv6Net.Contains(upstreamIP) || (u.routeMatch != nil && u.routeMatch(upstreamIP)) if needsPrivate { - var bindIP netip.Addr - switch { - case upstreamIP.Is6() && u.lIPv6.IsValid(): - bindIP = u.lIPv6 - case upstreamIP.Is4() && u.lIP.IsValid(): - bindIP = u.lIP - } - - if bindIP.IsValid() { - log.Debugf("using private client to query %s via upstream %s", r.Question[0].Name, upstream) - client, err = GetClientPrivate(bindIP, u.interfaceName, timeout) - if err != nil { - return nil, 0, fmt.Errorf("create private client: %s", err) - } + log.Debugf("using private client to query %s via upstream %s", r.Question[0].Name, upstream) + client, err = GetClientPrivate(u.wgIface, upstreamIP, timeout) + if err != nil { + return nil, 0, fmt.Errorf("create private client: %s", err) } } @@ -94,23 +77,29 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r * return ExchangeWithFallback(nil, client, r, upstream) } -// GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface -// This method is needed for iOS -func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) { - index, err := getInterfaceIndex(interfaceName) +// GetClientPrivate returns a new DNS client bound to the local IP of the Netbird interface. +// It selects the v6 bind address when the upstream is IPv6 and the interface has one, otherwise v4. +func GetClientPrivate(iface privateClientIface, upstreamIP netip.Addr, dialTimeout time.Duration) (*dns.Client, error) { + index, err := getInterfaceIndex(iface.Name()) if err != nil { - log.Debugf("unable to get interface index for %s: %s", interfaceName, err) + log.Debugf("unable to get interface index for %s: %s", iface.Name(), err) return nil, err } + addr := iface.Address() + bindIP := addr.IP + if upstreamIP.Is6() && addr.HasIPv6() { + bindIP = addr.IPv6 + } + proto, opt := unix.IPPROTO_IP, unix.IP_BOUND_IF - if ip.Is6() { + if bindIP.Is6() { proto, opt = unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF } dialer := &net.Dialer{ - LocalAddr: net.UDPAddrFromAddrPort(netip.AddrPortFrom(ip, 0)), - Timeout: dialTimeout, + LocalAddr: net.UDPAddrFromAddrPort(netip.AddrPortFrom(bindIP, 0)), + Timeout: dialTimeout, Control: func(network, address string, c syscall.RawConn) error { var operr error fn := func(s uintptr) { diff --git a/client/internal/dnsfwd/manager.go b/client/internal/dnsfwd/manager.go index 58b88d9ef..c4c16cd3f 100644 --- a/client/internal/dnsfwd/manager.go +++ b/client/internal/dnsfwd/manager.go @@ -80,6 +80,7 @@ func (m *Manager) Start(fwdEntries []*ForwarderEntry) error { return err } + // IPv4-only: peers reach the forwarder via its v4 overlay address. localAddr := m.wgIface.Address().IP if localAddr.IsValid() && m.firewall != nil { diff --git a/client/internal/ebpf/ebpf/dns_fwd_linux.go b/client/internal/ebpf/ebpf/dns_fwd_linux.go index 93797da76..1e7774573 100644 --- a/client/internal/ebpf/ebpf/dns_fwd_linux.go +++ b/client/internal/ebpf/ebpf/dns_fwd_linux.go @@ -2,7 +2,8 @@ package ebpf import ( "encoding/binary" - "net" + "fmt" + "net/netip" log "github.com/sirupsen/logrus" ) @@ -12,7 +13,7 @@ const ( mapKeyDNSPort uint32 = 1 ) -func (tf *GeneralManager) LoadDNSFwd(ip string, dnsPort int) error { +func (tf *GeneralManager) LoadDNSFwd(ip netip.Addr, dnsPort int) error { log.Debugf("load eBPF DNS forwarder, watching addr: %s:53, redirect to port: %d", ip, dnsPort) tf.lock.Lock() defer tf.lock.Unlock() @@ -22,7 +23,11 @@ func (tf *GeneralManager) LoadDNSFwd(ip string, dnsPort int) error { return err } - err = tf.bpfObjs.NbMapDnsIp.Put(mapKeyDNSIP, ip2int(ip)) + if !ip.Is4() { + return fmt.Errorf("eBPF DNS forwarder only supports IPv4, got %s", ip) + } + ip4 := ip.As4() + err = tf.bpfObjs.NbMapDnsIp.Put(mapKeyDNSIP, binary.BigEndian.Uint32(ip4[:])) if err != nil { return err } @@ -45,7 +50,3 @@ func (tf *GeneralManager) FreeDNSFwd() error { return tf.unsetFeatureFlag(featureFlagDnsForwarder) } -func ip2int(ipString string) uint32 { - ip := net.ParseIP(ipString) - return binary.BigEndian.Uint32(ip.To4()) -} diff --git a/client/internal/ebpf/manager/manager.go b/client/internal/ebpf/manager/manager.go index af10142d5..25a767090 100644 --- a/client/internal/ebpf/manager/manager.go +++ b/client/internal/ebpf/manager/manager.go @@ -1,8 +1,10 @@ package manager +import "net/netip" + // Manager is used to load multiple eBPF programs. E.g., current DNS programs and WireGuard proxy type Manager interface { - LoadDNSFwd(ip string, dnsPort int) error + LoadDNSFwd(ip netip.Addr, dnsPort int) error FreeDNSFwd() error LoadWgProxy(proxyPort, wgPort int) error FreeWGProxy() error diff --git a/client/internal/engine.go b/client/internal/engine.go index 35c49fe3f..16410519b 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -630,7 +630,7 @@ func (e *Engine) initFirewall() error { rosenpassPort := e.rpManager.GetAddress().Port port := firewallManager.Port{Values: []uint16{uint16(rosenpassPort)}} - // this rule is static and will be torn down on engine down by the firewall manager + // IPv4-only: rosenpass peers connect via AllowedIps[0] which is always v4. if _, err := e.firewall.AddPeerFiltering( nil, net.IP{0, 0, 0, 0}, @@ -682,10 +682,15 @@ func (e *Engine) blockLanAccess() { log.Infof("blocking route LAN access for networks: %v", toBlock) v4 := netip.PrefixFrom(netip.IPv4Unspecified(), 0) + v6 := netip.PrefixFrom(netip.IPv6Unspecified(), 0) for _, network := range toBlock { + source := v4 + if network.Addr().Is6() { + source = v6 + } if _, err := e.firewall.AddRouteFiltering( nil, - []netip.Prefix{v4}, + []netip.Prefix{source}, firewallManager.Network{Prefix: network}, firewallManager.ProtocolALL, nil, @@ -1494,10 +1499,10 @@ func (e *Engine) updateOfflinePeers(offlinePeers []*mgmProto.RemotePeerConfig) { replacement := make([]peer.State, len(offlinePeers)) for i, offlinePeer := range offlinePeers { log.Debugf("added offline peer %s", offlinePeer.Fqdn) - v4, v6 := splitAllowedIPs(offlinePeer.GetAllowedIps(), e.wgInterface.Address().IPv6Net) + v4, v6 := overlayAddrsFromAllowedIPs(offlinePeer.GetAllowedIps(), e.wgInterface.Address().IPv6Net) replacement[i] = peer.State{ - IP: v4, - IPv6: v6, + IP: addrToString(v4), + IPv6: addrToString(v6), PubKey: offlinePeer.GetWgPubKey(), FQDN: offlinePeer.GetFqdn(), ConnStatus: peer.StatusIdle, @@ -1508,30 +1513,37 @@ func (e *Engine) updateOfflinePeers(offlinePeers []*mgmProto.RemotePeerConfig) { e.statusRecorder.ReplaceOfflinePeers(replacement) } -// splitAllowedIPs separates the peer's overlay v4 (/32) and v6 (/128) addresses -// from a list of AllowedIPs CIDRs. The v6 address is only matched if it falls -// within ourV6Net (the local overlay v6 subnet), to avoid confusing routed /128 -// prefixes with the peer's overlay address. -func splitAllowedIPs(allowedIPs []string, ourV6Net netip.Prefix) (v4, v6 string) { +// overlayAddrsFromAllowedIPs extracts the peer's v4 and v6 overlay addresses +// from AllowedIPs strings. Only host routes (/32, /128) are considered; v6 must +// fall within ourV6Net to distinguish overlay addresses from routed prefixes. +func overlayAddrsFromAllowedIPs(allowedIPs []string, ourV6Net netip.Prefix) (v4, v6 netip.Addr) { for _, cidr := range allowedIPs { prefix, err := netip.ParsePrefix(cidr) if err != nil { log.Warnf("failed to parse AllowedIP %q: %v", cidr, err) continue } + addr := prefix.Addr().Unmap() switch { - case prefix.Addr().Is4() && prefix.Bits() == 32 && v4 == "": - v4 = prefix.Addr().String() - case prefix.Addr().Is6() && prefix.Bits() == 128 && ourV6Net.Contains(prefix.Addr()) && v6 == "": - v6 = prefix.Addr().String() + case addr.Is4() && prefix.Bits() == 32 && !v4.IsValid(): + v4 = addr + case addr.Is6() && prefix.Bits() == 128 && ourV6Net.Contains(addr) && !v6.IsValid(): + v6 = addr } - if v4 != "" && v6 != "" { + if v4.IsValid() && v6.IsValid() { break } } return } +func addrToString(addr netip.Addr) string { + if !addr.IsValid() { + return "" + } + return addr.String() +} + // addNewPeers adds peers that were not know before but arrived from the Management service with the update func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig) error { for _, p := range peersUpdate { @@ -1572,8 +1584,8 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error { return fmt.Errorf("create peer connection: %w", err) } - peerV4, peerV6 := splitAllowedIPs(peerConfig.GetAllowedIps(), e.wgInterface.Address().IPv6Net) - err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn, peerV4, peerV6) + peerV4, peerV6 := overlayAddrsFromAllowedIPs(peerConfig.GetAllowedIps(), e.wgInterface.Address().IPv6Net) + err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn, addrToString(peerV4), addrToString(peerV6)) if err != nil { log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err) } @@ -2355,8 +2367,7 @@ func getInterfacePrefixes() ([]netip.Prefix, error) { prefix := netip.PrefixFrom(addr.Unmap(), ones).Masked() ip := prefix.Addr() - // TODO: add IPv6 - if !ip.Is4() || ip.IsLoopback() || ip.IsMulticast() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { + if ip.IsLoopback() || ip.IsMulticast() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { continue } diff --git a/client/internal/engine_ssh.go b/client/internal/engine_ssh.go index 9ef70bf6e..53d2c1122 100644 --- a/client/internal/engine_ssh.go +++ b/client/internal/engine_ssh.go @@ -145,13 +145,13 @@ func (e *Engine) extractPeerSSHInfo(remotePeers []*mgmProto.RemotePeerConfig) [] continue } - peerIP, peerIPv6 := e.extractPeerIPs(peerConfig) + peerV4, peerV6 := overlayAddrsFromAllowedIPs(peerConfig.GetAllowedIps(), e.wgInterface.Address().IPv6Net) hostname := e.extractHostname(peerConfig) peerInfo = append(peerInfo, sshconfig.PeerSSHInfo{ Hostname: hostname, - IP: peerIP, - IPv6: peerIPv6, + IP: peerV4, + IPv6: peerV6, FQDN: peerConfig.GetFqdn(), }) } @@ -159,28 +159,6 @@ func (e *Engine) extractPeerSSHInfo(remotePeers []*mgmProto.RemotePeerConfig) [] return peerInfo } -// extractPeerIPs extracts IPv4 and IPv6 overlay addresses from peer's allowed IPs. -// Only considers host routes (/32, /128) within the overlay networks to avoid -// picking up routed prefixes or static routes like 2620:fe::fe/128. -func (e *Engine) extractPeerIPs(peerConfig *mgmProto.RemotePeerConfig) (v4, v6 netip.Addr) { - wgAddr := e.wgInterface.Address() - for _, allowedIP := range peerConfig.GetAllowedIps() { - prefix, err := netip.ParsePrefix(allowedIP) - if err != nil { - log.Warnf("failed to parse AllowedIP %q: %v", allowedIP, err) - continue - } - addr := prefix.Addr().Unmap() - switch { - case addr.Is4() && prefix.Bits() == 32 && wgAddr.Network.Contains(addr) && !v4.IsValid(): - v4 = addr - case addr.Is6() && prefix.Bits() == 128 && wgAddr.IPv6Net.IsValid() && wgAddr.IPv6Net.Contains(addr) && !v6.IsValid(): - v6 = addr - } - } - return v4, v6 -} - // extractHostname extracts short hostname from FQDN func (e *Engine) extractHostname(peerConfig *mgmProto.RemotePeerConfig) string { fqdn := peerConfig.GetFqdn() diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 010ad3b77..bf1bf6c89 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -1837,7 +1837,7 @@ func TestFilterAllowedIPs(t *testing.T) { } } -func TestSplitAllowedIPs(t *testing.T) { +func TestOverlayAddrsFromAllowedIPs(t *testing.T) { ourV6Net := netip.MustParsePrefix("fd00:1234:5678:abcd::/64") tests := []struct { @@ -1900,9 +1900,17 @@ func TestSplitAllowedIPs(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - v4, v6 := splitAllowedIPs(tt.allowedIPs, tt.ourV6Net) - assert.Equal(t, tt.wantV4, v4, "v4") - assert.Equal(t, tt.wantV6, v6, "v6") + v4, v6 := overlayAddrsFromAllowedIPs(tt.allowedIPs, tt.ourV6Net) + if tt.wantV4 == "" { + assert.False(t, v4.IsValid(), "expected no v4") + } else { + assert.Equal(t, tt.wantV4, v4.String(), "v4") + } + if tt.wantV6 == "" { + assert.False(t, v6.IsValid(), "expected no v6") + } else { + assert.Equal(t, tt.wantV6, v6.String(), "v6") + } }) } } diff --git a/client/internal/lazyconn/activity/listener_bind.go b/client/internal/lazyconn/activity/listener_bind.go index 792d04215..60b8baadb 100644 --- a/client/internal/lazyconn/activity/listener_bind.go +++ b/client/internal/lazyconn/activity/listener_bind.go @@ -57,6 +57,7 @@ func NewBindListener(wgIface WgInterface, bind device.EndpointManager, cfg lazyc // deriveFakeIP creates a deterministic fake IP for bind mode based on peer's NetBird IP. // Maps peer IP 100.64.x.y to fake IP 127.2.x.y (similar to relay proxy using 127.1.x.y). // It finds the peer's actual NetBird IP by checking which allowedIP is in the same subnet as our WG interface. +// For IPv6-only peers, the last two bytes of the v6 address are used. func deriveFakeIP(wgIface WgInterface, allowedIPs []netip.Prefix) (netip.Addr, error) { if len(allowedIPs) == 0 { return netip.Addr{}, fmt.Errorf("no allowed IPs for peer") @@ -64,6 +65,7 @@ func deriveFakeIP(wgIface WgInterface, allowedIPs []netip.Prefix) (netip.Addr, e ourNetwork := wgIface.Address().Network + // Try v4 first (preferred: deterministic from overlay IP) var peerIP netip.Addr for _, allowedIP := range allowedIPs { ip := allowedIP.Addr() @@ -76,13 +78,24 @@ func deriveFakeIP(wgIface WgInterface, allowedIPs []netip.Prefix) (netip.Addr, e } } - if !peerIP.IsValid() { - return netip.Addr{}, fmt.Errorf("no peer NetBird IP found in allowed IPs") + if peerIP.IsValid() { + octets := peerIP.As4() + return netip.AddrFrom4([4]byte{127, 2, octets[2], octets[3]}), nil } - octets := peerIP.As4() - fakeIP := netip.AddrFrom4([4]byte{127, 2, octets[2], octets[3]}) - return fakeIP, nil + // Fallback: use last two bytes of first v6 overlay IP + addr := wgIface.Address() + if addr.IPv6Net.IsValid() { + for _, allowedIP := range allowedIPs { + ip := allowedIP.Addr() + if ip.Is6() && addr.IPv6Net.Contains(ip) { + raw := ip.As16() + return netip.AddrFrom4([4]byte{127, 2, raw[14], raw[15]}), nil + } + } + } + + return netip.Addr{}, fmt.Errorf("no peer NetBird IP found in allowed IPs") } func (d *BindListener) setupLazyConn() error { diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index fbf95de21..f4db95c8a 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -1055,7 +1055,11 @@ func (d *Status) notifyPeerListChanged() { } func (d *Status) notifyAddressChanged() { - d.notifier.localAddressChanged(d.localPeer.FQDN, d.localPeer.IP) + addr := d.localPeer.IP + if d.localPeer.IPv6 != "" { + addr = addr + "\n" + d.localPeer.IPv6 + } + d.notifier.localAddressChanged(d.localPeer.FQDN, addr) } func (d *Status) numOfPeers() int { diff --git a/client/internal/routemanager/client/client.go b/client/internal/routemanager/client/client.go index e6ef8b876..c691c54f8 100644 --- a/client/internal/routemanager/client/client.go +++ b/client/internal/routemanager/client/client.go @@ -3,9 +3,8 @@ package client import ( "context" "fmt" - "net" + "net/netip" "reflect" - "strconv" "time" log "github.com/sirupsen/logrus" @@ -566,7 +565,7 @@ func HandlerFromRoute(params common.HandlerParams) RouteHandler { return dnsinterceptor.New(params) case handlerTypeDynamic: dns := nbdns.NewServiceViaMemory(params.WgInterface) - dnsAddr := net.JoinHostPort(dns.RuntimeIP().String(), strconv.Itoa(dns.RuntimePort())) + dnsAddr := netip.AddrPortFrom(dns.RuntimeIP(), uint16(dns.RuntimePort())) return dynamic.NewRoute(params, dnsAddr) default: return static.NewRoute(params) diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go index 64f2a8789..e25cc2a5c 100644 --- a/client/internal/routemanager/dnsinterceptor/handler.go +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -582,7 +582,7 @@ func (d *DnsInterceptor) queryUpstreamDNS(ctx context.Context, w dns.ResponseWri if nsNet != nil { reply, err = nbdns.ExchangeWithNetstack(ctx, nsNet, r, upstream) } else { - client, clientErr := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), dnsTimeout) + client, clientErr := nbdns.GetClientPrivate(d.wgInterface, upstreamIP, dnsTimeout) if clientErr != nil { d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", clientErr)) return nil diff --git a/client/internal/routemanager/dynamic/route.go b/client/internal/routemanager/dynamic/route.go index 8d1398a7a..f0efd7b22 100644 --- a/client/internal/routemanager/dynamic/route.go +++ b/client/internal/routemanager/dynamic/route.go @@ -50,10 +50,10 @@ type Route struct { cancel context.CancelFunc statusRecorder *peer.Status wgInterface iface.WGIface - resolverAddr string + resolverAddr netip.AddrPort } -func NewRoute(params common.HandlerParams, resolverAddr string) *Route { +func NewRoute(params common.HandlerParams, resolverAddr netip.AddrPort) *Route { return &Route{ route: params.Route, routeRefCounter: params.RouteRefCounter, diff --git a/client/internal/routemanager/dynamic/route_ios.go b/client/internal/routemanager/dynamic/route_ios.go index 8fed1c8f9..1ae281d56 100644 --- a/client/internal/routemanager/dynamic/route_ios.go +++ b/client/internal/routemanager/dynamic/route_ios.go @@ -17,37 +17,47 @@ import ( const dialTimeout = 10 * time.Second func (r *Route) getIPsFromResolver(domain domain.Domain) ([]net.IP, error) { - privateClient, err := nbdns.GetClientPrivate(r.wgInterface.Address().IP, r.wgInterface.Name(), dialTimeout) + privateClient, err := nbdns.GetClientPrivate(r.wgInterface, r.resolverAddr.Addr(), dialTimeout) if err != nil { return nil, fmt.Errorf("error while creating private client: %s", err) } - msg := new(dns.Msg) - msg.SetQuestion(dns.Fqdn(domain.PunycodeString()), dns.TypeA) - + fqdn := dns.Fqdn(domain.PunycodeString()) startTime := time.Now() - response, _, err := nbdns.ExchangeWithFallback(nil, privateClient, msg, r.resolverAddr) - if err != nil { - return nil, fmt.Errorf("DNS query for %s failed after %s: %s ", domain.SafeString(), time.Since(startTime), err) - } + var ips []net.IP + var queryErr error - if response.Rcode != dns.RcodeSuccess { - return nil, fmt.Errorf("dns response code: %s", dns.RcodeToString[response.Rcode]) - } + for _, qtype := range []uint16{dns.TypeA, dns.TypeAAAA} { + msg := new(dns.Msg) + msg.SetQuestion(fqdn, qtype) - ips := make([]net.IP, 0) - - for _, answ := range response.Answer { - if aRecord, ok := answ.(*dns.A); ok { - ips = append(ips, aRecord.A) + response, _, err := nbdns.ExchangeWithFallback(nil, privateClient, msg, r.resolverAddr.String()) + if err != nil { + if queryErr == nil { + queryErr = fmt.Errorf("DNS query for %s (type %d) after %s: %w", domain.SafeString(), qtype, time.Since(startTime), err) + } + continue } - if aaaaRecord, ok := answ.(*dns.AAAA); ok { - ips = append(ips, aaaaRecord.AAAA) + + if response.Rcode != dns.RcodeSuccess { + continue + } + + for _, answ := range response.Answer { + if aRecord, ok := answ.(*dns.A); ok { + ips = append(ips, aRecord.A) + } + if aaaaRecord, ok := answ.(*dns.AAAA); ok { + ips = append(ips, aaaaRecord.AAAA) + } } } if len(ips) == 0 { + if queryErr != nil { + return nil, queryErr + } return nil, fmt.Errorf("no A or AAAA records found for %s", domain.SafeString()) } diff --git a/client/internal/routemanager/fakeip/fakeip.go b/client/internal/routemanager/fakeip/fakeip.go index 1592045d2..5be4ca12e 100644 --- a/client/internal/routemanager/fakeip/fakeip.go +++ b/client/internal/routemanager/fakeip/fakeip.go @@ -1,93 +1,145 @@ package fakeip import ( + "errors" "fmt" "net/netip" "sync" ) -// Manager manages allocation of fake IPs from the 240.0.0.0/8 block -type Manager struct { - mu sync.Mutex - nextIP netip.Addr // Next IP to allocate +var ( + // 240.0.0.1 - 240.255.255.254, block 240.0.0.0/8 (reserved, RFC 1112) + v4Base = netip.AddrFrom4([4]byte{240, 0, 0, 1}) + v4Max = netip.AddrFrom4([4]byte{240, 255, 255, 254}) + v4Block = netip.PrefixFrom(netip.AddrFrom4([4]byte{240, 0, 0, 0}), 8) + + // 0100::1 - 0100::ffff:ffff:ffff:fffe, block 0100::/64 (discard, RFC 6666) + v6Base = netip.AddrFrom16([16]byte{0x01, 0x00, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01}) + v6Max = netip.AddrFrom16([16]byte{0x01, 0x00, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe}) + v6Block = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x01, 0x00}), 64) +) + +// fakeIPPool holds the allocation state for a single address family. +type fakeIPPool struct { + nextIP netip.Addr + baseIP netip.Addr + maxIP netip.Addr + block netip.Prefix allocated map[netip.Addr]netip.Addr // real IP -> fake IP fakeToReal map[netip.Addr]netip.Addr // fake IP -> real IP - baseIP netip.Addr // First usable IP: 240.0.0.1 - maxIP netip.Addr // Last usable IP: 240.255.255.254 } -// NewManager creates a new fake IP manager using 240.0.0.0/8 block -func NewManager() *Manager { - baseIP := netip.AddrFrom4([4]byte{240, 0, 0, 1}) - maxIP := netip.AddrFrom4([4]byte{240, 255, 255, 254}) - - return &Manager{ - nextIP: baseIP, +func newPool(base, maxAddr netip.Addr, block netip.Prefix) *fakeIPPool { + return &fakeIPPool{ + nextIP: base, + baseIP: base, + maxIP: maxAddr, + block: block, allocated: make(map[netip.Addr]netip.Addr), fakeToReal: make(map[netip.Addr]netip.Addr), - baseIP: baseIP, - maxIP: maxIP, } } -// AllocateFakeIP allocates a fake IP for the given real IP -// Returns the fake IP, or existing fake IP if already allocated -func (m *Manager) AllocateFakeIP(realIP netip.Addr) (netip.Addr, error) { - if !realIP.Is4() { - return netip.Addr{}, fmt.Errorf("only IPv4 addresses supported") - } - - m.mu.Lock() - defer m.mu.Unlock() - - if fakeIP, exists := m.allocated[realIP]; exists { +// allocate allocates a fake IP for the given real IP. +// Returns the existing fake IP if already allocated. +func (p *fakeIPPool) allocate(realIP netip.Addr) (netip.Addr, error) { + if fakeIP, exists := p.allocated[realIP]; exists { return fakeIP, nil } - startIP := m.nextIP + startIP := p.nextIP for { - currentIP := m.nextIP + currentIP := p.nextIP // Advance to next IP, wrapping at boundary - if m.nextIP.Compare(m.maxIP) >= 0 { - m.nextIP = m.baseIP + if p.nextIP.Compare(p.maxIP) >= 0 { + p.nextIP = p.baseIP } else { - m.nextIP = m.nextIP.Next() + p.nextIP = p.nextIP.Next() } - // Check if current IP is available - if _, inUse := m.fakeToReal[currentIP]; !inUse { - m.allocated[realIP] = currentIP - m.fakeToReal[currentIP] = realIP + if _, inUse := p.fakeToReal[currentIP]; !inUse { + p.allocated[realIP] = currentIP + p.fakeToReal[currentIP] = realIP return currentIP, nil } - // Prevent infinite loop if all IPs exhausted - if m.nextIP.Compare(startIP) == 0 { - return netip.Addr{}, fmt.Errorf("no more fake IPs available in 240.0.0.0/8 block") + if p.nextIP.Compare(startIP) == 0 { + return netip.Addr{}, fmt.Errorf("no more fake IPs available in %s block", p.block) } } } -// GetFakeIP returns the fake IP for a real IP if it exists +// Manager manages allocation of fake IPs for dynamic DNS routes. +// IPv4 uses 240.0.0.0/8 (reserved), IPv6 uses 0100::/64 (discard, RFC 6666). +type Manager struct { + mu sync.Mutex + v4 *fakeIPPool + v6 *fakeIPPool +} + +// NewManager creates a new fake IP manager. +func NewManager() *Manager { + return &Manager{ + v4: newPool(v4Base, v4Max, v4Block), + v6: newPool(v6Base, v6Max, v6Block), + } +} + +func (m *Manager) pool(ip netip.Addr) *fakeIPPool { + if ip.Is6() { + return m.v6 + } + return m.v4 +} + +// AllocateFakeIP allocates a fake IP for the given real IP. +func (m *Manager) AllocateFakeIP(realIP netip.Addr) (netip.Addr, error) { + realIP = realIP.Unmap() + if !realIP.IsValid() { + return netip.Addr{}, errors.New("invalid IP address") + } + + m.mu.Lock() + defer m.mu.Unlock() + + return m.pool(realIP).allocate(realIP) +} + +// GetFakeIP returns the fake IP for a real IP if it exists. func (m *Manager) GetFakeIP(realIP netip.Addr) (netip.Addr, bool) { + realIP = realIP.Unmap() + if !realIP.IsValid() { + return netip.Addr{}, false + } + m.mu.Lock() defer m.mu.Unlock() - fakeIP, exists := m.allocated[realIP] - return fakeIP, exists + fakeIP, ok := m.pool(realIP).allocated[realIP] + return fakeIP, ok } -// GetRealIP returns the real IP for a fake IP if it exists, otherwise false +// GetRealIP returns the real IP for a fake IP if it exists. func (m *Manager) GetRealIP(fakeIP netip.Addr) (netip.Addr, bool) { + fakeIP = fakeIP.Unmap() + if !fakeIP.IsValid() { + return netip.Addr{}, false + } + m.mu.Lock() defer m.mu.Unlock() - realIP, exists := m.fakeToReal[fakeIP] - return realIP, exists + realIP, ok := m.pool(fakeIP).fakeToReal[fakeIP] + return realIP, ok } -// GetFakeIPBlock returns the fake IP block used by this manager +// GetFakeIPBlock returns the v4 fake IP block used by this manager. func (m *Manager) GetFakeIPBlock() netip.Prefix { - return netip.MustParsePrefix("240.0.0.0/8") + return m.v4.block +} + +// GetFakeIPv6Block returns the v6 fake IP block used by this manager. +func (m *Manager) GetFakeIPv6Block() netip.Prefix { + return m.v6.block } diff --git a/client/internal/routemanager/fakeip/fakeip_test.go b/client/internal/routemanager/fakeip/fakeip_test.go index ad3e4bd4e..f554f970d 100644 --- a/client/internal/routemanager/fakeip/fakeip_test.go +++ b/client/internal/routemanager/fakeip/fakeip_test.go @@ -9,16 +9,16 @@ import ( func TestNewManager(t *testing.T) { manager := NewManager() - if manager.baseIP.String() != "240.0.0.1" { - t.Errorf("Expected base IP 240.0.0.1, got %s", manager.baseIP.String()) + if manager.v4.baseIP.String() != "240.0.0.1" { + t.Errorf("Expected v4 base IP 240.0.0.1, got %s", manager.v4.baseIP.String()) } - if manager.maxIP.String() != "240.255.255.254" { - t.Errorf("Expected max IP 240.255.255.254, got %s", manager.maxIP.String()) + if manager.v4.maxIP.String() != "240.255.255.254" { + t.Errorf("Expected v4 max IP 240.255.255.254, got %s", manager.v4.maxIP.String()) } - if manager.nextIP.Compare(manager.baseIP) != 0 { - t.Errorf("Expected nextIP to start at baseIP") + if manager.v6.baseIP.String() != "100::1" { + t.Errorf("Expected v6 base IP 100::1, got %s", manager.v6.baseIP.String()) } } @@ -35,7 +35,6 @@ func TestAllocateFakeIP(t *testing.T) { t.Error("Fake IP should be IPv4") } - // Check it's in the correct range if fakeIP.As4()[0] != 240 { t.Errorf("Fake IP should be in 240.0.0.0/8 range, got %s", fakeIP.String()) } @@ -51,13 +50,31 @@ func TestAllocateFakeIP(t *testing.T) { } } -func TestAllocateFakeIPIPv6Rejection(t *testing.T) { +func TestAllocateFakeIPv6(t *testing.T) { manager := NewManager() - realIPv6 := netip.MustParseAddr("2001:db8::1") + realIP := netip.MustParseAddr("2001:db8::1") - _, err := manager.AllocateFakeIP(realIPv6) - if err == nil { - t.Error("Expected error for IPv6 address") + fakeIP, err := manager.AllocateFakeIP(realIP) + if err != nil { + t.Fatalf("Failed to allocate fake IPv6: %v", err) + } + + if !fakeIP.Is6() { + t.Error("Fake IP should be IPv6") + } + + if !netip.MustParsePrefix("100::/64").Contains(fakeIP) { + t.Errorf("Fake IP should be in 100::/64 range, got %s", fakeIP.String()) + } + + // Should return same fake IP for same real IP + fakeIP2, err := manager.AllocateFakeIP(realIP) + if err != nil { + t.Fatalf("Failed to get existing fake IPv6: %v", err) + } + + if fakeIP.Compare(fakeIP2) != 0 { + t.Errorf("Expected same fake IP, got %s and %s", fakeIP.String(), fakeIP2.String()) } } @@ -65,13 +82,11 @@ func TestGetFakeIP(t *testing.T) { manager := NewManager() realIP := netip.MustParseAddr("1.1.1.1") - // Should not exist initially _, exists := manager.GetFakeIP(realIP) if exists { t.Error("Fake IP should not exist before allocation") } - // Allocate and check expectedFakeIP, err := manager.AllocateFakeIP(realIP) if err != nil { t.Fatalf("Failed to allocate: %v", err) @@ -87,12 +102,30 @@ func TestGetFakeIP(t *testing.T) { } } +func TestGetRealIPv6(t *testing.T) { + manager := NewManager() + realIP := netip.MustParseAddr("2001:db8::1") + + fakeIP, err := manager.AllocateFakeIP(realIP) + if err != nil { + t.Fatalf("Failed to allocate: %v", err) + } + + gotReal, exists := manager.GetRealIP(fakeIP) + if !exists { + t.Error("Real IP should exist for allocated fake IP") + } + + if gotReal.Compare(realIP) != 0 { + t.Errorf("Expected real IP %s, got %s", realIP, gotReal) + } +} + func TestMultipleAllocations(t *testing.T) { manager := NewManager() allocations := make(map[netip.Addr]netip.Addr) - // Allocate multiple IPs for i := 1; i <= 100; i++ { realIP := netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)}) fakeIP, err := manager.AllocateFakeIP(realIP) @@ -100,7 +133,6 @@ func TestMultipleAllocations(t *testing.T) { t.Fatalf("Failed to allocate fake IP for %s: %v", realIP.String(), err) } - // Check for duplicates for _, existingFake := range allocations { if fakeIP.Compare(existingFake) == 0 { t.Errorf("Duplicate fake IP allocated: %s", fakeIP.String()) @@ -110,7 +142,6 @@ func TestMultipleAllocations(t *testing.T) { allocations[realIP] = fakeIP } - // Verify all allocations can be retrieved for realIP, expectedFake := range allocations { actualFake, exists := manager.GetFakeIP(realIP) if !exists { @@ -124,11 +155,13 @@ func TestMultipleAllocations(t *testing.T) { func TestGetFakeIPBlock(t *testing.T) { manager := NewManager() - block := manager.GetFakeIPBlock() - expected := "240.0.0.0/8" - if block.String() != expected { - t.Errorf("Expected %s, got %s", expected, block.String()) + if block := manager.GetFakeIPBlock(); block.String() != "240.0.0.0/8" { + t.Errorf("Expected 240.0.0.0/8, got %s", block.String()) + } + + if block := manager.GetFakeIPv6Block(); block.String() != "100::/64" { + t.Errorf("Expected 100::/64, got %s", block.String()) } } @@ -141,7 +174,6 @@ func TestConcurrentAccess(t *testing.T) { var wg sync.WaitGroup results := make(chan netip.Addr, numGoroutines*allocationsPerGoroutine) - // Concurrent allocations for i := 0; i < numGoroutines; i++ { wg.Add(1) go func(goroutineID int) { @@ -161,7 +193,6 @@ func TestConcurrentAccess(t *testing.T) { wg.Wait() close(results) - // Check for duplicates seen := make(map[netip.Addr]bool) count := 0 for fakeIP := range results { @@ -178,47 +209,61 @@ func TestConcurrentAccess(t *testing.T) { } func TestIPExhaustion(t *testing.T) { - // Create a manager with limited range for testing manager := &Manager{ - nextIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}), - allocated: make(map[netip.Addr]netip.Addr), - fakeToReal: make(map[netip.Addr]netip.Addr), - baseIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}), - maxIP: netip.AddrFrom4([4]byte{240, 0, 0, 3}), // Only 3 IPs available + v4: newPool( + netip.AddrFrom4([4]byte{240, 0, 0, 1}), + netip.AddrFrom4([4]byte{240, 0, 0, 3}), + netip.MustParsePrefix("240.0.0.0/8"), + ), + v6: newPool( + netip.MustParseAddr("100::1"), + netip.MustParseAddr("100::3"), + netip.MustParsePrefix("100::/64"), + ), } - // Allocate all available IPs - realIPs := []netip.Addr{ - netip.MustParseAddr("1.0.0.1"), - netip.MustParseAddr("1.0.0.2"), - netip.MustParseAddr("1.0.0.3"), - } - - for _, realIP := range realIPs { - _, err := manager.AllocateFakeIP(realIP) + for _, realIP := range []string{"1.0.0.1", "1.0.0.2", "1.0.0.3"} { + _, err := manager.AllocateFakeIP(netip.MustParseAddr(realIP)) if err != nil { t.Fatalf("Failed to allocate fake IP: %v", err) } } - // Try to allocate one more - should fail _, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.4")) if err == nil { - t.Error("Expected exhaustion error") + t.Error("Expected v4 exhaustion error") + } + + // Same for v6 + for _, realIP := range []string{"2001:db8::1", "2001:db8::2", "2001:db8::3"} { + _, err := manager.AllocateFakeIP(netip.MustParseAddr(realIP)) + if err != nil { + t.Fatalf("Failed to allocate fake IPv6: %v", err) + } + } + + _, err = manager.AllocateFakeIP(netip.MustParseAddr("2001:db8::4")) + if err == nil { + t.Error("Expected v6 exhaustion error") } } func TestWrapAround(t *testing.T) { - // Create manager starting near the end of range manager := &Manager{ - nextIP: netip.AddrFrom4([4]byte{240, 0, 0, 254}), - allocated: make(map[netip.Addr]netip.Addr), - fakeToReal: make(map[netip.Addr]netip.Addr), - baseIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}), - maxIP: netip.AddrFrom4([4]byte{240, 0, 0, 254}), + v4: newPool( + netip.AddrFrom4([4]byte{240, 0, 0, 1}), + netip.AddrFrom4([4]byte{240, 0, 0, 254}), + netip.MustParsePrefix("240.0.0.0/8"), + ), + v6: newPool( + netip.MustParseAddr("100::1"), + netip.MustParseAddr("100::ffff:ffff:ffff:fffe"), + netip.MustParsePrefix("100::/64"), + ), } + // Start near the end + manager.v4.nextIP = netip.AddrFrom4([4]byte{240, 0, 0, 254}) - // Allocate the last IP fakeIP1, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.1")) if err != nil { t.Fatalf("Failed to allocate first IP: %v", err) @@ -228,7 +273,6 @@ func TestWrapAround(t *testing.T) { t.Errorf("Expected 240.0.0.254, got %s", fakeIP1.String()) } - // Next allocation should wrap around to the beginning fakeIP2, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.2")) if err != nil { t.Fatalf("Failed to allocate second IP: %v", err) @@ -238,3 +282,32 @@ func TestWrapAround(t *testing.T) { t.Errorf("Expected 240.0.0.1 after wrap, got %s", fakeIP2.String()) } } + +func TestMixedV4V6(t *testing.T) { + manager := NewManager() + + v4Fake, err := manager.AllocateFakeIP(netip.MustParseAddr("8.8.8.8")) + if err != nil { + t.Fatalf("Failed to allocate v4: %v", err) + } + + v6Fake, err := manager.AllocateFakeIP(netip.MustParseAddr("2001:db8::1")) + if err != nil { + t.Fatalf("Failed to allocate v6: %v", err) + } + + if !v4Fake.Is4() || !v6Fake.Is6() { + t.Errorf("Wrong families: v4=%s v6=%s", v4Fake, v6Fake) + } + + // Reverse lookups should work for both + gotV4, ok := manager.GetRealIP(v4Fake) + if !ok || gotV4.String() != "8.8.8.8" { + t.Errorf("v4 reverse lookup failed: got %s, ok=%v", gotV4, ok) + } + + gotV6, ok := manager.GetRealIP(v6Fake) + if !ok || gotV6.String() != "2001:db8::1" { + t.Errorf("v6 reverse lookup failed: got %s, ok=%v", gotV6, ok) + } +} diff --git a/client/internal/routemanager/ipfwdstate/ipfwdstate.go b/client/internal/routemanager/ipfwdstate/ipfwdstate.go index da81c18f9..2be1c2ae7 100644 --- a/client/internal/routemanager/ipfwdstate/ipfwdstate.go +++ b/client/internal/routemanager/ipfwdstate/ipfwdstate.go @@ -9,7 +9,11 @@ import ( ) // IPForwardingState is a struct that keeps track of the IP forwarding state. -// todo: read initial state of the IP forwarding from the system and reset the state based on it +// todo: read initial state of the IP forwarding from the system and reset the state based on it. +// todo: separate v4/v6 forwarding state, since the sysctls are independent +// (net.ipv4.ip_forward vs net.ipv6.conf.all.forwarding). Currently the nftables +// manager shares one instance between both routers, which works only because +// EnableIPForwarding enables both sysctls in a single call. type IPForwardingState struct { enabledCounter int } diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index e7ca44239..bf89296d3 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -159,15 +159,23 @@ func (m *DefaultManager) setupAndroidRoutes(config ManagerConfig) { if config.DNSFeatureFlag { m.fakeIPManager = fakeip.NewManager() - id := uuid.NewString() - fakeIPRoute := &route.Route{ - ID: route.ID(id), + v4ID := uuid.NewString() + cr = append(cr, &route.Route{ + ID: route.ID(v4ID), Network: m.fakeIPManager.GetFakeIPBlock(), - NetID: route.NetID(id), + NetID: route.NetID(v4ID), Peer: m.pubKey, NetworkType: route.IPv4Network, - } - cr = append(cr, fakeIPRoute) + }) + + v6ID := uuid.NewString() + cr = append(cr, &route.Route{ + ID: route.ID(v6ID), + Network: m.fakeIPManager.GetFakeIPv6Block(), + NetID: route.NetID(v6ID), + Peer: m.pubKey, + NetworkType: route.IPv6Network, + }) } m.notifier.SetInitialClientRoutes(cr, routesForComparison) diff --git a/client/internal/routemanager/server/server.go b/client/internal/routemanager/server/server.go index e674c80cd..d35b44f5b 100644 --- a/client/internal/routemanager/server/server.go +++ b/client/internal/routemanager/server/server.go @@ -146,8 +146,7 @@ func routeToRouterPair(route *route.Route, useNewDNSRoute bool) firewall.RouterP if useNewDNSRoute { destination.Set = firewall.NewDomainSet(route.Domains) } else { - // TODO: add ipv6 additionally - destination = getDefaultPrefix(destination.Prefix) + destination = getDefaultPrefix(route.Network) } } else { destination.Prefix = route.Network.Masked() diff --git a/client/internal/routemanager/systemops/systemops.go b/client/internal/routemanager/systemops/systemops.go index c0ca21d22..8724ed1ba 100644 --- a/client/internal/routemanager/systemops/systemops.go +++ b/client/internal/routemanager/systemops/systemops.go @@ -107,8 +107,13 @@ func (r *SysOps) validateRoute(prefix netip.Prefix) error { addr.IsInterfaceLocalMulticast(), addr.IsMulticast(), addr.IsUnspecified() && prefix.Bits() != 0, - r.wgInterface.Address().Network.Contains(addr): + r.isOwnAddress(addr): return vars.ErrRouteNotAllowed } return nil } + +func (r *SysOps) isOwnAddress(addr netip.Addr) bool { + wgAddr := r.wgInterface.Address() + return wgAddr.Network.Contains(addr) || (wgAddr.IPv6Net.IsValid() && wgAddr.IPv6Net.Contains(addr)) +} diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go index ec219c7fe..07bd2c118 100644 --- a/client/internal/routemanager/systemops/systemops_generic.go +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -222,30 +222,20 @@ func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) er return err } - // TODO: remove once IPv6 is supported on the interface - if err := r.addToRouteTable(splitDefaultv6_1, nextHop); err != nil { - return fmt.Errorf("add unreachable route split 1: %w", err) - } - if err := r.addToRouteTable(splitDefaultv6_2, nextHop); err != nil { - if err2 := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err2 != nil { - log.Warnf("Failed to rollback route addition: %s", err2) + // When the interface has no v6, add v6 split-default as blackhole so + // unroutable v6 goes to WG (dropped, no AllowedIPs) instead of leaking + // to the system default route. When v6 is active, management sends ::/0 + // as a separate route that the dedicated handler adds. + // Soft-fail: v6 blackhole is best-effort, don't abort v4 routing on failure. + if !r.wgInterface.Address().HasIPv6() { + if err := r.addV6SplitDefault(nextHop); err != nil { + log.Warnf("failed to add v6 split-default blackhole: %s", err) } - return fmt.Errorf("add unreachable route split 2: %w", err) } return nil case vars.Defaultv6: - if err := r.addToRouteTable(splitDefaultv6_1, nextHop); err != nil { - return fmt.Errorf("add unreachable route split 1: %w", err) - } - if err := r.addToRouteTable(splitDefaultv6_2, nextHop); err != nil { - if err2 := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err2 != nil { - log.Warnf("Failed to rollback route addition: %s", err2) - } - return fmt.Errorf("add unreachable route split 2: %w", err) - } - - return nil + return r.addV6SplitDefault(nextHop) } return r.addToRouteTable(prefix, nextHop) @@ -266,30 +256,42 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) result = multierror.Append(result, err) } - // TODO: remove once IPv6 is supported on the interface - if err := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err != nil { - result = multierror.Append(result, err) - } - if err := r.removeFromRouteTable(splitDefaultv6_2, nextHop); err != nil { - result = multierror.Append(result, err) + if !r.wgInterface.Address().HasIPv6() { + result = multierror.Append(result, r.removeV6SplitDefault(nextHop)) } return nberrors.FormatErrorOrNil(result) case vars.Defaultv6: - var result *multierror.Error - if err := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err != nil { - result = multierror.Append(result, err) - } - if err := r.removeFromRouteTable(splitDefaultv6_2, nextHop); err != nil { - result = multierror.Append(result, err) - } - - return nberrors.FormatErrorOrNil(result) + return nberrors.FormatErrorOrNil(r.removeV6SplitDefault(nextHop)) default: return r.removeFromRouteTable(prefix, nextHop) } } +func (r *SysOps) addV6SplitDefault(nextHop Nexthop) error { + if err := r.addToRouteTable(splitDefaultv6_1, nextHop); err != nil { + return fmt.Errorf("add split 1: %w", err) + } + if err := r.addToRouteTable(splitDefaultv6_2, nextHop); err != nil { + if err2 := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err2 != nil { + log.Warnf("Failed to rollback v6 split-default: %s", err2) + } + return fmt.Errorf("add split 2: %w", err) + } + return nil +} + +func (r *SysOps) removeV6SplitDefault(nextHop Nexthop) *multierror.Error { + var result *multierror.Error + if err := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err != nil { + result = multierror.Append(result, err) + } + if err := r.removeFromRouteTable(splitDefaultv6_2, nextHop); err != nil { + result = multierror.Append(result, err) + } + return result +} + func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) error { beforeHook := func(connID hooks.ConnectionID, prefix netip.Prefix) error { if _, err := r.refCounter.IncrementWithID(string(connID), prefix, struct{}{}); err != nil { diff --git a/client/internal/routemanager/systemops/systemops_linux.go b/client/internal/routemanager/systemops/systemops_linux.go index bd10f131f..55e45279c 100644 --- a/client/internal/routemanager/systemops/systemops_linux.go +++ b/client/internal/routemanager/systemops/systemops_linux.go @@ -53,6 +53,8 @@ const ( // ipv4ForwardingPath is the path to the file containing the IP forwarding setting. ipv4ForwardingPath = "net.ipv4.ip_forward" + // ipv6ForwardingPath is the path to the file containing the IPv6 forwarding setting. + ipv6ForwardingPath = "net.ipv6.conf.all.forwarding" ) var ErrTableIDExists = errors.New("ID exists with different name") @@ -185,10 +187,11 @@ func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error { // No need to check if routes exist as main table takes precedence over the VPN table via Rule 1 - // TODO remove this once we have ipv6 support - if prefix == vars.Defaultv4 { + // When the peer has no IPv6, blackhole v6 to prevent leaking. + // When IPv6 is enabled, management sends ::/0 as a separate route. + if prefix == vars.Defaultv4 && (r.wgInterface == nil || !r.wgInterface.Address().HasIPv6()) { if err := addUnreachableRoute(vars.Defaultv6, NetbirdVPNTableID); err != nil { - return fmt.Errorf("add blackhole: %w", err) + return fmt.Errorf("add v6 blackhole: %w", err) } } if err := addRoute(prefix, Nexthop{netip.Addr{}, intf}, NetbirdVPNTableID); err != nil { @@ -206,10 +209,9 @@ func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error return r.genericRemoveVPNRoute(prefix, intf) } - // TODO remove this once we have ipv6 support - if prefix == vars.Defaultv4 { + if prefix == vars.Defaultv4 && (r.wgInterface == nil || !r.wgInterface.Address().HasIPv6()) { if err := removeUnreachableRoute(vars.Defaultv6, NetbirdVPNTableID); err != nil { - return fmt.Errorf("remove unreachable route: %w", err) + log.Debugf("remove v6 blackhole: %v", err) } } if err := removeRoute(prefix, Nexthop{netip.Addr{}, intf}, NetbirdVPNTableID); err != nil { @@ -762,8 +764,13 @@ func flushRoutes(tableID, family int) error { } func EnableIPForwarding() error { - _, err := sysctl.Set(ipv4ForwardingPath, 1, false) - return err + if _, err := sysctl.Set(ipv4ForwardingPath, 1, false); err != nil { + return err + } + if _, err := sysctl.Set(ipv6ForwardingPath, 1, false); err != nil { + log.Warnf("failed to enable IPv6 forwarding: %v", err) + } + return nil } // entryExists checks if the specified ID or name already exists in the rt_tables file diff --git a/client/ios/NetBirdSDK/client.go b/client/ios/NetBirdSDK/client.go index 990e03034..c73a0dcd1 100644 --- a/client/ios/NetBirdSDK/client.go +++ b/client/ios/NetBirdSDK/client.go @@ -50,10 +50,11 @@ type CustomLogger interface { } type selectRoute struct { - NetID string - Network netip.Prefix - Domains domain.List - Selected bool + NetID string + Network netip.Prefix + Domains domain.List + Selected bool + extraNetworks []netip.Prefix } func init() { @@ -363,48 +364,60 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) { } routeManager := engine.GetRouteManager() - routesMap := routeManager.GetClientRoutesWithNetID() if routeManager == nil { return nil, fmt.Errorf("could not get route manager") } + routesMap := routeManager.GetClientRoutesWithNetID() routeSelector := routeManager.GetRouteSelector() if routeSelector == nil { return nil, fmt.Errorf("could not get route selector") } + v6ExitMerged := route.V6ExitMergeSet(routesMap) + routes := buildSelectRoutes(routesMap, routeSelector.IsSelected, v6ExitMerged) + resolvedDomains := c.recorder.GetResolvedDomainsStates() + + return prepareRouteSelectionDetails(routes, resolvedDomains), nil +} + +func buildSelectRoutes(routesMap map[route.NetID][]*route.Route, isSelected func(route.NetID) bool, v6Merged map[route.NetID]struct{}) []*selectRoute { var routes []*selectRoute for id, rt := range routesMap { if len(rt) == 0 { continue } - route := &selectRoute{ + if _, ok := v6Merged[id]; ok { + continue + } + + r := &selectRoute{ NetID: string(id), Network: rt[0].Network, Domains: rt[0].Domains, - Selected: routeSelector.IsSelected(id), + Selected: isSelected(id), } - routes = append(routes, route) + + v6ID := route.NetID(string(id) + route.V6ExitSuffix) + if _, ok := v6Merged[v6ID]; ok { + r.extraNetworks = []netip.Prefix{routesMap[v6ID][0].Network} + } + + routes = append(routes, r) } sort.Slice(routes, func(i, j int) bool { - iPrefix := routes[i].Network.Bits() - jPrefix := routes[j].Network.Bits() - - if iPrefix == jPrefix { - iAddr := routes[i].Network.Addr() - jAddr := routes[j].Network.Addr() - if iAddr == jAddr { - return routes[i].NetID < routes[j].NetID - } - return iAddr.String() < jAddr.String() + iBits, jBits := routes[i].Network.Bits(), routes[j].Network.Bits() + if iBits != jBits { + return iBits < jBits } - return iPrefix < jPrefix + iAddr, jAddr := routes[i].Network.Addr(), routes[j].Network.Addr() + if iAddr != jAddr { + return iAddr.Less(jAddr) + } + return routes[i].NetID < routes[j].NetID }) - resolvedDomains := c.recorder.GetResolvedDomainsStates() - - return prepareRouteSelectionDetails(routes, resolvedDomains), nil - + return routes } func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[domain.Domain]peer.ResolvedDomainInfo) *RoutesSelectionDetails { @@ -425,10 +438,15 @@ func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[dom } domainList = append(domainList, domainResp) } + rangeStr := r.Network.String() + for _, extra := range r.extraNetworks { + rangeStr += ", " + extra.String() + } + domainDetails := DomainDetails{items: domainList} routeSelection = append(routeSelection, RoutesSelectionInfo{ ID: r.NetID, - Network: r.Network.String(), + Network: rangeStr, Domains: &domainDetails, Selected: r.Selected, }) @@ -456,7 +474,9 @@ func (c *Client) SelectRoute(id string) error { } else { log.Debugf("select route with id: %s", id) routes := toNetIDs([]string{id}) - if err := routeSelector.SelectRoutes(routes, true, maps.Keys(routeManager.GetClientRoutesWithNetID())); err != nil { + routesMap := routeManager.GetClientRoutesWithNetID() + routes = route.ExpandV6ExitPairs(routes, routesMap) + if err := routeSelector.SelectRoutes(routes, true, maps.Keys(routesMap)); err != nil { log.Debugf("error when selecting routes: %s", err) return fmt.Errorf("select routes: %w", err) } @@ -483,7 +503,9 @@ func (c *Client) DeselectRoute(id string) error { } else { log.Debugf("deselect route with id: %s", id) routes := toNetIDs([]string{id}) - if err := routeSelector.DeselectRoutes(routes, maps.Keys(routeManager.GetClientRoutesWithNetID())); err != nil { + routesMap := routeManager.GetClientRoutesWithNetID() + routes = route.ExpandV6ExitPairs(routes, routesMap) + if err := routeSelector.DeselectRoutes(routes, maps.Keys(routesMap)); err != nil { log.Debugf("error when deselecting routes: %s", err) return fmt.Errorf("deselect routes: %w", err) } diff --git a/client/server/network.go b/client/server/network.go index bb1cce56c..4a439d8cf 100644 --- a/client/server/network.go +++ b/client/server/network.go @@ -16,10 +16,11 @@ import ( ) type selectRoute struct { - NetID route.NetID - Network netip.Prefix - Domains domain.List - Selected bool + NetID route.NetID + Network netip.Prefix + Domains domain.List + Selected bool + extraNetworks []netip.Prefix } // ListNetworks returns a list of all available networks. @@ -44,18 +45,33 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro routesMap := routeMgr.GetClientRoutesWithNetID() routeSelector := routeMgr.GetRouteSelector() + v6ExitMerged := route.V6ExitMergeSet(routesMap) + var routes []*selectRoute for id, rt := range routesMap { if len(rt) == 0 { continue } - route := &selectRoute{ + // Skip v6 exit nodes that are merged into their v4 counterpart. + if _, ok := v6ExitMerged[id]; ok { + continue + } + + r := &selectRoute{ NetID: id, Network: rt[0].Network, Domains: rt[0].Domains, Selected: routeSelector.IsSelected(id), } - routes = append(routes, route) + + // Merge paired v6 exit node prefix into this entry. + v6ID := route.NetID(string(id) + route.V6ExitSuffix) + if _, ok := v6ExitMerged[v6ID]; ok { + v6Prefix := routesMap[v6ID][0].Network + r.extraNetworks = []netip.Prefix{v6Prefix} + } + + routes = append(routes, r) } sort.Slice(routes, func(i, j int) bool { @@ -76,9 +92,13 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro resolvedDomains := s.statusRecorder.GetResolvedDomainsStates() var pbRoutes []*proto.Network for _, route := range routes { + rangeStr := route.Network.String() + for _, extra := range route.extraNetworks { + rangeStr += ", " + extra.String() + } pbRoute := &proto.Network{ ID: string(route.NetID), - Range: route.Network.String(), + Range: rangeStr, Domains: route.Domains.ToSafeStringList(), ResolvedIPs: map[string]*proto.IPList{}, Selected: route.Selected, @@ -137,7 +157,9 @@ func (s *Server) SelectNetworks(_ context.Context, req *proto.SelectNetworksRequ routeSelector.SelectAllRoutes() } else { routes := toNetIDs(req.GetNetworkIDs()) - netIdRoutes := maps.Keys(routeManager.GetClientRoutesWithNetID()) + routesMap := routeManager.GetClientRoutesWithNetID() + routes = route.ExpandV6ExitPairs(routes, routesMap) + netIdRoutes := maps.Keys(routesMap) if err := routeSelector.SelectRoutes(routes, req.GetAppend(), netIdRoutes); err != nil { return nil, fmt.Errorf("select routes: %w", err) } @@ -183,7 +205,9 @@ func (s *Server) DeselectNetworks(_ context.Context, req *proto.SelectNetworksRe routeSelector.DeselectAllRoutes() } else { routes := toNetIDs(req.GetNetworkIDs()) - netIdRoutes := maps.Keys(routeManager.GetClientRoutesWithNetID()) + routesMap := routeManager.GetClientRoutesWithNetID() + routes = route.ExpandV6ExitPairs(routes, routesMap) + netIdRoutes := maps.Keys(routesMap) if err := routeSelector.DeselectRoutes(routes, netIdRoutes); err != nil { return nil, fmt.Errorf("deselect routes: %w", err) } diff --git a/client/ui/network.go b/client/ui/network.go index 6ae57122e..4bb0b7611 100644 --- a/client/ui/network.go +++ b/client/ui/network.go @@ -195,7 +195,7 @@ func getOverlappingNetworks(routes []*proto.Network) []*proto.Network { func getExitNodeNetworks(routes []*proto.Network) []*proto.Network { var filteredRoutes []*proto.Network for _, route := range routes { - if route.Range == "0.0.0.0/0" || route.Range == "::/0" { + if strings.Contains(route.Range, "0.0.0.0/0") || route.Range == "::/0" { filteredRoutes = append(filteredRoutes, route) } } diff --git a/client/wasm/cmd/main.go b/client/wasm/cmd/main.go index 0c1a5dc69..6fa0eeb2a 100644 --- a/client/wasm/cmd/main.go +++ b/client/wasm/cmd/main.go @@ -6,6 +6,7 @@ import ( "context" "fmt" "net" + "strconv" "syscall/js" "time" @@ -166,39 +167,58 @@ func createSSHMethod(client *netbird.Client) js.Func { }) } - var jwtToken string - if len(args) > 3 && !args[3].IsNull() && !args[3].IsUndefined() { - jwtToken = args[3].String() - } + jwtToken, ipVersion := parseSSHOptions(args) return createPromise(func(resolve, reject js.Value) { - sshClient := ssh.NewClient(client) - - if err := sshClient.Connect(host, port, username, jwtToken); err != nil { + jsInterface, err := connectSSH(client, host, port, username, jwtToken, ipVersion) + if err != nil { reject.Invoke(err.Error()) return } - - if err := sshClient.StartSession(80, 24); err != nil { - if closeErr := sshClient.Close(); closeErr != nil { - log.Errorf("Error closing SSH client: %v", closeErr) - } - reject.Invoke(err.Error()) - return - } - - jsInterface := ssh.CreateJSInterface(sshClient) resolve.Invoke(jsInterface) }) }) } -func performPing(client *netbird.Client, hostname string) { +func parseSSHOptions(args []js.Value) (jwtToken string, ipVersion int) { + if len(args) > 3 && !args[3].IsNull() && !args[3].IsUndefined() { + jwtToken = args[3].String() + } + if len(args) > 4 { + ipVersion = jsIPVersion(args[4]) + } + return +} + +func connectSSH(client *netbird.Client, host string, port int, username, jwtToken string, ipVersion int) (js.Value, error) { + sshClient := ssh.NewClient(client) + + if err := sshClient.Connect(host, port, username, jwtToken, ipVersion); err != nil { + return js.Undefined(), err + } + + if err := sshClient.StartSession(80, 24); err != nil { + if closeErr := sshClient.Close(); closeErr != nil { + log.Errorf("Error closing SSH client: %v", closeErr) + } + return js.Undefined(), err + } + + return ssh.CreateJSInterface(sshClient), nil +} + +func performPing(client *netbird.Client, hostname string, ipVersion int) { ctx, cancel := context.WithTimeout(context.Background(), pingTimeout) defer cancel() + // Default to ping4 to avoid dual-stack ICMP endpoint issues in wireguard-go netstack. + network := "ping4" + if ipVersion == 6 { + network = "ping6" + } + start := time.Now() - conn, err := client.Dial(ctx, "ping", hostname) + conn, err := client.Dial(ctx, network, hostname) if err != nil { js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s failed: %v", hostname, err)) return @@ -225,27 +245,39 @@ func performPing(client *netbird.Client, hostname string) { } latency := time.Since(start) - js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s: %dms", hostname, latency.Milliseconds())) + remote := conn.RemoteAddr().String() + msg := fmt.Sprintf("Ping to %s: %dms", hostname, latency.Milliseconds()) + if remote != hostname { + msg += fmt.Sprintf(" (via %s)", remote) + } + js.Global().Get("console").Call("log", msg) } -func performPingTCP(client *netbird.Client, hostname string, port int) { +func performPingTCP(client *netbird.Client, hostname string, port, ipVersion int) { ctx, cancel := context.WithTimeout(context.Background(), pingTimeout) defer cancel() + network := ipVersionNetwork("tcp", ipVersion) + address := net.JoinHostPort(hostname, fmt.Sprintf("%d", port)) start := time.Now() - conn, err := client.Dial(ctx, "tcp", address) + conn, err := client.Dial(ctx, network, address) if err != nil { js.Global().Get("console").Call("log", fmt.Sprintf("TCP ping to %s failed: %v", address, err)) return } latency := time.Since(start) + remote := conn.RemoteAddr().String() if err := conn.Close(); err != nil { log.Debugf("failed to close TCP connection: %v", err) } - js.Global().Get("console").Call("log", fmt.Sprintf("TCP ping to %s succeeded: %dms", address, latency.Milliseconds())) + msg := fmt.Sprintf("TCP ping to %s succeeded: %dms", address, latency.Milliseconds()) + if remote != address { + msg += fmt.Sprintf(" (via %s)", remote) + } + js.Global().Get("console").Call("log", msg) } // createPingMethod creates the ping method @@ -262,8 +294,12 @@ func createPingMethod(client *netbird.Client) js.Func { } hostname := args[0].String() + var ipVersion int + if len(args) > 1 { + ipVersion = jsIPVersion(args[1]) + } return createPromise(func(resolve, reject js.Value) { - performPing(client, hostname) + performPing(client, hostname, ipVersion) resolve.Invoke(js.Undefined()) }) }) @@ -290,8 +326,12 @@ func createPingTCPMethod(client *netbird.Client) js.Func { hostname := args[0].String() port := args[1].Int() + var ipVersion int + if len(args) > 2 { + ipVersion = jsIPVersion(args[2]) + } return createPromise(func(resolve, reject js.Value) { - performPingTCP(client, hostname, port) + performPingTCP(client, hostname, port, ipVersion) resolve.Invoke(js.Undefined()) }) }) @@ -464,6 +504,31 @@ func createSetLogLevelMethod(client *netbird.Client) js.Func { }) } +// ipVersionNetwork appends "4" or "6" to a base network string (e.g. "tcp" -> "tcp4"). +func ipVersionNetwork(base string, ipVersion int) string { + switch ipVersion { + case 4: + return base + "4" + case 6: + return base + "6" + default: + return base + } +} + +// jsIPVersion extracts an IP version (4 or 6) from a JS string or number. +func jsIPVersion(v js.Value) int { + switch v.Type() { + case js.TypeNumber: + return v.Int() + case js.TypeString: + n, _ := strconv.Atoi(v.String()) + return n + default: + return 0 + } +} + // createPromise is a helper to create JavaScript promises func createPromise(handler func(resolve, reject js.Value)) js.Value { return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any { diff --git a/client/wasm/internal/ssh/client.go b/client/wasm/internal/ssh/client.go index 2f425c614..9cfe65266 100644 --- a/client/wasm/internal/ssh/client.go +++ b/client/wasm/internal/ssh/client.go @@ -46,8 +46,9 @@ func NewClient(nbClient *netbird.Client) *Client { } } -// Connect establishes an SSH connection through NetBird network -func (c *Client) Connect(host string, port int, username, jwtToken string) error { +// Connect establishes an SSH connection through NetBird network. +// ipVersion may be 4, 6, or 0 for automatic selection. +func (c *Client) Connect(host string, port int, username, jwtToken string, ipVersion int) error { addr := net.JoinHostPort(host, fmt.Sprintf("%d", port)) logrus.Infof("SSH: Connecting to %s as %s", addr, username) @@ -63,10 +64,18 @@ func (c *Client) Connect(host string, port int, username, jwtToken string) error Timeout: sshDialTimeout, } + network := "tcp" + switch ipVersion { + case 4: + network = "tcp4" + case 6: + network = "tcp6" + } + ctx, cancel := context.WithTimeout(context.Background(), sshDialTimeout) defer cancel() - conn, err := c.nbClient.Dial(ctx, "tcp", addr) + conn, err := c.nbClient.Dial(ctx, network, addr) if err != nil { return fmt.Errorf("dial %s: %w", addr, err) } diff --git a/proxy/cmd/proxy/cmd/debug.go b/proxy/cmd/proxy/cmd/debug.go index 59f7a6b65..81879e404 100644 --- a/proxy/cmd/proxy/cmd/debug.go +++ b/proxy/cmd/proxy/cmd/debug.go @@ -3,6 +3,7 @@ package cmd import ( "fmt" "strconv" + "time" "github.com/spf13/cobra" @@ -57,7 +58,11 @@ var debugSyncCmd = &cobra.Command{ SilenceUsage: true, } -var pingTimeout string +var ( + pingTimeout time.Duration + pingIPv4 bool + pingIPv6 bool +) var debugPingCmd = &cobra.Command{ Use: "ping [port]", @@ -108,7 +113,10 @@ func init() { debugStatusCmd.Flags().StringVar(&statusFilterByStatus, "filter-by-status", "", "Filter by status (idle|connecting|connected)") debugStatusCmd.Flags().StringVar(&statusFilterByConnectionType, "filter-by-connection-type", "", "Filter by connection type (P2P|Relayed)") - debugPingCmd.Flags().StringVar(&pingTimeout, "timeout", "", "Ping timeout (e.g., 10s)") + debugPingCmd.Flags().DurationVar(&pingTimeout, "timeout", 0, "Ping timeout (e.g., 10s)") + debugPingCmd.Flags().BoolVarP(&pingIPv4, "ipv4", "4", false, "Force IPv4") + debugPingCmd.Flags().BoolVarP(&pingIPv6, "ipv6", "6", false, "Force IPv6") + debugPingCmd.MarkFlagsMutuallyExclusive("ipv4", "ipv6") debugCmd.AddCommand(debugHealthCmd) debugCmd.AddCommand(debugClientsCmd) @@ -157,7 +165,14 @@ func runDebugPing(cmd *cobra.Command, args []string) error { } port = p } - return getDebugClient(cmd).PingTCP(cmd.Context(), args[0], args[1], port, pingTimeout) + var ipVersion string + switch { + case pingIPv4: + ipVersion = "4" + case pingIPv6: + ipVersion = "6" + } + return getDebugClient(cmd).PingTCP(cmd.Context(), args[0], args[1], port, pingTimeout, ipVersion) } func runDebugLogLevel(cmd *cobra.Command, args []string) error { diff --git a/proxy/internal/debug/client.go b/proxy/internal/debug/client.go index 01b0bc8e6..2ce721eb8 100644 --- a/proxy/internal/debug/client.go +++ b/proxy/internal/debug/client.go @@ -6,10 +6,12 @@ import ( "encoding/json" "fmt" "io" + "net" "net/http" "net/url" "strings" "time" + ) // StatusFilters contains filter options for status queries. @@ -230,12 +232,16 @@ func (c *Client) ClientSyncResponse(ctx context.Context, accountID string) error } // PingTCP performs a TCP ping through a client. -func (c *Client) PingTCP(ctx context.Context, accountID, host string, port int, timeout string) error { +// ipVersion may be "4", "6", or "" for automatic. +func (c *Client) PingTCP(ctx context.Context, accountID, host string, port int, timeout time.Duration, ipVersion string) error { params := url.Values{} params.Set("host", host) params.Set("port", fmt.Sprintf("%d", port)) - if timeout != "" { - params.Set("timeout", timeout) + if timeout > 0 { + params.Set("timeout", timeout.String()) + } + if ipVersion != "" { + params.Set("ip_version", ipVersion) } path := fmt.Sprintf("/debug/clients/%s/pingtcp?%s", url.PathEscape(accountID), params.Encode()) @@ -244,11 +250,17 @@ func (c *Client) PingTCP(ctx context.Context, accountID, host string, port int, func (c *Client) printPingResult(data map[string]any) { success, _ := data["success"].(bool) + host := net.JoinHostPort(fmt.Sprint(data["host"]), fmt.Sprint(data["port"])) if success { - _, _ = fmt.Fprintf(c.out, "Success: %v:%v\n", data["host"], data["port"]) + remote, _ := data["remote"].(string) + if remote != "" && remote != host { + _, _ = fmt.Fprintf(c.out, "Success: %s (via %s)\n", host, remote) + } else { + _, _ = fmt.Fprintf(c.out, "Success: %s\n", host) + } _, _ = fmt.Fprintf(c.out, "Latency: %v\n", data["latency"]) } else { - _, _ = fmt.Fprintf(c.out, "Failed: %v:%v\n", data["host"], data["port"]) + _, _ = fmt.Fprintf(c.out, "Failed: %s\n", host) c.printError(data) } } diff --git a/proxy/internal/debug/handler.go b/proxy/internal/debug/handler.go index c507cfad9..c1d145204 100644 --- a/proxy/internal/debug/handler.go +++ b/proxy/internal/debug/handler.go @@ -9,6 +9,7 @@ import ( "fmt" "html/template" "maps" + "net" "net/http" "slices" "strconv" @@ -525,13 +526,18 @@ func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountI } } + network := "tcp" + if v := r.URL.Query().Get("ip_version"); v == "4" || v == "6" { + network += v + } + ctx, cancel := context.WithTimeout(r.Context(), timeout) defer cancel() - address := fmt.Sprintf("%s:%d", host, port) + address := net.JoinHostPort(host, strconv.Itoa(port)) start := time.Now() - conn, err := client.Dial(ctx, "tcp", address) + conn, err := client.Dial(ctx, network, address) if err != nil { h.writeJSON(w, map[string]interface{}{ "success": false, @@ -541,18 +547,22 @@ func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountI }) return } + + remote := conn.RemoteAddr().String() if err := conn.Close(); err != nil { h.logger.Debugf("close tcp ping connection: %v", err) } latency := time.Since(start) - h.writeJSON(w, map[string]interface{}{ + resp := map[string]interface{}{ "success": true, "host": host, "port": port, + "remote": remote, "latency_ms": latency.Milliseconds(), "latency": formatDuration(latency), - }) + } + h.writeJSON(w, resp) } func (h *Handler) handleLogLevel(w http.ResponseWriter, r *http.Request, accountID types.AccountID) { diff --git a/route/route.go b/route/route.go index c724e7c7d..97b9721f6 100644 --- a/route/route.go +++ b/route/route.go @@ -20,6 +20,9 @@ const ( MaxMetric = 9999 // MaxNetIDChar Max Network Identifier MaxNetIDChar = 40 + + // V6ExitSuffix is appended to a v4 exit node NetID to form its v6 counterpart. + V6ExitSuffix = "-v6" ) const ( @@ -215,3 +218,61 @@ func ParseNetwork(networkString string) (NetworkType, netip.Prefix, error) { return IPv4Network, masked, nil } + +var ( + v4Default = netip.PrefixFrom(netip.IPv4Unspecified(), 0) + v6Default = netip.PrefixFrom(netip.IPv6Unspecified(), 0) +) + +// IsV4DefaultRoute reports whether p is the IPv4 default route (0.0.0.0/0). +func IsV4DefaultRoute(p netip.Prefix) bool { return p == v4Default } + +// IsV6DefaultRoute reports whether p is the IPv6 default route (::/0). +func IsV6DefaultRoute(p netip.Prefix) bool { return p == v6Default } + +// ExpandV6ExitPairs appends the paired "-v6" exit node NetID for any v4 exit +// node (0.0.0.0/0) in ids that has a matching v6 counterpart (::/0) in routesMap. +// It modifies and returns the input slice. +func ExpandV6ExitPairs(ids []NetID, routesMap map[NetID][]*Route) []NetID { + for _, id := range ids { + rt, ok := routesMap[id] + if !ok || len(rt) == 0 || !IsV4DefaultRoute(rt[0].Network) { + continue + } + v6ID := NetID(string(id) + V6ExitSuffix) + if v6Rt, ok := routesMap[v6ID]; ok && len(v6Rt) > 0 && IsV6DefaultRoute(v6Rt[0].Network) { + if !slices.Contains(ids, v6ID) { + ids = append(ids, v6ID) + } + } + } + return ids +} + +// V6ExitMergeSet scans routesMap and returns the set of v6 exit node NetIDs +// that should be hidden from the UI because they are paired with a v4 exit node. +// A v6 ID is paired when it has suffix "-v6", its route is ::/0, and the base +// name (without "-v6") exists with route 0.0.0.0/0. +func V6ExitMergeSet(routesMap map[NetID][]*Route) map[NetID]struct{} { + merged := make(map[NetID]struct{}) + for id, rt := range routesMap { + if len(rt) == 0 { + continue + } + name := string(id) + if !IsV6DefaultRoute(rt[0].Network) || !strings.HasSuffix(name, V6ExitSuffix) { + continue + } + baseName := NetID(strings.TrimSuffix(name, V6ExitSuffix)) + if baseRt, ok := routesMap[baseName]; ok && len(baseRt) > 0 && IsV4DefaultRoute(baseRt[0].Network) { + merged[id] = struct{}{} + } + } + return merged +} + +// HasV6ExitPair reports whether id has a paired v6 exit node in the merge set. +func HasV6ExitPair(id NetID, v6Merged map[NetID]struct{}) bool { + _, ok := v6Merged[NetID(string(id)+"-v6")] + return ok +} diff --git a/route/route_test.go b/route/route_test.go new file mode 100644 index 000000000..dab707ed3 --- /dev/null +++ b/route/route_test.go @@ -0,0 +1,108 @@ +package route + +import ( + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestExpandV6ExitPairs(t *testing.T) { + v4ExitRoute := &Route{Network: netip.MustParsePrefix("0.0.0.0/0")} + v6ExitRoute := &Route{Network: netip.MustParsePrefix("::/0")} + regularRoute := &Route{Network: netip.MustParsePrefix("10.0.0.0/8")} + + tests := []struct { + name string + ids []NetID + routesMap map[NetID][]*Route + expected []NetID + }{ + { + name: "v4 exit node with matching v6 pair", + ids: []NetID{"exit-node"}, + routesMap: map[NetID][]*Route{ + "exit-node": {v4ExitRoute}, + "exit-node-v6": {v6ExitRoute}, + }, + expected: []NetID{"exit-node", "exit-node-v6"}, + }, + { + name: "v4 exit node without v6 pair", + ids: []NetID{"exit-node"}, + routesMap: map[NetID][]*Route{ + "exit-node": {v4ExitRoute}, + }, + expected: []NetID{"exit-node"}, + }, + { + name: "regular route is not expanded", + ids: []NetID{"office"}, + routesMap: map[NetID][]*Route{ + "office": {regularRoute}, + "office-v6": {v6ExitRoute}, + }, + expected: []NetID{"office"}, + }, + { + name: "v6 already included is not duplicated", + ids: []NetID{"exit-node", "exit-node-v6"}, + routesMap: map[NetID][]*Route{ + "exit-node": {v4ExitRoute}, + "exit-node-v6": {v6ExitRoute}, + }, + expected: []NetID{"exit-node", "exit-node-v6"}, + }, + { + name: "multiple exit nodes expanded independently", + ids: []NetID{"exit-a", "exit-b"}, + routesMap: map[NetID][]*Route{ + "exit-a": {v4ExitRoute}, + "exit-a-v6": {v6ExitRoute}, + "exit-b": {v4ExitRoute}, + "exit-b-v6": {v6ExitRoute}, + }, + expected: []NetID{"exit-a", "exit-b", "exit-a-v6", "exit-b-v6"}, + }, + { + name: "v6 suffix but not exit node network", + ids: []NetID{"office"}, + routesMap: map[NetID][]*Route{ + "office": {regularRoute}, + "office-v6": {regularRoute}, + }, + expected: []NetID{"office"}, + }, + { + name: "user-chosen name for exit node with v6 pair", + ids: []NetID{"my-exit"}, + routesMap: map[NetID][]*Route{ + "my-exit": {v4ExitRoute}, + "my-exit-v6": {v6ExitRoute}, + }, + expected: []NetID{"my-exit", "my-exit-v6"}, + }, + { + name: "real-world management-generated IDs", + ids: []NetID{"0.0.0.0/0"}, + routesMap: map[NetID][]*Route{ + "0.0.0.0/0": {v4ExitRoute}, + "0.0.0.0/0-v6": {v6ExitRoute}, + }, + expected: []NetID{"0.0.0.0/0", "0.0.0.0/0-v6"}, + }, + { + name: "empty input", + ids: []NetID{}, + routesMap: map[NetID][]*Route{}, + expected: []NetID{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ExpandV6ExitPairs(tt.ids, tt.routesMap) + assert.ElementsMatch(t, tt.expected, result) + }) + } +} diff --git a/shared/relay/client/dialer/quic/quic.go b/shared/relay/client/dialer/quic/quic.go index 2d7b00a80..7d413d4c1 100644 --- a/shared/relay/client/dialer/quic/quic.go +++ b/shared/relay/client/dialer/quic/quic.go @@ -46,7 +46,7 @@ func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) { InitialPacketSize: nbRelay.QUICInitialPacketSize, } - udpConn, err := nbnet.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + udpConn, err := nbnet.ListenUDP("udp", &net.UDPAddr{Port: 0}) if err != nil { log.Errorf("failed to listen on UDP: %s", err) return nil, err