[client] Add IPv6 support to ACL manager, USP filter, and forwarder (#5688)

This commit is contained in:
Viktor Liu
2026-04-09 16:56:08 +08:00
committed by GitHub
parent a1e7db2713
commit 1c4e5e71d7
78 changed files with 3606 additions and 1071 deletions

View File

@@ -238,43 +238,84 @@ func (c *Client) Networks() *NetworkArray {
return nil return nil
} }
routesMap := routeManager.GetClientRoutesWithNetID()
v6Merged := route.V6ExitMergeSet(routesMap)
resolvedDomains := c.recorder.GetResolvedDomainsStates()
networkArray := &NetworkArray{ networkArray := &NetworkArray{
items: make([]Network, 0), items: make([]Network, 0),
} }
resolvedDomains := c.recorder.GetResolvedDomainsStates() for id, routes := range routesMap {
for id, routes := range routeManager.GetClientRoutesWithNetID() {
if len(routes) == 0 { if len(routes) == 0 {
continue continue
} }
if _, skip := v6Merged[id]; skip {
r := routes[0]
domains := c.getNetworkDomainsFromRoute(r, resolvedDomains)
netStr := r.Network.String()
if r.IsDynamic() {
netStr = r.Domains.SafeString()
}
routePeer, err := c.recorder.GetPeer(routes[0].Peer)
if err != nil {
log.Errorf("could not get peer info for %s: %v", routes[0].Peer, err)
continue continue
} }
network := Network{
Name: string(id), network := c.buildNetwork(id, routes, routeSelector.IsSelected(id), resolvedDomains, v6Merged)
Network: netStr, if network == nil {
Peer: routePeer.FQDN, continue
Status: routePeer.ConnStatus.String(),
IsSelected: routeSelector.IsSelected(id),
Domains: domains,
} }
networkArray.Add(network) networkArray.Add(*network)
} }
return networkArray return networkArray
} }
func (c *Client) buildNetwork(id route.NetID, routes []*route.Route, selected bool, resolvedDomains map[domain.Domain]peer.ResolvedDomainInfo, v6Merged map[route.NetID]struct{}) *Network {
r := routes[0]
netStr := r.Network.String()
if r.IsDynamic() {
netStr = r.Domains.SafeString()
}
routePeer, err := c.findBestRoutePeer(routes)
if err != nil {
log.Errorf("could not get peer info for route %s: %v", id, err)
return nil
}
network := &Network{
Name: string(id),
Network: netStr,
Peer: routePeer.FQDN,
Status: routePeer.ConnStatus.String(),
IsSelected: selected,
Domains: c.getNetworkDomainsFromRoute(r, resolvedDomains),
}
if route.IsV4DefaultRoute(r.Network) && route.HasV6ExitPair(id, v6Merged) {
network.Network = "0.0.0.0/0, ::/0"
}
return network
}
// findBestRoutePeer returns the peer actively routing traffic for the given
// HA route group. Falls back to the first connected peer, then the first peer.
func (c *Client) findBestRoutePeer(routes []*route.Route) (peer.State, error) {
netStr := routes[0].Network.String()
fullStatus := c.recorder.GetFullStatus()
for _, p := range fullStatus.Peers {
if _, ok := p.GetRoutes()[netStr]; ok {
return p, nil
}
}
for _, r := range routes {
p, err := c.recorder.GetPeer(r.Peer)
if err != nil {
continue
}
if p.ConnStatus == peer.StatusConnected {
return p, nil
}
}
return c.recorder.GetPeer(routes[0].Peer)
}
// OnUpdatedHostDNS update the DNS servers addresses for root zones // OnUpdatedHostDNS update the DNS servers addresses for root zones
func (c *Client) OnUpdatedHostDNS(list *DNSList) error { func (c *Client) OnUpdatedHostDNS(list *DNSList) error {
dnsServer, err := dns.GetServerDns() dnsServer, err := dns.GetServerDns()

View File

@@ -18,9 +18,12 @@ func executeRouteToggle(id string, manager routemanager.Manager,
netID := route.NetID(id) netID := route.NetID(id)
routes := []route.NetID{netID} routes := []route.NetID{netID}
log.Debugf("%s with id: %s", operationName, id) routesMap := manager.GetClientRoutesWithNetID()
routes = route.ExpandV6ExitPairs(routes, routesMap)
if err := routeOperation(routes, maps.Keys(manager.GetClientRoutesWithNetID())); err != nil { log.Debugf("%s with ids: %v", operationName, routes)
if err := routeOperation(routes, maps.Keys(routesMap)); err != nil {
log.Debugf("error when %s: %s", operationName, err) log.Debugf("error when %s: %s", operationName, err)
return fmt.Errorf("error %s: %w", operationName, err) return fmt.Errorf("error %s: %w", operationName, err)
} }

View File

@@ -9,6 +9,7 @@ import (
"net/url" "net/url"
"regexp" "regexp"
"slices" "slices"
"strconv"
"strings" "strings"
) )
@@ -26,8 +27,9 @@ type Anonymizer struct {
} }
func DefaultAddresses() (netip.Addr, netip.Addr) { func DefaultAddresses() (netip.Addr, netip.Addr) {
// 198.51.100.0, 100:: // 198.51.100.0 (RFC 5737 TEST-NET-2), 2001:db8:ffff:: (RFC 3849 documentation, last /48)
return netip.AddrFrom4([4]byte{198, 51, 100, 0}), netip.AddrFrom16([16]byte{0x01}) // The old start 100:: (discard, RFC 6666) is now used for fake IPs on Android.
return netip.AddrFrom4([4]byte{198, 51, 100, 0}), netip.MustParseAddr("2001:db8:ffff::")
} }
func NewAnonymizer(startIPv4, startIPv6 netip.Addr) *Anonymizer { func NewAnonymizer(startIPv4, startIPv6 netip.Addr) *Anonymizer {
@@ -96,6 +98,11 @@ func (a *Anonymizer) isInAnonymizedRange(ip netip.Addr) bool {
} }
func (a *Anonymizer) AnonymizeIPString(ip string) string { func (a *Anonymizer) AnonymizeIPString(ip string) string {
// Handle CIDR notation (e.g. "2001:db8::/32")
if prefix, err := netip.ParsePrefix(ip); err == nil {
return a.AnonymizeIP(prefix.Addr()).String() + "/" + strconv.Itoa(prefix.Bits())
}
addr, err := netip.ParseAddr(ip) addr, err := netip.ParseAddr(ip)
if err != nil { if err != nil {
return ip return ip

View File

@@ -13,7 +13,7 @@ import (
func TestAnonymizeIP(t *testing.T) { func TestAnonymizeIP(t *testing.T) {
startIPv4 := netip.MustParseAddr("198.51.100.0") startIPv4 := netip.MustParseAddr("198.51.100.0")
startIPv6 := netip.MustParseAddr("100::") startIPv6 := netip.MustParseAddr("2001:db8:ffff::")
anonymizer := anonymize.NewAnonymizer(startIPv4, startIPv6) anonymizer := anonymize.NewAnonymizer(startIPv4, startIPv6)
tests := []struct { tests := []struct {
@@ -26,9 +26,9 @@ func TestAnonymizeIP(t *testing.T) {
{"Second Public IPv4", "4.3.2.1", "198.51.100.1"}, {"Second Public IPv4", "4.3.2.1", "198.51.100.1"},
{"Repeated IPv4", "1.2.3.4", "198.51.100.0"}, {"Repeated IPv4", "1.2.3.4", "198.51.100.0"},
{"Private IPv4", "192.168.1.1", "192.168.1.1"}, {"Private IPv4", "192.168.1.1", "192.168.1.1"},
{"First Public IPv6", "2607:f8b0:4005:805::200e", "100::"}, {"First Public IPv6", "2607:f8b0:4005:805::200e", "2001:db8:ffff::"},
{"Second Public IPv6", "a::b", "100::1"}, {"Second Public IPv6", "a::b", "2001:db8:ffff::1"},
{"Repeated IPv6", "2607:f8b0:4005:805::200e", "100::"}, {"Repeated IPv6", "2607:f8b0:4005:805::200e", "2001:db8:ffff::"},
{"Private IPv6", "fe80::1", "fe80::1"}, {"Private IPv6", "fe80::1", "fe80::1"},
{"In Range IPv4", "198.51.100.2", "198.51.100.2"}, {"In Range IPv4", "198.51.100.2", "198.51.100.2"},
} }
@@ -274,17 +274,17 @@ func TestAnonymizeString_IPAddresses(t *testing.T) {
{ {
name: "IPv6 Address", name: "IPv6 Address",
input: "Access attempted from 2001:db8::ff00:42", input: "Access attempted from 2001:db8::ff00:42",
expect: "Access attempted from 100::", expect: "Access attempted from 2001:db8:ffff::",
}, },
{ {
name: "IPv6 Address with Port", name: "IPv6 Address with Port",
input: "Access attempted from [2001:db8::ff00:42]:8080", input: "Access attempted from [2001:db8::ff00:42]:8080",
expect: "Access attempted from [100::]:8080", expect: "Access attempted from [2001:db8:ffff::]:8080",
}, },
{ {
name: "Both IPv4 and IPv6", name: "Both IPv4 and IPv6",
input: "IPv4: 142.108.0.1 and IPv6: 2001:db8::ff00:43", input: "IPv4: 142.108.0.1 and IPv6: 2001:db8::ff00:43",
expect: "IPv4: 198.51.100.1 and IPv6: 100::1", expect: "IPv4: 198.51.100.1 and IPv6: 2001:db8:ffff::1",
}, },
} }

View File

@@ -787,10 +787,10 @@ func isUnixSocket(path string) bool {
return strings.HasPrefix(path, "/") || strings.HasPrefix(path, "./") return strings.HasPrefix(path, "/") || strings.HasPrefix(path, "./")
} }
// normalizeLocalHost converts "*" to "0.0.0.0" for binding to all interfaces. // normalizeLocalHost converts "*" to "" for binding to all interfaces (dual-stack).
func normalizeLocalHost(host string) string { func normalizeLocalHost(host string) string {
if host == "*" { if host == "*" {
return "0.0.0.0" return ""
} }
return host return host
} }

View File

@@ -527,10 +527,10 @@ func TestParsePortForward(t *testing.T) {
{ {
name: "wildcard bind all interfaces", name: "wildcard bind all interfaces",
spec: "*:8080:localhost:80", spec: "*:8080:localhost:80",
expectedLocal: "0.0.0.0:8080", expectedLocal: ":8080",
expectedRemote: "localhost:80", expectedRemote: "localhost:80",
expectError: false, expectError: false,
description: "Wildcard * should bind to all interfaces (0.0.0.0)", description: "Wildcard * should bind to all interfaces (dual-stack)",
}, },
{ {
name: "wildcard for port only", name: "wildcard for port only",

View File

@@ -36,6 +36,7 @@ type aclManager struct {
entries aclEntries entries aclEntries
optionalEntries map[string][]entry optionalEntries map[string][]entry
ipsetStore *ipsetStore ipsetStore *ipsetStore
v6 bool
stateManager *statemanager.Manager stateManager *statemanager.Manager
} }
@@ -47,6 +48,7 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*acl
entries: make(map[string][][]string), entries: make(map[string][][]string),
optionalEntries: make(map[string][]entry), optionalEntries: make(map[string][]entry),
ipsetStore: newIpsetStore(), ipsetStore: newIpsetStore(),
v6: iptablesClient.Proto() == iptables.ProtocolIPv6,
}, nil }, nil
} }
@@ -81,7 +83,11 @@ func (m *aclManager) AddPeerFiltering(
chain := chainNameInputRules chain := chainNameInputRules
ipsetName = transformIPsetName(ipsetName, sPort, dPort, action) ipsetName = transformIPsetName(ipsetName, sPort, dPort, action)
specs := filterRuleSpecs(ip, string(protocol), sPort, dPort, action, ipsetName) if m.v6 && ipsetName != "" {
ipsetName += "-v6"
}
proto := protoForFamily(protocol, m.v6)
specs := filterRuleSpecs(ip, proto, sPort, dPort, action, ipsetName)
mangleSpecs := slices.Clone(specs) mangleSpecs := slices.Clone(specs)
mangleSpecs = append(mangleSpecs, mangleSpecs = append(mangleSpecs,
@@ -105,6 +111,7 @@ func (m *aclManager) AddPeerFiltering(
ip: ip.String(), ip: ip.String(),
chain: chain, chain: chain,
specs: specs, specs: specs,
v6: m.v6,
}}, nil }}, nil
} }
@@ -157,6 +164,7 @@ func (m *aclManager) AddPeerFiltering(
ipsetName: ipsetName, ipsetName: ipsetName,
ip: ip.String(), ip: ip.String(),
chain: chain, chain: chain,
v6: m.v6,
} }
m.updateState() m.updateState()
@@ -376,8 +384,13 @@ func (m *aclManager) updateState() {
currentState.Lock() currentState.Lock()
defer currentState.Unlock() defer currentState.Unlock()
currentState.ACLEntries = m.entries if m.v6 {
currentState.ACLIPsetStore = m.ipsetStore currentState.ACLEntries6 = m.entries
currentState.ACLIPsetStore6 = m.ipsetStore
} else {
currentState.ACLEntries = m.entries
currentState.ACLIPsetStore = m.ipsetStore
}
if err := m.stateManager.UpdateState(currentState); err != nil { if err := m.stateManager.UpdateState(currentState); err != nil {
log.Errorf("failed to update state: %v", err) log.Errorf("failed to update state: %v", err)
@@ -385,6 +398,15 @@ func (m *aclManager) updateState() {
} }
// filterRuleSpecs returns the specs of a filtering rule // filterRuleSpecs returns the specs of a filtering rule
// protoForFamily translates ICMP to ICMPv6 for ip6tables.
// ip6tables requires "ipv6-icmp" (or "icmpv6") instead of "icmp".
func protoForFamily(protocol firewall.Protocol, v6 bool) string {
if v6 && protocol == firewall.ProtocolICMP {
return "ipv6-icmp"
}
return string(protocol)
}
func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) { func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) {
// don't use IP matching if IP is 0.0.0.0 // don't use IP matching if IP is 0.0.0.0
matchByIP := !ip.IsUnspecified() matchByIP := !ip.IsUnspecified()
@@ -437,6 +459,9 @@ func (m *aclManager) createIPSet(name string) error {
opts := ipset.CreateOptions{ opts := ipset.CreateOptions{
Replace: true, Replace: true,
} }
if m.v6 {
opts.Family = ipset.FamilyIPV6
}
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil { if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
return fmt.Errorf("create ipset %s: %w", name, err) return fmt.Errorf("create ipset %s: %w", name, err)

View File

@@ -17,6 +17,10 @@ import (
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
type resetter interface {
Reset() error
}
// Manager of iptables firewall // Manager of iptables firewall
type Manager struct { type Manager struct {
mutex sync.Mutex mutex sync.Mutex
@@ -27,6 +31,11 @@ type Manager struct {
aclMgr *aclManager aclMgr *aclManager
router *router router *router
rawSupported bool rawSupported bool
// IPv6 counterparts, nil when no v6 overlay
ipv6Client *iptables.IPTables
aclMgr6 *aclManager
router6 *router
} }
// iFaceMapper defines subset methods of interface required for manager // iFaceMapper defines subset methods of interface required for manager
@@ -58,9 +67,43 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
return nil, fmt.Errorf("create acl manager: %w", err) return nil, fmt.Errorf("create acl manager: %w", err)
} }
if wgIface.Address().HasIPv6() {
if err := m.createIPv6Components(wgIface, mtu); err != nil {
return nil, fmt.Errorf("create IPv6 firewall: %w", err)
}
}
return m, nil return m, nil
} }
func (m *Manager) createIPv6Components(wgIface iFaceMapper, mtu uint16) error {
ip6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
if err != nil {
return fmt.Errorf("init ip6tables: %w", err)
}
m.ipv6Client = ip6Client
m.router6, err = newRouter(ip6Client, wgIface, mtu)
if err != nil {
return fmt.Errorf("create v6 router: %w", err)
}
// Share the same IP forwarding state with the v4 router, since
// EnableIPForwarding controls both v4 and v6 sysctls.
m.router6.ipFwdState = m.router.ipFwdState
m.aclMgr6, err = newAclManager(ip6Client, wgIface)
if err != nil {
return fmt.Errorf("create v6 acl manager: %w", err)
}
return nil
}
func (m *Manager) hasIPv6() bool {
return m.ipv6Client != nil
}
func (m *Manager) Init(stateManager *statemanager.Manager) error { func (m *Manager) Init(stateManager *statemanager.Manager) error {
state := &ShutdownState{ state := &ShutdownState{
InterfaceState: &InterfaceState{ InterfaceState: &InterfaceState{
@@ -75,13 +118,8 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
log.Errorf("failed to update state: %v", err) log.Errorf("failed to update state: %v", err)
} }
if err := m.router.init(stateManager); err != nil { if err := m.initChains(stateManager); err != nil {
return fmt.Errorf("router init: %w", err) return err
}
if err := m.aclMgr.init(stateManager); err != nil {
// TODO: cleanup router
return fmt.Errorf("acl manager init: %w", err)
} }
if err := m.initNoTrackChain(); err != nil { if err := m.initNoTrackChain(); err != nil {
@@ -98,6 +136,41 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
return nil return nil
} }
// initChains initializes router and ACL chains for both address families,
// rolling back on failure.
func (m *Manager) initChains(stateManager *statemanager.Manager) error {
type initStep struct {
name string
init func(*statemanager.Manager) error
mgr resetter
}
steps := []initStep{
{"router", m.router.init, m.router},
{"acl manager", m.aclMgr.init, m.aclMgr},
}
if m.hasIPv6() {
steps = append(steps,
initStep{"v6 router", m.router6.init, m.router6},
initStep{"v6 acl manager", m.aclMgr6.init, m.aclMgr6},
)
}
var initialized []initStep
for _, s := range steps {
if err := s.init(stateManager); err != nil {
for i := len(initialized) - 1; i >= 0; i-- {
if rerr := initialized[i].mgr.Reset(); rerr != nil {
log.Warnf("rollback %s: %v", initialized[i].name, rerr)
}
}
return fmt.Errorf("%s init: %w", s.name, err)
}
initialized = append(initialized, s)
}
return nil
}
// AddPeerFiltering adds a rule to the firewall // AddPeerFiltering adds a rule to the firewall
// //
// Comment will be ignored because some system this feature is not supported // Comment will be ignored because some system this feature is not supported
@@ -113,7 +186,13 @@ func (m *Manager) AddPeerFiltering(
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.aclMgr.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName) if ip.To4() != nil {
return m.aclMgr.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
}
if !m.hasIPv6() {
return nil, fmt.Errorf("IPv6 not initialized, cannot add rule for %s", ip)
}
return m.aclMgr6.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
} }
func (m *Manager) AddRouteFiltering( func (m *Manager) AddRouteFiltering(
@@ -127,25 +206,48 @@ func (m *Manager) AddRouteFiltering(
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if destination.IsPrefix() && !destination.Prefix.Addr().Is4() { if isIPv6RouteRule(sources, destination) {
return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String()) if !m.hasIPv6() {
return nil, fmt.Errorf("IPv6 not initialized, cannot add route rule")
}
return m.router6.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
} }
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action) return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
} }
func isIPv6RouteRule(sources []netip.Prefix, destination firewall.Network) bool {
if destination.IsPrefix() {
return destination.Prefix.Addr().Is6()
}
return len(sources) > 0 && sources[0].Addr().Is6()
}
// DeletePeerRule from the firewall by rule definition // DeletePeerRule from the firewall by rule definition
func (m *Manager) DeletePeerRule(rule firewall.Rule) error { func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if m.hasIPv6() && isIPv6IptRule(rule) {
return m.aclMgr6.DeletePeerRule(rule)
}
return m.aclMgr.DeletePeerRule(rule) return m.aclMgr.DeletePeerRule(rule)
} }
func isIPv6IptRule(rule firewall.Rule) bool {
r, ok := rule.(*Rule)
return ok && r.v6
}
// DeleteRouteRule deletes a routing rule.
// Route rules are keyed by content hash. Check v4 first, try v6 if not found.
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error { func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if m.hasIPv6() && !m.router.hasRule(rule.ID()) {
return m.router6.DeleteRouteRule(rule)
}
return m.router.DeleteRouteRule(rule) return m.router.DeleteRouteRule(rule)
} }
@@ -161,18 +263,63 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.router.AddNatRule(pair) if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() {
if !m.hasIPv6() {
return fmt.Errorf("IPv6 not initialized, cannot add NAT rule")
}
return m.router6.AddNatRule(pair)
}
if err := m.router.AddNatRule(pair); err != nil {
return err
}
// Dynamic routes need NAT in both tables
if m.hasIPv6() && pair.Destination.IsSet() {
v6Pair := pair
v6Pair.Source = firewall.Network{Prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
if err := m.router6.AddNatRule(v6Pair); err != nil {
return fmt.Errorf("add v6 NAT rule: %w", err)
}
}
return nil
} }
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error { func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.router.RemoveNatRule(pair) if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() {
if !m.hasIPv6() {
return nil
}
return m.router6.RemoveNatRule(pair)
}
if err := m.router.RemoveNatRule(pair); err != nil {
return err
}
if m.hasIPv6() && pair.Destination.IsSet() {
v6Pair := pair
v6Pair.Source = firewall.Network{Prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
if err := m.router6.RemoveNatRule(v6Pair); err != nil {
return fmt.Errorf("remove v6 NAT rule: %w", err)
}
}
return nil
} }
func (m *Manager) SetLegacyManagement(isLegacy bool) error { func (m *Manager) SetLegacyManagement(isLegacy bool) error {
return firewall.SetLegacyManagement(m.router, isLegacy) if err := firewall.SetLegacyManagement(m.router, isLegacy); err != nil {
return err
}
if m.hasIPv6() {
return firewall.SetLegacyManagement(m.router6, isLegacy)
}
return nil
} }
// Reset firewall to the default state // Reset firewall to the default state
@@ -186,6 +333,15 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
merr = multierror.Append(merr, fmt.Errorf("cleanup notrack chain: %w", err)) merr = multierror.Append(merr, fmt.Errorf("cleanup notrack chain: %w", err))
} }
if m.hasIPv6() {
if err := m.aclMgr6.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset v6 acl manager: %w", err))
}
if err := m.router6.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset v6 router: %w", err))
}
}
if err := m.aclMgr.Reset(); err != nil { if err := m.aclMgr.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err)) merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err))
} }
@@ -209,19 +365,16 @@ func (m *Manager) AllowNetbird() error {
return nil return nil
} }
_, err := m.AddPeerFiltering( var merr *multierror.Error
nil, if _, err := m.aclMgr.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolALL, nil, nil, firewall.ActionAccept, ""); err != nil {
net.IP{0, 0, 0, 0}, merr = multierror.Append(merr, fmt.Errorf("allow netbird interface traffic: %w", err))
firewall.ProtocolALL,
nil,
nil,
firewall.ActionAccept,
"",
)
if err != nil {
return fmt.Errorf("allow netbird interface traffic: %w", err)
} }
return nil if m.hasIPv6() {
if _, err := m.aclMgr6.AddPeerFiltering(nil, net.IPv6zero, firewall.ProtocolALL, nil, nil, firewall.ActionAccept, ""); err != nil {
merr = multierror.Append(merr, fmt.Errorf("allow v6 netbird interface traffic: %w", err))
}
}
return nberrors.FormatErrorOrNil(merr)
} }
// Flush doesn't need to be implemented for this manager // Flush doesn't need to be implemented for this manager
@@ -251,6 +404,9 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error)
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if m.hasIPv6() && rule.TranslatedAddress.Is6() {
return m.router6.AddDNATRule(rule)
}
return m.router.AddDNATRule(rule) return m.router.AddDNATRule(rule)
} }
@@ -259,6 +415,9 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if m.hasIPv6() && !m.router.hasRule(rule.ID()+dnatSuffix) {
return m.router6.DeleteDNATRule(rule)
}
return m.router.DeleteDNATRule(rule) return m.router.DeleteDNATRule(rule)
} }
@@ -267,7 +426,26 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.router.UpdateSet(set, prefixes) var v4Prefixes, v6Prefixes []netip.Prefix
for _, p := range prefixes {
if p.Addr().Is6() {
v6Prefixes = append(v6Prefixes, p)
} else {
v4Prefixes = append(v4Prefixes, p)
}
}
if err := m.router.UpdateSet(set, v4Prefixes); err != nil {
return err
}
if m.hasIPv6() && len(v6Prefixes) > 0 {
if err := m.router6.UpdateSet(set, v6Prefixes); err != nil {
return fmt.Errorf("update v6 set: %w", err)
}
}
return nil
} }
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. // AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
@@ -275,6 +453,9 @@ func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protoco
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if m.hasIPv6() && localAddr.Is6() {
return m.router6.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort) return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
} }
@@ -283,6 +464,9 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if m.hasIPv6() && localAddr.Is6() {
return m.router6.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort) return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
} }

View File

@@ -54,8 +54,10 @@ const (
snatSuffix = "_snat" snatSuffix = "_snat"
fwdSuffix = "_fwd" fwdSuffix = "_fwd"
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation // ipv4TCPHeaderSize is the minimum IPv4 (20) + TCP (20) header size for MSS calculation.
ipTCPHeaderMinSize = 40 ipv4TCPHeaderSize = 40
// ipv6TCPHeaderSize is the minimum IPv6 (40) + TCP (20) header size for MSS calculation.
ipv6TCPHeaderSize = 60
) )
type ruleInfo struct { type ruleInfo struct {
@@ -86,6 +88,7 @@ type router struct {
wgIface iFaceMapper wgIface iFaceMapper
legacyManagement bool legacyManagement bool
mtu uint16 mtu uint16
v6 bool
stateManager *statemanager.Manager stateManager *statemanager.Manager
ipFwdState *ipfwdstate.IPForwardingState ipFwdState *ipfwdstate.IPForwardingState
@@ -97,6 +100,7 @@ func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint1
rules: make(map[string][]string), rules: make(map[string][]string),
wgIface: wgIface, wgIface: wgIface,
mtu: mtu, mtu: mtu,
v6: iptablesClient.Proto() == iptables.ProtocolIPv6,
ipFwdState: ipfwdstate.NewIPForwardingState(), ipFwdState: ipfwdstate.NewIPForwardingState(),
} }
@@ -186,6 +190,11 @@ func (r *router) AddRouteFiltering(
return ruleKey, nil return ruleKey, nil
} }
func (r *router) hasRule(id string) bool {
_, ok := r.rules[id]
return ok
}
func (r *router) DeleteRouteRule(rule firewall.Rule) error { func (r *router) DeleteRouteRule(rule firewall.Rule) error {
ruleKey := rule.ID() ruleKey := rule.ID()
@@ -434,6 +443,12 @@ func (r *router) createContainers() error {
{chainRTRDR, tableNat}, {chainRTRDR, tableNat},
{chainRTMSSCLAMP, tableMangle}, {chainRTMSSCLAMP, tableMangle},
} { } {
// Fallback: clear chains that survived an unclean shutdown.
if ok, _ := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain); ok {
if err := r.iptablesClient.ClearAndDeleteChain(chainInfo.table, chainInfo.chain); err != nil {
log.Warnf("clear stale chain %s in %s: %v", chainInfo.chain, chainInfo.table, err)
}
}
if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil { if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil {
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err) return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
} }
@@ -540,9 +555,12 @@ func (r *router) addPostroutingRules() error {
} }
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic. // addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
// TODO: Add IPv6 support
func (r *router) addMSSClampingRules() error { func (r *router) addMSSClampingRules() error {
mss := r.mtu - ipTCPHeaderMinSize overhead := uint16(ipv4TCPHeaderSize)
if r.v6 {
overhead = ipv6TCPHeaderSize
}
mss := r.mtu - overhead
// Add jump rule from FORWARD chain in mangle table to our custom chain // Add jump rule from FORWARD chain in mangle table to our custom chain
jumpRule := []string{ jumpRule := []string{
@@ -727,8 +745,13 @@ func (r *router) updateState() {
currentState.Lock() currentState.Lock()
defer currentState.Unlock() defer currentState.Unlock()
currentState.RouteRules = r.rules if r.v6 {
currentState.RouteIPsetCounter = r.ipsetCounter currentState.RouteRules6 = r.rules
currentState.RouteIPsetCounter6 = r.ipsetCounter
} else {
currentState.RouteRules = r.rules
currentState.RouteIPsetCounter = r.ipsetCounter
}
if err := r.stateManager.UpdateState(currentState); err != nil { if err := r.stateManager.UpdateState(currentState); err != nil {
log.Errorf("failed to update state: %v", err) log.Errorf("failed to update state: %v", err)
@@ -856,7 +879,7 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
} }
if fwdRule, exists := r.rules[ruleKey+fwdSuffix]; exists { if fwdRule, exists := r.rules[ruleKey+fwdSuffix]; exists {
if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, fwdRule...); err != nil { if err := r.iptablesClient.Delete(tableFilter, chainRTFWDOUT, fwdRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete forward rule: %w", err)) merr = multierror.Append(merr, fmt.Errorf("delete forward rule: %w", err))
} }
delete(r.rules, ruleKey+fwdSuffix) delete(r.rules, ruleKey+fwdSuffix)
@@ -883,7 +906,7 @@ func (r *router) genRouteRuleSpec(params routeFilteringRuleParams, sources []net
rule = append(rule, destExp...) rule = append(rule, destExp...)
if params.Proto != firewall.ProtocolALL { if params.Proto != firewall.ProtocolALL {
rule = append(rule, "-p", strings.ToLower(string(params.Proto))) rule = append(rule, "-p", strings.ToLower(protoForFamily(params.Proto, r.v6)))
rule = append(rule, applyPort("--sport", params.SPort)...) rule = append(rule, applyPort("--sport", params.SPort)...)
rule = append(rule, applyPort("--dport", params.DPort)...) rule = append(rule, applyPort("--dport", params.DPort)...)
} }
@@ -900,11 +923,12 @@ func (r *router) applyNetwork(flag string, network firewall.Network, prefixes []
} }
if network.IsSet() { if network.IsSet() {
if _, err := r.ipsetCounter.Increment(network.Set.HashedName(), prefixes); err != nil { name := r.ipsetName(network.Set.HashedName())
if _, err := r.ipsetCounter.Increment(name, prefixes); err != nil {
return nil, fmt.Errorf("create or get ipset: %w", err) return nil, fmt.Errorf("create or get ipset: %w", err)
} }
return []string{"-m", "set", matchSet, network.Set.HashedName(), direction}, nil return []string{"-m", "set", matchSet, name, direction}, nil
} }
if network.IsPrefix() { if network.IsPrefix() {
return []string{flag, network.Prefix.String()}, nil return []string{flag, network.Prefix.String()}, nil
@@ -915,19 +939,15 @@ func (r *router) applyNetwork(flag string, network firewall.Network, prefixes []
} }
func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
name := r.ipsetName(set.HashedName())
var merr *multierror.Error var merr *multierror.Error
for _, prefix := range prefixes { for _, prefix := range prefixes {
// TODO: Implement IPv6 support if err := r.addPrefixToIPSet(name, prefix); err != nil {
if prefix.Addr().Is6() {
log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
continue
}
if err := r.addPrefixToIPSet(set.HashedName(), prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add prefix to ipset: %w", err)) merr = multierror.Append(merr, fmt.Errorf("add prefix to ipset: %w", err))
} }
} }
if merr == nil { if merr == nil {
log.Debugf("updated set %s with prefixes %v", set.HashedName(), prefixes) log.Debugf("updated set %s with prefixes %v", name, prefixes)
} }
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)
@@ -943,7 +963,7 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol
dnatRule := []string{ dnatRule := []string{
"-i", r.wgIface.Name(), "-i", r.wgIface.Name(),
"-p", strings.ToLower(string(protocol)), "-p", strings.ToLower(protoForFamily(protocol, r.v6)),
"--dport", strconv.Itoa(int(sourcePort)), "--dport", strconv.Itoa(int(sourcePort)),
"-d", localAddr.String(), "-d", localAddr.String(),
"-m", "addrtype", "--dst-type", "LOCAL", "-m", "addrtype", "--dst-type", "LOCAL",
@@ -1076,10 +1096,22 @@ func applyPort(flag string, port *firewall.Port) []string {
return []string{flag, strconv.Itoa(int(port.Values[0]))} return []string{flag, strconv.Itoa(int(port.Values[0]))}
} }
// ipsetName returns the ipset name, suffixed with "-v6" for the v6 router
// to avoid collisions since ipsets are global in the kernel.
func (r *router) ipsetName(name string) string {
if r.v6 {
return name + "-v6"
}
return name
}
func (r *router) createIPSet(name string) error { func (r *router) createIPSet(name string) error {
opts := ipset.CreateOptions{ opts := ipset.CreateOptions{
Replace: true, Replace: true,
} }
if r.v6 {
opts.Family = ipset.FamilyIPV6
}
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil { if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
return fmt.Errorf("create ipset %s: %w", name, err) return fmt.Errorf("create ipset %s: %w", name, err)

View File

@@ -9,6 +9,7 @@ type Rule struct {
mangleSpecs []string mangleSpecs []string
ip string ip string
chain string chain string
v6 bool
} }
// GetRuleID returns the rule id // GetRuleID returns the rule id

View File

@@ -4,6 +4,8 @@ import (
"fmt" "fmt"
"sync" "sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
@@ -37,6 +39,12 @@ type ShutdownState struct {
ACLEntries aclEntries `json:"acl_entries,omitempty"` ACLEntries aclEntries `json:"acl_entries,omitempty"`
ACLIPsetStore *ipsetStore `json:"acl_ipset_store,omitempty"` ACLIPsetStore *ipsetStore `json:"acl_ipset_store,omitempty"`
// IPv6 counterparts
RouteRules6 routeRules `json:"route_rules_v6,omitempty"`
RouteIPsetCounter6 *ipsetCounter `json:"route_ipset_counter_v6,omitempty"`
ACLEntries6 aclEntries `json:"acl_entries_v6,omitempty"`
ACLIPsetStore6 *ipsetStore `json:"acl_ipset_store_v6,omitempty"`
} }
func (s *ShutdownState) Name() string { func (s *ShutdownState) Name() string {
@@ -67,6 +75,28 @@ func (s *ShutdownState) Cleanup() error {
ipt.aclMgr.ipsetStore = s.ACLIPsetStore ipt.aclMgr.ipsetStore = s.ACLIPsetStore
} }
// Clean up v6 state even if the current run has no IPv6.
// The previous run may have left ip6tables rules behind.
if !ipt.hasIPv6() {
if err := ipt.createIPv6Components(s.InterfaceState, mtu); err != nil {
log.Warnf("failed to create v6 components for cleanup: %v", err)
}
}
if ipt.hasIPv6() {
if s.RouteRules6 != nil {
ipt.router6.rules = s.RouteRules6
}
if s.RouteIPsetCounter6 != nil {
ipt.router6.ipsetCounter.LoadData(s.RouteIPsetCounter6)
}
if s.ACLEntries6 != nil {
ipt.aclMgr6.entries = s.ACLEntries6
}
if s.ACLIPsetStore6 != nil {
ipt.aclMgr6.ipsetStore = s.ACLIPsetStore6
}
}
if err := ipt.Close(nil); err != nil { if err := ipt.Close(nil); err != nil {
return fmt.Errorf("reset iptables manager: %w", err) return fmt.Errorf("reset iptables manager: %w", err)
} }

View File

@@ -33,15 +33,12 @@ const (
const flushError = "flush: %w" const flushError = "flush: %w"
var (
anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
)
type AclManager struct { type AclManager struct {
rConn *nftables.Conn rConn *nftables.Conn
sConn *nftables.Conn sConn *nftables.Conn
wgIface iFaceMapper wgIface iFaceMapper
routingFwChainName string routingFwChainName string
af addrFamily
workTable *nftables.Table workTable *nftables.Table
chainInputRules *nftables.Chain chainInputRules *nftables.Chain
@@ -67,6 +64,7 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainNam
wgIface: wgIface, wgIface: wgIface,
workTable: table, workTable: table,
routingFwChainName: routingFwChainName, routingFwChainName: routingFwChainName,
af: familyForAddr(table.Family == nftables.TableFamilyIPv4),
ipsetStore: newIpsetStore(), ipsetStore: newIpsetStore(),
rules: make(map[string]*Rule), rules: make(map[string]*Rule),
@@ -145,7 +143,7 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
} }
if _, ok := ips[r.ip.String()]; ok { if _, ok := ips[r.ip.String()]; ok {
err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: r.ip.To4()}}) err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: ipToBytes(r.ip, m.af)}})
if err != nil { if err != nil {
log.Errorf("delete elements for set %q: %v", r.nftSet.Name, err) log.Errorf("delete elements for set %q: %v", r.nftSet.Name, err)
} }
@@ -254,11 +252,11 @@ func (m *AclManager) addIOFiltering(
expressions = append(expressions, &expr.Payload{ expressions = append(expressions, &expr.Payload{
DestRegister: 1, DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader, Base: expr.PayloadBaseNetworkHeader,
Offset: uint32(9), Offset: m.af.protoOffset,
Len: uint32(1), Len: uint32(1),
}) })
protoData, err := protoToInt(proto) protoData, err := m.af.protoNum(proto)
if err != nil { if err != nil {
return nil, fmt.Errorf("convert protocol to number: %v", err) return nil, fmt.Errorf("convert protocol to number: %v", err)
} }
@@ -270,19 +268,16 @@ func (m *AclManager) addIOFiltering(
}) })
} }
rawIP := ip.To4() rawIP := ipToBytes(ip, m.af)
// check if rawIP contains zeroed IPv4 0.0.0.0 value // check if rawIP contains zeroed IPv4 0.0.0.0 value
// in that case not add IP match expression into the rule definition // in that case not add IP match expression into the rule definition
if !bytes.HasPrefix(anyIP, rawIP) { if slices.ContainsFunc(rawIP, func(v byte) bool { return v != 0 }) {
// source address position
addrOffset := uint32(12)
expressions = append(expressions, expressions = append(expressions,
&expr.Payload{ &expr.Payload{
DestRegister: 1, DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader, Base: expr.PayloadBaseNetworkHeader,
Offset: addrOffset, Offset: m.af.srcAddrOffset,
Len: 4, Len: m.af.addrLen,
}, },
) )
// add individual IP for match if no ipset defined // add individual IP for match if no ipset defined
@@ -587,7 +582,7 @@ func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr
func (m *AclManager) addIpToSet(ipsetName string, ip net.IP) (*nftables.Set, error) { func (m *AclManager) addIpToSet(ipsetName string, ip net.IP) (*nftables.Set, error) {
ipset, err := m.rConn.GetSetByName(m.workTable, ipsetName) ipset, err := m.rConn.GetSetByName(m.workTable, ipsetName)
rawIP := ip.To4() rawIP := ipToBytes(ip, m.af)
if err != nil { if err != nil {
if ipset, err = m.createSet(m.workTable, ipsetName); err != nil { if ipset, err = m.createSet(m.workTable, ipsetName); err != nil {
return nil, fmt.Errorf("get set name: %v", err) return nil, fmt.Errorf("get set name: %v", err)
@@ -619,7 +614,7 @@ func (m *AclManager) createSet(table *nftables.Table, name string) (*nftables.Se
Name: name, Name: name,
Table: table, Table: table,
Dynamic: true, Dynamic: true,
KeyType: nftables.TypeIPAddr, KeyType: m.af.setKeyType,
} }
if err := m.rConn.AddSet(ipset, nil); err != nil { if err := m.rConn.AddSet(ipset, nil); err != nil {
@@ -707,15 +702,12 @@ func ifname(n string) []byte {
return b return b
} }
func protoToInt(protocol firewall.Protocol) (uint8, error) {
switch protocol {
case firewall.ProtocolTCP:
return unix.IPPROTO_TCP, nil
case firewall.ProtocolUDP:
return unix.IPPROTO_UDP, nil
case firewall.ProtocolICMP:
return unix.IPPROTO_ICMP, nil
}
return 0, fmt.Errorf("unsupported protocol: %s", protocol) // ipToBytes converts net.IP to the correct byte length for the address family.
func ipToBytes(ip net.IP, af addrFamily) []byte {
if af.addrLen == 4 {
return ip.To4()
}
return ip.To16()
} }

View File

@@ -0,0 +1,81 @@
package nftables
import (
"fmt"
"net"
"github.com/google/nftables"
"golang.org/x/sys/unix"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
var (
// afIPv4 defines IPv4 header layout and nftables types.
afIPv4 = addrFamily{
protoOffset: 9,
srcAddrOffset: 12,
dstAddrOffset: 16,
addrLen: net.IPv4len,
totalBits: 8 * net.IPv4len,
setKeyType: nftables.TypeIPAddr,
tableFamily: nftables.TableFamilyIPv4,
icmpProto: unix.IPPROTO_ICMP,
}
// afIPv6 defines IPv6 header layout and nftables types.
afIPv6 = addrFamily{
protoOffset: 6,
srcAddrOffset: 8,
dstAddrOffset: 24,
addrLen: net.IPv6len,
totalBits: 8 * net.IPv6len,
setKeyType: nftables.TypeIP6Addr,
tableFamily: nftables.TableFamilyIPv6,
icmpProto: unix.IPPROTO_ICMPV6,
}
)
// addrFamily holds protocol-specific constants for nftables expression building.
type addrFamily struct {
// protoOffset is the IP header offset for the protocol/next-header field (9 for v4, 6 for v6)
protoOffset uint32
// srcAddrOffset is the IP header offset for the source address (12 for v4, 8 for v6)
srcAddrOffset uint32
// dstAddrOffset is the IP header offset for the destination address (16 for v4, 24 for v6)
dstAddrOffset uint32
// addrLen is the byte length of addresses (4 for v4, 16 for v6)
addrLen uint32
// totalBits is the address size in bits (32 for v4, 128 for v6)
totalBits int
// setKeyType is the nftables set data type for addresses
setKeyType nftables.SetDatatype
// tableFamily is the nftables table family
tableFamily nftables.TableFamily
// icmpProto is the ICMP protocol number for this family (1 for v4, 58 for v6)
icmpProto uint8
}
// familyForAddr returns the address family for the given IP.
func familyForAddr(is4 bool) addrFamily {
if is4 {
return afIPv4
}
return afIPv6
}
// protoNum converts a firewall protocol to the IP protocol number,
// using the correct ICMP variant for the address family.
func (af addrFamily) protoNum(protocol firewall.Protocol) (uint8, error) {
switch protocol {
case firewall.ProtocolTCP:
return unix.IPPROTO_TCP, nil
case firewall.ProtocolUDP:
return unix.IPPROTO_UDP, nil
case firewall.ProtocolICMP:
return af.icmpProto, nil
case firewall.ProtocolALL:
return 0, nil
default:
return 0, fmt.Errorf("unsupported protocol: %s", protocol)
}
}

View File

@@ -11,9 +11,11 @@ import (
"github.com/google/nftables" "github.com/google/nftables"
"github.com/google/nftables/binaryutil" "github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr" "github.com/google/nftables/expr"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
@@ -49,8 +51,13 @@ type Manager struct {
rConn *nftables.Conn rConn *nftables.Conn
wgIface iFaceMapper wgIface iFaceMapper
router *router router *router
aclManager *AclManager aclManager *AclManager
// IPv6 counterparts, nil when no v6 overlay
router6 *router
aclManager6 *AclManager
notrackOutputChain *nftables.Chain notrackOutputChain *nftables.Chain
notrackPreroutingChain *nftables.Chain notrackPreroutingChain *nftables.Chain
} }
@@ -62,7 +69,8 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
wgIface: wgIface, wgIface: wgIface,
} }
workTable := &nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4} tableName := getTableName()
workTable := &nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4}
var err error var err error
m.router, err = newRouter(workTable, wgIface, mtu) m.router, err = newRouter(workTable, wgIface, mtu)
@@ -75,11 +83,70 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
return nil, fmt.Errorf("create acl manager: %w", err) return nil, fmt.Errorf("create acl manager: %w", err)
} }
if wgIface.Address().HasIPv6() {
if err := m.createIPv6Components(tableName, wgIface, mtu); err != nil {
return nil, fmt.Errorf("create IPv6 firewall: %w", err)
}
}
return m, nil return m, nil
} }
func (m *Manager) createIPv6Components(tableName string, wgIface iFaceMapper, mtu uint16) error {
workTable6 := &nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv6}
var err error
m.router6, err = newRouter(workTable6, wgIface, mtu)
if err != nil {
return fmt.Errorf("create v6 router: %w", err)
}
// Share the same IP forwarding state with the v4 router, since
// EnableIPForwarding controls both v4 and v6 sysctls.
m.router6.ipFwdState = m.router.ipFwdState
m.aclManager6, err = newAclManager(workTable6, wgIface, chainNameRoutingFw)
if err != nil {
return fmt.Errorf("create v6 acl manager: %w", err)
}
return nil
}
// hasIPv6 reports whether the manager has IPv6 components initialized.
func (m *Manager) hasIPv6() bool {
return m.router6 != nil
}
func (m *Manager) initIPv6() error {
workTable6, err := m.createWorkTableFamily(nftables.TableFamilyIPv6)
if err != nil {
return fmt.Errorf("create v6 work table: %w", err)
}
if err := m.router6.init(workTable6); err != nil {
return fmt.Errorf("v6 router init: %w", err)
}
if err := m.aclManager6.init(workTable6); err != nil {
return fmt.Errorf("v6 acl manager init: %w", err)
}
return nil
}
// Init nftables firewall manager // Init nftables firewall manager
func (m *Manager) Init(stateManager *statemanager.Manager) error { func (m *Manager) Init(stateManager *statemanager.Manager) error {
if err := m.initFirewall(); err != nil {
return err
}
m.persistState(stateManager)
return nil
}
func (m *Manager) initFirewall() error {
workTable, err := m.createWorkTable() workTable, err := m.createWorkTable()
if err != nil { if err != nil {
return fmt.Errorf("create work table: %w", err) return fmt.Errorf("create work table: %w", err)
@@ -90,20 +157,32 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
} }
if err := m.aclManager.init(workTable); err != nil { if err := m.aclManager.init(workTable); err != nil {
// TODO: cleanup router m.rollbackInit()
return fmt.Errorf("acl manager init: %w", err) return fmt.Errorf("acl manager init: %w", err)
} }
if m.hasIPv6() {
if err := m.initIPv6(); err != nil {
// Peer has a v6 address: v6 firewall MUST work or we risk fail-open.
m.rollbackInit()
return fmt.Errorf("init IPv6 firewall (required because peer has IPv6 address): %w", err)
}
}
if err := m.initNoTrackChains(workTable); err != nil { if err := m.initNoTrackChains(workTable); err != nil {
log.Warnf("raw priority chains not available, notrack rules will be disabled: %v", err) log.Warnf("raw priority chains not available, notrack rules will be disabled: %v", err)
} }
return nil
}
// persistState saves the current interface state for potential recreation on restart.
// Unlike iptables, which requires tracking individual rules, nftables maintains
// a known state (our netbird table plus a few static rules). This allows for easy
// cleanup using Close() without needing to store specific rules.
func (m *Manager) persistState(stateManager *statemanager.Manager) {
stateManager.RegisterState(&ShutdownState{}) stateManager.RegisterState(&ShutdownState{})
// We only need to record minimal interface state for potential recreation.
// Unlike iptables, which requires tracking individual rules, nftables maintains
// a known state (our netbird table plus a few static rules). This allows for easy
// cleanup using Close() without needing to store specific rules.
if err := stateManager.UpdateState(&ShutdownState{ if err := stateManager.UpdateState(&ShutdownState{
InterfaceState: &InterfaceState{ InterfaceState: &InterfaceState{
NameStr: m.wgIface.Name(), NameStr: m.wgIface.Name(),
@@ -115,14 +194,29 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
log.Errorf("failed to update state: %v", err) log.Errorf("failed to update state: %v", err)
} }
// persist early
go func() { go func() {
if err := stateManager.PersistState(context.Background()); err != nil { if err := stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err) log.Errorf("failed to persist state: %v", err)
} }
}() }()
}
return nil // rollbackInit performs best-effort cleanup of already-initialized state when Init fails partway through.
func (m *Manager) rollbackInit() {
if err := m.router.Reset(); err != nil {
log.Warnf("rollback router: %v", err)
}
if m.hasIPv6() {
if err := m.router6.Reset(); err != nil {
log.Warnf("rollback v6 router: %v", err)
}
}
if err := m.cleanupNetbirdTables(); err != nil {
log.Warnf("cleanup tables: %v", err)
}
if err := m.rConn.Flush(); err != nil {
log.Warnf("flush: %v", err)
}
} }
// AddPeerFiltering rule to the firewall // AddPeerFiltering rule to the firewall
@@ -141,12 +235,14 @@ func (m *Manager) AddPeerFiltering(
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
rawIP := ip.To4() if ip.To4() != nil {
if rawIP == nil { return m.aclManager.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
} }
return m.aclManager.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName) if !m.hasIPv6() {
return nil, fmt.Errorf("IPv6 not initialized, cannot add rule for %s", ip)
}
return m.aclManager6.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
} }
func (m *Manager) AddRouteFiltering( func (m *Manager) AddRouteFiltering(
@@ -160,8 +256,11 @@ func (m *Manager) AddRouteFiltering(
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if destination.IsPrefix() && !destination.Prefix.Addr().Is4() { if isIPv6RouteRule(sources, destination) {
return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String()) if !m.hasIPv6() {
return nil, fmt.Errorf("IPv6 not initialized, cannot add route rule")
}
return m.router6.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
} }
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action) return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
@@ -172,14 +271,38 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if m.hasIPv6() && isIPv6Rule(rule) {
return m.aclManager6.DeletePeerRule(rule)
}
return m.aclManager.DeletePeerRule(rule) return m.aclManager.DeletePeerRule(rule)
} }
// DeleteRouteRule deletes a routing rule func isIPv6Rule(rule firewall.Rule) bool {
r, ok := rule.(*Rule)
return ok && r.nftRule != nil && r.nftRule.Table != nil && r.nftRule.Table.Family == nftables.TableFamilyIPv6
}
// isIPv6RouteRule determines whether a route rule belongs to the v6 table.
// For static routes, the destination prefix determines the family. For dynamic
// routes (DomainSet), the sources determine the family since management
// duplicates dynamic rules per family.
func isIPv6RouteRule(sources []netip.Prefix, destination firewall.Network) bool {
if destination.IsPrefix() {
return destination.Prefix.Addr().Is6()
}
return len(sources) > 0 && sources[0].Addr().Is6()
}
// DeleteRouteRule deletes a routing rule.
// Route rules are keyed by content hash, so the rule exists in exactly one
// router. We check v4 first; if the key isn't there, try v6.
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error { func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if m.hasIPv6() && !m.router.hasRule(rule.ID()) {
return m.router6.DeleteRouteRule(rule)
}
return m.router.DeleteRouteRule(rule) return m.router.DeleteRouteRule(rule)
} }
@@ -195,17 +318,63 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.router.AddNatRule(pair) if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() {
if !m.hasIPv6() {
return fmt.Errorf("IPv6 not initialized, cannot add NAT rule")
}
return m.router6.AddNatRule(pair)
}
if err := m.router.AddNatRule(pair); err != nil {
return err
}
// Dynamic routes (DomainSet) need NAT in both tables since resolved IPs
// can be either v4 or v6.
if m.hasIPv6() && pair.Destination.IsSet() {
v6Pair := pair
v6Pair.Source = firewall.Network{Prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
if err := m.router6.AddNatRule(v6Pair); err != nil {
return fmt.Errorf("add v6 NAT rule: %w", err)
}
}
return nil
} }
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error { func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.router.RemoveNatRule(pair) if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() {
if !m.hasIPv6() {
return nil
}
return m.router6.RemoveNatRule(pair)
}
if err := m.router.RemoveNatRule(pair); err != nil {
return err
}
if m.hasIPv6() && pair.Destination.IsSet() {
v6Pair := pair
v6Pair.Source = firewall.Network{Prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
if err := m.router6.RemoveNatRule(v6Pair); err != nil {
return fmt.Errorf("remove v6 NAT rule: %w", err)
}
}
return nil
} }
// AllowNetbird allows netbird interface traffic // AllowNetbird allows netbird interface traffic.
// TODO: In USP mode this only adds ACCEPT to the netbird table's own chains,
// which doesn't override DROP rules in external tables (e.g. firewalld).
// Should add passthrough rules to external chains (like the native mode router's
// addExternalChainsRules does) for both the netbird table family and inet tables.
// The netbird table itself is fine (routing chains already exist there), but
// non-netbird tables with INPUT/FORWARD hooks can still DROP our WG traffic.
func (m *Manager) AllowNetbird() error { func (m *Manager) AllowNetbird() error {
if !m.wgIface.IsUserspaceBind() { if !m.wgIface.IsUserspaceBind() {
return nil return nil
@@ -217,6 +386,11 @@ func (m *Manager) AllowNetbird() error {
if err := m.aclManager.createDefaultAllowRules(); err != nil { if err := m.aclManager.createDefaultAllowRules(); err != nil {
return fmt.Errorf("create default allow rules: %w", err) return fmt.Errorf("create default allow rules: %w", err)
} }
if m.hasIPv6() {
if err := m.aclManager6.createDefaultAllowRules(); err != nil {
return fmt.Errorf("create v6 default allow rules: %w", err)
}
}
if err := m.rConn.Flush(); err != nil { if err := m.rConn.Flush(); err != nil {
return fmt.Errorf("flush allow input netbird rules: %w", err) return fmt.Errorf("flush allow input netbird rules: %w", err)
} }
@@ -226,7 +400,13 @@ func (m *Manager) AllowNetbird() error {
// SetLegacyManagement sets the route manager to use legacy management // SetLegacyManagement sets the route manager to use legacy management
func (m *Manager) SetLegacyManagement(isLegacy bool) error { func (m *Manager) SetLegacyManagement(isLegacy bool) error {
return firewall.SetLegacyManagement(m.router, isLegacy) if err := firewall.SetLegacyManagement(m.router, isLegacy); err != nil {
return err
}
if m.hasIPv6() {
return firewall.SetLegacyManagement(m.router6, isLegacy)
}
return nil
} }
// Close closes the firewall manager // Close closes the firewall manager
@@ -234,23 +414,31 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
var merr *multierror.Error
if err := m.router.Reset(); err != nil { if err := m.router.Reset(); err != nil {
return fmt.Errorf("reset router: %v", err) merr = multierror.Append(merr, fmt.Errorf("reset router: %v", err))
}
if m.hasIPv6() {
if err := m.router6.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset v6 router: %v", err))
}
} }
if err := m.cleanupNetbirdTables(); err != nil { if err := m.cleanupNetbirdTables(); err != nil {
return fmt.Errorf("cleanup netbird tables: %v", err) merr = multierror.Append(merr, fmt.Errorf("cleanup netbird tables: %v", err))
} }
if err := m.rConn.Flush(); err != nil { if err := m.rConn.Flush(); err != nil {
return fmt.Errorf(flushError, err) merr = multierror.Append(merr, fmt.Errorf(flushError, err))
} }
if err := stateManager.DeleteState(&ShutdownState{}); err != nil { if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
return fmt.Errorf("delete state: %v", err) merr = multierror.Append(merr, fmt.Errorf("delete state: %v", err))
} }
return nil return nberrors.FormatErrorOrNil(merr)
} }
func (m *Manager) cleanupNetbirdTables() error { func (m *Manager) cleanupNetbirdTables() error {
@@ -299,6 +487,12 @@ func (m *Manager) Flush() error {
return err return err
} }
if m.hasIPv6() {
if err := m.aclManager6.Flush(); err != nil {
return fmt.Errorf("flush v6 acl: %w", err)
}
}
if err := m.refreshNoTrackChains(); err != nil { if err := m.refreshNoTrackChains(); err != nil {
log.Errorf("failed to refresh notrack chains: %v", err) log.Errorf("failed to refresh notrack chains: %v", err)
} }
@@ -311,6 +505,9 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error)
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if m.hasIPv6() && rule.TranslatedAddress.Is6() {
return m.router6.AddDNATRule(rule)
}
return m.router.AddDNATRule(rule) return m.router.AddDNATRule(rule)
} }
@@ -319,6 +516,9 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if m.hasIPv6() && !m.router.hasDNATRule(rule.ID()) {
return m.router6.DeleteDNATRule(rule)
}
return m.router.DeleteDNATRule(rule) return m.router.DeleteDNATRule(rule)
} }
@@ -327,7 +527,26 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.router.UpdateSet(set, prefixes) var v4Prefixes, v6Prefixes []netip.Prefix
for _, p := range prefixes {
if p.Addr().Is6() {
v6Prefixes = append(v6Prefixes, p)
} else {
v4Prefixes = append(v4Prefixes, p)
}
}
if err := m.router.UpdateSet(set, v4Prefixes); err != nil {
return err
}
if m.hasIPv6() && len(v6Prefixes) > 0 {
if err := m.router6.UpdateSet(set, v6Prefixes); err != nil {
return fmt.Errorf("update v6 set: %w", err)
}
}
return nil
} }
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. // AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
@@ -335,6 +554,9 @@ func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protoco
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if m.hasIPv6() && localAddr.Is6() {
return m.router6.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort) return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
} }
@@ -343,6 +565,9 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if m.hasIPv6() && localAddr.Is6() {
return m.router6.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort) return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
} }
@@ -533,7 +758,11 @@ func (m *Manager) refreshNoTrackChains() error {
} }
func (m *Manager) createWorkTable() (*nftables.Table, error) { func (m *Manager) createWorkTable() (*nftables.Table, error) {
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4) return m.createWorkTableFamily(nftables.TableFamilyIPv4)
}
func (m *Manager) createWorkTableFamily(family nftables.TableFamily) (*nftables.Table, error) {
tables, err := m.rConn.ListTablesOfFamily(family)
if err != nil { if err != nil {
return nil, fmt.Errorf("list of tables: %w", err) return nil, fmt.Errorf("list of tables: %w", err)
} }
@@ -545,7 +774,7 @@ func (m *Manager) createWorkTable() (*nftables.Table, error) {
} }
} }
table := m.rConn.AddTable(&nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4}) table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: family})
err = m.rConn.Flush() err = m.rConn.Flush()
return table, err return table, err
} }

View File

@@ -385,10 +385,134 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
err = manager.AddNatRule(pair) err = manager.AddNatRule(pair)
require.NoError(t, err, "failed to add NAT rule") require.NoError(t, err, "failed to add NAT rule")
dnatRule, err := manager.AddDNATRule(fw.ForwardRule{
Protocol: fw.ProtocolTCP,
DestinationPort: fw.Port{Values: []uint16{8080}},
TranslatedAddress: netip.MustParseAddr("100.96.0.2"),
TranslatedPort: fw.Port{Values: []uint16{80}},
})
require.NoError(t, err, "failed to add DNAT rule")
t.Cleanup(func() {
require.NoError(t, manager.DeleteDNATRule(dnatRule), "failed to delete DNAT rule")
})
stdout, stderr = runIptablesSave(t) stdout, stderr = runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr) verifyIptablesOutput(t, stdout, stderr)
} }
func TestNftablesManagerIPv6CompatibilityWithIp6tables(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
for _, bin := range []string{"ip6tables", "ip6tables-save", "iptables-save"} {
if _, err := exec.LookPath(bin); err != nil {
t.Skipf("%s not available on this system: %v", bin, err)
}
}
// Seed ip6 tables in the nft backend. Docker may not create them.
seedIp6tables(t)
ifaceMockV6 := &iFaceMock{
NameFunc: func() string { return "wt-test" },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: netip.MustParseAddr("100.96.0.1"),
Network: netip.MustParsePrefix("100.96.0.0/16"),
IPv6: netip.MustParseAddr("fd00::1"),
IPv6Net: netip.MustParsePrefix("fd00::/64"),
}
},
}
manager, err := Create(ifaceMockV6, iface.DefaultMTU)
require.NoError(t, err, "create manager")
require.NoError(t, manager.Init(nil))
t.Cleanup(func() {
require.NoError(t, manager.Close(nil), "close manager")
stdout, stderr := runIp6tablesSave(t)
verifyIp6tablesOutput(t, stdout, stderr)
})
ip := netip.MustParseAddr("fd00::2")
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
require.NoError(t, err, "add v6 peer filtering rule")
_, err = manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("fd00:1::/64")},
fw.Network{Prefix: netip.MustParsePrefix("2001:db8::/48")},
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err, "add v6 route filtering rule")
err = manager.AddNatRule(fw.RouterPair{
Source: fw.Network{Prefix: netip.MustParsePrefix("fd00::/64")},
Destination: fw.Network{Prefix: netip.MustParsePrefix("2001:db8::/48")},
Masquerade: true,
})
require.NoError(t, err, "add v6 NAT rule")
dnatRule, err := manager.AddDNATRule(fw.ForwardRule{
Protocol: fw.ProtocolTCP,
DestinationPort: fw.Port{Values: []uint16{8080}},
TranslatedAddress: netip.MustParseAddr("fd00::2"),
TranslatedPort: fw.Port{Values: []uint16{80}},
})
require.NoError(t, err, "add v6 DNAT rule")
t.Cleanup(func() {
require.NoError(t, manager.DeleteDNATRule(dnatRule), "delete v6 DNAT rule")
})
stdout, stderr := runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
stdout, stderr = runIp6tablesSave(t)
verifyIp6tablesOutput(t, stdout, stderr)
}
func seedIp6tables(t *testing.T) {
t.Helper()
for _, tc := range []struct{ table, chain string }{
{"filter", "FORWARD"},
{"nat", "POSTROUTING"},
{"mangle", "FORWARD"},
} {
add := exec.Command("ip6tables", "-t", tc.table, "-A", tc.chain, "-j", "ACCEPT")
require.NoError(t, add.Run(), "seed ip6tables -t %s", tc.table)
del := exec.Command("ip6tables", "-t", tc.table, "-D", tc.chain, "-j", "ACCEPT")
require.NoError(t, del.Run(), "unseed ip6tables -t %s", tc.table)
}
}
func runIp6tablesSave(t *testing.T) (string, string) {
t.Helper()
var stdout, stderr bytes.Buffer
cmd := exec.Command("ip6tables-save")
cmd.Stdout = &stdout
cmd.Stderr = &stderr
require.NoError(t, cmd.Run(), "ip6tables-save failed")
return stdout.String(), stderr.String()
}
func verifyIp6tablesOutput(t *testing.T, stdout, stderr string) {
t.Helper()
require.NotContains(t, stdout, "Table `nat' is incompatible",
"ip6tables-save: nat table incompatible. Full output: %s", stdout)
require.NotContains(t, stdout, "Table `mangle' is incompatible",
"ip6tables-save: mangle table incompatible. Full output: %s", stdout)
require.NotContains(t, stdout, "Table `filter' is incompatible",
"ip6tables-save: filter table incompatible. Full output: %s", stdout)
}
func TestNftablesManagerCompatibilityWithIptablesFor6kPrefixes(t *testing.T) { func TestNftablesManagerCompatibilityWithIptablesFor6kPrefixes(t *testing.T) {
if check() != NFTABLES { if check() != NFTABLES {
t.Skip("nftables not supported on this system") t.Skip("nftables not supported on this system")

View File

@@ -47,8 +47,10 @@ const (
dnatSuffix = "_dnat" dnatSuffix = "_dnat"
snatSuffix = "_snat" snatSuffix = "_snat"
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation // ipv4TCPHeaderSize is the minimum IPv4 (20) + TCP (20) header size for MSS calculation.
ipTCPHeaderMinSize = 40 ipv4TCPHeaderSize = 40
// ipv6TCPHeaderSize is the minimum IPv6 (40) + TCP (20) header size for MSS calculation.
ipv6TCPHeaderSize = 60
// maxPrefixesSet 1638 prefixes start to fail, taking some margin // maxPrefixesSet 1638 prefixes start to fail, taking some margin
maxPrefixesSet = 1500 maxPrefixesSet = 1500
@@ -73,6 +75,7 @@ type router struct {
rules map[string]*nftables.Rule rules map[string]*nftables.Rule
ipsetCounter *refcounter.Counter[string, setInput, *nftables.Set] ipsetCounter *refcounter.Counter[string, setInput, *nftables.Set]
af addrFamily
wgIface iFaceMapper wgIface iFaceMapper
ipFwdState *ipfwdstate.IPForwardingState ipFwdState *ipfwdstate.IPForwardingState
legacyManagement bool legacyManagement bool
@@ -85,6 +88,7 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*rou
workTable: workTable, workTable: workTable,
chains: make(map[string]*nftables.Chain), chains: make(map[string]*nftables.Chain),
rules: make(map[string]*nftables.Rule), rules: make(map[string]*nftables.Rule),
af: familyForAddr(workTable.Family == nftables.TableFamilyIPv4),
wgIface: wgIface, wgIface: wgIface,
ipFwdState: ipfwdstate.NewIPForwardingState(), ipFwdState: ipfwdstate.NewIPForwardingState(),
mtu: mtu, mtu: mtu,
@@ -143,7 +147,7 @@ func (r *router) Reset() error {
func (r *router) removeNatPreroutingRules() error { func (r *router) removeNatPreroutingRules() error {
table := &nftables.Table{ table := &nftables.Table{
Name: tableNat, Name: tableNat,
Family: nftables.TableFamilyIPv4, Family: r.af.tableFamily,
} }
chain := &nftables.Chain{ chain := &nftables.Chain{
Name: chainNameNatPrerouting, Name: chainNameNatPrerouting,
@@ -176,7 +180,7 @@ func (r *router) removeNatPreroutingRules() error {
} }
func (r *router) loadFilterTable() (*nftables.Table, error) { func (r *router) loadFilterTable() (*nftables.Table, error) {
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4) tables, err := r.conn.ListTablesOfFamily(r.af.tableFamily)
if err != nil { if err != nil {
return nil, fmt.Errorf("list tables: %w", err) return nil, fmt.Errorf("list tables: %w", err)
} }
@@ -408,7 +412,7 @@ func (r *router) AddRouteFiltering(
// Handle protocol // Handle protocol
if proto != firewall.ProtocolALL { if proto != firewall.ProtocolALL {
protoNum, err := protoToInt(proto) protoNum, err := r.af.protoNum(proto)
if err != nil { if err != nil {
return nil, fmt.Errorf("convert protocol to number: %w", err) return nil, fmt.Errorf("convert protocol to number: %w", err)
} }
@@ -468,7 +472,24 @@ func (r *router) getIpSet(set firewall.Set, prefixes []netip.Prefix, isSource bo
return nil, fmt.Errorf("create or get ipset: %w", err) return nil, fmt.Errorf("create or get ipset: %w", err)
} }
return getIpSetExprs(ref, isSource) return r.getIpSetExprs(ref, isSource)
}
func (r *router) iptablesProto() iptables.Protocol {
if r.af.tableFamily == nftables.TableFamilyIPv6 {
return iptables.ProtocolIPv6
}
return iptables.ProtocolIPv4
}
func (r *router) hasRule(id string) bool {
_, ok := r.rules[id]
return ok
}
func (r *router) hasDNATRule(id string) bool {
_, ok := r.rules[id+dnatSuffix]
return ok
} }
func (r *router) DeleteRouteRule(rule firewall.Rule) error { func (r *router) DeleteRouteRule(rule firewall.Rule) error {
@@ -517,10 +538,10 @@ func (r *router) createIpSet(setName string, input setInput) (*nftables.Set, err
Table: r.workTable, Table: r.workTable,
// required for prefixes // required for prefixes
Interval: true, Interval: true,
KeyType: nftables.TypeIPAddr, KeyType: r.af.setKeyType,
} }
elements := convertPrefixesToSet(prefixes) elements := r.convertPrefixesToSet(prefixes)
nElements := len(elements) nElements := len(elements)
maxElements := maxPrefixesSet * 2 maxElements := maxPrefixesSet * 2
@@ -553,23 +574,17 @@ func (r *router) createIpSet(setName string, input setInput) (*nftables.Set, err
return nfset, nil return nfset, nil
} }
func convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement { func (r *router) convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement {
var elements []nftables.SetElement var elements []nftables.SetElement
for _, prefix := range prefixes { for _, prefix := range prefixes {
// TODO: Implement IPv6 support
if prefix.Addr().Is6() {
log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
continue
}
// nftables needs half-open intervals [firstIP, lastIP) for prefixes // nftables needs half-open intervals [firstIP, lastIP) for prefixes
// e.g. 10.0.0.0/24 becomes [10.0.0.0, 10.0.1.0), 10.1.1.1/32 becomes [10.1.1.1, 10.1.1.2) etc // e.g. 10.0.0.0/24 becomes [10.0.0.0, 10.0.1.0), 10.1.1.1/32 becomes [10.1.1.1, 10.1.1.2) etc
firstIP := prefix.Addr() firstIP := prefix.Addr()
lastIP := calculateLastIP(prefix).Next() lastIP := calculateLastIP(prefix).Next()
elements = append(elements, elements = append(elements,
// the nft tool also adds a line like this, see https://github.com/google/nftables/issues/247 // the nft tool also adds a zero-address IntervalEnd element, see https://github.com/google/nftables/issues/247
// nftables.SetElement{Key: []byte{0, 0, 0, 0}, IntervalEnd: true}, // nftables.SetElement{Key: make([]byte, r.af.addrLen), IntervalEnd: true},
nftables.SetElement{Key: firstIP.AsSlice()}, nftables.SetElement{Key: firstIP.AsSlice()},
nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true}, nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true},
) )
@@ -579,10 +594,20 @@ func convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement {
// calculateLastIP determines the last IP in a given prefix. // calculateLastIP determines the last IP in a given prefix.
func calculateLastIP(prefix netip.Prefix) netip.Addr { func calculateLastIP(prefix netip.Prefix) netip.Addr {
hostMask := ^uint32(0) >> prefix.Masked().Bits() masked := prefix.Masked()
lastIP := uint32FromNetipAddr(prefix.Addr()) | hostMask if masked.Addr().Is4() {
hostMask := ^uint32(0) >> masked.Bits()
lastIP := uint32FromNetipAddr(masked.Addr()) | hostMask
return netip.AddrFrom4(uint32ToBytes(lastIP))
}
return netip.AddrFrom4(uint32ToBytes(lastIP)) // IPv6: set host bits to all 1s
b := masked.Addr().As16()
bits := masked.Bits()
for i := bits; i < 128; i++ {
b[i/8] |= 1 << (7 - i%8)
}
return netip.AddrFrom16(b)
} }
// Utility function to convert netip.Addr to uint32. // Utility function to convert netip.Addr to uint32.
@@ -834,9 +859,12 @@ func (r *router) addPostroutingRules() {
} }
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic. // addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
// TODO: Add IPv6 support
func (r *router) addMSSClampingRules() error { func (r *router) addMSSClampingRules() error {
mss := r.mtu - ipTCPHeaderMinSize overhead := uint16(ipv4TCPHeaderSize)
if r.af.tableFamily == nftables.TableFamilyIPv6 {
overhead = ipv6TCPHeaderSize
}
mss := r.mtu - overhead
exprsOut := []expr.Any{ exprsOut := []expr.Any{
&expr.Meta{ &expr.Meta{
@@ -1043,17 +1071,22 @@ func (r *router) acceptFilterTableRules() error {
log.Debugf("Used %s to add accept forward and input rules", fw) log.Debugf("Used %s to add accept forward and input rules", fw)
}() }()
// Try iptables first and fallback to nftables if iptables is not available // Try iptables first and fallback to nftables if iptables is not available.
ipt, err := iptables.New() // Use the correct protocol (iptables vs ip6tables) for the address family.
ipt, err := iptables.NewWithProtocol(r.iptablesProto())
if err != nil { if err != nil {
// iptables is not available but the filter table exists
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err) log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
fw = "nftables" fw = "nftables"
return r.acceptFilterRulesNftables(r.filterTable) return r.acceptFilterRulesNftables(r.filterTable)
} }
return r.acceptFilterRulesIptables(ipt) if err := r.acceptFilterRulesIptables(ipt); err != nil {
log.Warnf("iptables failed (table may be incompatible), falling back to nftables: %v", err)
fw = "nftables"
return r.acceptFilterRulesNftables(r.filterTable)
}
return nil
} }
func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error { func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error {
@@ -1222,13 +1255,17 @@ func (r *router) removeFilterTableRules() error {
return nil return nil
} }
ipt, err := iptables.New() ipt, err := iptables.NewWithProtocol(r.iptablesProto())
if err != nil { if err != nil {
log.Debugf("iptables not available, using nftables to remove filter rules: %v", err) log.Debugf("iptables not available, using nftables to remove filter rules: %v", err)
return r.removeAcceptRulesFromTable(r.filterTable) return r.removeAcceptRulesFromTable(r.filterTable)
} }
return r.removeAcceptFilterRulesIptables(ipt) if err := r.removeAcceptFilterRulesIptables(ipt); err != nil {
log.Debugf("iptables removal failed (table may be incompatible), falling back to nftables: %v", err)
return r.removeAcceptRulesFromTable(r.filterTable)
}
return nil
} }
func (r *router) removeAcceptRulesFromTable(table *nftables.Table) error { func (r *router) removeAcceptRulesFromTable(table *nftables.Table) error {
@@ -1295,7 +1332,7 @@ func (r *router) removeExternalChainsRules() error {
func (r *router) findExternalChains() []*nftables.Chain { func (r *router) findExternalChains() []*nftables.Chain {
var chains []*nftables.Chain var chains []*nftables.Chain
families := []nftables.TableFamily{nftables.TableFamilyIPv4, nftables.TableFamilyINet} families := []nftables.TableFamily{r.af.tableFamily, nftables.TableFamilyINet}
for _, family := range families { for _, family := range families {
allChains, err := r.conn.ListChainsOfTableFamily(family) allChains, err := r.conn.ListChainsOfTableFamily(family)
@@ -1319,8 +1356,8 @@ func (r *router) isExternalChain(chain *nftables.Chain) bool {
return false return false
} }
// Skip all iptables-managed tables in the ip family // Skip iptables/ip6tables-managed tables (adding nft-native rules breaks iptables-save compat)
if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) { if (chain.Table.Family == nftables.TableFamilyIPv4 || chain.Table.Family == nftables.TableFamilyIPv6) && isIptablesTable(chain.Table.Name) {
return false return false
} }
@@ -1461,7 +1498,7 @@ func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
return rule, nil return rule, nil
} }
protoNum, err := protoToInt(rule.Protocol) protoNum, err := r.af.protoNum(rule.Protocol)
if err != nil { if err != nil {
return nil, fmt.Errorf("convert protocol to number: %w", err) return nil, fmt.Errorf("convert protocol to number: %w", err)
} }
@@ -1524,7 +1561,7 @@ func (r *router) addDnatRedirect(rule firewall.ForwardRule, protoNum uint8, rule
dnatExprs = append(dnatExprs, dnatExprs = append(dnatExprs,
&expr.NAT{ &expr.NAT{
Type: expr.NATTypeDestNAT, Type: expr.NATTypeDestNAT,
Family: uint32(nftables.TableFamilyIPv4), Family: uint32(r.af.tableFamily),
RegAddrMin: 1, RegAddrMin: 1,
RegProtoMin: regProtoMin, RegProtoMin: regProtoMin,
RegProtoMax: regProtoMax, RegProtoMax: regProtoMax,
@@ -1620,7 +1657,7 @@ func (r *router) addXTablesRedirect(dnatExprs []expr.Any, ruleKey string, rule f
dnatRule := &nftables.Rule{ dnatRule := &nftables.Rule{
Table: &nftables.Table{ Table: &nftables.Table{
Name: tableNat, Name: tableNat,
Family: nftables.TableFamilyIPv4, Family: r.af.tableFamily,
}, },
Chain: &nftables.Chain{ Chain: &nftables.Chain{
Name: chainNameNatPrerouting, Name: chainNameNatPrerouting,
@@ -1655,8 +1692,8 @@ func (r *router) addDnatMasq(rule firewall.ForwardRule, protoNum uint8, ruleKey
&expr.Payload{ &expr.Payload{
DestRegister: 1, DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader, Base: expr.PayloadBaseNetworkHeader,
Offset: 16, Offset: r.af.dstAddrOffset,
Len: 4, Len: r.af.addrLen,
}, },
&expr.Cmp{ &expr.Cmp{
Op: expr.CmpOpEq, Op: expr.CmpOpEq,
@@ -1734,7 +1771,7 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
return fmt.Errorf("get set %s: %w", set.HashedName(), err) return fmt.Errorf("get set %s: %w", set.HashedName(), err)
} }
elements := convertPrefixesToSet(prefixes) elements := r.convertPrefixesToSet(prefixes)
if err := r.conn.SetAddElements(nfset, elements); err != nil { if err := r.conn.SetAddElements(nfset, elements); err != nil {
return fmt.Errorf("add elements to set %s: %w", set.HashedName(), err) return fmt.Errorf("add elements to set %s: %w", set.HashedName(), err)
} }
@@ -1756,7 +1793,7 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol
return nil return nil
} }
protoNum, err := protoToInt(protocol) protoNum, err := r.af.protoNum(protocol)
if err != nil { if err != nil {
return fmt.Errorf("convert protocol to number: %w", err) return fmt.Errorf("convert protocol to number: %w", err)
} }
@@ -1787,7 +1824,11 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol
}, },
} }
exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...) bits := 32
if localAddr.Is6() {
bits = 128
}
exprs = append(exprs, r.applyPrefix(netip.PrefixFrom(localAddr, bits), false)...)
exprs = append(exprs, exprs = append(exprs,
&expr.Immediate{ &expr.Immediate{
@@ -1800,7 +1841,7 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol
}, },
&expr.NAT{ &expr.NAT{
Type: expr.NATTypeDestNAT, Type: expr.NATTypeDestNAT,
Family: uint32(nftables.TableFamilyIPv4), Family: uint32(r.af.tableFamily),
RegAddrMin: 1, RegAddrMin: 1,
RegProtoMin: 2, RegProtoMin: 2,
RegProtoMax: 0, RegProtoMax: 0,
@@ -1887,7 +1928,7 @@ func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol,
return err return err
} }
protoNum, err := protoToInt(protocol) protoNum, err := r.af.protoNum(protocol)
if err != nil { if err != nil {
return fmt.Errorf("convert protocol to number: %w", err) return fmt.Errorf("convert protocol to number: %w", err)
} }
@@ -1912,7 +1953,11 @@ func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol,
}, },
} }
exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...) bits := 32
if localAddr.Is6() {
bits = 128
}
exprs = append(exprs, r.applyPrefix(netip.PrefixFrom(localAddr, bits), false)...)
exprs = append(exprs, exprs = append(exprs,
&expr.Immediate{ &expr.Immediate{
@@ -1925,7 +1970,7 @@ func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol,
}, },
&expr.NAT{ &expr.NAT{
Type: expr.NATTypeDestNAT, Type: expr.NATTypeDestNAT,
Family: uint32(nftables.TableFamilyIPv4), Family: uint32(r.af.tableFamily),
RegAddrMin: 1, RegAddrMin: 1,
RegProtoMin: 2, RegProtoMin: 2,
}, },
@@ -1993,45 +2038,44 @@ func (r *router) applyNetwork(
} }
if network.IsPrefix() { if network.IsPrefix() {
return applyPrefix(network.Prefix, isSource), nil return r.applyPrefix(network.Prefix, isSource), nil
} }
return nil, nil return nil, nil
} }
// applyPrefix generates nftables expressions for a CIDR prefix // applyPrefix generates nftables expressions for a CIDR prefix
func applyPrefix(prefix netip.Prefix, isSource bool) []expr.Any { func (r *router) applyPrefix(prefix netip.Prefix, isSource bool) []expr.Any {
// dst offset // dst offset by default
offset := uint32(16) offset := r.af.dstAddrOffset
if isSource { if isSource {
// src offset // src offset
offset = 12 offset = r.af.srcAddrOffset
} }
ones := prefix.Bits() ones := prefix.Bits()
// 0.0.0.0/0 doesn't need extra expressions // unspecified address (/0) doesn't need extra expressions
if ones == 0 { if ones == 0 {
return nil return nil
} }
mask := net.CIDRMask(ones, 32) mask := net.CIDRMask(ones, r.af.totalBits)
xor := make([]byte, r.af.addrLen)
return []expr.Any{ return []expr.Any{
&expr.Payload{ &expr.Payload{
DestRegister: 1, DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader, Base: expr.PayloadBaseNetworkHeader,
Offset: offset, Offset: offset,
Len: 4, Len: r.af.addrLen,
}, },
// netmask
&expr.Bitwise{ &expr.Bitwise{
DestRegister: 1, DestRegister: 1,
SourceRegister: 1, SourceRegister: 1,
Len: 4, Len: r.af.addrLen,
Mask: mask, Mask: mask,
Xor: []byte{0, 0, 0, 0}, Xor: xor,
}, },
// net address
&expr.Cmp{ &expr.Cmp{
Op: expr.CmpOpEq, Op: expr.CmpOpEq,
Register: 1, Register: 1,
@@ -2114,13 +2158,12 @@ func getCtNewExprs() []expr.Any {
} }
} }
func getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any, error) { func (r *router) getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any, error) {
// dst offset by default
// dst offset offset := r.af.dstAddrOffset
offset := uint32(16)
if isSource { if isSource {
// src offset // src offset
offset = 12 offset = r.af.srcAddrOffset
} }
return []expr.Any{ return []expr.Any{
@@ -2128,7 +2171,7 @@ func getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any
DestRegister: 1, DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader, Base: expr.PayloadBaseNetworkHeader,
Offset: offset, Offset: offset,
Len: 4, Len: r.af.addrLen,
}, },
&expr.Lookup{ &expr.Lookup{
SourceRegister: 1, SourceRegister: 1,

View File

@@ -90,8 +90,9 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
} }
// Build CIDR matching expressions // Build CIDR matching expressions
sourceExp := applyPrefix(testCase.InputPair.Source.Prefix, true) testRouter := &router{af: afIPv4}
destExp := applyPrefix(testCase.InputPair.Destination.Prefix, false) sourceExp := testRouter.applyPrefix(testCase.InputPair.Source.Prefix, true)
destExp := testRouter.applyPrefix(testCase.InputPair.Destination.Prefix, false)
// Combine all expressions in the correct order // Combine all expressions in the correct order
// nolint:gocritic // nolint:gocritic
@@ -508,6 +509,136 @@ func TestNftablesCreateIpSet(t *testing.T) {
} }
} }
func TestNftablesCreateIpSet_IPv6(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
workTable, err := createWorkTableIPv6()
require.NoError(t, err, "Failed to create v6 work table")
defer deleteWorkTableIPv6()
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "Failed to create router")
require.NoError(t, r.init(workTable))
defer func() {
require.NoError(t, r.Reset(), "Failed to reset router")
}()
tests := []struct {
name string
sources []netip.Prefix
expected []netip.Prefix
}{
{
name: "Single IPv6",
sources: []netip.Prefix{netip.MustParsePrefix("2001:db8::1/128")},
},
{
name: "Multiple IPv6 Subnets",
sources: []netip.Prefix{
netip.MustParsePrefix("fd00::/64"),
netip.MustParsePrefix("2001:db8::/48"),
netip.MustParsePrefix("fe80::/10"),
},
},
{
name: "Overlapping IPv6",
sources: []netip.Prefix{
netip.MustParsePrefix("fd00::/48"),
netip.MustParsePrefix("fd00::/64"),
netip.MustParsePrefix("fd00::1/128"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("fd00::/48"),
},
},
{
name: "Mixed prefix lengths",
sources: []netip.Prefix{
netip.MustParsePrefix("2001:db8:1::/48"),
netip.MustParsePrefix("2001:db8:2::1/128"),
netip.MustParsePrefix("fd00:abcd::/32"),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
setName := firewall.NewPrefixSet(tt.sources).HashedName()
set, err := r.createIpSet(setName, setInput{prefixes: tt.sources})
require.NoError(t, err, "Failed to create IPv6 set")
require.NotNil(t, set)
assert.Equal(t, setName, set.Name)
assert.True(t, set.Interval)
assert.Equal(t, nftables.TypeIP6Addr, set.KeyType)
fetchedSet, err := r.conn.GetSetByName(r.workTable, setName)
require.NoError(t, err, "Failed to fetch created set")
elements, err := r.conn.GetSetElements(fetchedSet)
require.NoError(t, err, "Failed to get set elements")
uniquePrefixes := make(map[string]bool)
for _, elem := range elements {
if !elem.IntervalEnd && len(elem.Key) == 16 {
ip := netip.AddrFrom16([16]byte(elem.Key))
uniquePrefixes[ip.String()] = true
}
}
expectedCount := len(tt.expected)
if expectedCount == 0 {
expectedCount = len(tt.sources)
}
assert.Equal(t, expectedCount, len(uniquePrefixes), "unique prefix count mismatch")
r.conn.DelSet(set)
require.NoError(t, r.conn.Flush())
})
}
}
func createWorkTableIPv6() (*nftables.Table, error) {
sConn, err := nftables.New(nftables.AsLasting())
if err != nil {
return nil, err
}
tables, err := sConn.ListTablesOfFamily(nftables.TableFamilyIPv6)
if err != nil {
return nil, err
}
for _, t := range tables {
if t.Name == tableNameNetbird {
sConn.DelTable(t)
}
}
table := sConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv6})
err = sConn.Flush()
return table, err
}
func deleteWorkTableIPv6() {
sConn, err := nftables.New(nftables.AsLasting())
if err != nil {
return
}
tables, err := sConn.ListTablesOfFamily(nftables.TableFamilyIPv6)
if err != nil {
return
}
for _, t := range tables {
if t.Name == tableNameNetbird {
sConn.DelTable(t)
_ = sConn.Flush()
}
}
}
func verifyRule(t *testing.T, rule *nftables.Rule, sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, expectSet bool) { func verifyRule(t *testing.T, rule *nftables.Rule, sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, expectSet bool) {
t.Helper() t.Helper()
@@ -627,7 +758,7 @@ func containsPort(exprs []expr.Any, port *firewall.Port, isSource bool) bool {
func containsProtocol(exprs []expr.Any, proto firewall.Protocol) bool { func containsProtocol(exprs []expr.Any, proto firewall.Protocol) bool {
var metaFound, cmpFound bool var metaFound, cmpFound bool
expectedProto, _ := protoToInt(proto) expectedProto, _ := afIPv4.protoNum(proto)
for _, e := range exprs { for _, e := range exprs {
switch ex := e.(type) { switch ex := e.(type) {
case *expr.Meta: case *expr.Meta:
@@ -854,3 +985,55 @@ func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
} }
assert.Equal(t, 1, found, "NAT rule should exist in kernel") assert.Equal(t, 1, found, "NAT rule should exist in kernel")
} }
func TestCalculateLastIP(t *testing.T) {
tests := []struct {
prefix string
want string
}{
{"10.0.0.0/24", "10.0.0.255"},
{"10.0.0.0/32", "10.0.0.0"},
{"0.0.0.0/0", "255.255.255.255"},
{"192.168.1.0/28", "192.168.1.15"},
{"fd00::/64", "fd00::ffff:ffff:ffff:ffff"},
{"fd00::/128", "fd00::"},
{"2001:db8::/48", "2001:db8:0:ffff:ffff:ffff:ffff:ffff"},
{"::/0", "ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"},
}
for _, tt := range tests {
t.Run(tt.prefix, func(t *testing.T) {
prefix := netip.MustParsePrefix(tt.prefix)
got := calculateLastIP(prefix)
assert.Equal(t, tt.want, got.String())
})
}
}
func TestConvertPrefixesToSet_IPv6(t *testing.T) {
r := &router{af: afIPv6}
prefixes := []netip.Prefix{
netip.MustParsePrefix("fd00::/64"),
netip.MustParsePrefix("2001:db8::1/128"),
}
elements := r.convertPrefixesToSet(prefixes)
// Each prefix produces 2 elements (start + end)
require.Len(t, elements, 4)
// fd00::/64 start
assert.Equal(t, netip.MustParseAddr("fd00::").As16(), [16]byte(elements[0].Key))
assert.False(t, elements[0].IntervalEnd)
// fd00::/64 end (fd00:0:0:1::, one past the last)
assert.Equal(t, netip.MustParseAddr("fd00:0:0:1::").As16(), [16]byte(elements[1].Key))
assert.True(t, elements[1].IntervalEnd)
// 2001:db8::1/128 start
assert.Equal(t, netip.MustParseAddr("2001:db8::1").As16(), [16]byte(elements[2].Key))
assert.False(t, elements[2].IntervalEnd)
// 2001:db8::1/128 end (2001:db8::2)
assert.Equal(t, netip.MustParseAddr("2001:db8::2").As16(), [16]byte(elements[3].Key))
assert.True(t, elements[3].IntervalEnd)
}

View File

@@ -5,8 +5,10 @@ import (
"os/exec" "os/exec"
"syscall" "syscall"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
@@ -29,15 +31,20 @@ func (m *Manager) Close(*statemanager.Manager) error {
return nil return nil
} }
if !isFirewallRuleActive(firewallRuleName) { var merr *multierror.Error
return nil if isFirewallRuleActive(firewallRuleName) {
if err := manageFirewallRule(firewallRuleName, deleteRule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove windows firewall rule: %w", err))
}
} }
if err := manageFirewallRule(firewallRuleName, deleteRule); err != nil { if isFirewallRuleActive(firewallRuleName + "-v6") {
return fmt.Errorf("couldn't remove windows firewall: %w", err) if err := manageFirewallRule(firewallRuleName+"-v6", deleteRule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove windows v6 firewall rule: %w", err))
}
} }
return nil return nberrors.FormatErrorOrNil(merr)
} }
// AllowNetbird allows netbird interface traffic // AllowNetbird allows netbird interface traffic
@@ -46,17 +53,33 @@ func (m *Manager) AllowNetbird() error {
return nil return nil
} }
if isFirewallRuleActive(firewallRuleName) { if !isFirewallRuleActive(firewallRuleName) {
return nil if err := manageFirewallRule(firewallRuleName,
addRule,
"dir=in",
"enable=yes",
"action=allow",
"profile=any",
"localip="+m.wgIface.Address().IP.String(),
); err != nil {
return err
}
} }
return manageFirewallRule(firewallRuleName,
addRule, if v6 := m.wgIface.Address().IPv6; v6.IsValid() && !isFirewallRuleActive(firewallRuleName+"-v6") {
"dir=in", if err := manageFirewallRule(firewallRuleName+"-v6",
"enable=yes", addRule,
"action=allow", "dir=in",
"profile=any", "enable=yes",
"localip="+m.wgIface.Address().IP.String(), "action=allow",
) "profile=any",
"localip="+v6.String(),
); err != nil {
return err
}
}
return nil
} }
func manageFirewallRule(ruleName string, action action, extraArgs ...string) error { func manageFirewallRule(ruleName string, action action, extraArgs ...string) error {

View File

@@ -1,8 +1,9 @@
package conntrack package conntrack
import ( import (
"fmt" "net"
"net/netip" "net/netip"
"strconv"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -64,5 +65,7 @@ type ConnKey struct {
} }
func (c ConnKey) String() string { func (c ConnKey) String() string {
return fmt.Sprintf("%s:%d → %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort) return net.JoinHostPort(c.SrcIP.Unmap().String(), strconv.Itoa(int(c.SrcPort))) +
" → " +
net.JoinHostPort(c.DstIP.Unmap().String(), strconv.Itoa(int(c.DstPort)))
} }

View File

@@ -21,9 +21,10 @@ const (
// ICMPCleanupInterval is how often we check for stale ICMP connections // ICMPCleanupInterval is how often we check for stale ICMP connections
ICMPCleanupInterval = 15 * time.Second ICMPCleanupInterval = 15 * time.Second
// MaxICMPPayloadLength is the maximum length of ICMP payload we consider for original packet info, // MaxICMPPayloadLength is the maximum length of ICMP payload we consider for original packet info.
// which includes the IP header (20 bytes) and transport header (8 bytes) // IPv4: 20-byte header + 8-byte transport = 28 bytes.
MaxICMPPayloadLength = 28 // IPv6: 40-byte header + 8-byte transport = 48 bytes.
MaxICMPPayloadLength = 48
) )
// ICMPConnKey uniquely identifies an ICMP connection // ICMPConnKey uniquely identifies an ICMP connection
@@ -74,32 +75,64 @@ func (info ICMPInfo) String() string {
return info.TypeCode.String() return info.TypeCode.String()
} }
// isErrorMessage returns true if this ICMP type carries original packet info // isErrorMessage returns true if this ICMP type carries original packet info.
// Covers both ICMPv4 and ICMPv6 error types. Without a family field we match
// both sets; type 3 overlaps (v4 DestUnreachable / v6 TimeExceeded) so it's
// kept as a literal.
func (info ICMPInfo) isErrorMessage() bool { func (info ICMPInfo) isErrorMessage() bool {
typ := info.TypeCode.Type() typ := info.TypeCode.Type()
return typ == 3 || // Destination Unreachable // ICMPv4 error types
typ == 5 || // Redirect if typ == layers.ICMPv4TypeDestinationUnreachable ||
typ == 11 || // Time Exceeded typ == layers.ICMPv4TypeRedirect ||
typ == 12 // Parameter Problem typ == layers.ICMPv4TypeTimeExceeded ||
typ == layers.ICMPv4TypeParameterProblem {
return true
}
// ICMPv6 error types (type 3 already matched above as v4 DestUnreachable)
if typ == layers.ICMPv6TypeDestinationUnreachable ||
typ == layers.ICMPv6TypePacketTooBig ||
typ == layers.ICMPv6TypeParameterProblem {
return true
}
return false
} }
// parseOriginalPacket extracts info about the original packet from ICMP payload // parseOriginalPacket extracts info about the original packet from ICMP payload
func (info ICMPInfo) parseOriginalPacket() string { func (info ICMPInfo) parseOriginalPacket() string {
if info.PayloadLen < MaxICMPPayloadLength { if info.PayloadLen == 0 {
return "" return ""
} }
// TODO: handle IPv6 version := (info.PayloadData[0] >> 4) & 0xF
if version := (info.PayloadData[0] >> 4) & 0xF; version != 4 {
var protocol uint8
var srcIP, dstIP net.IP
var transportData []byte
switch version {
case 4:
// 20-byte IPv4 header + 8-byte transport minimum
if info.PayloadLen < 28 {
return ""
}
protocol = info.PayloadData[9]
srcIP = net.IP(info.PayloadData[12:16])
dstIP = net.IP(info.PayloadData[16:20])
transportData = info.PayloadData[20:]
case 6:
// 40-byte IPv6 header + 8-byte transport minimum
if info.PayloadLen < 48 {
return ""
}
// Next Header field in IPv6 header
protocol = info.PayloadData[6]
srcIP = net.IP(info.PayloadData[8:24])
dstIP = net.IP(info.PayloadData[24:40])
transportData = info.PayloadData[40:]
default:
return "" return ""
} }
protocol := info.PayloadData[9]
srcIP := net.IP(info.PayloadData[12:16])
dstIP := net.IP(info.PayloadData[16:20])
transportData := info.PayloadData[20:]
switch nftypes.Protocol(protocol) { switch nftypes.Protocol(protocol) {
case nftypes.TCP: case nftypes.TCP:
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1]) srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
@@ -247,9 +280,10 @@ func (t *ICMPTracker) track(
t.sendEvent(nftypes.TypeStart, conn, ruleId) t.sendEvent(nftypes.TypeStart, conn, ruleId)
} }
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request // IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request.
// Accepts both ICMPv4 (type 0) and ICMPv6 (type 129) echo replies.
func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, icmpType uint8, size int) bool { func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, icmpType uint8, size int) bool {
if icmpType != uint8(layers.ICMPv4TypeEchoReply) { if icmpType != uint8(layers.ICMPv4TypeEchoReply) && icmpType != uint8(layers.ICMPv6TypeEchoReply) {
return false return false
} }
@@ -301,6 +335,13 @@ func (t *ICMPTracker) cleanup() {
} }
} }
func icmpProtocolForAddr(ip netip.Addr) nftypes.Protocol {
if ip.Is6() {
return nftypes.ICMPv6
}
return nftypes.ICMP
}
// Close stops the cleanup routine and releases resources // Close stops the cleanup routine and releases resources
func (t *ICMPTracker) Close() { func (t *ICMPTracker) Close() {
t.tickerCancel() t.tickerCancel()
@@ -316,7 +357,7 @@ func (t *ICMPTracker) sendEvent(typ nftypes.Type, conn *ICMPConnTrack, ruleID []
Type: typ, Type: typ,
RuleID: ruleID, RuleID: ruleID,
Direction: conn.Direction, Direction: conn.Direction,
Protocol: nftypes.ICMP, // TODO: adjust for IPv6/icmpv6 Protocol: icmpProtocolForAddr(conn.SourceIP),
SourceIP: conn.SourceIP, SourceIP: conn.SourceIP,
DestIP: conn.DestIP, DestIP: conn.DestIP,
ICMPType: conn.ICMPType, ICMPType: conn.ICMPType,
@@ -334,7 +375,7 @@ func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, srcIP netip.Ad
Type: nftypes.TypeStart, Type: nftypes.TypeStart,
RuleID: ruleID, RuleID: ruleID,
Direction: direction, Direction: direction,
Protocol: nftypes.ICMP, Protocol: icmpProtocolForAddr(srcIP),
SourceIP: srcIP, SourceIP: srcIP,
DestIP: dstIP, DestIP: dstIP,
ICMPType: typ, ICMPType: typ,

View File

@@ -35,8 +35,10 @@ import (
const ( const (
layerTypeAll = 255 layerTypeAll = 255
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation // ipv4TCPHeaderMinSize represents minimum IPv4 (20) + TCP (20) header size for MSS calculation
ipTCPHeaderMinSize = 40 ipv4TCPHeaderMinSize = 40
// ipv6TCPHeaderMinSize represents minimum IPv6 (40) + TCP (20) header size for MSS calculation
ipv6TCPHeaderMinSize = 60
) )
// serviceKey represents a protocol/port combination for netstack service registry // serviceKey represents a protocol/port combination for netstack service registry
@@ -137,9 +139,10 @@ type Manager struct {
netstackServices map[serviceKey]struct{} netstackServices map[serviceKey]struct{}
netstackServiceMutex sync.RWMutex netstackServiceMutex sync.RWMutex
mtu uint16 mtu uint16
mssClampValue uint16 mssClampValueIPv4 uint16
mssClampEnabled bool mssClampValueIPv6 uint16
mssClampEnabled bool
// Only one hook per protocol is supported. Outbound direction only. // Only one hook per protocol is supported. Outbound direction only.
udpHookOut atomic.Pointer[packetHook] udpHookOut atomic.Pointer[packetHook]
@@ -163,11 +166,28 @@ type decoder struct {
icmp4 layers.ICMPv4 icmp4 layers.ICMPv4
icmp6 layers.ICMPv6 icmp6 layers.ICMPv6
decoded []gopacket.LayerType decoded []gopacket.LayerType
parser *gopacket.DecodingLayerParser parser4 *gopacket.DecodingLayerParser
parser6 *gopacket.DecodingLayerParser
dnatOrigPort uint16 dnatOrigPort uint16
} }
// decodePacket decodes packet data using the appropriate parser based on IP version.
func (d *decoder) decodePacket(data []byte) error {
if len(data) == 0 {
return errors.New("empty packet")
}
version := data[0] >> 4
switch version {
case 4:
return d.parser4.DecodeLayers(data, &d.decoded)
case 6:
return d.parser6.DecodeLayers(data, &d.decoded)
default:
return fmt.Errorf("unknown IP version %d", version)
}
}
// Create userspace firewall manager constructor // Create userspace firewall manager constructor
func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) { func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
return create(iface, nil, disableServerRoutes, flowLogger, mtu) return create(iface, nil, disableServerRoutes, flowLogger, mtu)
@@ -225,11 +245,17 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
d := &decoder{ d := &decoder{
decoded: []gopacket.LayerType{}, decoded: []gopacket.LayerType{},
} }
d.parser = gopacket.NewDecodingLayerParser( d.parser4 = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4, layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
) )
d.parser.IgnoreUnsupported = true d.parser4.IgnoreUnsupported = true
d.parser6 = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv6,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser6.IgnoreUnsupported = true
return d return d
}, },
}, },
@@ -255,7 +281,8 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
if !disableMSSClamping { if !disableMSSClamping {
m.mssClampEnabled = true m.mssClampEnabled = true
m.mssClampValue = mtu - ipTCPHeaderMinSize m.mssClampValueIPv4 = mtu - ipv4TCPHeaderMinSize
m.mssClampValueIPv6 = mtu - ipv6TCPHeaderMinSize
} }
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil { if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
return nil, fmt.Errorf("update local IPs: %w", err) return nil, fmt.Errorf("update local IPs: %w", err)
@@ -282,9 +309,14 @@ func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) (firewall.Rule, e
wgPrefix := iface.Address().Network wgPrefix := iface.Address().Network
log.Debugf("blocking invalid routed traffic for %s", wgPrefix) log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
sources := []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
if v6 := iface.Address().IPv6Net; v6.IsValid() {
sources = append(sources, netip.PrefixFrom(netip.IPv6Unspecified(), 0))
}
rule, err := m.addRouteFiltering( rule, err := m.addRouteFiltering(
nil, nil,
[]netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}, sources,
firewall.Network{Prefix: wgPrefix}, firewall.Network{Prefix: wgPrefix},
firewall.ProtocolALL, firewall.ProtocolALL,
nil, nil,
@@ -292,7 +324,22 @@ func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) (firewall.Rule, e
firewall.ActionDrop, firewall.ActionDrop,
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("block wg nte : %w", err) return nil, fmt.Errorf("block wg v4 net: %w", err)
}
if v6Net := iface.Address().IPv6Net; v6Net.IsValid() {
log.Debugf("blocking invalid routed traffic for %s", v6Net)
if _, err := m.addRouteFiltering(
nil,
sources,
firewall.Network{Prefix: v6Net},
firewall.ProtocolALL,
nil,
nil,
firewall.ActionDrop,
); err != nil {
return nil, fmt.Errorf("block wg v6 net: %w", err)
}
} }
// TODO: Block networks that we're a client of // TODO: Block networks that we're a client of
@@ -509,7 +556,7 @@ func (m *Manager) addRouteFiltering(
mgmtId: id, mgmtId: id,
sources: sources, sources: sources,
dstSet: destination.Set, dstSet: destination.Set,
protoLayer: protoToLayer(proto, layers.LayerTypeIPv4), protoLayer: protoToLayer(proto, ipLayerFromPrefix(destination.Prefix)),
srcPort: sPort, srcPort: sPort,
dstPort: dPort, dstPort: dPort,
action: action, action: action,
@@ -663,11 +710,7 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
} }
destinations := matches[0].destinations destinations := matches[0].destinations
for _, prefix := range prefixes { destinations = append(destinations, prefixes...)
if prefix.Addr().Is4() {
destinations = append(destinations, prefix)
}
}
slices.SortFunc(destinations, func(a, b netip.Prefix) int { slices.SortFunc(destinations, func(a, b netip.Prefix) int {
cmp := a.Addr().Compare(b.Addr()) cmp := a.Addr().Compare(b.Addr())
@@ -706,7 +749,7 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
d := m.decoders.Get().(*decoder) d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d) defer m.decoders.Put(d)
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { if err := d.decodePacket(packetData); err != nil {
return false return false
} }
@@ -790,12 +833,28 @@ func (m *Manager) clampTCPMSS(packetData []byte, d *decoder) bool {
return false return false
} }
var mssClampValue uint16
var ipHeaderSize int
switch d.decoded[0] {
case layers.LayerTypeIPv4:
mssClampValue = m.mssClampValueIPv4
ipHeaderSize = int(d.ip4.IHL) * 4
if ipHeaderSize < 20 {
return false
}
case layers.LayerTypeIPv6:
mssClampValue = m.mssClampValueIPv6
ipHeaderSize = 40
default:
return false
}
mssOptionIndex := -1 mssOptionIndex := -1
var currentMSS uint16 var currentMSS uint16
for i, opt := range d.tcp.Options { for i, opt := range d.tcp.Options {
if opt.OptionType == layers.TCPOptionKindMSS && len(opt.OptionData) == 2 { if opt.OptionType == layers.TCPOptionKindMSS && len(opt.OptionData) == 2 {
currentMSS = binary.BigEndian.Uint16(opt.OptionData) currentMSS = binary.BigEndian.Uint16(opt.OptionData)
if currentMSS > m.mssClampValue { if currentMSS > mssClampValue {
mssOptionIndex = i mssOptionIndex = i
break break
} }
@@ -806,20 +865,15 @@ func (m *Manager) clampTCPMSS(packetData []byte, d *decoder) bool {
return false return false
} }
ipHeaderSize := int(d.ip4.IHL) * 4 if !m.updateMSSOption(packetData, d, mssOptionIndex, mssClampValue, ipHeaderSize) {
if ipHeaderSize < 20 {
return false return false
} }
if !m.updateMSSOption(packetData, d, mssOptionIndex, ipHeaderSize) { m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, mssClampValue)
return false
}
m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, m.mssClampValue)
return true return true
} }
func (m *Manager) updateMSSOption(packetData []byte, d *decoder, mssOptionIndex, ipHeaderSize int) bool { func (m *Manager) updateMSSOption(packetData []byte, d *decoder, mssOptionIndex int, mssClampValue uint16, ipHeaderSize int) bool {
tcpHeaderStart := ipHeaderSize tcpHeaderStart := ipHeaderSize
tcpOptionsStart := tcpHeaderStart + 20 tcpOptionsStart := tcpHeaderStart + 20
@@ -834,7 +888,7 @@ func (m *Manager) updateMSSOption(packetData []byte, d *decoder, mssOptionIndex,
} }
mssValueOffset := optOffset + 2 mssValueOffset := optOffset + 2
binary.BigEndian.PutUint16(packetData[mssValueOffset:mssValueOffset+2], m.mssClampValue) binary.BigEndian.PutUint16(packetData[mssValueOffset:mssValueOffset+2], mssClampValue)
m.recalculateTCPChecksum(packetData, d, tcpHeaderStart) m.recalculateTCPChecksum(packetData, d, tcpHeaderStart)
return true return true
@@ -844,18 +898,32 @@ func (m *Manager) recalculateTCPChecksum(packetData []byte, d *decoder, tcpHeade
tcpLayer := packetData[tcpHeaderStart:] tcpLayer := packetData[tcpHeaderStart:]
tcpLength := len(packetData) - tcpHeaderStart tcpLength := len(packetData) - tcpHeaderStart
// Zero out existing checksum
tcpLayer[16] = 0 tcpLayer[16] = 0
tcpLayer[17] = 0 tcpLayer[17] = 0
// Build pseudo-header checksum based on IP version
var pseudoSum uint32 var pseudoSum uint32
pseudoSum += uint32(d.ip4.SrcIP[0])<<8 | uint32(d.ip4.SrcIP[1]) switch d.decoded[0] {
pseudoSum += uint32(d.ip4.SrcIP[2])<<8 | uint32(d.ip4.SrcIP[3]) case layers.LayerTypeIPv4:
pseudoSum += uint32(d.ip4.DstIP[0])<<8 | uint32(d.ip4.DstIP[1]) pseudoSum += uint32(d.ip4.SrcIP[0])<<8 | uint32(d.ip4.SrcIP[1])
pseudoSum += uint32(d.ip4.DstIP[2])<<8 | uint32(d.ip4.DstIP[3]) pseudoSum += uint32(d.ip4.SrcIP[2])<<8 | uint32(d.ip4.SrcIP[3])
pseudoSum += uint32(d.ip4.Protocol) pseudoSum += uint32(d.ip4.DstIP[0])<<8 | uint32(d.ip4.DstIP[1])
pseudoSum += uint32(tcpLength) pseudoSum += uint32(d.ip4.DstIP[2])<<8 | uint32(d.ip4.DstIP[3])
pseudoSum += uint32(d.ip4.Protocol)
pseudoSum += uint32(tcpLength)
case layers.LayerTypeIPv6:
for i := 0; i < 16; i += 2 {
pseudoSum += uint32(d.ip6.SrcIP[i])<<8 | uint32(d.ip6.SrcIP[i+1])
}
for i := 0; i < 16; i += 2 {
pseudoSum += uint32(d.ip6.DstIP[i])<<8 | uint32(d.ip6.DstIP[i+1])
}
pseudoSum += uint32(tcpLength)
pseudoSum += uint32(layers.IPProtocolTCP)
}
var sum = pseudoSum sum := pseudoSum
for i := 0; i < tcpLength-1; i += 2 { for i := 0; i < tcpLength-1; i += 2 {
sum += uint32(tcpLayer[i])<<8 | uint32(tcpLayer[i+1]) sum += uint32(tcpLayer[i])<<8 | uint32(tcpLayer[i+1])
} }
@@ -893,6 +961,9 @@ func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, packetData
} }
case layers.LayerTypeICMPv4: case layers.LayerTypeICMPv4:
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size) m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size)
case layers.LayerTypeICMPv6:
id, tc := icmpv6EchoFields(d)
m.icmpTracker.TrackOutbound(srcIP, dstIP, id, tc, d.icmp6.Payload, size)
} }
} }
@@ -906,6 +977,9 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size, d.dnatOrigPort) m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size, d.dnatOrigPort)
case layers.LayerTypeICMPv4: case layers.LayerTypeICMPv4:
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size) m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size)
case layers.LayerTypeICMPv6:
id, tc := icmpv6EchoFields(d)
m.icmpTracker.TrackInbound(srcIP, dstIP, id, tc, ruleID, d.icmp6.Payload, size)
} }
d.dnatOrigPort = 0 d.dnatOrigPort = 0
@@ -948,15 +1022,19 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
// TODO: pass fragments of routed packets to forwarder // TODO: pass fragments of routed packets to forwarder
if fragment { if fragment {
m.logger.Trace4("packet is a fragment: src=%v dst=%v id=%v flags=%v", if d.decoded[0] == layers.LayerTypeIPv4 {
srcIP, dstIP, d.ip4.Id, d.ip4.Flags) m.logger.Trace4("packet is a fragment: src=%v dst=%v id=%v flags=%v",
srcIP, dstIP, d.ip4.Id, d.ip4.Flags)
} else {
m.logger.Trace2("packet is an IPv6 fragment: src=%v dst=%v", srcIP, dstIP)
}
return false return false
} }
// TODO: optimize port DNAT by caching matched rules in conntrack // TODO: optimize port DNAT by caching matched rules in conntrack
if translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP); translated { if translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP); translated {
// Re-decode after port DNAT translation to update port information // Re-decode after port DNAT translation to update port information
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { if err := d.decodePacket(packetData); err != nil {
m.logger.Error1("failed to re-decode packet after port DNAT: %v", err) m.logger.Error1("failed to re-decode packet after port DNAT: %v", err)
return true return true
} }
@@ -965,7 +1043,7 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
if translated := m.translateInboundReverse(packetData, d); translated { if translated := m.translateInboundReverse(packetData, d); translated {
// Re-decode after translation to get original addresses // Re-decode after translation to get original addresses
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { if err := d.decodePacket(packetData); err != nil {
m.logger.Error1("failed to re-decode packet after reverse DNAT: %v", err) m.logger.Error1("failed to re-decode packet after reverse DNAT: %v", err)
return true return true
} }
@@ -1097,6 +1175,48 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
return true return true
} }
// icmpv6EchoFields extracts the echo identifier from an ICMPv6 packet and maps
// the ICMPv6 type code to an ICMPv4TypeCode so the ICMP conntrack can handle
// both families uniformly. The echo ID is in the first two payload bytes.
func icmpv6EchoFields(d *decoder) (id uint16, tc layers.ICMPv4TypeCode) {
if len(d.icmp6.Payload) >= 2 {
id = uint16(d.icmp6.Payload[0])<<8 | uint16(d.icmp6.Payload[1])
}
// Map ICMPv6 echo types to ICMPv4 equivalents for unified tracking.
switch d.icmp6.TypeCode.Type() {
case layers.ICMPv6TypeEchoRequest:
tc = layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0)
case layers.ICMPv6TypeEchoReply:
tc = layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoReply, 0)
default:
tc = layers.CreateICMPv4TypeCode(d.icmp6.TypeCode.Type(), d.icmp6.TypeCode.Code())
}
return id, tc
}
// protoLayerMatches checks if a packet's protocol layer matches a rule's expected
// protocol layer. ICMPv4 and ICMPv6 are treated as equivalent when matching
// ICMP rules since management sends a single ICMP rule for both families.
func protoLayerMatches(ruleLayer, packetLayer gopacket.LayerType) bool {
if ruleLayer == packetLayer {
return true
}
if ruleLayer == layers.LayerTypeICMPv4 && packetLayer == layers.LayerTypeICMPv6 {
return true
}
if ruleLayer == layers.LayerTypeICMPv6 && packetLayer == layers.LayerTypeICMPv4 {
return true
}
return false
}
func ipLayerFromPrefix(p netip.Prefix) gopacket.LayerType {
if p.Addr().Is6() {
return layers.LayerTypeIPv6
}
return layers.LayerTypeIPv4
}
func protoToLayer(proto firewall.Protocol, ipLayer gopacket.LayerType) gopacket.LayerType { func protoToLayer(proto firewall.Protocol, ipLayer gopacket.LayerType) gopacket.LayerType {
switch proto { switch proto {
case firewall.ProtocolTCP: case firewall.ProtocolTCP:
@@ -1120,8 +1240,10 @@ func getProtocolFromPacket(d *decoder) nftypes.Protocol {
return nftypes.TCP return nftypes.TCP
case layers.LayerTypeUDP: case layers.LayerTypeUDP:
return nftypes.UDP return nftypes.UDP
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6: case layers.LayerTypeICMPv4:
return nftypes.ICMP return nftypes.ICMP
case layers.LayerTypeICMPv6:
return nftypes.ICMPv6
default: default:
return nftypes.ProtocolUnknown return nftypes.ProtocolUnknown
} }
@@ -1142,7 +1264,7 @@ func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) {
// It returns true, false if the packet is valid and not a fragment. // It returns true, false if the packet is valid and not a fragment.
// It returns true, true if the packet is a fragment and valid. // It returns true, true if the packet is a fragment and valid.
func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) { func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) {
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { if err := d.decodePacket(packetData); err != nil {
m.logger.Trace1("couldn't decode packet, err: %s", err) m.logger.Trace1("couldn't decode packet, err: %s", err)
return false, false return false, false
} }
@@ -1155,10 +1277,18 @@ func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) {
} }
// Fragments are also valid // Fragments are also valid
if l == 1 && d.decoded[0] == layers.LayerTypeIPv4 { if l == 1 {
ip4 := d.ip4 switch d.decoded[0] {
if ip4.Flags&layers.IPv4MoreFragments != 0 || ip4.FragOffset != 0 { case layers.LayerTypeIPv4:
return true, true if d.ip4.Flags&layers.IPv4MoreFragments != 0 || d.ip4.FragOffset != 0 {
return true, true
}
case layers.LayerTypeIPv6:
// IPv6 uses Fragment extension header (NextHeader=44). If gopacket
// only decoded the IPv6 layer, the transport is in a fragment.
if d.ip6.NextHeader == layers.IPProtocolIPv6Fragment {
return true, true
}
} }
} }
@@ -1196,21 +1326,34 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr,
size, size,
) )
// TODO: ICMPv6 case layers.LayerTypeICMPv6:
id, _ := icmpv6EchoFields(d)
return m.icmpTracker.IsValidInbound(
srcIP,
dstIP,
id,
d.icmp6.TypeCode.Type(),
size,
)
} }
return false return false
} }
// isSpecialICMP returns true if the packet is a special ICMP packet that should be allowed // isSpecialICMP returns true if the packet is a special ICMP error packet that should be allowed.
func (m *Manager) isSpecialICMP(d *decoder) bool { func (m *Manager) isSpecialICMP(d *decoder) bool {
if d.decoded[1] != layers.LayerTypeICMPv4 { switch d.decoded[1] {
return false case layers.LayerTypeICMPv4:
icmpType := d.icmp4.TypeCode.Type()
return icmpType == layers.ICMPv4TypeDestinationUnreachable ||
icmpType == layers.ICMPv4TypeTimeExceeded
case layers.LayerTypeICMPv6:
icmpType := d.icmp6.TypeCode.Type()
return icmpType == layers.ICMPv6TypeDestinationUnreachable ||
icmpType == layers.ICMPv6TypePacketTooBig ||
icmpType == layers.ICMPv6TypeTimeExceeded
} }
return false
icmpType := d.icmp4.TypeCode.Type()
return icmpType == layers.ICMPv4TypeDestinationUnreachable ||
icmpType == layers.ICMPv4TypeTimeExceeded
} }
func (m *Manager) peerACLsBlock(srcIP netip.Addr, d *decoder, packetData []byte) ([]byte, bool) { func (m *Manager) peerACLsBlock(srcIP netip.Addr, d *decoder, packetData []byte) ([]byte, bool) {
@@ -1267,7 +1410,7 @@ func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d
return rule.mgmtId, rule.drop, true return rule.mgmtId, rule.drop, true
} }
if payloadLayer != rule.protoLayer { if !protoLayerMatches(rule.protoLayer, payloadLayer) {
continue continue
} }
@@ -1302,8 +1445,7 @@ func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, protoLayer gopacket.Lay
} }
func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) bool { func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) bool {
// TODO: handle ipv6 vs ipv4 icmp rules if rule.protoLayer != layerTypeAll && !protoLayerMatches(rule.protoLayer, protoLayer) {
if rule.protoLayer != layerTypeAll && rule.protoLayer != protoLayer {
return false return false
} }
@@ -1473,7 +1615,8 @@ func (m *Manager) shouldForward(d *decoder, dstIP netip.Addr) bool {
} }
// traffic to our other local interfaces (not NetBird IP) - always forward // traffic to our other local interfaces (not NetBird IP) - always forward
if dstIP != m.wgIface.Address().IP { addr := m.wgIface.Address()
if dstIP != addr.IP && (!addr.IPv6.IsValid() || dstIP != addr.IPv6) {
return true return true
} }

View File

@@ -1023,7 +1023,8 @@ func BenchmarkMSSClamping(b *testing.B) {
}() }()
manager.mssClampEnabled = true manager.mssClampEnabled = true
manager.mssClampValue = 1240 manager.mssClampValueIPv4 = 1240
manager.mssClampValueIPv6 = 1220
srcIP := net.ParseIP("100.64.0.2") srcIP := net.ParseIP("100.64.0.2")
dstIP := net.ParseIP("8.8.8.8") dstIP := net.ParseIP("8.8.8.8")
@@ -1088,7 +1089,8 @@ func BenchmarkMSSClampingOverhead(b *testing.B) {
manager.mssClampEnabled = sc.enabled manager.mssClampEnabled = sc.enabled
if sc.enabled { if sc.enabled {
manager.mssClampValue = 1240 manager.mssClampValueIPv4 = 1240
manager.mssClampValueIPv6 = 1220
} }
srcIP := net.ParseIP("100.64.0.2") srcIP := net.ParseIP("100.64.0.2")
@@ -1141,7 +1143,8 @@ func BenchmarkMSSClampingMemory(b *testing.B) {
}() }()
manager.mssClampEnabled = true manager.mssClampEnabled = true
manager.mssClampValue = 1240 manager.mssClampValueIPv4 = 1240
manager.mssClampValueIPv6 = 1220
srcIP := net.ParseIP("100.64.0.2") srcIP := net.ParseIP("100.64.0.2")
dstIP := net.ParseIP("8.8.8.8") dstIP := net.ParseIP("8.8.8.8")

View File

@@ -539,53 +539,236 @@ func TestPeerACLFiltering(t *testing.T) {
} }
} }
func TestPeerACLFilteringIPv6(t *testing.T) {
localIP := netip.MustParseAddr("100.10.0.100")
localIPv6 := netip.MustParseAddr("fd00::100")
wgNet := netip.MustParsePrefix("100.10.0.0/16")
wgNetV6 := netip.MustParsePrefix("fd00::/64")
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: localIP,
Network: wgNet,
IPv6: localIPv6,
IPv6Net: wgNetV6,
}
},
}
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
err = manager.UpdateLocalIPs()
require.NoError(t, err)
testCases := []struct {
name string
srcIP string
dstIP string
proto fw.Protocol
srcPort uint16
dstPort uint16
ruleIP string
ruleProto fw.Protocol
ruleDstPort *fw.Port
ruleAction fw.Action
shouldBeBlocked bool
}{
{
name: "IPv6: allow TCP from peer",
srcIP: "fd00::1",
dstIP: "fd00::100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 443,
ruleIP: "fd00::1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{Values: []uint16{443}},
ruleAction: fw.ActionAccept,
shouldBeBlocked: false,
},
{
name: "IPv6: allow UDP from peer",
srcIP: "fd00::1",
dstIP: "fd00::100",
proto: fw.ProtocolUDP,
srcPort: 12345,
dstPort: 53,
ruleIP: "fd00::1",
ruleProto: fw.ProtocolUDP,
ruleDstPort: &fw.Port{Values: []uint16{53}},
ruleAction: fw.ActionAccept,
shouldBeBlocked: false,
},
{
name: "IPv6: allow ICMPv6 from peer",
srcIP: "fd00::1",
dstIP: "fd00::100",
proto: fw.ProtocolICMP,
ruleIP: "fd00::1",
ruleProto: fw.ProtocolICMP,
ruleAction: fw.ActionAccept,
shouldBeBlocked: false,
},
{
name: "IPv6: block TCP without rule",
srcIP: "fd00::2",
dstIP: "fd00::100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 443,
ruleIP: "fd00::1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{Values: []uint16{443}},
ruleAction: fw.ActionAccept,
shouldBeBlocked: true,
},
{
name: "IPv6: drop rule",
srcIP: "fd00::1",
dstIP: "fd00::100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 22,
ruleIP: "fd00::1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{Values: []uint16{22}},
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
{
name: "IPv6: allow all protocols",
srcIP: "fd00::1",
dstIP: "fd00::100",
proto: fw.ProtocolUDP,
srcPort: 12345,
dstPort: 9999,
ruleIP: "fd00::1",
ruleProto: fw.ProtocolALL,
ruleAction: fw.ActionAccept,
shouldBeBlocked: false,
},
{
name: "IPv6: v4 wildcard ICMP rule matches ICMPv6 via protoLayerMatches",
srcIP: "fd00::1",
dstIP: "fd00::100",
proto: fw.ProtocolICMP,
ruleIP: "0.0.0.0",
ruleProto: fw.ProtocolICMP,
ruleAction: fw.ActionAccept,
shouldBeBlocked: false,
},
}
t.Run("IPv6 implicit DROP (no rules)", func(t *testing.T) {
packet := createTestPacket(t, "fd00::1", "fd00::100", fw.ProtocolTCP, 12345, 443)
isDropped := manager.FilterInbound(packet, 0)
require.True(t, isDropped, "IPv6 packet should be dropped when no rules exist")
})
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if tc.ruleAction == fw.ActionDrop {
rules, err := manager.AddPeerFiltering(nil, net.ParseIP(tc.ruleIP), fw.ProtocolALL, nil, nil, fw.ActionAccept, "")
require.NoError(t, err)
t.Cleanup(func() {
for _, rule := range rules {
require.NoError(t, manager.DeletePeerRule(rule))
}
})
}
rules, err := manager.AddPeerFiltering(nil, net.ParseIP(tc.ruleIP), tc.ruleProto, nil, tc.ruleDstPort, tc.ruleAction, "")
require.NoError(t, err)
require.NotEmpty(t, rules)
t.Cleanup(func() {
for _, rule := range rules {
require.NoError(t, manager.DeletePeerRule(rule))
}
})
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
isDropped := manager.FilterInbound(packet, 0)
require.Equal(t, tc.shouldBeBlocked, isDropped, "packet filter result mismatch")
})
}
}
func createTestPacket(t *testing.T, srcIP, dstIP string, proto fw.Protocol, srcPort, dstPort uint16) []byte { func createTestPacket(t *testing.T, srcIP, dstIP string, proto fw.Protocol, srcPort, dstPort uint16) []byte {
t.Helper() t.Helper()
src := net.ParseIP(srcIP)
dst := net.ParseIP(dstIP)
buf := gopacket.NewSerializeBuffer() buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{ opts := gopacket.SerializeOptions{
ComputeChecksums: true, ComputeChecksums: true,
FixLengths: true, FixLengths: true,
} }
ipLayer := &layers.IPv4{ // Detect address family
Version: 4, isV6 := src.To4() == nil
TTL: 64,
SrcIP: net.ParseIP(srcIP),
DstIP: net.ParseIP(dstIP),
}
var err error var err error
switch proto {
case fw.ProtocolTCP:
ipLayer.Protocol = layers.IPProtocolTCP
tcp := &layers.TCP{
SrcPort: layers.TCPPort(srcPort),
DstPort: layers.TCPPort(dstPort),
}
err = tcp.SetNetworkLayerForChecksum(ipLayer)
require.NoError(t, err)
err = gopacket.SerializeLayers(buf, opts, ipLayer, tcp)
case fw.ProtocolUDP: if isV6 {
ipLayer.Protocol = layers.IPProtocolUDP ip6 := &layers.IPv6{
udp := &layers.UDP{ Version: 6,
SrcPort: layers.UDPPort(srcPort), HopLimit: 64,
DstPort: layers.UDPPort(dstPort), SrcIP: src,
DstIP: dst,
} }
err = udp.SetNetworkLayerForChecksum(ipLayer)
require.NoError(t, err)
err = gopacket.SerializeLayers(buf, opts, ipLayer, udp)
case fw.ProtocolICMP: switch proto {
ipLayer.Protocol = layers.IPProtocolICMPv4 case fw.ProtocolTCP:
icmp := &layers.ICMPv4{ ip6.NextHeader = layers.IPProtocolTCP
TypeCode: layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0), tcp := &layers.TCP{SrcPort: layers.TCPPort(srcPort), DstPort: layers.TCPPort(dstPort)}
_ = tcp.SetNetworkLayerForChecksum(ip6)
err = gopacket.SerializeLayers(buf, opts, ip6, tcp)
case fw.ProtocolUDP:
ip6.NextHeader = layers.IPProtocolUDP
udp := &layers.UDP{SrcPort: layers.UDPPort(srcPort), DstPort: layers.UDPPort(dstPort)}
_ = udp.SetNetworkLayerForChecksum(ip6)
err = gopacket.SerializeLayers(buf, opts, ip6, udp)
case fw.ProtocolICMP:
ip6.NextHeader = layers.IPProtocolICMPv6
icmp := &layers.ICMPv6{
TypeCode: layers.CreateICMPv6TypeCode(layers.ICMPv6TypeEchoRequest, 0),
}
_ = icmp.SetNetworkLayerForChecksum(ip6)
err = gopacket.SerializeLayers(buf, opts, ip6, icmp)
default:
err = gopacket.SerializeLayers(buf, opts, ip6)
}
} else {
ip4 := &layers.IPv4{
Version: 4,
TTL: 64,
SrcIP: src,
DstIP: dst,
} }
err = gopacket.SerializeLayers(buf, opts, ipLayer, icmp)
default: switch proto {
err = gopacket.SerializeLayers(buf, opts, ipLayer) case fw.ProtocolTCP:
ip4.Protocol = layers.IPProtocolTCP
tcp := &layers.TCP{SrcPort: layers.TCPPort(srcPort), DstPort: layers.TCPPort(dstPort)}
_ = tcp.SetNetworkLayerForChecksum(ip4)
err = gopacket.SerializeLayers(buf, opts, ip4, tcp)
case fw.ProtocolUDP:
ip4.Protocol = layers.IPProtocolUDP
udp := &layers.UDP{SrcPort: layers.UDPPort(srcPort), DstPort: layers.UDPPort(dstPort)}
_ = udp.SetNetworkLayerForChecksum(ip4)
err = gopacket.SerializeLayers(buf, opts, ip4, udp)
case fw.ProtocolICMP:
ip4.Protocol = layers.IPProtocolICMPv4
icmp := &layers.ICMPv4{TypeCode: layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0)}
err = gopacket.SerializeLayers(buf, opts, ip4, icmp)
default:
err = gopacket.SerializeLayers(buf, opts, ip4)
}
} }
require.NoError(t, err) require.NoError(t, err)
@@ -1498,3 +1681,103 @@ func TestRouteACLSet(t *testing.T) {
_, isAllowed = manager.routeACLsPass(srcIP, dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80) _, isAllowed = manager.routeACLsPass(srcIP, dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
require.True(t, isAllowed, "After set update, traffic to the added network should be allowed") require.True(t, isAllowed, "After set update, traffic to the added network should be allowed")
} }
// TestRouteACLFilteringIPv6 tests IPv6 route ACL matching directly via routeACLsPass.
// Note: full FilterInbound for routed IPv6 traffic drops at the forwarder stage (IPv4-only)
// but the ACL decision itself is correct.
func TestRouteACLFilteringIPv6(t *testing.T) {
manager := setupRoutedManager(t, "10.10.0.100/16")
v6Dst := netip.MustParsePrefix("fd00:dead:beef::/48")
_, err := manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("fd00::/16")},
fw.Network{Prefix: v6Dst},
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{80}},
fw.ActionAccept,
)
require.NoError(t, err)
_, err = manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("fd00::/16")},
fw.Network{Prefix: netip.MustParsePrefix("fd00:dead:beef:1::/64")},
fw.ProtocolALL,
nil,
nil,
fw.ActionDrop,
)
require.NoError(t, err)
tests := []struct {
name string
srcIP netip.Addr
dstIP netip.Addr
proto gopacket.LayerType
srcPort uint16
dstPort uint16
allowed bool
}{
{
name: "IPv6 TCP to allowed dest",
srcIP: netip.MustParseAddr("fd00::1"),
dstIP: netip.MustParseAddr("fd00:dead:beef::80"),
proto: layers.LayerTypeTCP,
srcPort: 12345,
dstPort: 80,
allowed: true,
},
{
name: "IPv6 TCP wrong port",
srcIP: netip.MustParseAddr("fd00::1"),
dstIP: netip.MustParseAddr("fd00:dead:beef::80"),
proto: layers.LayerTypeTCP,
srcPort: 12345,
dstPort: 443,
allowed: false,
},
{
name: "IPv6 UDP not matched by TCP rule",
srcIP: netip.MustParseAddr("fd00::1"),
dstIP: netip.MustParseAddr("fd00:dead:beef::80"),
proto: layers.LayerTypeUDP,
srcPort: 12345,
dstPort: 80,
allowed: false,
},
{
name: "IPv6 ICMPv6 matches ICMP rule via protoLayerMatches",
srcIP: netip.MustParseAddr("fd00::1"),
dstIP: netip.MustParseAddr("fd00:dead:beef::80"),
proto: layers.LayerTypeICMPv6,
allowed: false,
},
{
name: "IPv6 to denied subnet",
srcIP: netip.MustParseAddr("fd00::1"),
dstIP: netip.MustParseAddr("fd00:dead:beef:1::1"),
proto: layers.LayerTypeTCP,
srcPort: 12345,
dstPort: 80,
allowed: false,
},
{
name: "IPv6 source outside allowed range",
srcIP: netip.MustParseAddr("fe80::1"),
dstIP: netip.MustParseAddr("fd00:dead:beef::80"),
proto: layers.LayerTypeTCP,
srcPort: 12345,
dstPort: 80,
allowed: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
_, pass := manager.routeACLsPass(tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
require.Equal(t, tc.allowed, pass, "route ACL result mismatch")
})
}
}

View File

@@ -527,11 +527,16 @@ func TestProcessOutgoingHooks(t *testing.T) {
d := &decoder{ d := &decoder{
decoded: []gopacket.LayerType{}, decoded: []gopacket.LayerType{},
} }
d.parser = gopacket.NewDecodingLayerParser( d.parser4 = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4, layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
) )
d.parser.IgnoreUnsupported = true d.parser4.IgnoreUnsupported = true
d.parser6 = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv6,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser6.IgnoreUnsupported = true
return d return d
}, },
} }
@@ -630,11 +635,16 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
d := &decoder{ d := &decoder{
decoded: []gopacket.LayerType{}, decoded: []gopacket.LayerType{},
} }
d.parser = gopacket.NewDecodingLayerParser( d.parser4 = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4, layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
) )
d.parser.IgnoreUnsupported = true d.parser4.IgnoreUnsupported = true
d.parser6 = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv6,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser6.IgnoreUnsupported = true
return d return d
}, },
} }
@@ -1040,8 +1050,8 @@ func TestMSSClamping(t *testing.T) {
}() }()
require.True(t, manager.mssClampEnabled, "MSS clamping should be enabled by default") require.True(t, manager.mssClampEnabled, "MSS clamping should be enabled by default")
expectedMSSValue := uint16(1280 - ipTCPHeaderMinSize) require.Equal(t, uint16(1280-ipv4TCPHeaderMinSize), manager.mssClampValueIPv4, "IPv4 MSS clamp value should be MTU - 40")
require.Equal(t, expectedMSSValue, manager.mssClampValue, "MSS clamp value should be MTU - 40") require.Equal(t, uint16(1280-ipv6TCPHeaderMinSize), manager.mssClampValueIPv6, "IPv6 MSS clamp value should be MTU - 60")
err = manager.UpdateLocalIPs() err = manager.UpdateLocalIPs()
require.NoError(t, err) require.NoError(t, err)
@@ -1059,7 +1069,7 @@ func TestMSSClamping(t *testing.T) {
require.Len(t, d.tcp.Options, 1, "Should have MSS option") require.Len(t, d.tcp.Options, 1, "Should have MSS option")
require.Equal(t, uint8(layers.TCPOptionKindMSS), uint8(d.tcp.Options[0].OptionType)) require.Equal(t, uint8(layers.TCPOptionKindMSS), uint8(d.tcp.Options[0].OptionType))
actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData) actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData)
require.Equal(t, expectedMSSValue, actualMSS, "MSS should be clamped to MTU - 40") require.Equal(t, manager.mssClampValueIPv4, actualMSS, "MSS should be clamped to MTU - 40")
}) })
t.Run("SYN packet with low MSS unchanged", func(t *testing.T) { t.Run("SYN packet with low MSS unchanged", func(t *testing.T) {
@@ -1083,7 +1093,7 @@ func TestMSSClamping(t *testing.T) {
d := parsePacket(t, packet) d := parsePacket(t, packet)
require.Len(t, d.tcp.Options, 1, "Should have MSS option") require.Len(t, d.tcp.Options, 1, "Should have MSS option")
actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData) actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData)
require.Equal(t, expectedMSSValue, actualMSS, "MSS in SYN-ACK should be clamped") require.Equal(t, manager.mssClampValueIPv4, actualMSS, "MSS in SYN-ACK should be clamped")
}) })
t.Run("Non-SYN packet unchanged", func(t *testing.T) { t.Run("Non-SYN packet unchanged", func(t *testing.T) {
@@ -1255,13 +1265,18 @@ func TestShouldForward(t *testing.T) {
d := &decoder{ d := &decoder{
decoded: []gopacket.LayerType{}, decoded: []gopacket.LayerType{},
} }
d.parser = gopacket.NewDecodingLayerParser( d.parser4 = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4, layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
) )
d.parser.IgnoreUnsupported = true d.parser4.IgnoreUnsupported = true
d.parser6 = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv6,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser6.IgnoreUnsupported = true
err = d.parser.DecodeLayers(buf.Bytes(), &d.decoded) err = d.decodePacket(buf.Bytes())
require.NoError(t, err) require.NoError(t, err)
return d return d
@@ -1321,6 +1336,44 @@ func TestShouldForward(t *testing.T) {
}, },
} }
// Add IPv6 to the interface and test dual-stack cases
wgIPv6 := netip.MustParseAddr("fd00::1")
otherIPv6 := netip.MustParseAddr("fd00::2")
ifaceMock.AddressFunc = func() wgaddr.Address {
return wgaddr.Address{
IP: wgIP,
Network: netip.PrefixFrom(wgIP, 24),
IPv6: wgIPv6,
IPv6Net: netip.PrefixFrom(wgIPv6, 64),
}
}
// Re-create manager to pick up the new address with IPv6
require.NoError(t, manager.Close(nil))
manager, err = Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
v6Cases := []struct {
name string
dstIP netip.Addr
expected bool
description string
}{
{"v6 traffic to other address", otherIPv6, true, "should forward v6 traffic not destined to our v6 address"},
{"v6 traffic to our v6 IP", wgIPv6, false, "should not forward traffic destined to our v6 address"},
{"v4 traffic to other with v6 configured", otherIP, true, "should forward v4 traffic when v6 configured"},
{"v4 traffic to our v4 IP with v6 configured", wgIP, false, "should not forward traffic to our v4 address"},
}
for _, tt := range v6Cases {
t.Run(tt.name, func(t *testing.T) {
manager.localForwarding = true
manager.netstack = false
decoder := createTCPDecoder(8080)
result := manager.shouldForward(decoder, tt.dstIP)
require.Equal(t, tt.expected, result, tt.description)
})
}
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
// Configure manager // Configure manager

View File

@@ -1,7 +1,8 @@
package forwarder package forwarder
import ( import (
"fmt" "net"
"strconv"
"sync/atomic" "sync/atomic"
wgdevice "golang.zx2c4.com/wireguard/device" wgdevice "golang.zx2c4.com/wireguard/device"
@@ -47,17 +48,23 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress {
func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) { func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) {
var written int var written int
for _, pkt := range pkts.AsSlice() { for _, pkt := range pkts.AsSlice() {
netHeader := header.IPv4(pkt.NetworkHeader().View().AsSlice())
data := stack.PayloadSince(pkt.NetworkHeader()) data := stack.PayloadSince(pkt.NetworkHeader())
if data == nil { if data == nil {
continue continue
} }
// Send the packet through WireGuard raw := pkt.NetworkHeader().View().AsSlice()
address := netHeader.DestinationAddress() if len(raw) == 0 {
err := e.device.CreateOutboundPacket(data.AsSlice(), address.AsSlice()) continue
if err != nil { }
var address tcpip.Address
if raw[0]>>4 == 6 {
address = header.IPv6(raw).DestinationAddress()
} else {
address = header.IPv4(raw).DestinationAddress()
}
if err := e.device.CreateOutboundPacket(data.AsSlice(), address.AsSlice()); err != nil {
e.logger.Error1("CreateOutboundPacket: %v", err) e.logger.Error1("CreateOutboundPacket: %v", err)
continue continue
} }
@@ -103,5 +110,7 @@ type epID stack.TransportEndpointID
func (i epID) String() string { func (i epID) String() string {
// src and remote is swapped // src and remote is swapped
return fmt.Sprintf("%s:%d → %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort) return net.JoinHostPort(i.RemoteAddress.String(), strconv.Itoa(int(i.RemotePort))) +
" → " +
net.JoinHostPort(i.LocalAddress.String(), strconv.Itoa(int(i.LocalPort)))
} }

View File

@@ -14,6 +14,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
@@ -36,25 +37,31 @@ type Forwarder struct {
logger *nblog.Logger logger *nblog.Logger
flowLogger nftypes.FlowLogger flowLogger nftypes.FlowLogger
// ruleIdMap is used to store the rule ID for a given connection // ruleIdMap is used to store the rule ID for a given connection
ruleIdMap sync.Map ruleIdMap sync.Map
stack *stack.Stack stack *stack.Stack
endpoint *endpoint endpoint *endpoint
udpForwarder *udpForwarder udpForwarder *udpForwarder
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
ip tcpip.Address ip tcpip.Address
netstack bool ipv6 tcpip.Address
hasRawICMPAccess bool netstack bool
pingSemaphore chan struct{} hasRawICMPAccess bool
hasRawICMPv6Access bool
pingSemaphore chan struct{}
} }
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) { func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) {
s := stack.New(stack.Options{ s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, NetworkProtocols: []stack.NetworkProtocolFactory{
ipv4.NewProtocol,
ipv6.NewProtocol,
},
TransportProtocols: []stack.TransportProtocolFactory{ TransportProtocols: []stack.TransportProtocolFactory{
tcp.NewProtocol, tcp.NewProtocol,
udp.NewProtocol, udp.NewProtocol,
icmp.NewProtocol4, icmp.NewProtocol4,
icmp.NewProtocol6,
}, },
HandleLocal: false, HandleLocal: false,
}) })
@@ -73,7 +80,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
protoAddr := tcpip.ProtocolAddress{ protoAddr := tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber, Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{ AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()), Address: tcpip.AddrFrom4(iface.Address().IP.As4()),
PrefixLen: iface.Address().Network.Bits(), PrefixLen: iface.Address().Network.Bits(),
}, },
} }
@@ -82,6 +89,19 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
return nil, fmt.Errorf("failed to add protocol address: %s", err) return nil, fmt.Errorf("failed to add protocol address: %s", err)
} }
if v6 := iface.Address().IPv6; v6.IsValid() {
v6Addr := tcpip.ProtocolAddress{
Protocol: ipv6.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.AddrFrom16(v6.As16()),
PrefixLen: iface.Address().IPv6Net.Bits(),
},
}
if err := s.AddProtocolAddress(nicID, v6Addr, stack.AddressProperties{}); err != nil {
return nil, fmt.Errorf("add IPv6 protocol address: %s", err)
}
}
defaultSubnet, err := tcpip.NewSubnet( defaultSubnet, err := tcpip.NewSubnet(
tcpip.AddrFrom4([4]byte{0, 0, 0, 0}), tcpip.AddrFrom4([4]byte{0, 0, 0, 0}),
tcpip.MaskFromBytes([]byte{0, 0, 0, 0}), tcpip.MaskFromBytes([]byte{0, 0, 0, 0}),
@@ -90,6 +110,14 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
return nil, fmt.Errorf("creating default subnet: %w", err) return nil, fmt.Errorf("creating default subnet: %w", err)
} }
defaultSubnetV6, err := tcpip.NewSubnet(
tcpip.AddrFrom16([16]byte{}),
tcpip.MaskFromBytes(make([]byte, 16)),
)
if err != nil {
return nil, fmt.Errorf("creating default v6 subnet: %w", err)
}
if err := s.SetPromiscuousMode(nicID, true); err != nil { if err := s.SetPromiscuousMode(nicID, true); err != nil {
return nil, fmt.Errorf("set promiscuous mode: %s", err) return nil, fmt.Errorf("set promiscuous mode: %s", err)
} }
@@ -98,10 +126,8 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
} }
s.SetRouteTable([]tcpip.Route{ s.SetRouteTable([]tcpip.Route{
{ {Destination: defaultSubnet, NIC: nicID},
Destination: defaultSubnet, {Destination: defaultSubnetV6, NIC: nicID},
NIC: nicID,
},
}) })
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
@@ -114,7 +140,8 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
netstack: netstack, netstack: netstack,
ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()), ip: tcpip.AddrFrom4(iface.Address().IP.As4()),
ipv6: addrFromNetipAddr(iface.Address().IPv6),
pingSemaphore: make(chan struct{}, 3), pingSemaphore: make(chan struct{}, 3),
} }
@@ -131,7 +158,10 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
udpForwarder := udp.NewForwarder(s, f.handleUDP) udpForwarder := udp.NewForwarder(s, f.handleUDP)
s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket) s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
s.SetTransportProtocolHandler(icmp.ProtocolNumber4, f.handleICMP) // ICMP is handled directly in InjectIncomingPacket, bypassing gVisor's
// network layer. This avoids duplicate echo replies (v4) and the v6
// auto-reply bug where gVisor responds at the network layer before
// our transport handler fires.
f.checkICMPCapability() f.checkICMPCapability()
@@ -140,8 +170,30 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
} }
func (f *Forwarder) InjectIncomingPacket(payload []byte) error { func (f *Forwarder) InjectIncomingPacket(payload []byte) error {
if len(payload) < header.IPv4MinimumSize { if len(payload) == 0 {
return fmt.Errorf("packet too small: %d bytes", len(payload)) return fmt.Errorf("empty packet")
}
var protoNum tcpip.NetworkProtocolNumber
switch payload[0] >> 4 {
case 4:
if len(payload) < header.IPv4MinimumSize {
return fmt.Errorf("IPv4 packet too small: %d bytes", len(payload))
}
if f.handleICMPDirect(payload) {
return nil
}
protoNum = ipv4.ProtocolNumber
case 6:
if len(payload) < header.IPv6MinimumSize {
return fmt.Errorf("IPv6 packet too small: %d bytes", len(payload))
}
if f.handleICMPDirect(payload) {
return nil
}
protoNum = ipv6.ProtocolNumber
default:
return fmt.Errorf("unknown IP version: %d", payload[0]>>4)
} }
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -150,11 +202,95 @@ func (f *Forwarder) InjectIncomingPacket(payload []byte) error {
defer pkt.DecRef() defer pkt.DecRef()
if f.endpoint.dispatcher != nil { if f.endpoint.dispatcher != nil {
f.endpoint.dispatcher.DeliverNetworkPacket(ipv4.ProtocolNumber, pkt) f.endpoint.dispatcher.DeliverNetworkPacket(protoNum, pkt)
} }
return nil return nil
} }
// handleICMPDirect intercepts ICMP packets from raw IP payloads before they
// enter gVisor. It synthesizes the TransportEndpointID and PacketBuffer that
// the existing handlers expect, then dispatches to handleICMP/handleICMPv6.
// This bypasses gVisor's network layer which causes duplicate v4 echo replies
// and auto-replies to all v6 echo requests in promiscuous mode.
//
// Unlike gVisor's network layer, this does not validate ICMP checksums or
// reassemble IP fragments. Fragmented ICMP packets fall through to gVisor.
func parseICMPv4(payload []byte) (ipHdrLen int, src, dst tcpip.Address, ok bool) {
ip := header.IPv4(payload)
if ip.Protocol() != uint8(header.ICMPv4ProtocolNumber) {
return 0, src, dst, false
}
if ip.FragmentOffset() != 0 || ip.Flags()&header.IPv4FlagMoreFragments != 0 {
return 0, src, dst, false
}
ipHdrLen = int(ip.HeaderLength())
if len(payload)-ipHdrLen < header.ICMPv4MinimumSize {
return 0, src, dst, false
}
return ipHdrLen, ip.SourceAddress(), ip.DestinationAddress(), true
}
func parseICMPv6(payload []byte) (ipHdrLen int, src, dst tcpip.Address, ok bool) {
ip := header.IPv6(payload)
if ip.NextHeader() != uint8(header.ICMPv6ProtocolNumber) {
return 0, src, dst, false
}
ipHdrLen = header.IPv6MinimumSize
if len(payload)-ipHdrLen < header.ICMPv6MinimumSize {
return 0, src, dst, false
}
return ipHdrLen, ip.SourceAddress(), ip.DestinationAddress(), true
}
func (f *Forwarder) handleICMPDirect(payload []byte) bool {
var (
ipHdrLen int
srcAddr tcpip.Address
dstAddr tcpip.Address
ok bool
)
switch payload[0] >> 4 {
case 4:
ipHdrLen, srcAddr, dstAddr, ok = parseICMPv4(payload)
case 6:
ipHdrLen, srcAddr, dstAddr, ok = parseICMPv6(payload)
}
if !ok {
return false
}
// Let gVisor handle ICMP destined for our own addresses natively.
// Its network-layer auto-reply is correct and efficient for local traffic.
if f.ip.Equal(dstAddr) || f.ipv6.Equal(dstAddr) {
return false
}
id := stack.TransportEndpointID{
LocalAddress: dstAddr,
RemoteAddress: srcAddr,
}
// Build a PacketBuffer with headers consumed the same way gVisor would.
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(payload),
})
defer pkt.DecRef()
if _, ok := pkt.NetworkHeader().Consume(ipHdrLen); !ok {
return false
}
icmpPayload := payload[ipHdrLen:]
if _, ok := pkt.TransportHeader().Consume(len(icmpPayload)); !ok {
return false
}
if payload[0]>>4 == 6 {
return f.handleICMPv6(id, pkt)
}
return f.handleICMP(id, pkt)
}
// Stop gracefully shuts down the forwarder // Stop gracefully shuts down the forwarder
func (f *Forwarder) Stop() { func (f *Forwarder) Stop() {
f.cancel() f.cancel()
@@ -167,11 +303,14 @@ func (f *Forwarder) Stop() {
f.stack.Wait() f.stack.Wait()
} }
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP { func (f *Forwarder) determineDialAddr(addr tcpip.Address) netip.Addr {
if f.netstack && f.ip.Equal(addr) { if f.netstack && f.ip.Equal(addr) {
return net.IPv4(127, 0, 0, 1) return netip.AddrFrom4([4]byte{127, 0, 0, 1})
} }
return addr.AsSlice() if f.netstack && f.ipv6.Equal(addr) {
return netip.IPv6Loopback()
}
return addrToNetipAddr(addr)
} }
func (f *Forwarder) RegisterRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, ruleID []byte) { func (f *Forwarder) RegisterRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, ruleID []byte) {
@@ -205,23 +344,50 @@ func buildKey(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) conntrack.ConnKe
} }
} }
// addrFromNetipAddr converts a netip.Addr to a gvisor tcpip.Address without allocating.
func addrFromNetipAddr(addr netip.Addr) tcpip.Address {
if !addr.IsValid() {
return tcpip.Address{}
}
if addr.Is4() {
return tcpip.AddrFrom4(addr.As4())
}
return tcpip.AddrFrom16(addr.As16())
}
// addrToNetipAddr converts a gvisor tcpip.Address to netip.Addr without allocating.
func addrToNetipAddr(addr tcpip.Address) netip.Addr {
switch addr.Len() {
case 4:
return netip.AddrFrom4(addr.As4())
case 16:
return netip.AddrFrom16(addr.As16())
default:
return netip.Addr{}
}
}
// checkICMPCapability tests whether we have raw ICMP socket access at startup. // checkICMPCapability tests whether we have raw ICMP socket access at startup.
func (f *Forwarder) checkICMPCapability() { func (f *Forwarder) checkICMPCapability() {
f.hasRawICMPAccess = probeRawICMP("ip4:icmp", "0.0.0.0", f.logger)
f.hasRawICMPv6Access = probeRawICMP("ip6:ipv6-icmp", "::", f.logger)
}
func probeRawICMP(network, addr string, logger *nblog.Logger) bool {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel() defer cancel()
lc := net.ListenConfig{} lc := net.ListenConfig{}
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0") conn, err := lc.ListenPacket(ctx, network, addr)
if err != nil { if err != nil {
f.hasRawICMPAccess = false logger.Debug1("forwarder: no raw %s socket access, will use ping binary fallback", network)
f.logger.Debug("forwarder: No raw ICMP socket access, will use ping binary fallback") return false
return
} }
if err := conn.Close(); err != nil { if err := conn.Close(); err != nil {
f.logger.Debug1("forwarder: Failed to close ICMP capability test socket: %v", err) logger.Debug2("forwarder: failed to close %s capability test socket: %v", network, err)
} }
f.hasRawICMPAccess = true logger.Debug1("forwarder: raw %s socket access available", network)
f.logger.Debug("forwarder: Raw ICMP socket access available") return true
} }

View File

@@ -35,7 +35,7 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt *stack.PacketBu
} }
icmpData := stack.PayloadSince(pkt.TransportHeader()).AsSlice() icmpData := stack.PayloadSince(pkt.TransportHeader()).AsSlice()
conn, err := f.forwardICMPPacket(id, icmpData, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), 100*time.Millisecond) conn, err := f.forwardICMPPacket(id, icmpData, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), false, 100*time.Millisecond)
if err != nil { if err != nil {
f.logger.Error2("forwarder: Failed to forward ICMP packet for %v: %v", epID(id), err) f.logger.Error2("forwarder: Failed to forward ICMP packet for %v: %v", epID(id), err)
return true return true
@@ -58,7 +58,7 @@ func (f *Forwarder) handleICMPEcho(flowID uuid.UUID, id stack.TransportEndpointI
defer func() { <-f.pingSemaphore }() defer func() { <-f.pingSemaphore }()
if f.hasRawICMPAccess { if f.hasRawICMPAccess {
f.handleICMPViaSocket(flowID, id, icmpType, icmpCode, icmpData, rxBytes) f.handleICMPViaSocket(flowID, id, icmpType, icmpCode, icmpData, rxBytes, false)
} else { } else {
f.handleICMPViaPing(flowID, id, icmpType, icmpCode, icmpData, rxBytes) f.handleICMPViaPing(flowID, id, icmpType, icmpCode, icmpData, rxBytes)
} }
@@ -72,18 +72,23 @@ func (f *Forwarder) handleICMPEcho(flowID uuid.UUID, id stack.TransportEndpointI
// forwardICMPPacket creates a raw ICMP socket and sends the packet, returning the connection. // forwardICMPPacket creates a raw ICMP socket and sends the packet, returning the connection.
// The caller is responsible for closing the returned connection. // The caller is responsible for closing the returned connection.
func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []byte, icmpType, icmpCode uint8, timeout time.Duration) (net.PacketConn, error) { func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []byte, icmpType, icmpCode uint8, v6 bool, timeout time.Duration) (net.PacketConn, error) {
ctx, cancel := context.WithTimeout(f.ctx, timeout) ctx, cancel := context.WithTimeout(f.ctx, timeout)
defer cancel() defer cancel()
network, listenAddr := "ip4:icmp", "0.0.0.0"
if v6 {
network, listenAddr = "ip6:ipv6-icmp", "::"
}
lc := net.ListenConfig{} lc := net.ListenConfig{}
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0") conn, err := lc.ListenPacket(ctx, network, listenAddr)
if err != nil { if err != nil {
return nil, fmt.Errorf("create ICMP socket: %w", err) return nil, fmt.Errorf("create ICMP socket: %w", err)
} }
dstIP := f.determineDialAddr(id.LocalAddress) dstIP := f.determineDialAddr(id.LocalAddress)
dst := &net.IPAddr{IP: dstIP} dst := &net.IPAddr{IP: dstIP.AsSlice()}
if _, err = conn.WriteTo(payload, dst); err != nil { if _, err = conn.WriteTo(payload, dst); err != nil {
if closeErr := conn.Close(); closeErr != nil { if closeErr := conn.Close(); closeErr != nil {
@@ -98,11 +103,11 @@ func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []by
return conn, nil return conn, nil
} }
// handleICMPViaSocket handles ICMP echo requests using raw sockets. // handleICMPViaSocket handles ICMP echo requests using raw sockets for both v4 and v6.
func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int) { func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int, v6 bool) {
sendTime := time.Now() sendTime := time.Now()
conn, err := f.forwardICMPPacket(id, icmpData, icmpType, icmpCode, 5*time.Second) conn, err := f.forwardICMPPacket(id, icmpData, icmpType, icmpCode, v6, 5*time.Second)
if err != nil { if err != nil {
f.logger.Error2("forwarder: Failed to send ICMP packet for %v: %v", epID(id), err) f.logger.Error2("forwarder: Failed to send ICMP packet for %v: %v", epID(id), err)
return return
@@ -113,16 +118,20 @@ func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndp
} }
}() }()
txBytes := f.handleEchoResponse(conn, id) txBytes := f.handleEchoResponse(conn, id, v6)
rtt := time.Since(sendTime).Round(10 * time.Microsecond) rtt := time.Since(sendTime).Round(10 * time.Microsecond)
f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, raw socket)", proto := "ICMP"
epID(id), icmpType, icmpCode, rtt) if v6 {
proto = "ICMPv6"
}
f.logger.Trace5("forwarder: Forwarded %s echo reply %v type %v code %v (rtt=%v, raw socket)",
proto, epID(id), icmpType, icmpCode, rtt)
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes)) f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
} }
func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEndpointID) int { func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEndpointID, v6 bool) int {
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
f.logger.Error1("forwarder: Failed to set read deadline for ICMP response: %v", err) f.logger.Error1("forwarder: Failed to set read deadline for ICMP response: %v", err)
return 0 return 0
@@ -137,6 +146,19 @@ func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEn
return 0 return 0
} }
if v6 {
// Recompute checksum: the raw socket response has a checksum computed
// over the real endpoint addresses, but we inject with overlay addresses.
icmpHdr := header.ICMPv6(response[:n])
icmpHdr.SetChecksum(0)
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: icmpHdr,
Src: id.LocalAddress,
Dst: id.RemoteAddress,
}))
return f.injectICMPv6Reply(id, response[:n])
}
return f.injectICMPReply(id, response[:n]) return f.injectICMPReply(id, response[:n])
} }
@@ -150,19 +172,23 @@ func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.T
txPackets = 1 txPackets = 1
} }
srcIp := netip.AddrFrom4(id.RemoteAddress.As4()) srcIp := addrToNetipAddr(id.RemoteAddress)
dstIp := netip.AddrFrom4(id.LocalAddress.As4()) dstIp := addrToNetipAddr(id.LocalAddress)
proto := nftypes.ICMP
if srcIp.Is6() {
proto = nftypes.ICMPv6
}
fields := nftypes.EventFields{ fields := nftypes.EventFields{
FlowID: flowID, FlowID: flowID,
Type: typ, Type: typ,
Direction: nftypes.Ingress, Direction: nftypes.Ingress,
Protocol: nftypes.ICMP, Protocol: proto,
// TODO: handle ipv6 SourceIP: srcIp,
SourceIP: srcIp, DestIP: dstIp,
DestIP: dstIp, ICMPType: icmpType,
ICMPType: icmpType, ICMPCode: icmpCode,
ICMPCode: icmpCode,
RxBytes: rxBytes, RxBytes: rxBytes,
TxBytes: txBytes, TxBytes: txBytes,
@@ -209,26 +235,164 @@ func (f *Forwarder) handleICMPViaPing(flowID uuid.UUID, id stack.TransportEndpoi
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes)) f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
} }
// handleICMPv6 handles ICMPv6 packets from the network stack.
func (f *Forwarder) handleICMPv6(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
icmpHdr := header.ICMPv6(pkt.TransportHeader().View().AsSlice())
flowID := uuid.New()
f.sendICMPEvent(nftypes.TypeStart, flowID, id, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), 0, 0)
if icmpHdr.Type() == header.ICMPv6EchoRequest {
return f.handleICMPv6Echo(flowID, id, pkt, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()))
}
// For non-echo types (Destination Unreachable, Packet Too Big, etc), forward without waiting
if !f.hasRawICMPv6Access {
f.logger.Debug2("forwarder: Cannot handle ICMPv6 type %v without raw socket access for %v", icmpHdr.Type(), epID(id))
return false
}
icmpData := stack.PayloadSince(pkt.TransportHeader()).AsSlice()
conn, err := f.forwardICMPPacket(id, icmpData, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), true, 100*time.Millisecond)
if err != nil {
f.logger.Error2("forwarder: Failed to forward ICMPv6 packet for %v: %v", epID(id), err)
return true
}
if err := conn.Close(); err != nil {
f.logger.Debug1("forwarder: Failed to close ICMPv6 socket: %v", err)
}
return true
}
// handleICMPv6Echo handles ICMPv6 echo requests via raw socket or ping binary fallback.
func (f *Forwarder) handleICMPv6Echo(flowID uuid.UUID, id stack.TransportEndpointID, pkt *stack.PacketBuffer, icmpType, icmpCode uint8) bool {
select {
case f.pingSemaphore <- struct{}{}:
icmpData := stack.PayloadSince(pkt.TransportHeader()).ToSlice()
rxBytes := pkt.Size()
go func() {
defer func() { <-f.pingSemaphore }()
if f.hasRawICMPv6Access {
f.handleICMPViaSocket(flowID, id, icmpType, icmpCode, icmpData, rxBytes, true)
} else {
f.handleICMPv6ViaPing(flowID, id, icmpType, icmpCode, icmpData, rxBytes)
}
}()
default:
f.logger.Debug3("forwarder: ICMPv6 rate limit exceeded for %v type %v code %v", epID(id), icmpType, icmpCode)
}
return true
}
// handleICMPv6ViaPing uses the system ping6 binary for ICMPv6 echo.
func (f *Forwarder) handleICMPv6ViaPing(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int) {
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
defer cancel()
dstIP := f.determineDialAddr(id.LocalAddress)
cmd := buildPingCommand(ctx, dstIP, 5*time.Second)
pingStart := time.Now()
if err := cmd.Run(); err != nil {
f.logger.Warn4("forwarder: Ping6 failed for %v type %v code %v: %v", epID(id), icmpType, icmpCode, err)
return
}
rtt := time.Since(pingStart).Round(10 * time.Microsecond)
f.logger.Trace3("forwarder: Forwarded ICMPv6 echo request %v type %v code %v",
epID(id), icmpType, icmpCode)
txBytes := f.synthesizeICMPv6EchoReply(id, icmpData)
f.logger.Trace4("forwarder: Forwarded ICMPv6 echo reply %v type %v code %v (rtt=%v, ping binary)",
epID(id), icmpType, icmpCode, rtt)
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
}
// synthesizeICMPv6EchoReply creates an ICMPv6 echo reply and injects it back.
func (f *Forwarder) synthesizeICMPv6EchoReply(id stack.TransportEndpointID, icmpData []byte) int {
replyICMP := make([]byte, len(icmpData))
copy(replyICMP, icmpData)
replyHdr := header.ICMPv6(replyICMP)
replyHdr.SetType(header.ICMPv6EchoReply)
replyHdr.SetChecksum(0)
// ICMPv6Checksum computes the pseudo-header internally from Src/Dst.
// Header contains the full ICMP message, so PayloadCsum/PayloadLen are zero.
replyHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: replyHdr,
Src: id.LocalAddress,
Dst: id.RemoteAddress,
}))
return f.injectICMPv6Reply(id, replyICMP)
}
// injectICMPv6Reply wraps an ICMPv6 payload in an IPv6 header and sends to the peer.
func (f *Forwarder) injectICMPv6Reply(id stack.TransportEndpointID, icmpPayload []byte) int {
ipHdr := make([]byte, header.IPv6MinimumSize)
ip := header.IPv6(ipHdr)
ip.Encode(&header.IPv6Fields{
PayloadLength: uint16(len(icmpPayload)),
TransportProtocol: header.ICMPv6ProtocolNumber,
HopLimit: 64,
SrcAddr: id.LocalAddress,
DstAddr: id.RemoteAddress,
})
fullPacket := make([]byte, 0, len(ipHdr)+len(icmpPayload))
fullPacket = append(fullPacket, ipHdr...)
fullPacket = append(fullPacket, icmpPayload...)
if err := f.endpoint.device.CreateOutboundPacket(fullPacket, id.RemoteAddress.AsSlice()); err != nil {
f.logger.Error1("forwarder: Failed to send ICMPv6 reply to peer: %v", err)
return 0
}
return len(fullPacket)
}
const (
pingBin = "ping"
ping6Bin = "ping6"
)
// buildPingCommand creates a platform-specific ping command. // buildPingCommand creates a platform-specific ping command.
func buildPingCommand(ctx context.Context, target net.IP, timeout time.Duration) *exec.Cmd { // Most platforms auto-detect IPv6 from raw addresses. macOS/iOS/OpenBSD require ping6.
func buildPingCommand(ctx context.Context, target netip.Addr, timeout time.Duration) *exec.Cmd {
timeoutSec := int(timeout.Seconds()) timeoutSec := int(timeout.Seconds())
if timeoutSec < 1 { if timeoutSec < 1 {
timeoutSec = 1 timeoutSec = 1
} }
isV6 := target.Is6()
timeoutStr := fmt.Sprintf("%d", timeoutSec)
switch runtime.GOOS { switch runtime.GOOS {
case "linux", "android": case "linux", "android":
return exec.CommandContext(ctx, "ping", "-c", "1", "-W", fmt.Sprintf("%d", timeoutSec), "-q", target.String()) return exec.CommandContext(ctx, pingBin, "-c", "1", "-W", timeoutStr, "-q", target.String())
case "darwin", "ios": case "darwin", "ios":
return exec.CommandContext(ctx, "ping", "-c", "1", "-t", fmt.Sprintf("%d", timeoutSec), "-q", target.String()) bin := pingBin
if isV6 {
bin = ping6Bin
}
return exec.CommandContext(ctx, bin, "-c", "1", "-t", timeoutStr, "-q", target.String())
case "freebsd": case "freebsd":
return exec.CommandContext(ctx, "ping", "-c", "1", "-t", fmt.Sprintf("%d", timeoutSec), target.String()) return exec.CommandContext(ctx, pingBin, "-c", "1", "-t", timeoutStr, target.String())
case "openbsd", "netbsd": case "openbsd", "netbsd":
return exec.CommandContext(ctx, "ping", "-c", "1", "-w", fmt.Sprintf("%d", timeoutSec), target.String()) bin := pingBin
if isV6 {
bin = ping6Bin
}
return exec.CommandContext(ctx, bin, "-c", "1", "-w", timeoutStr, target.String())
case "windows": case "windows":
return exec.CommandContext(ctx, "ping", "-n", "1", "-w", fmt.Sprintf("%d", timeoutSec*1000), target.String()) return exec.CommandContext(ctx, pingBin, "-n", "1", "-w", fmt.Sprintf("%d", timeoutSec*1000), target.String())
default: default:
return exec.CommandContext(ctx, "ping", "-c", "1", target.String()) return exec.CommandContext(ctx, pingBin, "-c", "1", target.String())
} }
} }

View File

@@ -2,10 +2,9 @@ package forwarder
import ( import (
"context" "context"
"fmt"
"io" "io"
"net" "net"
"net/netip" "strconv"
"sync" "sync"
"github.com/google/uuid" "github.com/google/uuid"
@@ -33,7 +32,7 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
} }
}() }()
dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort) dialAddr := net.JoinHostPort(f.determineDialAddr(id.LocalAddress).String(), strconv.Itoa(int(id.LocalPort)))
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr) outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
if err != nil { if err != nil {
@@ -133,15 +132,14 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
} }
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) { func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) {
srcIp := netip.AddrFrom4(id.RemoteAddress.As4()) srcIp := addrToNetipAddr(id.RemoteAddress)
dstIp := netip.AddrFrom4(id.LocalAddress.As4()) dstIp := addrToNetipAddr(id.LocalAddress)
fields := nftypes.EventFields{ fields := nftypes.EventFields{
FlowID: flowID, FlowID: flowID,
Type: typ, Type: typ,
Direction: nftypes.Ingress, Direction: nftypes.Ingress,
Protocol: nftypes.TCP, Protocol: nftypes.TCP,
// TODO: handle ipv6
SourceIP: srcIp, SourceIP: srcIp,
DestIP: dstIp, DestIP: dstIp,
SourcePort: id.RemotePort, SourcePort: id.RemotePort,

View File

@@ -6,7 +6,7 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"net/netip" "strconv"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -158,7 +158,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
} }
}() }()
dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort) dstAddr := net.JoinHostPort(f.determineDialAddr(id.LocalAddress).String(), strconv.Itoa(int(id.LocalPort)))
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr) outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
if err != nil { if err != nil {
f.logger.Debug2("forwarder: UDP dial error for %v: %v", epID(id), err) f.logger.Debug2("forwarder: UDP dial error for %v: %v", epID(id), err)
@@ -276,15 +276,14 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
// sendUDPEvent stores flow events for UDP connections // sendUDPEvent stores flow events for UDP connections
func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) { func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) {
srcIp := netip.AddrFrom4(id.RemoteAddress.As4()) srcIp := addrToNetipAddr(id.RemoteAddress)
dstIp := netip.AddrFrom4(id.LocalAddress.As4()) dstIp := addrToNetipAddr(id.LocalAddress)
fields := nftypes.EventFields{ fields := nftypes.EventFields{
FlowID: flowID, FlowID: flowID,
Type: typ, Type: typ,
Direction: nftypes.Ingress, Direction: nftypes.Ingress,
Protocol: nftypes.UDP, Protocol: nftypes.UDP,
// TODO: handle ipv6
SourceIP: srcIp, SourceIP: srcIp,
DestIP: dstIp, DestIP: dstIp,
SourcePort: id.RemotePort, SourcePort: id.RemotePort,

View File

@@ -4,89 +4,32 @@ import (
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"sync" "sync/atomic"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter/common" "github.com/netbirdio/netbird/client/firewall/uspfilter/common"
) )
type localIPManager struct { // localIPSnapshot is an immutable snapshot of local IP addresses, swapped
mu sync.RWMutex // atomically so reads are lock-free.
type localIPSnapshot struct {
// fixed-size high array for upper byte of a IPv4 address ips map[netip.Addr]struct{}
ipv4Bitmap [256]*ipv4LowBitmap
} }
// ipv4LowBitmap is a map for the low 16 bits of a IPv4 address type localIPManager struct {
type ipv4LowBitmap struct { snapshot atomic.Pointer[localIPSnapshot]
bitmap [8192]uint32
} }
func newLocalIPManager() *localIPManager { func newLocalIPManager() *localIPManager {
return &localIPManager{} m := &localIPManager{}
m.snapshot.Store(&localIPSnapshot{
ips: make(map[netip.Addr]struct{}),
})
return m
} }
func (m *localIPManager) setBitmapBit(ip net.IP) { func processInterface(iface net.Interface, ips map[netip.Addr]struct{}, addresses *[]netip.Addr) {
ipv4 := ip.To4()
if ipv4 == nil {
return
}
high := uint16(ipv4[0])
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
index := low / 32
bit := low % 32
if m.ipv4Bitmap[high] == nil {
m.ipv4Bitmap[high] = &ipv4LowBitmap{}
}
m.ipv4Bitmap[high].bitmap[index] |= 1 << bit
}
func (m *localIPManager) setBitInBitmap(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
if !ip.Is4() {
return
}
ipv4 := ip.AsSlice()
high := uint16(ipv4[0])
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
if bitmap[high] == nil {
bitmap[high] = &ipv4LowBitmap{}
}
index := low / 32
bit := low % 32
bitmap[high].bitmap[index] |= 1 << bit
if _, exists := ipv4Set[ip]; !exists {
ipv4Set[ip] = struct{}{}
*ipv4Addresses = append(*ipv4Addresses, ip)
}
}
func (m *localIPManager) checkBitmapBit(ip []byte) bool {
high := uint16(ip[0])
low := (uint16(ip[1]) << 8) | (uint16(ip[2]) << 4) | uint16(ip[3])
if m.ipv4Bitmap[high] == nil {
return false
}
index := low / 32
bit := low % 32
return (m.ipv4Bitmap[high].bitmap[index] & (1 << bit)) != 0
}
func (m *localIPManager) processIP(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) error {
m.setBitInBitmap(ip, bitmap, ipv4Set, ipv4Addresses)
return nil
}
func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
addrs, err := iface.Addrs() addrs, err := iface.Addrs()
if err != nil { if err != nil {
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err) log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
@@ -104,18 +47,19 @@ func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv
continue continue
} }
addr, ok := netip.AddrFromSlice(ip) parsed, ok := netip.AddrFromSlice(ip)
if !ok { if !ok {
log.Warnf("invalid IP address %s in interface %s", ip.String(), iface.Name) log.Warnf("invalid IP address %s in interface %s", ip.String(), iface.Name)
continue continue
} }
if err := m.processIP(addr.Unmap(), bitmap, ipv4Set, ipv4Addresses); err != nil { parsed = parsed.Unmap()
log.Debugf("process IP failed: %v", err) ips[parsed] = struct{}{}
} *addresses = append(*addresses, parsed)
} }
} }
// UpdateLocalIPs rebuilds the local IP snapshot and swaps it in atomically.
func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) { func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
@@ -123,20 +67,20 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
} }
}() }()
var newIPv4Bitmap [256]*ipv4LowBitmap ips := make(map[netip.Addr]struct{})
ipv4Set := make(map[netip.Addr]struct{}) var addresses []netip.Addr
var ipv4Addresses []netip.Addr
// 127.0.0.0/8 // loopback
newIPv4Bitmap[127] = &ipv4LowBitmap{} ips[netip.AddrFrom4([4]byte{127, 0, 0, 1})] = struct{}{}
for i := 0; i < 8192; i++ { ips[netip.IPv6Loopback()] = struct{}{}
// #nosec G602 -- bitmap is defined as [8192]uint32, loop range is correct
newIPv4Bitmap[127].bitmap[i] = 0xFFFFFFFF
}
if iface != nil { if iface != nil {
if err := m.processIP(iface.Address().IP, &newIPv4Bitmap, ipv4Set, &ipv4Addresses); err != nil { ip := iface.Address().IP
return err ips[ip] = struct{}{}
addresses = append(addresses, ip)
if v6 := iface.Address().IPv6; v6.IsValid() {
ips[v6] = struct{}{}
addresses = append(addresses, v6)
} }
} }
@@ -147,25 +91,24 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
// TODO: filter out down interfaces (net.FlagUp). Also handle the reverse // TODO: filter out down interfaces (net.FlagUp). Also handle the reverse
// case where an interface comes up between refreshes. // case where an interface comes up between refreshes.
for _, intf := range interfaces { for _, intf := range interfaces {
m.processInterface(intf, &newIPv4Bitmap, ipv4Set, &ipv4Addresses) processInterface(intf, ips, &addresses)
} }
} }
m.mu.Lock() m.snapshot.Store(&localIPSnapshot{ips: ips})
m.ipv4Bitmap = newIPv4Bitmap
m.mu.Unlock()
log.Debugf("Local IPv4 addresses: %v", ipv4Addresses) log.Debugf("Local IP addresses: %v", addresses)
return nil return nil
} }
// IsLocalIP checks if the given IP is a local address. Lock-free on the read path.
func (m *localIPManager) IsLocalIP(ip netip.Addr) bool { func (m *localIPManager) IsLocalIP(ip netip.Addr) bool {
if !ip.Is4() { s := m.snapshot.Load()
return false
if ip.Is4() && ip.As4()[0] == 127 {
return true
} }
m.mu.RLock() _, found := s.ips[ip]
defer m.mu.RUnlock() return found
return m.checkBitmapBit(ip.AsSlice())
} }

View File

@@ -0,0 +1,72 @@
package uspfilter
import (
"net/netip"
"testing"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
func setupManager(b *testing.B) *localIPManager {
b.Helper()
m := newLocalIPManager()
mock := &IFaceMock{
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: netip.MustParseAddr("100.64.0.1"),
Network: netip.MustParsePrefix("100.64.0.0/16"),
IPv6: netip.MustParseAddr("fd00::1"),
IPv6Net: netip.MustParsePrefix("fd00::/64"),
}
},
}
if err := m.UpdateLocalIPs(mock); err != nil {
b.Fatalf("UpdateLocalIPs: %v", err)
}
return m
}
func BenchmarkIsLocalIP_v4_hit(b *testing.B) {
m := setupManager(b)
ip := netip.MustParseAddr("100.64.0.1")
b.ResetTimer()
for i := 0; i < b.N; i++ {
m.IsLocalIP(ip)
}
}
func BenchmarkIsLocalIP_v4_miss(b *testing.B) {
m := setupManager(b)
ip := netip.MustParseAddr("8.8.8.8")
b.ResetTimer()
for i := 0; i < b.N; i++ {
m.IsLocalIP(ip)
}
}
func BenchmarkIsLocalIP_v6_hit(b *testing.B) {
m := setupManager(b)
ip := netip.MustParseAddr("fd00::1")
b.ResetTimer()
for i := 0; i < b.N; i++ {
m.IsLocalIP(ip)
}
}
func BenchmarkIsLocalIP_v6_miss(b *testing.B) {
m := setupManager(b)
ip := netip.MustParseAddr("2001:db8::1")
b.ResetTimer()
for i := 0; i < b.N; i++ {
m.IsLocalIP(ip)
}
}
func BenchmarkIsLocalIP_loopback(b *testing.B) {
m := setupManager(b)
ip := netip.MustParseAddr("127.0.0.1")
b.ResetTimer()
for i := 0; i < b.N; i++ {
m.IsLocalIP(ip)
}
}

View File

@@ -72,14 +72,45 @@ func TestLocalIPManager(t *testing.T) {
expected: false, expected: false,
}, },
{ {
name: "IPv6 address", name: "IPv6 address matches",
setupAddr: wgaddr.Address{ setupAddr: wgaddr.Address{
IP: netip.MustParseAddr("fe80::1"), IP: netip.MustParseAddr("100.64.0.1"),
Network: netip.MustParsePrefix("100.64.0.0/16"),
IPv6: netip.MustParseAddr("fd00::1"),
IPv6Net: netip.MustParsePrefix("fd00::/64"),
},
testIP: netip.MustParseAddr("fd00::1"),
expected: true,
},
{
name: "IPv6 address does not match",
setupAddr: wgaddr.Address{
IP: netip.MustParseAddr("100.64.0.1"),
Network: netip.MustParsePrefix("100.64.0.0/16"),
IPv6: netip.MustParseAddr("fd00::1"),
IPv6Net: netip.MustParsePrefix("fd00::/64"),
},
testIP: netip.MustParseAddr("fd00::99"),
expected: false,
},
{
name: "No aliasing between similar IPs",
setupAddr: wgaddr.Address{
IP: netip.MustParseAddr("192.168.1.1"),
Network: netip.MustParsePrefix("192.168.1.0/24"), Network: netip.MustParsePrefix("192.168.1.0/24"),
}, },
testIP: netip.MustParseAddr("fe80::1"), testIP: netip.MustParseAddr("192.168.0.17"),
expected: false, expected: false,
}, },
{
name: "IPv6 loopback",
setupAddr: wgaddr.Address{
IP: netip.MustParseAddr("100.64.0.1"),
Network: netip.MustParsePrefix("100.64.0.0/16"),
},
testIP: netip.MustParseAddr("::1"),
expected: true,
},
} }
for _, tt := range tests { for _, tt := range tests {
@@ -171,90 +202,3 @@ func TestLocalIPManager_AllInterfaces(t *testing.T) {
}) })
} }
} }
// MapImplementation is a version using map[string]struct{}
type MapImplementation struct {
localIPs map[string]struct{}
}
func BenchmarkIPChecks(b *testing.B) {
interfaces := make([]net.IP, 16)
for i := range interfaces {
interfaces[i] = net.IPv4(10, 0, byte(i>>8), byte(i))
}
// Setup bitmap
bitmapManager := newLocalIPManager()
for _, ip := range interfaces[:8] { // Add half of IPs
bitmapManager.setBitmapBit(ip)
}
// Setup map version
mapManager := &MapImplementation{
localIPs: make(map[string]struct{}),
}
for _, ip := range interfaces[:8] {
mapManager.localIPs[ip.String()] = struct{}{}
}
b.Run("Bitmap_Hit", func(b *testing.B) {
ip := interfaces[4]
b.ResetTimer()
for i := 0; i < b.N; i++ {
bitmapManager.checkBitmapBit(ip)
}
})
b.Run("Bitmap_Miss", func(b *testing.B) {
ip := interfaces[12]
b.ResetTimer()
for i := 0; i < b.N; i++ {
bitmapManager.checkBitmapBit(ip)
}
})
b.Run("Map_Hit", func(b *testing.B) {
ip := interfaces[4]
b.ResetTimer()
for i := 0; i < b.N; i++ {
// nolint:gosimple
_ = mapManager.localIPs[ip.String()]
}
})
b.Run("Map_Miss", func(b *testing.B) {
ip := interfaces[12]
b.ResetTimer()
for i := 0; i < b.N; i++ {
// nolint:gosimple
_ = mapManager.localIPs[ip.String()]
}
})
}
func BenchmarkWGPosition(b *testing.B) {
wgIP := net.ParseIP("10.10.0.1")
// Create two managers - one checks WG IP first, other checks it last
b.Run("WG_First", func(b *testing.B) {
bm := newLocalIPManager()
bm.setBitmapBit(wgIP)
b.ResetTimer()
for i := 0; i < b.N; i++ {
bm.checkBitmapBit(wgIP)
}
})
b.Run("WG_Last", func(b *testing.B) {
bm := newLocalIPManager()
// Fill with other IPs first
for i := 0; i < 15; i++ {
bm.setBitmapBit(net.IPv4(10, 0, byte(i>>8), byte(i)))
}
bm.setBitmapBit(wgIP) // Add WG IP last
b.ResetTimer()
for i := 0; i < b.N; i++ {
bm.checkBitmapBit(wgIP)
}
})
}

View File

@@ -13,8 +13,6 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
) )
var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT")
var ( var (
errInvalidIPHeaderLength = errors.New("invalid IP header length") errInvalidIPHeaderLength = errors.New("invalid IP header length")
) )
@@ -25,10 +23,33 @@ const (
destinationPortOffset = 2 destinationPortOffset = 2
// IP address offsets in IPv4 header // IP address offsets in IPv4 header
sourceIPOffset = 12 ipv4SrcOffset = 12
destinationIPOffset = 16 ipv4DstOffset = 16
// IP address offsets in IPv6 header
ipv6SrcOffset = 8
ipv6DstOffset = 24
// IPv6 fixed header length
ipv6HeaderLen = 40
) )
// ipHeaderLen returns the IP header length based on the decoded layer type.
func ipHeaderLen(d *decoder) (int, error) {
switch d.decoded[0] {
case layers.LayerTypeIPv4:
n := int(d.ip4.IHL) * 4
if n < 20 {
return 0, errInvalidIPHeaderLength
}
return n, nil
case layers.LayerTypeIPv6:
return ipv6HeaderLen, nil
default:
return 0, fmt.Errorf("unknown IP layer: %v", d.decoded[0])
}
}
// ipv4Checksum calculates IPv4 header checksum. // ipv4Checksum calculates IPv4 header checksum.
func ipv4Checksum(header []byte) uint16 { func ipv4Checksum(header []byte) uint16 {
if len(header) < 20 { if len(header) < 20 {
@@ -234,14 +255,13 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
return false return false
} }
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]}) _, dstIP := extractPacketIPs(packetData, d)
translatedIP, exists := m.getDNATTranslation(dstIP) translatedIP, exists := m.getDNATTranslation(dstIP)
if !exists { if !exists {
return false return false
} }
if err := m.rewritePacketIP(packetData, d, translatedIP, destinationIPOffset); err != nil { if err := m.rewritePacketIP(packetData, d, translatedIP, false); err != nil {
m.logger.Error1("failed to rewrite packet destination: %v", err) m.logger.Error1("failed to rewrite packet destination: %v", err)
return false return false
} }
@@ -256,14 +276,13 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
return false return false
} }
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]}) srcIP, _ := extractPacketIPs(packetData, d)
originalIP, exists := m.findReverseDNATMapping(srcIP) originalIP, exists := m.findReverseDNATMapping(srcIP)
if !exists { if !exists {
return false return false
} }
if err := m.rewritePacketIP(packetData, d, originalIP, sourceIPOffset); err != nil { if err := m.rewritePacketIP(packetData, d, originalIP, true); err != nil {
m.logger.Error1("failed to rewrite packet source: %v", err) m.logger.Error1("failed to rewrite packet source: %v", err)
return false return false
} }
@@ -272,38 +291,96 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
return true return true
} }
// rewritePacketIP replaces an IP address (source or destination) in the packet and updates checksums. // extractPacketIPs extracts src and dst IP addresses directly from raw packet bytes.
func (m *Manager) rewritePacketIP(packetData []byte, d *decoder, newIP netip.Addr, ipOffset int) error { func extractPacketIPs(packetData []byte, d *decoder) (src, dst netip.Addr) {
switch d.decoded[0] {
case layers.LayerTypeIPv4:
src = netip.AddrFrom4([4]byte{packetData[ipv4SrcOffset], packetData[ipv4SrcOffset+1], packetData[ipv4SrcOffset+2], packetData[ipv4SrcOffset+3]})
dst = netip.AddrFrom4([4]byte{packetData[ipv4DstOffset], packetData[ipv4DstOffset+1], packetData[ipv4DstOffset+2], packetData[ipv4DstOffset+3]})
case layers.LayerTypeIPv6:
src = netip.AddrFrom16([16]byte(packetData[ipv6SrcOffset : ipv6SrcOffset+16]))
dst = netip.AddrFrom16([16]byte(packetData[ipv6DstOffset : ipv6DstOffset+16]))
}
return src, dst
}
// rewritePacketIP replaces a source (isSource=true) or destination IP address in the packet and updates checksums.
func (m *Manager) rewritePacketIP(packetData []byte, d *decoder, newIP netip.Addr, isSource bool) error {
hdrLen, err := ipHeaderLen(d)
if err != nil {
return err
}
switch d.decoded[0] {
case layers.LayerTypeIPv4:
return m.rewriteIPv4(packetData, d, newIP, hdrLen, isSource)
case layers.LayerTypeIPv6:
return m.rewriteIPv6(packetData, d, newIP, hdrLen, isSource)
default:
return fmt.Errorf("unknown IP layer: %v", d.decoded[0])
}
}
func (m *Manager) rewriteIPv4(packetData []byte, d *decoder, newIP netip.Addr, hdrLen int, isSource bool) error {
if !newIP.Is4() { if !newIP.Is4() {
return ErrIPv4Only return fmt.Errorf("cannot write IPv6 address into IPv4 packet")
}
offset := ipv4DstOffset
if isSource {
offset = ipv4SrcOffset
} }
var oldIP [4]byte var oldIP [4]byte
copy(oldIP[:], packetData[ipOffset:ipOffset+4]) copy(oldIP[:], packetData[offset:offset+4])
newIPBytes := newIP.As4() newIPBytes := newIP.As4()
copy(packetData[offset:offset+4], newIPBytes[:])
copy(packetData[ipOffset:ipOffset+4], newIPBytes[:]) // Recalculate IPv4 header checksum
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return errInvalidIPHeaderLength
}
binary.BigEndian.PutUint16(packetData[10:12], 0) binary.BigEndian.PutUint16(packetData[10:12], 0)
ipChecksum := ipv4Checksum(packetData[:ipHeaderLen]) binary.BigEndian.PutUint16(packetData[10:12], ipv4Checksum(packetData[:hdrLen]))
binary.BigEndian.PutUint16(packetData[10:12], ipChecksum)
// Update transport checksums incrementally
if len(d.decoded) > 1 { if len(d.decoded) > 1 {
switch d.decoded[1] { switch d.decoded[1] {
case layers.LayerTypeTCP: case layers.LayerTypeTCP:
m.updateTCPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:]) m.updateTCPChecksum(packetData, hdrLen, oldIP[:], newIPBytes[:])
case layers.LayerTypeUDP: case layers.LayerTypeUDP:
m.updateUDPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:]) m.updateUDPChecksum(packetData, hdrLen, oldIP[:], newIPBytes[:])
case layers.LayerTypeICMPv4: case layers.LayerTypeICMPv4:
m.updateICMPChecksum(packetData, ipHeaderLen) m.updateICMPChecksum(packetData, hdrLen)
} }
} }
return nil
}
func (m *Manager) rewriteIPv6(packetData []byte, d *decoder, newIP netip.Addr, hdrLen int, isSource bool) error {
if !newIP.Is6() {
return fmt.Errorf("cannot write IPv4 address into IPv6 packet")
}
offset := ipv6DstOffset
if isSource {
offset = ipv6SrcOffset
}
var oldIP [16]byte
copy(oldIP[:], packetData[offset:offset+16])
newIPBytes := newIP.As16()
copy(packetData[offset:offset+16], newIPBytes[:])
// IPv6 has no header checksum, only update transport checksums
if len(d.decoded) > 1 {
switch d.decoded[1] {
case layers.LayerTypeTCP:
m.updateTCPChecksum(packetData, hdrLen, oldIP[:], newIPBytes[:])
case layers.LayerTypeUDP:
m.updateUDPChecksum(packetData, hdrLen, oldIP[:], newIPBytes[:])
case layers.LayerTypeICMPv6:
// ICMPv6 checksum includes pseudo-header with addresses, use incremental update
m.updateICMPv6Checksum(packetData, hdrLen, oldIP[:], newIPBytes[:])
}
}
return nil return nil
} }
@@ -351,6 +428,20 @@ func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) {
binary.BigEndian.PutUint16(icmpData[2:4], checksum) binary.BigEndian.PutUint16(icmpData[2:4], checksum)
} }
// updateICMPv6Checksum updates ICMPv6 checksum after address change.
// ICMPv6 uses a pseudo-header (like TCP/UDP), so incremental update applies.
func (m *Manager) updateICMPv6Checksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
icmpStart := ipHeaderLen
if len(packetData) < icmpStart+4 {
return
}
checksumOffset := icmpStart + 2
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
newChecksum := incrementalUpdate(oldChecksum, oldIP, newIP)
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
}
// incrementalUpdate performs incremental checksum update per RFC 1624. // incrementalUpdate performs incremental checksum update per RFC 1624.
func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 { func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
sum := uint32(^oldChecksum) sum := uint32(^oldChecksum)
@@ -532,12 +623,12 @@ func (m *Manager) applyPortRule(packetData []byte, d *decoder, srcIP, dstIP neti
// rewriteTCPPort rewrites a TCP port (source or destination) and updates checksum. // rewriteTCPPort rewrites a TCP port (source or destination) and updates checksum.
func (m *Manager) rewriteTCPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error { func (m *Manager) rewriteTCPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error {
ipHeaderLen := int(d.ip4.IHL) * 4 hdrLen, err := ipHeaderLen(d)
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { if err != nil {
return errInvalidIPHeaderLength return err
} }
tcpStart := ipHeaderLen tcpStart := hdrLen
if len(packetData) < tcpStart+4 { if len(packetData) < tcpStart+4 {
return fmt.Errorf("packet too short for TCP header") return fmt.Errorf("packet too short for TCP header")
} }
@@ -563,12 +654,12 @@ func (m *Manager) rewriteTCPPort(packetData []byte, d *decoder, newPort uint16,
// rewriteUDPPort rewrites a UDP port (source or destination) and updates checksum. // rewriteUDPPort rewrites a UDP port (source or destination) and updates checksum.
func (m *Manager) rewriteUDPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error { func (m *Manager) rewriteUDPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error {
ipHeaderLen := int(d.ip4.IHL) * 4 hdrLen, err := ipHeaderLen(d)
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { if err != nil {
return errInvalidIPHeaderLength return err
} }
udpStart := ipHeaderLen udpStart := hdrLen
if len(packetData) < udpStart+8 { if len(packetData) < udpStart+8 {
return fmt.Errorf("packet too short for UDP header") return fmt.Errorf("packet too short for UDP header")
} }

View File

@@ -342,12 +342,17 @@ func BenchmarkDNATMemoryAllocations(b *testing.B) {
// Parse the packet fresh each time to get a clean decoder // Parse the packet fresh each time to get a clean decoder
d := &decoder{decoded: []gopacket.LayerType{}} d := &decoder{decoded: []gopacket.LayerType{}}
d.parser = gopacket.NewDecodingLayerParser( d.parser4 = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4, layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
) )
d.parser.IgnoreUnsupported = true d.parser4.IgnoreUnsupported = true
err = d.parser.DecodeLayers(testPacket, &d.decoded) d.parser6 = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv6,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser6.IgnoreUnsupported = true
err = d.decodePacket(testPacket)
assert.NoError(b, err) assert.NoError(b, err)
manager.translateOutboundDNAT(testPacket, d) manager.translateOutboundDNAT(testPacket, d)
@@ -371,12 +376,17 @@ func BenchmarkDirectIPExtraction(b *testing.B) {
b.Run("decoder_extraction", func(b *testing.B) { b.Run("decoder_extraction", func(b *testing.B) {
// Create decoder once for comparison // Create decoder once for comparison
d := &decoder{decoded: []gopacket.LayerType{}} d := &decoder{decoded: []gopacket.LayerType{}}
d.parser = gopacket.NewDecodingLayerParser( d.parser4 = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4, layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
) )
d.parser.IgnoreUnsupported = true d.parser4.IgnoreUnsupported = true
err := d.parser.DecodeLayers(packet, &d.decoded) d.parser6 = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv6,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser6.IgnoreUnsupported = true
err := d.decodePacket(packet)
assert.NoError(b, err) assert.NoError(b, err)
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {

View File

@@ -86,13 +86,18 @@ func parsePacket(t testing.TB, packetData []byte) *decoder {
d := &decoder{ d := &decoder{
decoded: []gopacket.LayerType{}, decoded: []gopacket.LayerType{},
} }
d.parser = gopacket.NewDecodingLayerParser( d.parser4 = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4, layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
) )
d.parser.IgnoreUnsupported = true d.parser4.IgnoreUnsupported = true
d.parser6 = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv6,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser6.IgnoreUnsupported = true
err := d.parser.DecodeLayers(packetData, &d.decoded) err := d.decodePacket(packetData)
require.NoError(t, err) require.NoError(t, err)
return d return d
} }

View File

@@ -112,10 +112,13 @@ func (t *PacketTrace) AddResultWithForwarder(stage PacketStage, message string,
} }
func (p *PacketBuilder) Build() ([]byte, error) { func (p *PacketBuilder) Build() ([]byte, error) {
ip := p.buildIPLayer() ipLayer, err := p.buildIPLayer()
pktLayers := []gopacket.SerializableLayer{ip} if err != nil {
return nil, err
}
pktLayers := []gopacket.SerializableLayer{ipLayer}
transportLayer, err := p.buildTransportLayer(ip) transportLayer, err := p.buildTransportLayer(ipLayer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -129,30 +132,43 @@ func (p *PacketBuilder) Build() ([]byte, error) {
return serializePacket(pktLayers) return serializePacket(pktLayers)
} }
func (p *PacketBuilder) buildIPLayer() *layers.IPv4 { func (p *PacketBuilder) buildIPLayer() (gopacket.SerializableLayer, error) {
if p.SrcIP.Is4() != p.DstIP.Is4() {
return nil, fmt.Errorf("mixed address families: src=%s dst=%s", p.SrcIP, p.DstIP)
}
proto := getIPProtocolNumber(p.Protocol, p.SrcIP.Is6())
if p.SrcIP.Is6() {
return &layers.IPv6{
Version: 6,
HopLimit: 64,
NextHeader: proto,
SrcIP: p.SrcIP.AsSlice(),
DstIP: p.DstIP.AsSlice(),
}, nil
}
return &layers.IPv4{ return &layers.IPv4{
Version: 4, Version: 4,
TTL: 64, TTL: 64,
Protocol: layers.IPProtocol(getIPProtocolNumber(p.Protocol)), Protocol: proto,
SrcIP: p.SrcIP.AsSlice(), SrcIP: p.SrcIP.AsSlice(),
DstIP: p.DstIP.AsSlice(), DstIP: p.DstIP.AsSlice(),
} }, nil
} }
func (p *PacketBuilder) buildTransportLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) { func (p *PacketBuilder) buildTransportLayer(ipLayer gopacket.SerializableLayer) ([]gopacket.SerializableLayer, error) {
switch p.Protocol { switch p.Protocol {
case "tcp": case "tcp":
return p.buildTCPLayer(ip) return p.buildTCPLayer(ipLayer)
case "udp": case "udp":
return p.buildUDPLayer(ip) return p.buildUDPLayer(ipLayer)
case "icmp": case "icmp":
return p.buildICMPLayer() return p.buildICMPLayer(ipLayer)
default: default:
return nil, fmt.Errorf("unsupported protocol: %s", p.Protocol) return nil, fmt.Errorf("unsupported protocol: %s", p.Protocol)
} }
} }
func (p *PacketBuilder) buildTCPLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) { func (p *PacketBuilder) buildTCPLayer(ipLayer gopacket.SerializableLayer) ([]gopacket.SerializableLayer, error) {
tcp := &layers.TCP{ tcp := &layers.TCP{
SrcPort: layers.TCPPort(p.SrcPort), SrcPort: layers.TCPPort(p.SrcPort),
DstPort: layers.TCPPort(p.DstPort), DstPort: layers.TCPPort(p.DstPort),
@@ -164,24 +180,44 @@ func (p *PacketBuilder) buildTCPLayer(ip *layers.IPv4) ([]gopacket.SerializableL
PSH: p.TCPState != nil && p.TCPState.PSH, PSH: p.TCPState != nil && p.TCPState.PSH,
URG: p.TCPState != nil && p.TCPState.URG, URG: p.TCPState != nil && p.TCPState.URG,
} }
if err := tcp.SetNetworkLayerForChecksum(ip); err != nil { if nl, ok := ipLayer.(gopacket.NetworkLayer); ok {
return nil, fmt.Errorf("set network layer for TCP checksum: %w", err) if err := tcp.SetNetworkLayerForChecksum(nl); err != nil {
return nil, fmt.Errorf("set network layer for TCP checksum: %w", err)
}
} }
return []gopacket.SerializableLayer{tcp}, nil return []gopacket.SerializableLayer{tcp}, nil
} }
func (p *PacketBuilder) buildUDPLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) { func (p *PacketBuilder) buildUDPLayer(ipLayer gopacket.SerializableLayer) ([]gopacket.SerializableLayer, error) {
udp := &layers.UDP{ udp := &layers.UDP{
SrcPort: layers.UDPPort(p.SrcPort), SrcPort: layers.UDPPort(p.SrcPort),
DstPort: layers.UDPPort(p.DstPort), DstPort: layers.UDPPort(p.DstPort),
} }
if err := udp.SetNetworkLayerForChecksum(ip); err != nil { if nl, ok := ipLayer.(gopacket.NetworkLayer); ok {
return nil, fmt.Errorf("set network layer for UDP checksum: %w", err) if err := udp.SetNetworkLayerForChecksum(nl); err != nil {
return nil, fmt.Errorf("set network layer for UDP checksum: %w", err)
}
} }
return []gopacket.SerializableLayer{udp}, nil return []gopacket.SerializableLayer{udp}, nil
} }
func (p *PacketBuilder) buildICMPLayer() ([]gopacket.SerializableLayer, error) { func (p *PacketBuilder) buildICMPLayer(ipLayer gopacket.SerializableLayer) ([]gopacket.SerializableLayer, error) {
if p.SrcIP.Is6() || p.DstIP.Is6() {
icmp := &layers.ICMPv6{
TypeCode: layers.CreateICMPv6TypeCode(p.ICMPType, p.ICMPCode),
}
if nl, ok := ipLayer.(gopacket.NetworkLayer); ok {
_ = icmp.SetNetworkLayerForChecksum(nl)
}
if p.ICMPType == layers.ICMPv6TypeEchoRequest || p.ICMPType == layers.ICMPv6TypeEchoReply {
echo := &layers.ICMPv6Echo{
Identifier: 1,
SeqNumber: 1,
}
return []gopacket.SerializableLayer{icmp, echo}, nil
}
return []gopacket.SerializableLayer{icmp}, nil
}
icmp := &layers.ICMPv4{ icmp := &layers.ICMPv4{
TypeCode: layers.CreateICMPv4TypeCode(p.ICMPType, p.ICMPCode), TypeCode: layers.CreateICMPv4TypeCode(p.ICMPType, p.ICMPCode),
} }
@@ -204,14 +240,17 @@ func serializePacket(layers []gopacket.SerializableLayer) ([]byte, error) {
return buf.Bytes(), nil return buf.Bytes(), nil
} }
func getIPProtocolNumber(protocol fw.Protocol) int { func getIPProtocolNumber(protocol fw.Protocol, isV6 bool) layers.IPProtocol {
switch protocol { switch protocol {
case fw.ProtocolTCP: case fw.ProtocolTCP:
return int(layers.IPProtocolTCP) return layers.IPProtocolTCP
case fw.ProtocolUDP: case fw.ProtocolUDP:
return int(layers.IPProtocolUDP) return layers.IPProtocolUDP
case fw.ProtocolICMP: case fw.ProtocolICMP:
return int(layers.IPProtocolICMPv4) if isV6 {
return layers.IPProtocolICMPv6
}
return layers.IPProtocolICMPv4
default: default:
return 0 return 0
} }
@@ -234,7 +273,7 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
trace := &PacketTrace{Direction: direction} trace := &PacketTrace{Direction: direction}
// Initial packet decoding // Initial packet decoding
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { if err := d.decodePacket(packetData); err != nil {
trace.AddResult(StageReceived, fmt.Sprintf("Failed to decode packet: %v", err), false) trace.AddResult(StageReceived, fmt.Sprintf("Failed to decode packet: %v", err), false)
return trace return trace
} }
@@ -256,6 +295,8 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
trace.DestinationPort = uint16(d.udp.DstPort) trace.DestinationPort = uint16(d.udp.DstPort)
case layers.LayerTypeICMPv4: case layers.LayerTypeICMPv4:
trace.Protocol = "ICMP" trace.Protocol = "ICMP"
case layers.LayerTypeICMPv6:
trace.Protocol = "ICMPv6"
} }
trace.AddResult(StageReceived, fmt.Sprintf("Received %s packet: %s:%d -> %s:%d", trace.AddResult(StageReceived, fmt.Sprintf("Received %s packet: %s:%d -> %s:%d",
@@ -319,6 +360,13 @@ func (m *Manager) buildConntrackStateMessage(d *decoder) string {
flags&conntrack.TCPFin != 0) flags&conntrack.TCPFin != 0)
case layers.LayerTypeICMPv4: case layers.LayerTypeICMPv4:
msg += fmt.Sprintf(" (ICMP ID=%d, Seq=%d)", d.icmp4.Id, d.icmp4.Seq) msg += fmt.Sprintf(" (ICMP ID=%d, Seq=%d)", d.icmp4.Id, d.icmp4.Seq)
case layers.LayerTypeICMPv6:
var id, seq uint16
if len(d.icmp6.Payload) >= 4 {
id = uint16(d.icmp6.Payload[0])<<8 | uint16(d.icmp6.Payload[1])
seq = uint16(d.icmp6.Payload[2])<<8 | uint16(d.icmp6.Payload[3])
}
msg += fmt.Sprintf(" (ICMPv6 ID=%d, Seq=%d)", id, seq)
} }
return msg return msg
} }
@@ -415,7 +463,7 @@ func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTr
d := m.decoders.Get().(*decoder) d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d) defer m.decoders.Put(d)
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { if err := d.decodePacket(packetData); err != nil {
trace.AddResult(StageCompleted, "Packet dropped - decode error", false) trace.AddResult(StageCompleted, "Packet dropped - decode error", false)
return trace return trace
} }
@@ -434,7 +482,7 @@ func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTr
func (m *Manager) handleInboundDNAT(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP *netip.Addr) bool { func (m *Manager) handleInboundDNAT(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP *netip.Addr) bool {
portDNATApplied := m.traceInboundPortDNAT(trace, packetData, d) portDNATApplied := m.traceInboundPortDNAT(trace, packetData, d)
if portDNATApplied { if portDNATApplied {
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { if err := d.decodePacket(packetData); err != nil {
trace.AddResult(StageInboundPortDNAT, "Failed to re-decode after port DNAT", false) trace.AddResult(StageInboundPortDNAT, "Failed to re-decode after port DNAT", false)
return true return true
} }
@@ -444,7 +492,7 @@ func (m *Manager) handleInboundDNAT(trace *PacketTrace, packetData []byte, d *de
nat1to1Applied := m.traceInbound1to1NAT(trace, packetData, d) nat1to1Applied := m.traceInbound1to1NAT(trace, packetData, d)
if nat1to1Applied { if nat1to1Applied {
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { if err := d.decodePacket(packetData); err != nil {
trace.AddResult(StageInbound1to1NAT, "Failed to re-decode after 1:1 NAT", false) trace.AddResult(StageInbound1to1NAT, "Failed to re-decode after 1:1 NAT", false)
return true return true
} }
@@ -509,7 +557,7 @@ func (m *Manager) traceInbound1to1NAT(trace *PacketTrace, packetData []byte, d *
return false return false
} }
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]}) srcIP, _ := extractPacketIPs(packetData, d)
translated := m.translateInboundReverse(packetData, d) translated := m.translateInboundReverse(packetData, d)
if translated { if translated {
@@ -539,7 +587,7 @@ func (m *Manager) traceOutbound1to1NAT(trace *PacketTrace, packetData []byte, d
return false return false
} }
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]}) _, dstIP := extractPacketIPs(packetData, d)
translated := m.translateOutboundDNAT(packetData, d) translated := m.translateOutboundDNAT(packetData, d)
if translated { if translated {

View File

@@ -119,7 +119,7 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
if err != nil { if err != nil {
return fmt.Errorf("failed to parse endpoint address: %w", err) return fmt.Errorf("failed to parse endpoint address: %w", err)
} }
addrPort := netip.AddrPortFrom(addr, uint16(endpoint.Port)) addrPort := netip.AddrPortFrom(addr.Unmap(), uint16(endpoint.Port))
c.activityRecorder.UpsertAddress(peerKey, addrPort) c.activityRecorder.UpsertAddress(peerKey, addrPort)
} }
return nil return nil

View File

@@ -2,7 +2,7 @@ package device
// TunAdapter is an interface for create tun device from external service // TunAdapter is an interface for create tun device from external service
type TunAdapter interface { type TunAdapter interface {
ConfigureInterface(address string, mtu int, dns string, searchDomains string, routes string) (int, error) ConfigureInterface(address string, addressV6 string, mtu int, dns string, searchDomains string, routes string) (int, error)
UpdateAddr(address string) error UpdateAddr(address string) error
ProtectSocket(fd int32) bool ProtectSocket(fd int32) bool
} }

View File

@@ -63,7 +63,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
searchDomainsToString = "" searchDomainsToString = ""
} }
fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), int(t.mtu), dns, searchDomainsToString, routesString) fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.address.IPv6String(), int(t.mtu), dns, searchDomainsToString, routesString)
if err != nil { if err != nil {
log.Errorf("failed to create Android interface: %s", err) log.Errorf("failed to create Android interface: %s", err)
return nil, err return nil, err

View File

@@ -6,7 +6,7 @@ import (
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"strings"
"sync" "sync"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -196,18 +196,22 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
} }
} }
// fakeAddress returns a fake address that is used to as an identifier for the peer. // fakeAddress returns a fake address that is used as an identifier for the peer.
// The fake address is in the format of 127.1.x.x where x.x is the last two octets of the peer address. // The fake address is in the format of 127.1.x.x where x.x is derived from the
// last two bytes of the peer address (works for both IPv4 and IPv6).
func fakeAddress(peerAddress *net.UDPAddr) (*netip.AddrPort, error) { func fakeAddress(peerAddress *net.UDPAddr) (*netip.AddrPort, error) {
octets := strings.Split(peerAddress.IP.String(), ".") if peerAddress == nil {
if len(octets) != 4 { return nil, fmt.Errorf("nil peer address")
return nil, fmt.Errorf("invalid IP format")
} }
fakeIP, err := netip.ParseAddr(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3])) addr, ok := netip.AddrFromSlice(peerAddress.IP)
if err != nil { if !ok {
return nil, fmt.Errorf("parse new IP: %w", err) return nil, fmt.Errorf("invalid IP format")
} }
addr = addr.Unmap()
raw := addr.As16()
fakeIP := netip.AddrFrom4([4]byte{127, 1, raw[14], raw[15]})
netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port)) netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port))
return &netipAddr, nil return &netipAddr, nil

View File

@@ -5,7 +5,6 @@ import (
"encoding/hex" "encoding/hex"
"errors" "errors"
"fmt" "fmt"
"net"
"net/netip" "net/netip"
"strconv" "strconv"
"sync" "sync"
@@ -19,6 +18,7 @@ import (
"github.com/netbirdio/netbird/client/internal/acl/id" "github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/domain"
mgmProto "github.com/netbirdio/netbird/shared/management/proto" mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/netiputil"
) )
var ErrSourceRangesEmpty = errors.New("sources range is empty") var ErrSourceRangesEmpty = errors.New("sources range is empty")
@@ -105,6 +105,10 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
newRulePairs := make(map[id.RuleID][]firewall.Rule) newRulePairs := make(map[id.RuleID][]firewall.Rule)
ipsetByRuleSelectors := make(map[string]string) ipsetByRuleSelectors := make(map[string]string)
// TODO: deny rules should be fatal: if a deny rule fails to apply, we must
// roll back all allow rules to avoid a fail-open where allowed traffic bypasses
// the missing deny. Currently we accumulate errors and continue.
var merr *multierror.Error
for _, r := range rules { for _, r := range rules {
// if this rule is member of rule selection with more than DefaultIPsCountForSet // if this rule is member of rule selection with more than DefaultIPsCountForSet
// it's IP address can be used in the ipset for firewall manager which supports it // it's IP address can be used in the ipset for firewall manager which supports it
@@ -117,9 +121,8 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
} }
pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName) pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName)
if err != nil { if err != nil {
log.Errorf("failed to apply firewall rule: %+v, %v", r, err) merr = multierror.Append(merr, fmt.Errorf("apply firewall rule: %w", err))
d.rollBack(newRulePairs) continue
break
} }
if len(rulePair) > 0 { if len(rulePair) > 0 {
d.peerRulesPairs[pairID] = rulePair d.peerRulesPairs[pairID] = rulePair
@@ -127,6 +130,10 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
} }
} }
if merr != nil {
log.Errorf("failed to apply %d peer ACL rule(s): %v", merr.Len(), nberrors.FormatErrorOrNil(merr))
}
for pairID, rules := range d.peerRulesPairs { for pairID, rules := range d.peerRulesPairs {
if _, ok := newRulePairs[pairID]; !ok { if _, ok := newRulePairs[pairID]; !ok {
for _, rule := range rules { for _, rule := range rules {
@@ -216,10 +223,9 @@ func (d *DefaultManager) protoRuleToFirewallRule(
r *mgmProto.FirewallRule, r *mgmProto.FirewallRule,
ipsetName string, ipsetName string,
) (id.RuleID, []firewall.Rule, error) { ) (id.RuleID, []firewall.Rule, error) {
//nolint:staticcheck // PeerIP used for backward compatibility with old management ip, err := extractRuleIP(r)
ip := net.ParseIP(r.PeerIP) if err != nil {
if ip == nil { return "", nil, err
return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule")
} }
protocol, err := convertToFirewallProtocol(r.Protocol) protocol, err := convertToFirewallProtocol(r.Protocol)
@@ -290,13 +296,13 @@ func portInfoEmpty(portInfo *mgmProto.PortInfo) bool {
func (d *DefaultManager) addInRules( func (d *DefaultManager) addInRules(
id []byte, id []byte,
ip net.IP, ip netip.Addr,
protocol firewall.Protocol, protocol firewall.Protocol,
port *firewall.Port, port *firewall.Port,
action firewall.Action, action firewall.Action,
ipsetName string, ipsetName string,
) ([]firewall.Rule, error) { ) ([]firewall.Rule, error) {
rule, err := d.firewall.AddPeerFiltering(id, ip, protocol, nil, port, action, ipsetName) rule, err := d.firewall.AddPeerFiltering(id, ip.AsSlice(), protocol, nil, port, action, ipsetName)
if err != nil { if err != nil {
return nil, fmt.Errorf("add firewall rule: %w", err) return nil, fmt.Errorf("add firewall rule: %w", err)
} }
@@ -306,7 +312,7 @@ func (d *DefaultManager) addInRules(
func (d *DefaultManager) addOutRules( func (d *DefaultManager) addOutRules(
id []byte, id []byte,
ip net.IP, ip netip.Addr,
protocol firewall.Protocol, protocol firewall.Protocol,
port *firewall.Port, port *firewall.Port,
action firewall.Action, action firewall.Action,
@@ -316,7 +322,7 @@ func (d *DefaultManager) addOutRules(
return nil, nil return nil, nil
} }
rule, err := d.firewall.AddPeerFiltering(id, ip, protocol, port, nil, action, ipsetName) rule, err := d.firewall.AddPeerFiltering(id, ip.AsSlice(), protocol, port, nil, action, ipsetName)
if err != nil { if err != nil {
return nil, fmt.Errorf("add firewall rule: %w", err) return nil, fmt.Errorf("add firewall rule: %w", err)
} }
@@ -324,9 +330,9 @@ func (d *DefaultManager) addOutRules(
return rule, nil return rule, nil
} }
// getPeerRuleID() returns unique ID for the rule based on its parameters. // getPeerRuleID returns unique ID for the rule based on its parameters.
func (d *DefaultManager) getPeerRuleID( func (d *DefaultManager) getPeerRuleID(
ip net.IP, ip netip.Addr,
proto firewall.Protocol, proto firewall.Protocol,
direction int, direction int,
port *firewall.Port, port *firewall.Port,
@@ -345,15 +351,25 @@ func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) st
return fmt.Sprintf("%v:%v:%v:%s:%v", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port, rule.PortInfo) return fmt.Sprintf("%v:%v:%v:%s:%v", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port, rule.PortInfo)
} }
func (d *DefaultManager) rollBack(newRulePairs map[id.RuleID][]firewall.Rule) {
log.Debugf("rollback ACL to previous state") // extractRuleIP extracts the peer IP from a firewall rule.
for _, rules := range newRulePairs { // If sourcePrefixes is populated (new management), decode the first entry and use its address.
for _, rule := range rules { // Otherwise fall back to the deprecated PeerIP string field (old management).
if err := d.firewall.DeletePeerRule(rule); err != nil { func extractRuleIP(r *mgmProto.FirewallRule) (netip.Addr, error) {
log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.ID(), err) if len(r.SourcePrefixes) > 0 {
} addr, err := netiputil.DecodeAddr(r.SourcePrefixes[0])
if err != nil {
return netip.Addr{}, fmt.Errorf("decode source prefix: %w", err)
} }
return addr.Unmap(), nil
} }
//nolint:staticcheck // PeerIP used for backward compatibility with old management
addr, err := netip.ParseAddr(r.PeerIP)
if err != nil {
return netip.Addr{}, fmt.Errorf("invalid IP address, skipping firewall rule")
}
return addr.Unmap(), nil
} }
func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewall.Protocol, error) { func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewall.Protocol, error) {

View File

@@ -430,8 +430,6 @@ func isInCGNATRange(ip net.IP) bool {
} }
func TestAnonymizeFirewallRules(t *testing.T) { func TestAnonymizeFirewallRules(t *testing.T) {
// TODO: Add ipv6
// Example iptables-save output // Example iptables-save output
iptablesSave := `# Generated by iptables-save v1.8.7 on Thu Dec 19 10:00:00 2024 iptablesSave := `# Generated by iptables-save v1.8.7 on Thu Dec 19 10:00:00 2024
*filter *filter
@@ -467,17 +465,31 @@ Chain FORWARD (policy ACCEPT 0 packets, 0 bytes)
Chain OUTPUT (policy ACCEPT 0 packets, 0 bytes) Chain OUTPUT (policy ACCEPT 0 packets, 0 bytes)
pkts bytes target prot opt in out source destination` pkts bytes target prot opt in out source destination`
// Example nftables output // Example ip6tables-save output
ip6tablesSave := `# Generated by ip6tables-save v1.8.7 on Thu Dec 19 10:00:00 2024
*filter
:INPUT ACCEPT [0:0]
:FORWARD ACCEPT [0:0]
:OUTPUT ACCEPT [0:0]
-A INPUT -s fd00:1234::1/128 -j ACCEPT
-A INPUT -s 2607:f8b0:4005::1/128 -j DROP
-A FORWARD -s 2001:db8::/32 -d 2607:f8b0:4005::200e/128 -j ACCEPT
COMMIT`
// Example nftables output with IPv6
nftablesRules := `table inet filter { nftablesRules := `table inet filter {
chain input { chain input {
type filter hook input priority filter; policy accept; type filter hook input priority filter; policy accept;
ip saddr 192.168.1.1 accept ip saddr 192.168.1.1 accept
ip saddr 44.192.140.1 drop ip saddr 44.192.140.1 drop
ip6 saddr 2607:f8b0:4005::1 drop
ip6 saddr fd00:1234::1 accept
} }
chain forward { chain forward {
type filter hook forward priority filter; policy accept; type filter hook forward priority filter; policy accept;
ip saddr 10.0.0.0/8 drop ip saddr 10.0.0.0/8 drop
ip saddr 44.192.140.0/24 ip daddr 52.84.12.34/24 accept ip saddr 44.192.140.0/24 ip daddr 52.84.12.34/24 accept
ip6 saddr 2001:db8::/32 ip6 daddr 2607:f8b0:4005::200e/128 accept
} }
}` }`
@@ -540,4 +552,35 @@ Chain OUTPUT (policy ACCEPT 0 packets, 0 bytes)
assert.Contains(t, anonNftables, "table inet filter {") assert.Contains(t, anonNftables, "table inet filter {")
assert.Contains(t, anonNftables, "chain input {") assert.Contains(t, anonNftables, "chain input {")
assert.Contains(t, anonNftables, "type filter hook input priority filter; policy accept;") assert.Contains(t, anonNftables, "type filter hook input priority filter; policy accept;")
// IPv6 public addresses in nftables should be anonymized
assert.NotContains(t, anonNftables, "2607:f8b0:4005::1")
assert.NotContains(t, anonNftables, "2607:f8b0:4005::200e")
assert.NotContains(t, anonNftables, "2001:db8::")
assert.Contains(t, anonNftables, "2001:db8:ffff::") // Default anonymous v6 range
// ULA addresses in nftables should remain unchanged (private)
assert.Contains(t, anonNftables, "fd00:1234::1")
// IPv6 nftables structure preserved
assert.Contains(t, anonNftables, "ip6 saddr")
assert.Contains(t, anonNftables, "ip6 daddr")
// Test ip6tables-save anonymization
anonIp6tablesSave := anonymizer.AnonymizeString(ip6tablesSave)
// ULA (private) IPv6 should remain unchanged
assert.Contains(t, anonIp6tablesSave, "fd00:1234::1/128")
// Public IPv6 addresses should be anonymized
assert.NotContains(t, anonIp6tablesSave, "2607:f8b0:4005::1")
assert.NotContains(t, anonIp6tablesSave, "2607:f8b0:4005::200e")
assert.NotContains(t, anonIp6tablesSave, "2001:db8::")
assert.Contains(t, anonIp6tablesSave, "2001:db8:ffff::") // Default anonymous v6 range
// Structure should be preserved
assert.Contains(t, anonIp6tablesSave, "*filter")
assert.Contains(t, anonIp6tablesSave, "COMMIT")
assert.Contains(t, anonIp6tablesSave, "-j DROP")
assert.Contains(t, anonIp6tablesSave, "-j ACCEPT")
} }

View File

@@ -189,10 +189,10 @@ func (s *serviceViaListener) RuntimeIP() netip.Addr {
} }
// evalListenAddress figure out the listen address for the DNS server // evalListenAddress figures out the listen address for the DNS server.
// first check the 53 port availability on WG interface or lo, if not success // IPv4-only: all peers have a v4 overlay address, and DNS config points to v4.
// pick a random port on WG interface for eBPF, if not success // First checks port 53 on WG interface or lo, then tries eBPF on a random port,
// check the 5053 port availability on WG interface or lo without eBPF usage, // then falls back to port 5053.
func (s *serviceViaListener) evalListenAddress() (netip.Addr, uint16, error) { func (s *serviceViaListener) evalListenAddress() (netip.Addr, uint16, error) {
if s.customAddr != nil { if s.customAddr != nil {
return s.customAddr.Addr(), s.customAddr.Port(), nil return s.customAddr.Addr(), s.customAddr.Port(), nil
@@ -278,7 +278,7 @@ func (s *serviceViaListener) tryToUseeBPF() (ebpfMgr.Manager, uint16, bool) {
} }
ebpfSrv := ebpf.GetEbpfManagerInstance() ebpfSrv := ebpf.GetEbpfManagerInstance()
err = ebpfSrv.LoadDNSFwd(s.wgInterface.Address().IP.String(), int(port)) err = ebpfSrv.LoadDNSFwd(s.wgInterface.Address().IP, int(port))
if err != nil { if err != nil {
log.Warnf("failed to load DNS forwarder eBPF program, error: %s", err) log.Warnf("failed to load DNS forwarder eBPF program, error: %s", err)
return nil, 0, false return nil, 0, false

View File

@@ -21,6 +21,7 @@ import (
"golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/dns/resutil" "github.com/netbirdio/netbird/client/internal/dns/resutil"
"github.com/netbirdio/netbird/client/internal/dns/types" "github.com/netbirdio/netbird/client/internal/dns/types"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
@@ -29,6 +30,12 @@ import (
var currentMTU uint16 = iface.DefaultMTU var currentMTU uint16 = iface.DefaultMTU
// privateClientIface is the subset of the WireGuard interface needed by GetClientPrivate.
type privateClientIface interface {
Name() string
Address() wgaddr.Address
}
func SetCurrentMTU(mtu uint16) { func SetCurrentMTU(mtu uint16) {
currentMTU = mtu currentMTU = mtu
} }

View File

@@ -86,7 +86,7 @@ func (u *upstreamResolver) isLocalResolver(upstream string) bool {
return false return false
} }
func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) { func GetClientPrivate(_ privateClientIface, _ netip.Addr, dialTimeout time.Duration) (*dns.Client, error) {
return &dns.Client{ return &dns.Client{
Timeout: dialTimeout, Timeout: dialTimeout,
Net: "udp", Net: "udp",

View File

@@ -52,7 +52,7 @@ func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns
return ExchangeWithFallback(ctx, client, r, upstream) return ExchangeWithFallback(ctx, client, r, upstream)
} }
func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) { func GetClientPrivate(_ privateClientIface, _ netip.Addr, dialTimeout time.Duration) (*dns.Client, error) {
return &dns.Client{ return &dns.Client{
Timeout: dialTimeout, Timeout: dialTimeout,
Net: "udp", Net: "udp",

View File

@@ -19,11 +19,7 @@ import (
type upstreamResolverIOS struct { type upstreamResolverIOS struct {
*upstreamResolverBase *upstreamResolverBase
lIP netip.Addr wgIface WGIface
lNet netip.Prefix
lIPv6 netip.Addr
lNetV6 netip.Prefix
interfaceName string
} }
func newUpstreamResolver( func newUpstreamResolver(
@@ -37,11 +33,7 @@ func newUpstreamResolver(
ios := &upstreamResolverIOS{ ios := &upstreamResolverIOS{
upstreamResolverBase: upstreamResolverBase, upstreamResolverBase: upstreamResolverBase,
lIP: wgIface.Address().IP, wgIface: wgIface,
lNet: wgIface.Address().Network,
lIPv6: wgIface.Address().IPv6,
lNetV6: wgIface.Address().IPv6Net,
interfaceName: wgIface.Name(),
} }
ios.upstreamClient = ios ios.upstreamClient = ios
@@ -69,24 +61,15 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
} else { } else {
upstreamIP = upstreamIP.Unmap() upstreamIP = upstreamIP.Unmap()
} }
needsPrivate := u.lNet.Contains(upstreamIP) || addr := u.wgIface.Address()
u.lNetV6.Contains(upstreamIP) || needsPrivate := addr.Network.Contains(upstreamIP) ||
addr.IPv6Net.Contains(upstreamIP) ||
(u.routeMatch != nil && u.routeMatch(upstreamIP)) (u.routeMatch != nil && u.routeMatch(upstreamIP))
if needsPrivate { if needsPrivate {
var bindIP netip.Addr log.Debugf("using private client to query %s via upstream %s", r.Question[0].Name, upstream)
switch { client, err = GetClientPrivate(u.wgIface, upstreamIP, timeout)
case upstreamIP.Is6() && u.lIPv6.IsValid(): if err != nil {
bindIP = u.lIPv6 return nil, 0, fmt.Errorf("create private client: %s", err)
case upstreamIP.Is4() && u.lIP.IsValid():
bindIP = u.lIP
}
if bindIP.IsValid() {
log.Debugf("using private client to query %s via upstream %s", r.Question[0].Name, upstream)
client, err = GetClientPrivate(bindIP, u.interfaceName, timeout)
if err != nil {
return nil, 0, fmt.Errorf("create private client: %s", err)
}
} }
} }
@@ -94,23 +77,29 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
return ExchangeWithFallback(nil, client, r, upstream) return ExchangeWithFallback(nil, client, r, upstream)
} }
// GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface // GetClientPrivate returns a new DNS client bound to the local IP of the Netbird interface.
// This method is needed for iOS // It selects the v6 bind address when the upstream is IPv6 and the interface has one, otherwise v4.
func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) { func GetClientPrivate(iface privateClientIface, upstreamIP netip.Addr, dialTimeout time.Duration) (*dns.Client, error) {
index, err := getInterfaceIndex(interfaceName) index, err := getInterfaceIndex(iface.Name())
if err != nil { if err != nil {
log.Debugf("unable to get interface index for %s: %s", interfaceName, err) log.Debugf("unable to get interface index for %s: %s", iface.Name(), err)
return nil, err return nil, err
} }
addr := iface.Address()
bindIP := addr.IP
if upstreamIP.Is6() && addr.HasIPv6() {
bindIP = addr.IPv6
}
proto, opt := unix.IPPROTO_IP, unix.IP_BOUND_IF proto, opt := unix.IPPROTO_IP, unix.IP_BOUND_IF
if ip.Is6() { if bindIP.Is6() {
proto, opt = unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF proto, opt = unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF
} }
dialer := &net.Dialer{ dialer := &net.Dialer{
LocalAddr: net.UDPAddrFromAddrPort(netip.AddrPortFrom(ip, 0)), LocalAddr: net.UDPAddrFromAddrPort(netip.AddrPortFrom(bindIP, 0)),
Timeout: dialTimeout, Timeout: dialTimeout,
Control: func(network, address string, c syscall.RawConn) error { Control: func(network, address string, c syscall.RawConn) error {
var operr error var operr error
fn := func(s uintptr) { fn := func(s uintptr) {

View File

@@ -80,6 +80,7 @@ func (m *Manager) Start(fwdEntries []*ForwarderEntry) error {
return err return err
} }
// IPv4-only: peers reach the forwarder via its v4 overlay address.
localAddr := m.wgIface.Address().IP localAddr := m.wgIface.Address().IP
if localAddr.IsValid() && m.firewall != nil { if localAddr.IsValid() && m.firewall != nil {

View File

@@ -2,7 +2,8 @@ package ebpf
import ( import (
"encoding/binary" "encoding/binary"
"net" "fmt"
"net/netip"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@@ -12,7 +13,7 @@ const (
mapKeyDNSPort uint32 = 1 mapKeyDNSPort uint32 = 1
) )
func (tf *GeneralManager) LoadDNSFwd(ip string, dnsPort int) error { func (tf *GeneralManager) LoadDNSFwd(ip netip.Addr, dnsPort int) error {
log.Debugf("load eBPF DNS forwarder, watching addr: %s:53, redirect to port: %d", ip, dnsPort) log.Debugf("load eBPF DNS forwarder, watching addr: %s:53, redirect to port: %d", ip, dnsPort)
tf.lock.Lock() tf.lock.Lock()
defer tf.lock.Unlock() defer tf.lock.Unlock()
@@ -22,7 +23,11 @@ func (tf *GeneralManager) LoadDNSFwd(ip string, dnsPort int) error {
return err return err
} }
err = tf.bpfObjs.NbMapDnsIp.Put(mapKeyDNSIP, ip2int(ip)) if !ip.Is4() {
return fmt.Errorf("eBPF DNS forwarder only supports IPv4, got %s", ip)
}
ip4 := ip.As4()
err = tf.bpfObjs.NbMapDnsIp.Put(mapKeyDNSIP, binary.BigEndian.Uint32(ip4[:]))
if err != nil { if err != nil {
return err return err
} }
@@ -45,7 +50,3 @@ func (tf *GeneralManager) FreeDNSFwd() error {
return tf.unsetFeatureFlag(featureFlagDnsForwarder) return tf.unsetFeatureFlag(featureFlagDnsForwarder)
} }
func ip2int(ipString string) uint32 {
ip := net.ParseIP(ipString)
return binary.BigEndian.Uint32(ip.To4())
}

View File

@@ -1,8 +1,10 @@
package manager package manager
import "net/netip"
// Manager is used to load multiple eBPF programs. E.g., current DNS programs and WireGuard proxy // Manager is used to load multiple eBPF programs. E.g., current DNS programs and WireGuard proxy
type Manager interface { type Manager interface {
LoadDNSFwd(ip string, dnsPort int) error LoadDNSFwd(ip netip.Addr, dnsPort int) error
FreeDNSFwd() error FreeDNSFwd() error
LoadWgProxy(proxyPort, wgPort int) error LoadWgProxy(proxyPort, wgPort int) error
FreeWGProxy() error FreeWGProxy() error

View File

@@ -630,7 +630,7 @@ func (e *Engine) initFirewall() error {
rosenpassPort := e.rpManager.GetAddress().Port rosenpassPort := e.rpManager.GetAddress().Port
port := firewallManager.Port{Values: []uint16{uint16(rosenpassPort)}} port := firewallManager.Port{Values: []uint16{uint16(rosenpassPort)}}
// this rule is static and will be torn down on engine down by the firewall manager // IPv4-only: rosenpass peers connect via AllowedIps[0] which is always v4.
if _, err := e.firewall.AddPeerFiltering( if _, err := e.firewall.AddPeerFiltering(
nil, nil,
net.IP{0, 0, 0, 0}, net.IP{0, 0, 0, 0},
@@ -682,10 +682,15 @@ func (e *Engine) blockLanAccess() {
log.Infof("blocking route LAN access for networks: %v", toBlock) log.Infof("blocking route LAN access for networks: %v", toBlock)
v4 := netip.PrefixFrom(netip.IPv4Unspecified(), 0) v4 := netip.PrefixFrom(netip.IPv4Unspecified(), 0)
v6 := netip.PrefixFrom(netip.IPv6Unspecified(), 0)
for _, network := range toBlock { for _, network := range toBlock {
source := v4
if network.Addr().Is6() {
source = v6
}
if _, err := e.firewall.AddRouteFiltering( if _, err := e.firewall.AddRouteFiltering(
nil, nil,
[]netip.Prefix{v4}, []netip.Prefix{source},
firewallManager.Network{Prefix: network}, firewallManager.Network{Prefix: network},
firewallManager.ProtocolALL, firewallManager.ProtocolALL,
nil, nil,
@@ -1494,10 +1499,10 @@ func (e *Engine) updateOfflinePeers(offlinePeers []*mgmProto.RemotePeerConfig) {
replacement := make([]peer.State, len(offlinePeers)) replacement := make([]peer.State, len(offlinePeers))
for i, offlinePeer := range offlinePeers { for i, offlinePeer := range offlinePeers {
log.Debugf("added offline peer %s", offlinePeer.Fqdn) log.Debugf("added offline peer %s", offlinePeer.Fqdn)
v4, v6 := splitAllowedIPs(offlinePeer.GetAllowedIps(), e.wgInterface.Address().IPv6Net) v4, v6 := overlayAddrsFromAllowedIPs(offlinePeer.GetAllowedIps(), e.wgInterface.Address().IPv6Net)
replacement[i] = peer.State{ replacement[i] = peer.State{
IP: v4, IP: addrToString(v4),
IPv6: v6, IPv6: addrToString(v6),
PubKey: offlinePeer.GetWgPubKey(), PubKey: offlinePeer.GetWgPubKey(),
FQDN: offlinePeer.GetFqdn(), FQDN: offlinePeer.GetFqdn(),
ConnStatus: peer.StatusIdle, ConnStatus: peer.StatusIdle,
@@ -1508,30 +1513,37 @@ func (e *Engine) updateOfflinePeers(offlinePeers []*mgmProto.RemotePeerConfig) {
e.statusRecorder.ReplaceOfflinePeers(replacement) e.statusRecorder.ReplaceOfflinePeers(replacement)
} }
// splitAllowedIPs separates the peer's overlay v4 (/32) and v6 (/128) addresses // overlayAddrsFromAllowedIPs extracts the peer's v4 and v6 overlay addresses
// from a list of AllowedIPs CIDRs. The v6 address is only matched if it falls // from AllowedIPs strings. Only host routes (/32, /128) are considered; v6 must
// within ourV6Net (the local overlay v6 subnet), to avoid confusing routed /128 // fall within ourV6Net to distinguish overlay addresses from routed prefixes.
// prefixes with the peer's overlay address. func overlayAddrsFromAllowedIPs(allowedIPs []string, ourV6Net netip.Prefix) (v4, v6 netip.Addr) {
func splitAllowedIPs(allowedIPs []string, ourV6Net netip.Prefix) (v4, v6 string) {
for _, cidr := range allowedIPs { for _, cidr := range allowedIPs {
prefix, err := netip.ParsePrefix(cidr) prefix, err := netip.ParsePrefix(cidr)
if err != nil { if err != nil {
log.Warnf("failed to parse AllowedIP %q: %v", cidr, err) log.Warnf("failed to parse AllowedIP %q: %v", cidr, err)
continue continue
} }
addr := prefix.Addr().Unmap()
switch { switch {
case prefix.Addr().Is4() && prefix.Bits() == 32 && v4 == "": case addr.Is4() && prefix.Bits() == 32 && !v4.IsValid():
v4 = prefix.Addr().String() v4 = addr
case prefix.Addr().Is6() && prefix.Bits() == 128 && ourV6Net.Contains(prefix.Addr()) && v6 == "": case addr.Is6() && prefix.Bits() == 128 && ourV6Net.Contains(addr) && !v6.IsValid():
v6 = prefix.Addr().String() v6 = addr
} }
if v4 != "" && v6 != "" { if v4.IsValid() && v6.IsValid() {
break break
} }
} }
return return
} }
func addrToString(addr netip.Addr) string {
if !addr.IsValid() {
return ""
}
return addr.String()
}
// addNewPeers adds peers that were not know before but arrived from the Management service with the update // addNewPeers adds peers that were not know before but arrived from the Management service with the update
func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig) error { func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
for _, p := range peersUpdate { for _, p := range peersUpdate {
@@ -1572,8 +1584,8 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
return fmt.Errorf("create peer connection: %w", err) return fmt.Errorf("create peer connection: %w", err)
} }
peerV4, peerV6 := splitAllowedIPs(peerConfig.GetAllowedIps(), e.wgInterface.Address().IPv6Net) peerV4, peerV6 := overlayAddrsFromAllowedIPs(peerConfig.GetAllowedIps(), e.wgInterface.Address().IPv6Net)
err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn, peerV4, peerV6) err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn, addrToString(peerV4), addrToString(peerV6))
if err != nil { if err != nil {
log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err) log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err)
} }
@@ -2355,8 +2367,7 @@ func getInterfacePrefixes() ([]netip.Prefix, error) {
prefix := netip.PrefixFrom(addr.Unmap(), ones).Masked() prefix := netip.PrefixFrom(addr.Unmap(), ones).Masked()
ip := prefix.Addr() ip := prefix.Addr()
// TODO: add IPv6 if ip.IsLoopback() || ip.IsMulticast() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
if !ip.Is4() || ip.IsLoopback() || ip.IsMulticast() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
continue continue
} }

View File

@@ -145,13 +145,13 @@ func (e *Engine) extractPeerSSHInfo(remotePeers []*mgmProto.RemotePeerConfig) []
continue continue
} }
peerIP, peerIPv6 := e.extractPeerIPs(peerConfig) peerV4, peerV6 := overlayAddrsFromAllowedIPs(peerConfig.GetAllowedIps(), e.wgInterface.Address().IPv6Net)
hostname := e.extractHostname(peerConfig) hostname := e.extractHostname(peerConfig)
peerInfo = append(peerInfo, sshconfig.PeerSSHInfo{ peerInfo = append(peerInfo, sshconfig.PeerSSHInfo{
Hostname: hostname, Hostname: hostname,
IP: peerIP, IP: peerV4,
IPv6: peerIPv6, IPv6: peerV6,
FQDN: peerConfig.GetFqdn(), FQDN: peerConfig.GetFqdn(),
}) })
} }
@@ -159,28 +159,6 @@ func (e *Engine) extractPeerSSHInfo(remotePeers []*mgmProto.RemotePeerConfig) []
return peerInfo return peerInfo
} }
// extractPeerIPs extracts IPv4 and IPv6 overlay addresses from peer's allowed IPs.
// Only considers host routes (/32, /128) within the overlay networks to avoid
// picking up routed prefixes or static routes like 2620:fe::fe/128.
func (e *Engine) extractPeerIPs(peerConfig *mgmProto.RemotePeerConfig) (v4, v6 netip.Addr) {
wgAddr := e.wgInterface.Address()
for _, allowedIP := range peerConfig.GetAllowedIps() {
prefix, err := netip.ParsePrefix(allowedIP)
if err != nil {
log.Warnf("failed to parse AllowedIP %q: %v", allowedIP, err)
continue
}
addr := prefix.Addr().Unmap()
switch {
case addr.Is4() && prefix.Bits() == 32 && wgAddr.Network.Contains(addr) && !v4.IsValid():
v4 = addr
case addr.Is6() && prefix.Bits() == 128 && wgAddr.IPv6Net.IsValid() && wgAddr.IPv6Net.Contains(addr) && !v6.IsValid():
v6 = addr
}
}
return v4, v6
}
// extractHostname extracts short hostname from FQDN // extractHostname extracts short hostname from FQDN
func (e *Engine) extractHostname(peerConfig *mgmProto.RemotePeerConfig) string { func (e *Engine) extractHostname(peerConfig *mgmProto.RemotePeerConfig) string {
fqdn := peerConfig.GetFqdn() fqdn := peerConfig.GetFqdn()

View File

@@ -1837,7 +1837,7 @@ func TestFilterAllowedIPs(t *testing.T) {
} }
} }
func TestSplitAllowedIPs(t *testing.T) { func TestOverlayAddrsFromAllowedIPs(t *testing.T) {
ourV6Net := netip.MustParsePrefix("fd00:1234:5678:abcd::/64") ourV6Net := netip.MustParsePrefix("fd00:1234:5678:abcd::/64")
tests := []struct { tests := []struct {
@@ -1900,9 +1900,17 @@ func TestSplitAllowedIPs(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
v4, v6 := splitAllowedIPs(tt.allowedIPs, tt.ourV6Net) v4, v6 := overlayAddrsFromAllowedIPs(tt.allowedIPs, tt.ourV6Net)
assert.Equal(t, tt.wantV4, v4, "v4") if tt.wantV4 == "" {
assert.Equal(t, tt.wantV6, v6, "v6") assert.False(t, v4.IsValid(), "expected no v4")
} else {
assert.Equal(t, tt.wantV4, v4.String(), "v4")
}
if tt.wantV6 == "" {
assert.False(t, v6.IsValid(), "expected no v6")
} else {
assert.Equal(t, tt.wantV6, v6.String(), "v6")
}
}) })
} }
} }

View File

@@ -57,6 +57,7 @@ func NewBindListener(wgIface WgInterface, bind device.EndpointManager, cfg lazyc
// deriveFakeIP creates a deterministic fake IP for bind mode based on peer's NetBird IP. // deriveFakeIP creates a deterministic fake IP for bind mode based on peer's NetBird IP.
// Maps peer IP 100.64.x.y to fake IP 127.2.x.y (similar to relay proxy using 127.1.x.y). // Maps peer IP 100.64.x.y to fake IP 127.2.x.y (similar to relay proxy using 127.1.x.y).
// It finds the peer's actual NetBird IP by checking which allowedIP is in the same subnet as our WG interface. // It finds the peer's actual NetBird IP by checking which allowedIP is in the same subnet as our WG interface.
// For IPv6-only peers, the last two bytes of the v6 address are used.
func deriveFakeIP(wgIface WgInterface, allowedIPs []netip.Prefix) (netip.Addr, error) { func deriveFakeIP(wgIface WgInterface, allowedIPs []netip.Prefix) (netip.Addr, error) {
if len(allowedIPs) == 0 { if len(allowedIPs) == 0 {
return netip.Addr{}, fmt.Errorf("no allowed IPs for peer") return netip.Addr{}, fmt.Errorf("no allowed IPs for peer")
@@ -64,6 +65,7 @@ func deriveFakeIP(wgIface WgInterface, allowedIPs []netip.Prefix) (netip.Addr, e
ourNetwork := wgIface.Address().Network ourNetwork := wgIface.Address().Network
// Try v4 first (preferred: deterministic from overlay IP)
var peerIP netip.Addr var peerIP netip.Addr
for _, allowedIP := range allowedIPs { for _, allowedIP := range allowedIPs {
ip := allowedIP.Addr() ip := allowedIP.Addr()
@@ -76,13 +78,24 @@ func deriveFakeIP(wgIface WgInterface, allowedIPs []netip.Prefix) (netip.Addr, e
} }
} }
if !peerIP.IsValid() { if peerIP.IsValid() {
return netip.Addr{}, fmt.Errorf("no peer NetBird IP found in allowed IPs") octets := peerIP.As4()
return netip.AddrFrom4([4]byte{127, 2, octets[2], octets[3]}), nil
} }
octets := peerIP.As4() // Fallback: use last two bytes of first v6 overlay IP
fakeIP := netip.AddrFrom4([4]byte{127, 2, octets[2], octets[3]}) addr := wgIface.Address()
return fakeIP, nil if addr.IPv6Net.IsValid() {
for _, allowedIP := range allowedIPs {
ip := allowedIP.Addr()
if ip.Is6() && addr.IPv6Net.Contains(ip) {
raw := ip.As16()
return netip.AddrFrom4([4]byte{127, 2, raw[14], raw[15]}), nil
}
}
}
return netip.Addr{}, fmt.Errorf("no peer NetBird IP found in allowed IPs")
} }
func (d *BindListener) setupLazyConn() error { func (d *BindListener) setupLazyConn() error {

View File

@@ -1055,7 +1055,11 @@ func (d *Status) notifyPeerListChanged() {
} }
func (d *Status) notifyAddressChanged() { func (d *Status) notifyAddressChanged() {
d.notifier.localAddressChanged(d.localPeer.FQDN, d.localPeer.IP) addr := d.localPeer.IP
if d.localPeer.IPv6 != "" {
addr = addr + "\n" + d.localPeer.IPv6
}
d.notifier.localAddressChanged(d.localPeer.FQDN, addr)
} }
func (d *Status) numOfPeers() int { func (d *Status) numOfPeers() int {

View File

@@ -3,9 +3,8 @@ package client
import ( import (
"context" "context"
"fmt" "fmt"
"net" "net/netip"
"reflect" "reflect"
"strconv"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -566,7 +565,7 @@ func HandlerFromRoute(params common.HandlerParams) RouteHandler {
return dnsinterceptor.New(params) return dnsinterceptor.New(params)
case handlerTypeDynamic: case handlerTypeDynamic:
dns := nbdns.NewServiceViaMemory(params.WgInterface) dns := nbdns.NewServiceViaMemory(params.WgInterface)
dnsAddr := net.JoinHostPort(dns.RuntimeIP().String(), strconv.Itoa(dns.RuntimePort())) dnsAddr := netip.AddrPortFrom(dns.RuntimeIP(), uint16(dns.RuntimePort()))
return dynamic.NewRoute(params, dnsAddr) return dynamic.NewRoute(params, dnsAddr)
default: default:
return static.NewRoute(params) return static.NewRoute(params)

View File

@@ -582,7 +582,7 @@ func (d *DnsInterceptor) queryUpstreamDNS(ctx context.Context, w dns.ResponseWri
if nsNet != nil { if nsNet != nil {
reply, err = nbdns.ExchangeWithNetstack(ctx, nsNet, r, upstream) reply, err = nbdns.ExchangeWithNetstack(ctx, nsNet, r, upstream)
} else { } else {
client, clientErr := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), dnsTimeout) client, clientErr := nbdns.GetClientPrivate(d.wgInterface, upstreamIP, dnsTimeout)
if clientErr != nil { if clientErr != nil {
d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", clientErr)) d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", clientErr))
return nil return nil

View File

@@ -50,10 +50,10 @@ type Route struct {
cancel context.CancelFunc cancel context.CancelFunc
statusRecorder *peer.Status statusRecorder *peer.Status
wgInterface iface.WGIface wgInterface iface.WGIface
resolverAddr string resolverAddr netip.AddrPort
} }
func NewRoute(params common.HandlerParams, resolverAddr string) *Route { func NewRoute(params common.HandlerParams, resolverAddr netip.AddrPort) *Route {
return &Route{ return &Route{
route: params.Route, route: params.Route,
routeRefCounter: params.RouteRefCounter, routeRefCounter: params.RouteRefCounter,

View File

@@ -17,37 +17,47 @@ import (
const dialTimeout = 10 * time.Second const dialTimeout = 10 * time.Second
func (r *Route) getIPsFromResolver(domain domain.Domain) ([]net.IP, error) { func (r *Route) getIPsFromResolver(domain domain.Domain) ([]net.IP, error) {
privateClient, err := nbdns.GetClientPrivate(r.wgInterface.Address().IP, r.wgInterface.Name(), dialTimeout) privateClient, err := nbdns.GetClientPrivate(r.wgInterface, r.resolverAddr.Addr(), dialTimeout)
if err != nil { if err != nil {
return nil, fmt.Errorf("error while creating private client: %s", err) return nil, fmt.Errorf("error while creating private client: %s", err)
} }
msg := new(dns.Msg) fqdn := dns.Fqdn(domain.PunycodeString())
msg.SetQuestion(dns.Fqdn(domain.PunycodeString()), dns.TypeA)
startTime := time.Now() startTime := time.Now()
response, _, err := nbdns.ExchangeWithFallback(nil, privateClient, msg, r.resolverAddr) var ips []net.IP
if err != nil { var queryErr error
return nil, fmt.Errorf("DNS query for %s failed after %s: %s ", domain.SafeString(), time.Since(startTime), err)
}
if response.Rcode != dns.RcodeSuccess { for _, qtype := range []uint16{dns.TypeA, dns.TypeAAAA} {
return nil, fmt.Errorf("dns response code: %s", dns.RcodeToString[response.Rcode]) msg := new(dns.Msg)
} msg.SetQuestion(fqdn, qtype)
ips := make([]net.IP, 0) response, _, err := nbdns.ExchangeWithFallback(nil, privateClient, msg, r.resolverAddr.String())
if err != nil {
for _, answ := range response.Answer { if queryErr == nil {
if aRecord, ok := answ.(*dns.A); ok { queryErr = fmt.Errorf("DNS query for %s (type %d) after %s: %w", domain.SafeString(), qtype, time.Since(startTime), err)
ips = append(ips, aRecord.A) }
continue
} }
if aaaaRecord, ok := answ.(*dns.AAAA); ok {
ips = append(ips, aaaaRecord.AAAA) if response.Rcode != dns.RcodeSuccess {
continue
}
for _, answ := range response.Answer {
if aRecord, ok := answ.(*dns.A); ok {
ips = append(ips, aRecord.A)
}
if aaaaRecord, ok := answ.(*dns.AAAA); ok {
ips = append(ips, aaaaRecord.AAAA)
}
} }
} }
if len(ips) == 0 { if len(ips) == 0 {
if queryErr != nil {
return nil, queryErr
}
return nil, fmt.Errorf("no A or AAAA records found for %s", domain.SafeString()) return nil, fmt.Errorf("no A or AAAA records found for %s", domain.SafeString())
} }

View File

@@ -1,93 +1,145 @@
package fakeip package fakeip
import ( import (
"errors"
"fmt" "fmt"
"net/netip" "net/netip"
"sync" "sync"
) )
// Manager manages allocation of fake IPs from the 240.0.0.0/8 block var (
type Manager struct { // 240.0.0.1 - 240.255.255.254, block 240.0.0.0/8 (reserved, RFC 1112)
mu sync.Mutex v4Base = netip.AddrFrom4([4]byte{240, 0, 0, 1})
nextIP netip.Addr // Next IP to allocate v4Max = netip.AddrFrom4([4]byte{240, 255, 255, 254})
v4Block = netip.PrefixFrom(netip.AddrFrom4([4]byte{240, 0, 0, 0}), 8)
// 0100::1 - 0100::ffff:ffff:ffff:fffe, block 0100::/64 (discard, RFC 6666)
v6Base = netip.AddrFrom16([16]byte{0x01, 0x00, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01})
v6Max = netip.AddrFrom16([16]byte{0x01, 0x00, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe})
v6Block = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x01, 0x00}), 64)
)
// fakeIPPool holds the allocation state for a single address family.
type fakeIPPool struct {
nextIP netip.Addr
baseIP netip.Addr
maxIP netip.Addr
block netip.Prefix
allocated map[netip.Addr]netip.Addr // real IP -> fake IP allocated map[netip.Addr]netip.Addr // real IP -> fake IP
fakeToReal map[netip.Addr]netip.Addr // fake IP -> real IP fakeToReal map[netip.Addr]netip.Addr // fake IP -> real IP
baseIP netip.Addr // First usable IP: 240.0.0.1
maxIP netip.Addr // Last usable IP: 240.255.255.254
} }
// NewManager creates a new fake IP manager using 240.0.0.0/8 block func newPool(base, maxAddr netip.Addr, block netip.Prefix) *fakeIPPool {
func NewManager() *Manager { return &fakeIPPool{
baseIP := netip.AddrFrom4([4]byte{240, 0, 0, 1}) nextIP: base,
maxIP := netip.AddrFrom4([4]byte{240, 255, 255, 254}) baseIP: base,
maxIP: maxAddr,
return &Manager{ block: block,
nextIP: baseIP,
allocated: make(map[netip.Addr]netip.Addr), allocated: make(map[netip.Addr]netip.Addr),
fakeToReal: make(map[netip.Addr]netip.Addr), fakeToReal: make(map[netip.Addr]netip.Addr),
baseIP: baseIP,
maxIP: maxIP,
} }
} }
// AllocateFakeIP allocates a fake IP for the given real IP // allocate allocates a fake IP for the given real IP.
// Returns the fake IP, or existing fake IP if already allocated // Returns the existing fake IP if already allocated.
func (m *Manager) AllocateFakeIP(realIP netip.Addr) (netip.Addr, error) { func (p *fakeIPPool) allocate(realIP netip.Addr) (netip.Addr, error) {
if !realIP.Is4() { if fakeIP, exists := p.allocated[realIP]; exists {
return netip.Addr{}, fmt.Errorf("only IPv4 addresses supported")
}
m.mu.Lock()
defer m.mu.Unlock()
if fakeIP, exists := m.allocated[realIP]; exists {
return fakeIP, nil return fakeIP, nil
} }
startIP := m.nextIP startIP := p.nextIP
for { for {
currentIP := m.nextIP currentIP := p.nextIP
// Advance to next IP, wrapping at boundary // Advance to next IP, wrapping at boundary
if m.nextIP.Compare(m.maxIP) >= 0 { if p.nextIP.Compare(p.maxIP) >= 0 {
m.nextIP = m.baseIP p.nextIP = p.baseIP
} else { } else {
m.nextIP = m.nextIP.Next() p.nextIP = p.nextIP.Next()
} }
// Check if current IP is available if _, inUse := p.fakeToReal[currentIP]; !inUse {
if _, inUse := m.fakeToReal[currentIP]; !inUse { p.allocated[realIP] = currentIP
m.allocated[realIP] = currentIP p.fakeToReal[currentIP] = realIP
m.fakeToReal[currentIP] = realIP
return currentIP, nil return currentIP, nil
} }
// Prevent infinite loop if all IPs exhausted if p.nextIP.Compare(startIP) == 0 {
if m.nextIP.Compare(startIP) == 0 { return netip.Addr{}, fmt.Errorf("no more fake IPs available in %s block", p.block)
return netip.Addr{}, fmt.Errorf("no more fake IPs available in 240.0.0.0/8 block")
} }
} }
} }
// GetFakeIP returns the fake IP for a real IP if it exists // Manager manages allocation of fake IPs for dynamic DNS routes.
// IPv4 uses 240.0.0.0/8 (reserved), IPv6 uses 0100::/64 (discard, RFC 6666).
type Manager struct {
mu sync.Mutex
v4 *fakeIPPool
v6 *fakeIPPool
}
// NewManager creates a new fake IP manager.
func NewManager() *Manager {
return &Manager{
v4: newPool(v4Base, v4Max, v4Block),
v6: newPool(v6Base, v6Max, v6Block),
}
}
func (m *Manager) pool(ip netip.Addr) *fakeIPPool {
if ip.Is6() {
return m.v6
}
return m.v4
}
// AllocateFakeIP allocates a fake IP for the given real IP.
func (m *Manager) AllocateFakeIP(realIP netip.Addr) (netip.Addr, error) {
realIP = realIP.Unmap()
if !realIP.IsValid() {
return netip.Addr{}, errors.New("invalid IP address")
}
m.mu.Lock()
defer m.mu.Unlock()
return m.pool(realIP).allocate(realIP)
}
// GetFakeIP returns the fake IP for a real IP if it exists.
func (m *Manager) GetFakeIP(realIP netip.Addr) (netip.Addr, bool) { func (m *Manager) GetFakeIP(realIP netip.Addr) (netip.Addr, bool) {
realIP = realIP.Unmap()
if !realIP.IsValid() {
return netip.Addr{}, false
}
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
fakeIP, exists := m.allocated[realIP] fakeIP, ok := m.pool(realIP).allocated[realIP]
return fakeIP, exists return fakeIP, ok
} }
// GetRealIP returns the real IP for a fake IP if it exists, otherwise false // GetRealIP returns the real IP for a fake IP if it exists.
func (m *Manager) GetRealIP(fakeIP netip.Addr) (netip.Addr, bool) { func (m *Manager) GetRealIP(fakeIP netip.Addr) (netip.Addr, bool) {
fakeIP = fakeIP.Unmap()
if !fakeIP.IsValid() {
return netip.Addr{}, false
}
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
realIP, exists := m.fakeToReal[fakeIP] realIP, ok := m.pool(fakeIP).fakeToReal[fakeIP]
return realIP, exists return realIP, ok
} }
// GetFakeIPBlock returns the fake IP block used by this manager // GetFakeIPBlock returns the v4 fake IP block used by this manager.
func (m *Manager) GetFakeIPBlock() netip.Prefix { func (m *Manager) GetFakeIPBlock() netip.Prefix {
return netip.MustParsePrefix("240.0.0.0/8") return m.v4.block
}
// GetFakeIPv6Block returns the v6 fake IP block used by this manager.
func (m *Manager) GetFakeIPv6Block() netip.Prefix {
return m.v6.block
} }

View File

@@ -9,16 +9,16 @@ import (
func TestNewManager(t *testing.T) { func TestNewManager(t *testing.T) {
manager := NewManager() manager := NewManager()
if manager.baseIP.String() != "240.0.0.1" { if manager.v4.baseIP.String() != "240.0.0.1" {
t.Errorf("Expected base IP 240.0.0.1, got %s", manager.baseIP.String()) t.Errorf("Expected v4 base IP 240.0.0.1, got %s", manager.v4.baseIP.String())
} }
if manager.maxIP.String() != "240.255.255.254" { if manager.v4.maxIP.String() != "240.255.255.254" {
t.Errorf("Expected max IP 240.255.255.254, got %s", manager.maxIP.String()) t.Errorf("Expected v4 max IP 240.255.255.254, got %s", manager.v4.maxIP.String())
} }
if manager.nextIP.Compare(manager.baseIP) != 0 { if manager.v6.baseIP.String() != "100::1" {
t.Errorf("Expected nextIP to start at baseIP") t.Errorf("Expected v6 base IP 100::1, got %s", manager.v6.baseIP.String())
} }
} }
@@ -35,7 +35,6 @@ func TestAllocateFakeIP(t *testing.T) {
t.Error("Fake IP should be IPv4") t.Error("Fake IP should be IPv4")
} }
// Check it's in the correct range
if fakeIP.As4()[0] != 240 { if fakeIP.As4()[0] != 240 {
t.Errorf("Fake IP should be in 240.0.0.0/8 range, got %s", fakeIP.String()) t.Errorf("Fake IP should be in 240.0.0.0/8 range, got %s", fakeIP.String())
} }
@@ -51,13 +50,31 @@ func TestAllocateFakeIP(t *testing.T) {
} }
} }
func TestAllocateFakeIPIPv6Rejection(t *testing.T) { func TestAllocateFakeIPv6(t *testing.T) {
manager := NewManager() manager := NewManager()
realIPv6 := netip.MustParseAddr("2001:db8::1") realIP := netip.MustParseAddr("2001:db8::1")
_, err := manager.AllocateFakeIP(realIPv6) fakeIP, err := manager.AllocateFakeIP(realIP)
if err == nil { if err != nil {
t.Error("Expected error for IPv6 address") t.Fatalf("Failed to allocate fake IPv6: %v", err)
}
if !fakeIP.Is6() {
t.Error("Fake IP should be IPv6")
}
if !netip.MustParsePrefix("100::/64").Contains(fakeIP) {
t.Errorf("Fake IP should be in 100::/64 range, got %s", fakeIP.String())
}
// Should return same fake IP for same real IP
fakeIP2, err := manager.AllocateFakeIP(realIP)
if err != nil {
t.Fatalf("Failed to get existing fake IPv6: %v", err)
}
if fakeIP.Compare(fakeIP2) != 0 {
t.Errorf("Expected same fake IP, got %s and %s", fakeIP.String(), fakeIP2.String())
} }
} }
@@ -65,13 +82,11 @@ func TestGetFakeIP(t *testing.T) {
manager := NewManager() manager := NewManager()
realIP := netip.MustParseAddr("1.1.1.1") realIP := netip.MustParseAddr("1.1.1.1")
// Should not exist initially
_, exists := manager.GetFakeIP(realIP) _, exists := manager.GetFakeIP(realIP)
if exists { if exists {
t.Error("Fake IP should not exist before allocation") t.Error("Fake IP should not exist before allocation")
} }
// Allocate and check
expectedFakeIP, err := manager.AllocateFakeIP(realIP) expectedFakeIP, err := manager.AllocateFakeIP(realIP)
if err != nil { if err != nil {
t.Fatalf("Failed to allocate: %v", err) t.Fatalf("Failed to allocate: %v", err)
@@ -87,12 +102,30 @@ func TestGetFakeIP(t *testing.T) {
} }
} }
func TestGetRealIPv6(t *testing.T) {
manager := NewManager()
realIP := netip.MustParseAddr("2001:db8::1")
fakeIP, err := manager.AllocateFakeIP(realIP)
if err != nil {
t.Fatalf("Failed to allocate: %v", err)
}
gotReal, exists := manager.GetRealIP(fakeIP)
if !exists {
t.Error("Real IP should exist for allocated fake IP")
}
if gotReal.Compare(realIP) != 0 {
t.Errorf("Expected real IP %s, got %s", realIP, gotReal)
}
}
func TestMultipleAllocations(t *testing.T) { func TestMultipleAllocations(t *testing.T) {
manager := NewManager() manager := NewManager()
allocations := make(map[netip.Addr]netip.Addr) allocations := make(map[netip.Addr]netip.Addr)
// Allocate multiple IPs
for i := 1; i <= 100; i++ { for i := 1; i <= 100; i++ {
realIP := netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)}) realIP := netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
fakeIP, err := manager.AllocateFakeIP(realIP) fakeIP, err := manager.AllocateFakeIP(realIP)
@@ -100,7 +133,6 @@ func TestMultipleAllocations(t *testing.T) {
t.Fatalf("Failed to allocate fake IP for %s: %v", realIP.String(), err) t.Fatalf("Failed to allocate fake IP for %s: %v", realIP.String(), err)
} }
// Check for duplicates
for _, existingFake := range allocations { for _, existingFake := range allocations {
if fakeIP.Compare(existingFake) == 0 { if fakeIP.Compare(existingFake) == 0 {
t.Errorf("Duplicate fake IP allocated: %s", fakeIP.String()) t.Errorf("Duplicate fake IP allocated: %s", fakeIP.String())
@@ -110,7 +142,6 @@ func TestMultipleAllocations(t *testing.T) {
allocations[realIP] = fakeIP allocations[realIP] = fakeIP
} }
// Verify all allocations can be retrieved
for realIP, expectedFake := range allocations { for realIP, expectedFake := range allocations {
actualFake, exists := manager.GetFakeIP(realIP) actualFake, exists := manager.GetFakeIP(realIP)
if !exists { if !exists {
@@ -124,11 +155,13 @@ func TestMultipleAllocations(t *testing.T) {
func TestGetFakeIPBlock(t *testing.T) { func TestGetFakeIPBlock(t *testing.T) {
manager := NewManager() manager := NewManager()
block := manager.GetFakeIPBlock()
expected := "240.0.0.0/8" if block := manager.GetFakeIPBlock(); block.String() != "240.0.0.0/8" {
if block.String() != expected { t.Errorf("Expected 240.0.0.0/8, got %s", block.String())
t.Errorf("Expected %s, got %s", expected, block.String()) }
if block := manager.GetFakeIPv6Block(); block.String() != "100::/64" {
t.Errorf("Expected 100::/64, got %s", block.String())
} }
} }
@@ -141,7 +174,6 @@ func TestConcurrentAccess(t *testing.T) {
var wg sync.WaitGroup var wg sync.WaitGroup
results := make(chan netip.Addr, numGoroutines*allocationsPerGoroutine) results := make(chan netip.Addr, numGoroutines*allocationsPerGoroutine)
// Concurrent allocations
for i := 0; i < numGoroutines; i++ { for i := 0; i < numGoroutines; i++ {
wg.Add(1) wg.Add(1)
go func(goroutineID int) { go func(goroutineID int) {
@@ -161,7 +193,6 @@ func TestConcurrentAccess(t *testing.T) {
wg.Wait() wg.Wait()
close(results) close(results)
// Check for duplicates
seen := make(map[netip.Addr]bool) seen := make(map[netip.Addr]bool)
count := 0 count := 0
for fakeIP := range results { for fakeIP := range results {
@@ -178,47 +209,61 @@ func TestConcurrentAccess(t *testing.T) {
} }
func TestIPExhaustion(t *testing.T) { func TestIPExhaustion(t *testing.T) {
// Create a manager with limited range for testing
manager := &Manager{ manager := &Manager{
nextIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}), v4: newPool(
allocated: make(map[netip.Addr]netip.Addr), netip.AddrFrom4([4]byte{240, 0, 0, 1}),
fakeToReal: make(map[netip.Addr]netip.Addr), netip.AddrFrom4([4]byte{240, 0, 0, 3}),
baseIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}), netip.MustParsePrefix("240.0.0.0/8"),
maxIP: netip.AddrFrom4([4]byte{240, 0, 0, 3}), // Only 3 IPs available ),
v6: newPool(
netip.MustParseAddr("100::1"),
netip.MustParseAddr("100::3"),
netip.MustParsePrefix("100::/64"),
),
} }
// Allocate all available IPs for _, realIP := range []string{"1.0.0.1", "1.0.0.2", "1.0.0.3"} {
realIPs := []netip.Addr{ _, err := manager.AllocateFakeIP(netip.MustParseAddr(realIP))
netip.MustParseAddr("1.0.0.1"),
netip.MustParseAddr("1.0.0.2"),
netip.MustParseAddr("1.0.0.3"),
}
for _, realIP := range realIPs {
_, err := manager.AllocateFakeIP(realIP)
if err != nil { if err != nil {
t.Fatalf("Failed to allocate fake IP: %v", err) t.Fatalf("Failed to allocate fake IP: %v", err)
} }
} }
// Try to allocate one more - should fail
_, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.4")) _, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.4"))
if err == nil { if err == nil {
t.Error("Expected exhaustion error") t.Error("Expected v4 exhaustion error")
}
// Same for v6
for _, realIP := range []string{"2001:db8::1", "2001:db8::2", "2001:db8::3"} {
_, err := manager.AllocateFakeIP(netip.MustParseAddr(realIP))
if err != nil {
t.Fatalf("Failed to allocate fake IPv6: %v", err)
}
}
_, err = manager.AllocateFakeIP(netip.MustParseAddr("2001:db8::4"))
if err == nil {
t.Error("Expected v6 exhaustion error")
} }
} }
func TestWrapAround(t *testing.T) { func TestWrapAround(t *testing.T) {
// Create manager starting near the end of range
manager := &Manager{ manager := &Manager{
nextIP: netip.AddrFrom4([4]byte{240, 0, 0, 254}), v4: newPool(
allocated: make(map[netip.Addr]netip.Addr), netip.AddrFrom4([4]byte{240, 0, 0, 1}),
fakeToReal: make(map[netip.Addr]netip.Addr), netip.AddrFrom4([4]byte{240, 0, 0, 254}),
baseIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}), netip.MustParsePrefix("240.0.0.0/8"),
maxIP: netip.AddrFrom4([4]byte{240, 0, 0, 254}), ),
v6: newPool(
netip.MustParseAddr("100::1"),
netip.MustParseAddr("100::ffff:ffff:ffff:fffe"),
netip.MustParsePrefix("100::/64"),
),
} }
// Start near the end
manager.v4.nextIP = netip.AddrFrom4([4]byte{240, 0, 0, 254})
// Allocate the last IP
fakeIP1, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.1")) fakeIP1, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.1"))
if err != nil { if err != nil {
t.Fatalf("Failed to allocate first IP: %v", err) t.Fatalf("Failed to allocate first IP: %v", err)
@@ -228,7 +273,6 @@ func TestWrapAround(t *testing.T) {
t.Errorf("Expected 240.0.0.254, got %s", fakeIP1.String()) t.Errorf("Expected 240.0.0.254, got %s", fakeIP1.String())
} }
// Next allocation should wrap around to the beginning
fakeIP2, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.2")) fakeIP2, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.2"))
if err != nil { if err != nil {
t.Fatalf("Failed to allocate second IP: %v", err) t.Fatalf("Failed to allocate second IP: %v", err)
@@ -238,3 +282,32 @@ func TestWrapAround(t *testing.T) {
t.Errorf("Expected 240.0.0.1 after wrap, got %s", fakeIP2.String()) t.Errorf("Expected 240.0.0.1 after wrap, got %s", fakeIP2.String())
} }
} }
func TestMixedV4V6(t *testing.T) {
manager := NewManager()
v4Fake, err := manager.AllocateFakeIP(netip.MustParseAddr("8.8.8.8"))
if err != nil {
t.Fatalf("Failed to allocate v4: %v", err)
}
v6Fake, err := manager.AllocateFakeIP(netip.MustParseAddr("2001:db8::1"))
if err != nil {
t.Fatalf("Failed to allocate v6: %v", err)
}
if !v4Fake.Is4() || !v6Fake.Is6() {
t.Errorf("Wrong families: v4=%s v6=%s", v4Fake, v6Fake)
}
// Reverse lookups should work for both
gotV4, ok := manager.GetRealIP(v4Fake)
if !ok || gotV4.String() != "8.8.8.8" {
t.Errorf("v4 reverse lookup failed: got %s, ok=%v", gotV4, ok)
}
gotV6, ok := manager.GetRealIP(v6Fake)
if !ok || gotV6.String() != "2001:db8::1" {
t.Errorf("v6 reverse lookup failed: got %s, ok=%v", gotV6, ok)
}
}

View File

@@ -9,7 +9,11 @@ import (
) )
// IPForwardingState is a struct that keeps track of the IP forwarding state. // IPForwardingState is a struct that keeps track of the IP forwarding state.
// todo: read initial state of the IP forwarding from the system and reset the state based on it // todo: read initial state of the IP forwarding from the system and reset the state based on it.
// todo: separate v4/v6 forwarding state, since the sysctls are independent
// (net.ipv4.ip_forward vs net.ipv6.conf.all.forwarding). Currently the nftables
// manager shares one instance between both routers, which works only because
// EnableIPForwarding enables both sysctls in a single call.
type IPForwardingState struct { type IPForwardingState struct {
enabledCounter int enabledCounter int
} }

View File

@@ -159,15 +159,23 @@ func (m *DefaultManager) setupAndroidRoutes(config ManagerConfig) {
if config.DNSFeatureFlag { if config.DNSFeatureFlag {
m.fakeIPManager = fakeip.NewManager() m.fakeIPManager = fakeip.NewManager()
id := uuid.NewString() v4ID := uuid.NewString()
fakeIPRoute := &route.Route{ cr = append(cr, &route.Route{
ID: route.ID(id), ID: route.ID(v4ID),
Network: m.fakeIPManager.GetFakeIPBlock(), Network: m.fakeIPManager.GetFakeIPBlock(),
NetID: route.NetID(id), NetID: route.NetID(v4ID),
Peer: m.pubKey, Peer: m.pubKey,
NetworkType: route.IPv4Network, NetworkType: route.IPv4Network,
} })
cr = append(cr, fakeIPRoute)
v6ID := uuid.NewString()
cr = append(cr, &route.Route{
ID: route.ID(v6ID),
Network: m.fakeIPManager.GetFakeIPv6Block(),
NetID: route.NetID(v6ID),
Peer: m.pubKey,
NetworkType: route.IPv6Network,
})
} }
m.notifier.SetInitialClientRoutes(cr, routesForComparison) m.notifier.SetInitialClientRoutes(cr, routesForComparison)

View File

@@ -146,8 +146,7 @@ func routeToRouterPair(route *route.Route, useNewDNSRoute bool) firewall.RouterP
if useNewDNSRoute { if useNewDNSRoute {
destination.Set = firewall.NewDomainSet(route.Domains) destination.Set = firewall.NewDomainSet(route.Domains)
} else { } else {
// TODO: add ipv6 additionally destination = getDefaultPrefix(route.Network)
destination = getDefaultPrefix(destination.Prefix)
} }
} else { } else {
destination.Prefix = route.Network.Masked() destination.Prefix = route.Network.Masked()

View File

@@ -107,8 +107,13 @@ func (r *SysOps) validateRoute(prefix netip.Prefix) error {
addr.IsInterfaceLocalMulticast(), addr.IsInterfaceLocalMulticast(),
addr.IsMulticast(), addr.IsMulticast(),
addr.IsUnspecified() && prefix.Bits() != 0, addr.IsUnspecified() && prefix.Bits() != 0,
r.wgInterface.Address().Network.Contains(addr): r.isOwnAddress(addr):
return vars.ErrRouteNotAllowed return vars.ErrRouteNotAllowed
} }
return nil return nil
} }
func (r *SysOps) isOwnAddress(addr netip.Addr) bool {
wgAddr := r.wgInterface.Address()
return wgAddr.Network.Contains(addr) || (wgAddr.IPv6Net.IsValid() && wgAddr.IPv6Net.Contains(addr))
}

View File

@@ -222,30 +222,20 @@ func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) er
return err return err
} }
// TODO: remove once IPv6 is supported on the interface // When the interface has no v6, add v6 split-default as blackhole so
if err := r.addToRouteTable(splitDefaultv6_1, nextHop); err != nil { // unroutable v6 goes to WG (dropped, no AllowedIPs) instead of leaking
return fmt.Errorf("add unreachable route split 1: %w", err) // to the system default route. When v6 is active, management sends ::/0
} // as a separate route that the dedicated handler adds.
if err := r.addToRouteTable(splitDefaultv6_2, nextHop); err != nil { // Soft-fail: v6 blackhole is best-effort, don't abort v4 routing on failure.
if err2 := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err2 != nil { if !r.wgInterface.Address().HasIPv6() {
log.Warnf("Failed to rollback route addition: %s", err2) if err := r.addV6SplitDefault(nextHop); err != nil {
log.Warnf("failed to add v6 split-default blackhole: %s", err)
} }
return fmt.Errorf("add unreachable route split 2: %w", err)
} }
return nil return nil
case vars.Defaultv6: case vars.Defaultv6:
if err := r.addToRouteTable(splitDefaultv6_1, nextHop); err != nil { return r.addV6SplitDefault(nextHop)
return fmt.Errorf("add unreachable route split 1: %w", err)
}
if err := r.addToRouteTable(splitDefaultv6_2, nextHop); err != nil {
if err2 := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err2 != nil {
log.Warnf("Failed to rollback route addition: %s", err2)
}
return fmt.Errorf("add unreachable route split 2: %w", err)
}
return nil
} }
return r.addToRouteTable(prefix, nextHop) return r.addToRouteTable(prefix, nextHop)
@@ -266,30 +256,42 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface)
result = multierror.Append(result, err) result = multierror.Append(result, err)
} }
// TODO: remove once IPv6 is supported on the interface if !r.wgInterface.Address().HasIPv6() {
if err := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err != nil { result = multierror.Append(result, r.removeV6SplitDefault(nextHop))
result = multierror.Append(result, err)
}
if err := r.removeFromRouteTable(splitDefaultv6_2, nextHop); err != nil {
result = multierror.Append(result, err)
} }
return nberrors.FormatErrorOrNil(result) return nberrors.FormatErrorOrNil(result)
case vars.Defaultv6: case vars.Defaultv6:
var result *multierror.Error return nberrors.FormatErrorOrNil(r.removeV6SplitDefault(nextHop))
if err := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err != nil {
result = multierror.Append(result, err)
}
if err := r.removeFromRouteTable(splitDefaultv6_2, nextHop); err != nil {
result = multierror.Append(result, err)
}
return nberrors.FormatErrorOrNil(result)
default: default:
return r.removeFromRouteTable(prefix, nextHop) return r.removeFromRouteTable(prefix, nextHop)
} }
} }
func (r *SysOps) addV6SplitDefault(nextHop Nexthop) error {
if err := r.addToRouteTable(splitDefaultv6_1, nextHop); err != nil {
return fmt.Errorf("add split 1: %w", err)
}
if err := r.addToRouteTable(splitDefaultv6_2, nextHop); err != nil {
if err2 := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err2 != nil {
log.Warnf("Failed to rollback v6 split-default: %s", err2)
}
return fmt.Errorf("add split 2: %w", err)
}
return nil
}
func (r *SysOps) removeV6SplitDefault(nextHop Nexthop) *multierror.Error {
var result *multierror.Error
if err := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err != nil {
result = multierror.Append(result, err)
}
if err := r.removeFromRouteTable(splitDefaultv6_2, nextHop); err != nil {
result = multierror.Append(result, err)
}
return result
}
func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) error { func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) error {
beforeHook := func(connID hooks.ConnectionID, prefix netip.Prefix) error { beforeHook := func(connID hooks.ConnectionID, prefix netip.Prefix) error {
if _, err := r.refCounter.IncrementWithID(string(connID), prefix, struct{}{}); err != nil { if _, err := r.refCounter.IncrementWithID(string(connID), prefix, struct{}{}); err != nil {

View File

@@ -53,6 +53,8 @@ const (
// ipv4ForwardingPath is the path to the file containing the IP forwarding setting. // ipv4ForwardingPath is the path to the file containing the IP forwarding setting.
ipv4ForwardingPath = "net.ipv4.ip_forward" ipv4ForwardingPath = "net.ipv4.ip_forward"
// ipv6ForwardingPath is the path to the file containing the IPv6 forwarding setting.
ipv6ForwardingPath = "net.ipv6.conf.all.forwarding"
) )
var ErrTableIDExists = errors.New("ID exists with different name") var ErrTableIDExists = errors.New("ID exists with different name")
@@ -185,10 +187,11 @@ func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
// No need to check if routes exist as main table takes precedence over the VPN table via Rule 1 // No need to check if routes exist as main table takes precedence over the VPN table via Rule 1
// TODO remove this once we have ipv6 support // When the peer has no IPv6, blackhole v6 to prevent leaking.
if prefix == vars.Defaultv4 { // When IPv6 is enabled, management sends ::/0 as a separate route.
if prefix == vars.Defaultv4 && (r.wgInterface == nil || !r.wgInterface.Address().HasIPv6()) {
if err := addUnreachableRoute(vars.Defaultv6, NetbirdVPNTableID); err != nil { if err := addUnreachableRoute(vars.Defaultv6, NetbirdVPNTableID); err != nil {
return fmt.Errorf("add blackhole: %w", err) return fmt.Errorf("add v6 blackhole: %w", err)
} }
} }
if err := addRoute(prefix, Nexthop{netip.Addr{}, intf}, NetbirdVPNTableID); err != nil { if err := addRoute(prefix, Nexthop{netip.Addr{}, intf}, NetbirdVPNTableID); err != nil {
@@ -206,10 +209,9 @@ func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error
return r.genericRemoveVPNRoute(prefix, intf) return r.genericRemoveVPNRoute(prefix, intf)
} }
// TODO remove this once we have ipv6 support if prefix == vars.Defaultv4 && (r.wgInterface == nil || !r.wgInterface.Address().HasIPv6()) {
if prefix == vars.Defaultv4 {
if err := removeUnreachableRoute(vars.Defaultv6, NetbirdVPNTableID); err != nil { if err := removeUnreachableRoute(vars.Defaultv6, NetbirdVPNTableID); err != nil {
return fmt.Errorf("remove unreachable route: %w", err) log.Debugf("remove v6 blackhole: %v", err)
} }
} }
if err := removeRoute(prefix, Nexthop{netip.Addr{}, intf}, NetbirdVPNTableID); err != nil { if err := removeRoute(prefix, Nexthop{netip.Addr{}, intf}, NetbirdVPNTableID); err != nil {
@@ -762,8 +764,13 @@ func flushRoutes(tableID, family int) error {
} }
func EnableIPForwarding() error { func EnableIPForwarding() error {
_, err := sysctl.Set(ipv4ForwardingPath, 1, false) if _, err := sysctl.Set(ipv4ForwardingPath, 1, false); err != nil {
return err return err
}
if _, err := sysctl.Set(ipv6ForwardingPath, 1, false); err != nil {
log.Warnf("failed to enable IPv6 forwarding: %v", err)
}
return nil
} }
// entryExists checks if the specified ID or name already exists in the rt_tables file // entryExists checks if the specified ID or name already exists in the rt_tables file

View File

@@ -50,10 +50,11 @@ type CustomLogger interface {
} }
type selectRoute struct { type selectRoute struct {
NetID string NetID string
Network netip.Prefix Network netip.Prefix
Domains domain.List Domains domain.List
Selected bool Selected bool
extraNetworks []netip.Prefix
} }
func init() { func init() {
@@ -363,48 +364,60 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) {
} }
routeManager := engine.GetRouteManager() routeManager := engine.GetRouteManager()
routesMap := routeManager.GetClientRoutesWithNetID()
if routeManager == nil { if routeManager == nil {
return nil, fmt.Errorf("could not get route manager") return nil, fmt.Errorf("could not get route manager")
} }
routesMap := routeManager.GetClientRoutesWithNetID()
routeSelector := routeManager.GetRouteSelector() routeSelector := routeManager.GetRouteSelector()
if routeSelector == nil { if routeSelector == nil {
return nil, fmt.Errorf("could not get route selector") return nil, fmt.Errorf("could not get route selector")
} }
v6ExitMerged := route.V6ExitMergeSet(routesMap)
routes := buildSelectRoutes(routesMap, routeSelector.IsSelected, v6ExitMerged)
resolvedDomains := c.recorder.GetResolvedDomainsStates()
return prepareRouteSelectionDetails(routes, resolvedDomains), nil
}
func buildSelectRoutes(routesMap map[route.NetID][]*route.Route, isSelected func(route.NetID) bool, v6Merged map[route.NetID]struct{}) []*selectRoute {
var routes []*selectRoute var routes []*selectRoute
for id, rt := range routesMap { for id, rt := range routesMap {
if len(rt) == 0 { if len(rt) == 0 {
continue continue
} }
route := &selectRoute{ if _, ok := v6Merged[id]; ok {
continue
}
r := &selectRoute{
NetID: string(id), NetID: string(id),
Network: rt[0].Network, Network: rt[0].Network,
Domains: rt[0].Domains, Domains: rt[0].Domains,
Selected: routeSelector.IsSelected(id), Selected: isSelected(id),
} }
routes = append(routes, route)
v6ID := route.NetID(string(id) + route.V6ExitSuffix)
if _, ok := v6Merged[v6ID]; ok {
r.extraNetworks = []netip.Prefix{routesMap[v6ID][0].Network}
}
routes = append(routes, r)
} }
sort.Slice(routes, func(i, j int) bool { sort.Slice(routes, func(i, j int) bool {
iPrefix := routes[i].Network.Bits() iBits, jBits := routes[i].Network.Bits(), routes[j].Network.Bits()
jPrefix := routes[j].Network.Bits() if iBits != jBits {
return iBits < jBits
if iPrefix == jPrefix {
iAddr := routes[i].Network.Addr()
jAddr := routes[j].Network.Addr()
if iAddr == jAddr {
return routes[i].NetID < routes[j].NetID
}
return iAddr.String() < jAddr.String()
} }
return iPrefix < jPrefix iAddr, jAddr := routes[i].Network.Addr(), routes[j].Network.Addr()
if iAddr != jAddr {
return iAddr.Less(jAddr)
}
return routes[i].NetID < routes[j].NetID
}) })
resolvedDomains := c.recorder.GetResolvedDomainsStates() return routes
return prepareRouteSelectionDetails(routes, resolvedDomains), nil
} }
func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[domain.Domain]peer.ResolvedDomainInfo) *RoutesSelectionDetails { func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[domain.Domain]peer.ResolvedDomainInfo) *RoutesSelectionDetails {
@@ -425,10 +438,15 @@ func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[dom
} }
domainList = append(domainList, domainResp) domainList = append(domainList, domainResp)
} }
rangeStr := r.Network.String()
for _, extra := range r.extraNetworks {
rangeStr += ", " + extra.String()
}
domainDetails := DomainDetails{items: domainList} domainDetails := DomainDetails{items: domainList}
routeSelection = append(routeSelection, RoutesSelectionInfo{ routeSelection = append(routeSelection, RoutesSelectionInfo{
ID: r.NetID, ID: r.NetID,
Network: r.Network.String(), Network: rangeStr,
Domains: &domainDetails, Domains: &domainDetails,
Selected: r.Selected, Selected: r.Selected,
}) })
@@ -456,7 +474,9 @@ func (c *Client) SelectRoute(id string) error {
} else { } else {
log.Debugf("select route with id: %s", id) log.Debugf("select route with id: %s", id)
routes := toNetIDs([]string{id}) routes := toNetIDs([]string{id})
if err := routeSelector.SelectRoutes(routes, true, maps.Keys(routeManager.GetClientRoutesWithNetID())); err != nil { routesMap := routeManager.GetClientRoutesWithNetID()
routes = route.ExpandV6ExitPairs(routes, routesMap)
if err := routeSelector.SelectRoutes(routes, true, maps.Keys(routesMap)); err != nil {
log.Debugf("error when selecting routes: %s", err) log.Debugf("error when selecting routes: %s", err)
return fmt.Errorf("select routes: %w", err) return fmt.Errorf("select routes: %w", err)
} }
@@ -483,7 +503,9 @@ func (c *Client) DeselectRoute(id string) error {
} else { } else {
log.Debugf("deselect route with id: %s", id) log.Debugf("deselect route with id: %s", id)
routes := toNetIDs([]string{id}) routes := toNetIDs([]string{id})
if err := routeSelector.DeselectRoutes(routes, maps.Keys(routeManager.GetClientRoutesWithNetID())); err != nil { routesMap := routeManager.GetClientRoutesWithNetID()
routes = route.ExpandV6ExitPairs(routes, routesMap)
if err := routeSelector.DeselectRoutes(routes, maps.Keys(routesMap)); err != nil {
log.Debugf("error when deselecting routes: %s", err) log.Debugf("error when deselecting routes: %s", err)
return fmt.Errorf("deselect routes: %w", err) return fmt.Errorf("deselect routes: %w", err)
} }

View File

@@ -16,10 +16,11 @@ import (
) )
type selectRoute struct { type selectRoute struct {
NetID route.NetID NetID route.NetID
Network netip.Prefix Network netip.Prefix
Domains domain.List Domains domain.List
Selected bool Selected bool
extraNetworks []netip.Prefix
} }
// ListNetworks returns a list of all available networks. // ListNetworks returns a list of all available networks.
@@ -44,18 +45,33 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro
routesMap := routeMgr.GetClientRoutesWithNetID() routesMap := routeMgr.GetClientRoutesWithNetID()
routeSelector := routeMgr.GetRouteSelector() routeSelector := routeMgr.GetRouteSelector()
v6ExitMerged := route.V6ExitMergeSet(routesMap)
var routes []*selectRoute var routes []*selectRoute
for id, rt := range routesMap { for id, rt := range routesMap {
if len(rt) == 0 { if len(rt) == 0 {
continue continue
} }
route := &selectRoute{ // Skip v6 exit nodes that are merged into their v4 counterpart.
if _, ok := v6ExitMerged[id]; ok {
continue
}
r := &selectRoute{
NetID: id, NetID: id,
Network: rt[0].Network, Network: rt[0].Network,
Domains: rt[0].Domains, Domains: rt[0].Domains,
Selected: routeSelector.IsSelected(id), Selected: routeSelector.IsSelected(id),
} }
routes = append(routes, route)
// Merge paired v6 exit node prefix into this entry.
v6ID := route.NetID(string(id) + route.V6ExitSuffix)
if _, ok := v6ExitMerged[v6ID]; ok {
v6Prefix := routesMap[v6ID][0].Network
r.extraNetworks = []netip.Prefix{v6Prefix}
}
routes = append(routes, r)
} }
sort.Slice(routes, func(i, j int) bool { sort.Slice(routes, func(i, j int) bool {
@@ -76,9 +92,13 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro
resolvedDomains := s.statusRecorder.GetResolvedDomainsStates() resolvedDomains := s.statusRecorder.GetResolvedDomainsStates()
var pbRoutes []*proto.Network var pbRoutes []*proto.Network
for _, route := range routes { for _, route := range routes {
rangeStr := route.Network.String()
for _, extra := range route.extraNetworks {
rangeStr += ", " + extra.String()
}
pbRoute := &proto.Network{ pbRoute := &proto.Network{
ID: string(route.NetID), ID: string(route.NetID),
Range: route.Network.String(), Range: rangeStr,
Domains: route.Domains.ToSafeStringList(), Domains: route.Domains.ToSafeStringList(),
ResolvedIPs: map[string]*proto.IPList{}, ResolvedIPs: map[string]*proto.IPList{},
Selected: route.Selected, Selected: route.Selected,
@@ -137,7 +157,9 @@ func (s *Server) SelectNetworks(_ context.Context, req *proto.SelectNetworksRequ
routeSelector.SelectAllRoutes() routeSelector.SelectAllRoutes()
} else { } else {
routes := toNetIDs(req.GetNetworkIDs()) routes := toNetIDs(req.GetNetworkIDs())
netIdRoutes := maps.Keys(routeManager.GetClientRoutesWithNetID()) routesMap := routeManager.GetClientRoutesWithNetID()
routes = route.ExpandV6ExitPairs(routes, routesMap)
netIdRoutes := maps.Keys(routesMap)
if err := routeSelector.SelectRoutes(routes, req.GetAppend(), netIdRoutes); err != nil { if err := routeSelector.SelectRoutes(routes, req.GetAppend(), netIdRoutes); err != nil {
return nil, fmt.Errorf("select routes: %w", err) return nil, fmt.Errorf("select routes: %w", err)
} }
@@ -183,7 +205,9 @@ func (s *Server) DeselectNetworks(_ context.Context, req *proto.SelectNetworksRe
routeSelector.DeselectAllRoutes() routeSelector.DeselectAllRoutes()
} else { } else {
routes := toNetIDs(req.GetNetworkIDs()) routes := toNetIDs(req.GetNetworkIDs())
netIdRoutes := maps.Keys(routeManager.GetClientRoutesWithNetID()) routesMap := routeManager.GetClientRoutesWithNetID()
routes = route.ExpandV6ExitPairs(routes, routesMap)
netIdRoutes := maps.Keys(routesMap)
if err := routeSelector.DeselectRoutes(routes, netIdRoutes); err != nil { if err := routeSelector.DeselectRoutes(routes, netIdRoutes); err != nil {
return nil, fmt.Errorf("deselect routes: %w", err) return nil, fmt.Errorf("deselect routes: %w", err)
} }

View File

@@ -195,7 +195,7 @@ func getOverlappingNetworks(routes []*proto.Network) []*proto.Network {
func getExitNodeNetworks(routes []*proto.Network) []*proto.Network { func getExitNodeNetworks(routes []*proto.Network) []*proto.Network {
var filteredRoutes []*proto.Network var filteredRoutes []*proto.Network
for _, route := range routes { for _, route := range routes {
if route.Range == "0.0.0.0/0" || route.Range == "::/0" { if strings.Contains(route.Range, "0.0.0.0/0") || route.Range == "::/0" {
filteredRoutes = append(filteredRoutes, route) filteredRoutes = append(filteredRoutes, route)
} }
} }

View File

@@ -6,6 +6,7 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"strconv"
"syscall/js" "syscall/js"
"time" "time"
@@ -166,39 +167,58 @@ func createSSHMethod(client *netbird.Client) js.Func {
}) })
} }
var jwtToken string jwtToken, ipVersion := parseSSHOptions(args)
if len(args) > 3 && !args[3].IsNull() && !args[3].IsUndefined() {
jwtToken = args[3].String()
}
return createPromise(func(resolve, reject js.Value) { return createPromise(func(resolve, reject js.Value) {
sshClient := ssh.NewClient(client) jsInterface, err := connectSSH(client, host, port, username, jwtToken, ipVersion)
if err != nil {
if err := sshClient.Connect(host, port, username, jwtToken); err != nil {
reject.Invoke(err.Error()) reject.Invoke(err.Error())
return return
} }
if err := sshClient.StartSession(80, 24); err != nil {
if closeErr := sshClient.Close(); closeErr != nil {
log.Errorf("Error closing SSH client: %v", closeErr)
}
reject.Invoke(err.Error())
return
}
jsInterface := ssh.CreateJSInterface(sshClient)
resolve.Invoke(jsInterface) resolve.Invoke(jsInterface)
}) })
}) })
} }
func performPing(client *netbird.Client, hostname string) { func parseSSHOptions(args []js.Value) (jwtToken string, ipVersion int) {
if len(args) > 3 && !args[3].IsNull() && !args[3].IsUndefined() {
jwtToken = args[3].String()
}
if len(args) > 4 {
ipVersion = jsIPVersion(args[4])
}
return
}
func connectSSH(client *netbird.Client, host string, port int, username, jwtToken string, ipVersion int) (js.Value, error) {
sshClient := ssh.NewClient(client)
if err := sshClient.Connect(host, port, username, jwtToken, ipVersion); err != nil {
return js.Undefined(), err
}
if err := sshClient.StartSession(80, 24); err != nil {
if closeErr := sshClient.Close(); closeErr != nil {
log.Errorf("Error closing SSH client: %v", closeErr)
}
return js.Undefined(), err
}
return ssh.CreateJSInterface(sshClient), nil
}
func performPing(client *netbird.Client, hostname string, ipVersion int) {
ctx, cancel := context.WithTimeout(context.Background(), pingTimeout) ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
defer cancel() defer cancel()
// Default to ping4 to avoid dual-stack ICMP endpoint issues in wireguard-go netstack.
network := "ping4"
if ipVersion == 6 {
network = "ping6"
}
start := time.Now() start := time.Now()
conn, err := client.Dial(ctx, "ping", hostname) conn, err := client.Dial(ctx, network, hostname)
if err != nil { if err != nil {
js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s failed: %v", hostname, err)) js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s failed: %v", hostname, err))
return return
@@ -225,27 +245,39 @@ func performPing(client *netbird.Client, hostname string) {
} }
latency := time.Since(start) latency := time.Since(start)
js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s: %dms", hostname, latency.Milliseconds())) remote := conn.RemoteAddr().String()
msg := fmt.Sprintf("Ping to %s: %dms", hostname, latency.Milliseconds())
if remote != hostname {
msg += fmt.Sprintf(" (via %s)", remote)
}
js.Global().Get("console").Call("log", msg)
} }
func performPingTCP(client *netbird.Client, hostname string, port int) { func performPingTCP(client *netbird.Client, hostname string, port, ipVersion int) {
ctx, cancel := context.WithTimeout(context.Background(), pingTimeout) ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
defer cancel() defer cancel()
network := ipVersionNetwork("tcp", ipVersion)
address := net.JoinHostPort(hostname, fmt.Sprintf("%d", port)) address := net.JoinHostPort(hostname, fmt.Sprintf("%d", port))
start := time.Now() start := time.Now()
conn, err := client.Dial(ctx, "tcp", address) conn, err := client.Dial(ctx, network, address)
if err != nil { if err != nil {
js.Global().Get("console").Call("log", fmt.Sprintf("TCP ping to %s failed: %v", address, err)) js.Global().Get("console").Call("log", fmt.Sprintf("TCP ping to %s failed: %v", address, err))
return return
} }
latency := time.Since(start) latency := time.Since(start)
remote := conn.RemoteAddr().String()
if err := conn.Close(); err != nil { if err := conn.Close(); err != nil {
log.Debugf("failed to close TCP connection: %v", err) log.Debugf("failed to close TCP connection: %v", err)
} }
js.Global().Get("console").Call("log", fmt.Sprintf("TCP ping to %s succeeded: %dms", address, latency.Milliseconds())) msg := fmt.Sprintf("TCP ping to %s succeeded: %dms", address, latency.Milliseconds())
if remote != address {
msg += fmt.Sprintf(" (via %s)", remote)
}
js.Global().Get("console").Call("log", msg)
} }
// createPingMethod creates the ping method // createPingMethod creates the ping method
@@ -262,8 +294,12 @@ func createPingMethod(client *netbird.Client) js.Func {
} }
hostname := args[0].String() hostname := args[0].String()
var ipVersion int
if len(args) > 1 {
ipVersion = jsIPVersion(args[1])
}
return createPromise(func(resolve, reject js.Value) { return createPromise(func(resolve, reject js.Value) {
performPing(client, hostname) performPing(client, hostname, ipVersion)
resolve.Invoke(js.Undefined()) resolve.Invoke(js.Undefined())
}) })
}) })
@@ -290,8 +326,12 @@ func createPingTCPMethod(client *netbird.Client) js.Func {
hostname := args[0].String() hostname := args[0].String()
port := args[1].Int() port := args[1].Int()
var ipVersion int
if len(args) > 2 {
ipVersion = jsIPVersion(args[2])
}
return createPromise(func(resolve, reject js.Value) { return createPromise(func(resolve, reject js.Value) {
performPingTCP(client, hostname, port) performPingTCP(client, hostname, port, ipVersion)
resolve.Invoke(js.Undefined()) resolve.Invoke(js.Undefined())
}) })
}) })
@@ -464,6 +504,31 @@ func createSetLogLevelMethod(client *netbird.Client) js.Func {
}) })
} }
// ipVersionNetwork appends "4" or "6" to a base network string (e.g. "tcp" -> "tcp4").
func ipVersionNetwork(base string, ipVersion int) string {
switch ipVersion {
case 4:
return base + "4"
case 6:
return base + "6"
default:
return base
}
}
// jsIPVersion extracts an IP version (4 or 6) from a JS string or number.
func jsIPVersion(v js.Value) int {
switch v.Type() {
case js.TypeNumber:
return v.Int()
case js.TypeString:
n, _ := strconv.Atoi(v.String())
return n
default:
return 0
}
}
// createPromise is a helper to create JavaScript promises // createPromise is a helper to create JavaScript promises
func createPromise(handler func(resolve, reject js.Value)) js.Value { func createPromise(handler func(resolve, reject js.Value)) js.Value {
return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any { return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any {

View File

@@ -46,8 +46,9 @@ func NewClient(nbClient *netbird.Client) *Client {
} }
} }
// Connect establishes an SSH connection through NetBird network // Connect establishes an SSH connection through NetBird network.
func (c *Client) Connect(host string, port int, username, jwtToken string) error { // ipVersion may be 4, 6, or 0 for automatic selection.
func (c *Client) Connect(host string, port int, username, jwtToken string, ipVersion int) error {
addr := net.JoinHostPort(host, fmt.Sprintf("%d", port)) addr := net.JoinHostPort(host, fmt.Sprintf("%d", port))
logrus.Infof("SSH: Connecting to %s as %s", addr, username) logrus.Infof("SSH: Connecting to %s as %s", addr, username)
@@ -63,10 +64,18 @@ func (c *Client) Connect(host string, port int, username, jwtToken string) error
Timeout: sshDialTimeout, Timeout: sshDialTimeout,
} }
network := "tcp"
switch ipVersion {
case 4:
network = "tcp4"
case 6:
network = "tcp6"
}
ctx, cancel := context.WithTimeout(context.Background(), sshDialTimeout) ctx, cancel := context.WithTimeout(context.Background(), sshDialTimeout)
defer cancel() defer cancel()
conn, err := c.nbClient.Dial(ctx, "tcp", addr) conn, err := c.nbClient.Dial(ctx, network, addr)
if err != nil { if err != nil {
return fmt.Errorf("dial %s: %w", addr, err) return fmt.Errorf("dial %s: %w", addr, err)
} }

View File

@@ -3,6 +3,7 @@ package cmd
import ( import (
"fmt" "fmt"
"strconv" "strconv"
"time"
"github.com/spf13/cobra" "github.com/spf13/cobra"
@@ -57,7 +58,11 @@ var debugSyncCmd = &cobra.Command{
SilenceUsage: true, SilenceUsage: true,
} }
var pingTimeout string var (
pingTimeout time.Duration
pingIPv4 bool
pingIPv6 bool
)
var debugPingCmd = &cobra.Command{ var debugPingCmd = &cobra.Command{
Use: "ping <account-id> <host> [port]", Use: "ping <account-id> <host> [port]",
@@ -108,7 +113,10 @@ func init() {
debugStatusCmd.Flags().StringVar(&statusFilterByStatus, "filter-by-status", "", "Filter by status (idle|connecting|connected)") debugStatusCmd.Flags().StringVar(&statusFilterByStatus, "filter-by-status", "", "Filter by status (idle|connecting|connected)")
debugStatusCmd.Flags().StringVar(&statusFilterByConnectionType, "filter-by-connection-type", "", "Filter by connection type (P2P|Relayed)") debugStatusCmd.Flags().StringVar(&statusFilterByConnectionType, "filter-by-connection-type", "", "Filter by connection type (P2P|Relayed)")
debugPingCmd.Flags().StringVar(&pingTimeout, "timeout", "", "Ping timeout (e.g., 10s)") debugPingCmd.Flags().DurationVar(&pingTimeout, "timeout", 0, "Ping timeout (e.g., 10s)")
debugPingCmd.Flags().BoolVarP(&pingIPv4, "ipv4", "4", false, "Force IPv4")
debugPingCmd.Flags().BoolVarP(&pingIPv6, "ipv6", "6", false, "Force IPv6")
debugPingCmd.MarkFlagsMutuallyExclusive("ipv4", "ipv6")
debugCmd.AddCommand(debugHealthCmd) debugCmd.AddCommand(debugHealthCmd)
debugCmd.AddCommand(debugClientsCmd) debugCmd.AddCommand(debugClientsCmd)
@@ -157,7 +165,14 @@ func runDebugPing(cmd *cobra.Command, args []string) error {
} }
port = p port = p
} }
return getDebugClient(cmd).PingTCP(cmd.Context(), args[0], args[1], port, pingTimeout) var ipVersion string
switch {
case pingIPv4:
ipVersion = "4"
case pingIPv6:
ipVersion = "6"
}
return getDebugClient(cmd).PingTCP(cmd.Context(), args[0], args[1], port, pingTimeout, ipVersion)
} }
func runDebugLogLevel(cmd *cobra.Command, args []string) error { func runDebugLogLevel(cmd *cobra.Command, args []string) error {

View File

@@ -6,10 +6,12 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"net"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
"time" "time"
) )
// StatusFilters contains filter options for status queries. // StatusFilters contains filter options for status queries.
@@ -230,12 +232,16 @@ func (c *Client) ClientSyncResponse(ctx context.Context, accountID string) error
} }
// PingTCP performs a TCP ping through a client. // PingTCP performs a TCP ping through a client.
func (c *Client) PingTCP(ctx context.Context, accountID, host string, port int, timeout string) error { // ipVersion may be "4", "6", or "" for automatic.
func (c *Client) PingTCP(ctx context.Context, accountID, host string, port int, timeout time.Duration, ipVersion string) error {
params := url.Values{} params := url.Values{}
params.Set("host", host) params.Set("host", host)
params.Set("port", fmt.Sprintf("%d", port)) params.Set("port", fmt.Sprintf("%d", port))
if timeout != "" { if timeout > 0 {
params.Set("timeout", timeout) params.Set("timeout", timeout.String())
}
if ipVersion != "" {
params.Set("ip_version", ipVersion)
} }
path := fmt.Sprintf("/debug/clients/%s/pingtcp?%s", url.PathEscape(accountID), params.Encode()) path := fmt.Sprintf("/debug/clients/%s/pingtcp?%s", url.PathEscape(accountID), params.Encode())
@@ -244,11 +250,17 @@ func (c *Client) PingTCP(ctx context.Context, accountID, host string, port int,
func (c *Client) printPingResult(data map[string]any) { func (c *Client) printPingResult(data map[string]any) {
success, _ := data["success"].(bool) success, _ := data["success"].(bool)
host := net.JoinHostPort(fmt.Sprint(data["host"]), fmt.Sprint(data["port"]))
if success { if success {
_, _ = fmt.Fprintf(c.out, "Success: %v:%v\n", data["host"], data["port"]) remote, _ := data["remote"].(string)
if remote != "" && remote != host {
_, _ = fmt.Fprintf(c.out, "Success: %s (via %s)\n", host, remote)
} else {
_, _ = fmt.Fprintf(c.out, "Success: %s\n", host)
}
_, _ = fmt.Fprintf(c.out, "Latency: %v\n", data["latency"]) _, _ = fmt.Fprintf(c.out, "Latency: %v\n", data["latency"])
} else { } else {
_, _ = fmt.Fprintf(c.out, "Failed: %v:%v\n", data["host"], data["port"]) _, _ = fmt.Fprintf(c.out, "Failed: %s\n", host)
c.printError(data) c.printError(data)
} }
} }

View File

@@ -9,6 +9,7 @@ import (
"fmt" "fmt"
"html/template" "html/template"
"maps" "maps"
"net"
"net/http" "net/http"
"slices" "slices"
"strconv" "strconv"
@@ -525,13 +526,18 @@ func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountI
} }
} }
network := "tcp"
if v := r.URL.Query().Get("ip_version"); v == "4" || v == "6" {
network += v
}
ctx, cancel := context.WithTimeout(r.Context(), timeout) ctx, cancel := context.WithTimeout(r.Context(), timeout)
defer cancel() defer cancel()
address := fmt.Sprintf("%s:%d", host, port) address := net.JoinHostPort(host, strconv.Itoa(port))
start := time.Now() start := time.Now()
conn, err := client.Dial(ctx, "tcp", address) conn, err := client.Dial(ctx, network, address)
if err != nil { if err != nil {
h.writeJSON(w, map[string]interface{}{ h.writeJSON(w, map[string]interface{}{
"success": false, "success": false,
@@ -541,18 +547,22 @@ func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountI
}) })
return return
} }
remote := conn.RemoteAddr().String()
if err := conn.Close(); err != nil { if err := conn.Close(); err != nil {
h.logger.Debugf("close tcp ping connection: %v", err) h.logger.Debugf("close tcp ping connection: %v", err)
} }
latency := time.Since(start) latency := time.Since(start)
h.writeJSON(w, map[string]interface{}{ resp := map[string]interface{}{
"success": true, "success": true,
"host": host, "host": host,
"port": port, "port": port,
"remote": remote,
"latency_ms": latency.Milliseconds(), "latency_ms": latency.Milliseconds(),
"latency": formatDuration(latency), "latency": formatDuration(latency),
}) }
h.writeJSON(w, resp)
} }
func (h *Handler) handleLogLevel(w http.ResponseWriter, r *http.Request, accountID types.AccountID) { func (h *Handler) handleLogLevel(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {

View File

@@ -20,6 +20,9 @@ const (
MaxMetric = 9999 MaxMetric = 9999
// MaxNetIDChar Max Network Identifier // MaxNetIDChar Max Network Identifier
MaxNetIDChar = 40 MaxNetIDChar = 40
// V6ExitSuffix is appended to a v4 exit node NetID to form its v6 counterpart.
V6ExitSuffix = "-v6"
) )
const ( const (
@@ -215,3 +218,61 @@ func ParseNetwork(networkString string) (NetworkType, netip.Prefix, error) {
return IPv4Network, masked, nil return IPv4Network, masked, nil
} }
var (
v4Default = netip.PrefixFrom(netip.IPv4Unspecified(), 0)
v6Default = netip.PrefixFrom(netip.IPv6Unspecified(), 0)
)
// IsV4DefaultRoute reports whether p is the IPv4 default route (0.0.0.0/0).
func IsV4DefaultRoute(p netip.Prefix) bool { return p == v4Default }
// IsV6DefaultRoute reports whether p is the IPv6 default route (::/0).
func IsV6DefaultRoute(p netip.Prefix) bool { return p == v6Default }
// ExpandV6ExitPairs appends the paired "-v6" exit node NetID for any v4 exit
// node (0.0.0.0/0) in ids that has a matching v6 counterpart (::/0) in routesMap.
// It modifies and returns the input slice.
func ExpandV6ExitPairs(ids []NetID, routesMap map[NetID][]*Route) []NetID {
for _, id := range ids {
rt, ok := routesMap[id]
if !ok || len(rt) == 0 || !IsV4DefaultRoute(rt[0].Network) {
continue
}
v6ID := NetID(string(id) + V6ExitSuffix)
if v6Rt, ok := routesMap[v6ID]; ok && len(v6Rt) > 0 && IsV6DefaultRoute(v6Rt[0].Network) {
if !slices.Contains(ids, v6ID) {
ids = append(ids, v6ID)
}
}
}
return ids
}
// V6ExitMergeSet scans routesMap and returns the set of v6 exit node NetIDs
// that should be hidden from the UI because they are paired with a v4 exit node.
// A v6 ID is paired when it has suffix "-v6", its route is ::/0, and the base
// name (without "-v6") exists with route 0.0.0.0/0.
func V6ExitMergeSet(routesMap map[NetID][]*Route) map[NetID]struct{} {
merged := make(map[NetID]struct{})
for id, rt := range routesMap {
if len(rt) == 0 {
continue
}
name := string(id)
if !IsV6DefaultRoute(rt[0].Network) || !strings.HasSuffix(name, V6ExitSuffix) {
continue
}
baseName := NetID(strings.TrimSuffix(name, V6ExitSuffix))
if baseRt, ok := routesMap[baseName]; ok && len(baseRt) > 0 && IsV4DefaultRoute(baseRt[0].Network) {
merged[id] = struct{}{}
}
}
return merged
}
// HasV6ExitPair reports whether id has a paired v6 exit node in the merge set.
func HasV6ExitPair(id NetID, v6Merged map[NetID]struct{}) bool {
_, ok := v6Merged[NetID(string(id)+"-v6")]
return ok
}

108
route/route_test.go Normal file
View File

@@ -0,0 +1,108 @@
package route
import (
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
)
func TestExpandV6ExitPairs(t *testing.T) {
v4ExitRoute := &Route{Network: netip.MustParsePrefix("0.0.0.0/0")}
v6ExitRoute := &Route{Network: netip.MustParsePrefix("::/0")}
regularRoute := &Route{Network: netip.MustParsePrefix("10.0.0.0/8")}
tests := []struct {
name string
ids []NetID
routesMap map[NetID][]*Route
expected []NetID
}{
{
name: "v4 exit node with matching v6 pair",
ids: []NetID{"exit-node"},
routesMap: map[NetID][]*Route{
"exit-node": {v4ExitRoute},
"exit-node-v6": {v6ExitRoute},
},
expected: []NetID{"exit-node", "exit-node-v6"},
},
{
name: "v4 exit node without v6 pair",
ids: []NetID{"exit-node"},
routesMap: map[NetID][]*Route{
"exit-node": {v4ExitRoute},
},
expected: []NetID{"exit-node"},
},
{
name: "regular route is not expanded",
ids: []NetID{"office"},
routesMap: map[NetID][]*Route{
"office": {regularRoute},
"office-v6": {v6ExitRoute},
},
expected: []NetID{"office"},
},
{
name: "v6 already included is not duplicated",
ids: []NetID{"exit-node", "exit-node-v6"},
routesMap: map[NetID][]*Route{
"exit-node": {v4ExitRoute},
"exit-node-v6": {v6ExitRoute},
},
expected: []NetID{"exit-node", "exit-node-v6"},
},
{
name: "multiple exit nodes expanded independently",
ids: []NetID{"exit-a", "exit-b"},
routesMap: map[NetID][]*Route{
"exit-a": {v4ExitRoute},
"exit-a-v6": {v6ExitRoute},
"exit-b": {v4ExitRoute},
"exit-b-v6": {v6ExitRoute},
},
expected: []NetID{"exit-a", "exit-b", "exit-a-v6", "exit-b-v6"},
},
{
name: "v6 suffix but not exit node network",
ids: []NetID{"office"},
routesMap: map[NetID][]*Route{
"office": {regularRoute},
"office-v6": {regularRoute},
},
expected: []NetID{"office"},
},
{
name: "user-chosen name for exit node with v6 pair",
ids: []NetID{"my-exit"},
routesMap: map[NetID][]*Route{
"my-exit": {v4ExitRoute},
"my-exit-v6": {v6ExitRoute},
},
expected: []NetID{"my-exit", "my-exit-v6"},
},
{
name: "real-world management-generated IDs",
ids: []NetID{"0.0.0.0/0"},
routesMap: map[NetID][]*Route{
"0.0.0.0/0": {v4ExitRoute},
"0.0.0.0/0-v6": {v6ExitRoute},
},
expected: []NetID{"0.0.0.0/0", "0.0.0.0/0-v6"},
},
{
name: "empty input",
ids: []NetID{},
routesMap: map[NetID][]*Route{},
expected: []NetID{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ExpandV6ExitPairs(tt.ids, tt.routesMap)
assert.ElementsMatch(t, tt.expected, result)
})
}
}

View File

@@ -46,7 +46,7 @@ func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) {
InitialPacketSize: nbRelay.QUICInitialPacketSize, InitialPacketSize: nbRelay.QUICInitialPacketSize,
} }
udpConn, err := nbnet.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) udpConn, err := nbnet.ListenUDP("udp", &net.UDPAddr{Port: 0})
if err != nil { if err != nil {
log.Errorf("failed to listen on UDP: %s", err) log.Errorf("failed to listen on UDP: %s", err)
return nil, err return nil, err