mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-15 23:06:38 +00:00
[client] Add IPv6 support to ACL manager, USP filter, and forwarder (#5688)
This commit is contained in:
@@ -238,43 +238,84 @@ func (c *Client) Networks() *NetworkArray {
|
||||
return nil
|
||||
}
|
||||
|
||||
routesMap := routeManager.GetClientRoutesWithNetID()
|
||||
v6Merged := route.V6ExitMergeSet(routesMap)
|
||||
resolvedDomains := c.recorder.GetResolvedDomainsStates()
|
||||
|
||||
networkArray := &NetworkArray{
|
||||
items: make([]Network, 0),
|
||||
}
|
||||
|
||||
resolvedDomains := c.recorder.GetResolvedDomainsStates()
|
||||
|
||||
for id, routes := range routeManager.GetClientRoutesWithNetID() {
|
||||
for id, routes := range routesMap {
|
||||
if len(routes) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
r := routes[0]
|
||||
domains := c.getNetworkDomainsFromRoute(r, resolvedDomains)
|
||||
netStr := r.Network.String()
|
||||
|
||||
if r.IsDynamic() {
|
||||
netStr = r.Domains.SafeString()
|
||||
}
|
||||
|
||||
routePeer, err := c.recorder.GetPeer(routes[0].Peer)
|
||||
if err != nil {
|
||||
log.Errorf("could not get peer info for %s: %v", routes[0].Peer, err)
|
||||
if _, skip := v6Merged[id]; skip {
|
||||
continue
|
||||
}
|
||||
network := Network{
|
||||
Name: string(id),
|
||||
Network: netStr,
|
||||
Peer: routePeer.FQDN,
|
||||
Status: routePeer.ConnStatus.String(),
|
||||
IsSelected: routeSelector.IsSelected(id),
|
||||
Domains: domains,
|
||||
|
||||
network := c.buildNetwork(id, routes, routeSelector.IsSelected(id), resolvedDomains, v6Merged)
|
||||
if network == nil {
|
||||
continue
|
||||
}
|
||||
networkArray.Add(network)
|
||||
networkArray.Add(*network)
|
||||
}
|
||||
return networkArray
|
||||
}
|
||||
|
||||
func (c *Client) buildNetwork(id route.NetID, routes []*route.Route, selected bool, resolvedDomains map[domain.Domain]peer.ResolvedDomainInfo, v6Merged map[route.NetID]struct{}) *Network {
|
||||
r := routes[0]
|
||||
netStr := r.Network.String()
|
||||
if r.IsDynamic() {
|
||||
netStr = r.Domains.SafeString()
|
||||
}
|
||||
|
||||
routePeer, err := c.findBestRoutePeer(routes)
|
||||
if err != nil {
|
||||
log.Errorf("could not get peer info for route %s: %v", id, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
network := &Network{
|
||||
Name: string(id),
|
||||
Network: netStr,
|
||||
Peer: routePeer.FQDN,
|
||||
Status: routePeer.ConnStatus.String(),
|
||||
IsSelected: selected,
|
||||
Domains: c.getNetworkDomainsFromRoute(r, resolvedDomains),
|
||||
}
|
||||
|
||||
if route.IsV4DefaultRoute(r.Network) && route.HasV6ExitPair(id, v6Merged) {
|
||||
network.Network = "0.0.0.0/0, ::/0"
|
||||
}
|
||||
|
||||
return network
|
||||
}
|
||||
|
||||
// findBestRoutePeer returns the peer actively routing traffic for the given
|
||||
// HA route group. Falls back to the first connected peer, then the first peer.
|
||||
func (c *Client) findBestRoutePeer(routes []*route.Route) (peer.State, error) {
|
||||
netStr := routes[0].Network.String()
|
||||
|
||||
fullStatus := c.recorder.GetFullStatus()
|
||||
for _, p := range fullStatus.Peers {
|
||||
if _, ok := p.GetRoutes()[netStr]; ok {
|
||||
return p, nil
|
||||
}
|
||||
}
|
||||
|
||||
for _, r := range routes {
|
||||
p, err := c.recorder.GetPeer(r.Peer)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if p.ConnStatus == peer.StatusConnected {
|
||||
return p, nil
|
||||
}
|
||||
}
|
||||
return c.recorder.GetPeer(routes[0].Peer)
|
||||
}
|
||||
|
||||
// OnUpdatedHostDNS update the DNS servers addresses for root zones
|
||||
func (c *Client) OnUpdatedHostDNS(list *DNSList) error {
|
||||
dnsServer, err := dns.GetServerDns()
|
||||
|
||||
@@ -18,9 +18,12 @@ func executeRouteToggle(id string, manager routemanager.Manager,
|
||||
netID := route.NetID(id)
|
||||
routes := []route.NetID{netID}
|
||||
|
||||
log.Debugf("%s with id: %s", operationName, id)
|
||||
routesMap := manager.GetClientRoutesWithNetID()
|
||||
routes = route.ExpandV6ExitPairs(routes, routesMap)
|
||||
|
||||
if err := routeOperation(routes, maps.Keys(manager.GetClientRoutesWithNetID())); err != nil {
|
||||
log.Debugf("%s with ids: %v", operationName, routes)
|
||||
|
||||
if err := routeOperation(routes, maps.Keys(routesMap)); err != nil {
|
||||
log.Debugf("error when %s: %s", operationName, err)
|
||||
return fmt.Errorf("error %s: %w", operationName, err)
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"net/url"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -26,8 +27,9 @@ type Anonymizer struct {
|
||||
}
|
||||
|
||||
func DefaultAddresses() (netip.Addr, netip.Addr) {
|
||||
// 198.51.100.0, 100::
|
||||
return netip.AddrFrom4([4]byte{198, 51, 100, 0}), netip.AddrFrom16([16]byte{0x01})
|
||||
// 198.51.100.0 (RFC 5737 TEST-NET-2), 2001:db8:ffff:: (RFC 3849 documentation, last /48)
|
||||
// The old start 100:: (discard, RFC 6666) is now used for fake IPs on Android.
|
||||
return netip.AddrFrom4([4]byte{198, 51, 100, 0}), netip.MustParseAddr("2001:db8:ffff::")
|
||||
}
|
||||
|
||||
func NewAnonymizer(startIPv4, startIPv6 netip.Addr) *Anonymizer {
|
||||
@@ -96,6 +98,11 @@ func (a *Anonymizer) isInAnonymizedRange(ip netip.Addr) bool {
|
||||
}
|
||||
|
||||
func (a *Anonymizer) AnonymizeIPString(ip string) string {
|
||||
// Handle CIDR notation (e.g. "2001:db8::/32")
|
||||
if prefix, err := netip.ParsePrefix(ip); err == nil {
|
||||
return a.AnonymizeIP(prefix.Addr()).String() + "/" + strconv.Itoa(prefix.Bits())
|
||||
}
|
||||
|
||||
addr, err := netip.ParseAddr(ip)
|
||||
if err != nil {
|
||||
return ip
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
|
||||
func TestAnonymizeIP(t *testing.T) {
|
||||
startIPv4 := netip.MustParseAddr("198.51.100.0")
|
||||
startIPv6 := netip.MustParseAddr("100::")
|
||||
startIPv6 := netip.MustParseAddr("2001:db8:ffff::")
|
||||
anonymizer := anonymize.NewAnonymizer(startIPv4, startIPv6)
|
||||
|
||||
tests := []struct {
|
||||
@@ -26,9 +26,9 @@ func TestAnonymizeIP(t *testing.T) {
|
||||
{"Second Public IPv4", "4.3.2.1", "198.51.100.1"},
|
||||
{"Repeated IPv4", "1.2.3.4", "198.51.100.0"},
|
||||
{"Private IPv4", "192.168.1.1", "192.168.1.1"},
|
||||
{"First Public IPv6", "2607:f8b0:4005:805::200e", "100::"},
|
||||
{"Second Public IPv6", "a::b", "100::1"},
|
||||
{"Repeated IPv6", "2607:f8b0:4005:805::200e", "100::"},
|
||||
{"First Public IPv6", "2607:f8b0:4005:805::200e", "2001:db8:ffff::"},
|
||||
{"Second Public IPv6", "a::b", "2001:db8:ffff::1"},
|
||||
{"Repeated IPv6", "2607:f8b0:4005:805::200e", "2001:db8:ffff::"},
|
||||
{"Private IPv6", "fe80::1", "fe80::1"},
|
||||
{"In Range IPv4", "198.51.100.2", "198.51.100.2"},
|
||||
}
|
||||
@@ -274,17 +274,17 @@ func TestAnonymizeString_IPAddresses(t *testing.T) {
|
||||
{
|
||||
name: "IPv6 Address",
|
||||
input: "Access attempted from 2001:db8::ff00:42",
|
||||
expect: "Access attempted from 100::",
|
||||
expect: "Access attempted from 2001:db8:ffff::",
|
||||
},
|
||||
{
|
||||
name: "IPv6 Address with Port",
|
||||
input: "Access attempted from [2001:db8::ff00:42]:8080",
|
||||
expect: "Access attempted from [100::]:8080",
|
||||
expect: "Access attempted from [2001:db8:ffff::]:8080",
|
||||
},
|
||||
{
|
||||
name: "Both IPv4 and IPv6",
|
||||
input: "IPv4: 142.108.0.1 and IPv6: 2001:db8::ff00:43",
|
||||
expect: "IPv4: 198.51.100.1 and IPv6: 100::1",
|
||||
expect: "IPv4: 198.51.100.1 and IPv6: 2001:db8:ffff::1",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -787,10 +787,10 @@ func isUnixSocket(path string) bool {
|
||||
return strings.HasPrefix(path, "/") || strings.HasPrefix(path, "./")
|
||||
}
|
||||
|
||||
// normalizeLocalHost converts "*" to "0.0.0.0" for binding to all interfaces.
|
||||
// normalizeLocalHost converts "*" to "" for binding to all interfaces (dual-stack).
|
||||
func normalizeLocalHost(host string) string {
|
||||
if host == "*" {
|
||||
return "0.0.0.0"
|
||||
return ""
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
@@ -527,10 +527,10 @@ func TestParsePortForward(t *testing.T) {
|
||||
{
|
||||
name: "wildcard bind all interfaces",
|
||||
spec: "*:8080:localhost:80",
|
||||
expectedLocal: "0.0.0.0:8080",
|
||||
expectedLocal: ":8080",
|
||||
expectedRemote: "localhost:80",
|
||||
expectError: false,
|
||||
description: "Wildcard * should bind to all interfaces (0.0.0.0)",
|
||||
description: "Wildcard * should bind to all interfaces (dual-stack)",
|
||||
},
|
||||
{
|
||||
name: "wildcard for port only",
|
||||
|
||||
@@ -36,6 +36,7 @@ type aclManager struct {
|
||||
entries aclEntries
|
||||
optionalEntries map[string][]entry
|
||||
ipsetStore *ipsetStore
|
||||
v6 bool
|
||||
|
||||
stateManager *statemanager.Manager
|
||||
}
|
||||
@@ -47,6 +48,7 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*acl
|
||||
entries: make(map[string][][]string),
|
||||
optionalEntries: make(map[string][]entry),
|
||||
ipsetStore: newIpsetStore(),
|
||||
v6: iptablesClient.Proto() == iptables.ProtocolIPv6,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -81,7 +83,11 @@ func (m *aclManager) AddPeerFiltering(
|
||||
chain := chainNameInputRules
|
||||
|
||||
ipsetName = transformIPsetName(ipsetName, sPort, dPort, action)
|
||||
specs := filterRuleSpecs(ip, string(protocol), sPort, dPort, action, ipsetName)
|
||||
if m.v6 && ipsetName != "" {
|
||||
ipsetName += "-v6"
|
||||
}
|
||||
proto := protoForFamily(protocol, m.v6)
|
||||
specs := filterRuleSpecs(ip, proto, sPort, dPort, action, ipsetName)
|
||||
|
||||
mangleSpecs := slices.Clone(specs)
|
||||
mangleSpecs = append(mangleSpecs,
|
||||
@@ -105,6 +111,7 @@ func (m *aclManager) AddPeerFiltering(
|
||||
ip: ip.String(),
|
||||
chain: chain,
|
||||
specs: specs,
|
||||
v6: m.v6,
|
||||
}}, nil
|
||||
}
|
||||
|
||||
@@ -157,6 +164,7 @@ func (m *aclManager) AddPeerFiltering(
|
||||
ipsetName: ipsetName,
|
||||
ip: ip.String(),
|
||||
chain: chain,
|
||||
v6: m.v6,
|
||||
}
|
||||
|
||||
m.updateState()
|
||||
@@ -376,8 +384,13 @@ func (m *aclManager) updateState() {
|
||||
currentState.Lock()
|
||||
defer currentState.Unlock()
|
||||
|
||||
currentState.ACLEntries = m.entries
|
||||
currentState.ACLIPsetStore = m.ipsetStore
|
||||
if m.v6 {
|
||||
currentState.ACLEntries6 = m.entries
|
||||
currentState.ACLIPsetStore6 = m.ipsetStore
|
||||
} else {
|
||||
currentState.ACLEntries = m.entries
|
||||
currentState.ACLIPsetStore = m.ipsetStore
|
||||
}
|
||||
|
||||
if err := m.stateManager.UpdateState(currentState); err != nil {
|
||||
log.Errorf("failed to update state: %v", err)
|
||||
@@ -385,6 +398,15 @@ func (m *aclManager) updateState() {
|
||||
}
|
||||
|
||||
// filterRuleSpecs returns the specs of a filtering rule
|
||||
// protoForFamily translates ICMP to ICMPv6 for ip6tables.
|
||||
// ip6tables requires "ipv6-icmp" (or "icmpv6") instead of "icmp".
|
||||
func protoForFamily(protocol firewall.Protocol, v6 bool) string {
|
||||
if v6 && protocol == firewall.ProtocolICMP {
|
||||
return "ipv6-icmp"
|
||||
}
|
||||
return string(protocol)
|
||||
}
|
||||
|
||||
func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) {
|
||||
// don't use IP matching if IP is 0.0.0.0
|
||||
matchByIP := !ip.IsUnspecified()
|
||||
@@ -437,6 +459,9 @@ func (m *aclManager) createIPSet(name string) error {
|
||||
opts := ipset.CreateOptions{
|
||||
Replace: true,
|
||||
}
|
||||
if m.v6 {
|
||||
opts.Family = ipset.FamilyIPV6
|
||||
}
|
||||
|
||||
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
|
||||
return fmt.Errorf("create ipset %s: %w", name, err)
|
||||
|
||||
@@ -17,6 +17,10 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
type resetter interface {
|
||||
Reset() error
|
||||
}
|
||||
|
||||
// Manager of iptables firewall
|
||||
type Manager struct {
|
||||
mutex sync.Mutex
|
||||
@@ -27,6 +31,11 @@ type Manager struct {
|
||||
aclMgr *aclManager
|
||||
router *router
|
||||
rawSupported bool
|
||||
|
||||
// IPv6 counterparts, nil when no v6 overlay
|
||||
ipv6Client *iptables.IPTables
|
||||
aclMgr6 *aclManager
|
||||
router6 *router
|
||||
}
|
||||
|
||||
// iFaceMapper defines subset methods of interface required for manager
|
||||
@@ -58,9 +67,43 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
|
||||
return nil, fmt.Errorf("create acl manager: %w", err)
|
||||
}
|
||||
|
||||
if wgIface.Address().HasIPv6() {
|
||||
if err := m.createIPv6Components(wgIface, mtu); err != nil {
|
||||
return nil, fmt.Errorf("create IPv6 firewall: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *Manager) createIPv6Components(wgIface iFaceMapper, mtu uint16) error {
|
||||
ip6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
||||
if err != nil {
|
||||
return fmt.Errorf("init ip6tables: %w", err)
|
||||
}
|
||||
m.ipv6Client = ip6Client
|
||||
|
||||
m.router6, err = newRouter(ip6Client, wgIface, mtu)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create v6 router: %w", err)
|
||||
}
|
||||
|
||||
// Share the same IP forwarding state with the v4 router, since
|
||||
// EnableIPForwarding controls both v4 and v6 sysctls.
|
||||
m.router6.ipFwdState = m.router.ipFwdState
|
||||
|
||||
m.aclMgr6, err = newAclManager(ip6Client, wgIface)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create v6 acl manager: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) hasIPv6() bool {
|
||||
return m.ipv6Client != nil
|
||||
}
|
||||
|
||||
func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
state := &ShutdownState{
|
||||
InterfaceState: &InterfaceState{
|
||||
@@ -75,13 +118,8 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
log.Errorf("failed to update state: %v", err)
|
||||
}
|
||||
|
||||
if err := m.router.init(stateManager); err != nil {
|
||||
return fmt.Errorf("router init: %w", err)
|
||||
}
|
||||
|
||||
if err := m.aclMgr.init(stateManager); err != nil {
|
||||
// TODO: cleanup router
|
||||
return fmt.Errorf("acl manager init: %w", err)
|
||||
if err := m.initChains(stateManager); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := m.initNoTrackChain(); err != nil {
|
||||
@@ -98,6 +136,41 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// initChains initializes router and ACL chains for both address families,
|
||||
// rolling back on failure.
|
||||
func (m *Manager) initChains(stateManager *statemanager.Manager) error {
|
||||
type initStep struct {
|
||||
name string
|
||||
init func(*statemanager.Manager) error
|
||||
mgr resetter
|
||||
}
|
||||
|
||||
steps := []initStep{
|
||||
{"router", m.router.init, m.router},
|
||||
{"acl manager", m.aclMgr.init, m.aclMgr},
|
||||
}
|
||||
if m.hasIPv6() {
|
||||
steps = append(steps,
|
||||
initStep{"v6 router", m.router6.init, m.router6},
|
||||
initStep{"v6 acl manager", m.aclMgr6.init, m.aclMgr6},
|
||||
)
|
||||
}
|
||||
|
||||
var initialized []initStep
|
||||
for _, s := range steps {
|
||||
if err := s.init(stateManager); err != nil {
|
||||
for i := len(initialized) - 1; i >= 0; i-- {
|
||||
if rerr := initialized[i].mgr.Reset(); rerr != nil {
|
||||
log.Warnf("rollback %s: %v", initialized[i].name, rerr)
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("%s init: %w", s.name, err)
|
||||
}
|
||||
initialized = append(initialized, s)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddPeerFiltering adds a rule to the firewall
|
||||
//
|
||||
// Comment will be ignored because some system this feature is not supported
|
||||
@@ -113,7 +186,13 @@ func (m *Manager) AddPeerFiltering(
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.aclMgr.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||
if ip.To4() != nil {
|
||||
return m.aclMgr.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||
}
|
||||
if !m.hasIPv6() {
|
||||
return nil, fmt.Errorf("IPv6 not initialized, cannot add rule for %s", ip)
|
||||
}
|
||||
return m.aclMgr6.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||
}
|
||||
|
||||
func (m *Manager) AddRouteFiltering(
|
||||
@@ -127,25 +206,48 @@ func (m *Manager) AddRouteFiltering(
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if destination.IsPrefix() && !destination.Prefix.Addr().Is4() {
|
||||
return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String())
|
||||
if isIPv6RouteRule(sources, destination) {
|
||||
if !m.hasIPv6() {
|
||||
return nil, fmt.Errorf("IPv6 not initialized, cannot add route rule")
|
||||
}
|
||||
return m.router6.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||
}
|
||||
|
||||
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||
}
|
||||
|
||||
func isIPv6RouteRule(sources []netip.Prefix, destination firewall.Network) bool {
|
||||
if destination.IsPrefix() {
|
||||
return destination.Prefix.Addr().Is6()
|
||||
}
|
||||
return len(sources) > 0 && sources[0].Addr().Is6()
|
||||
}
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if m.hasIPv6() && isIPv6IptRule(rule) {
|
||||
return m.aclMgr6.DeletePeerRule(rule)
|
||||
}
|
||||
return m.aclMgr.DeletePeerRule(rule)
|
||||
}
|
||||
|
||||
func isIPv6IptRule(rule firewall.Rule) bool {
|
||||
r, ok := rule.(*Rule)
|
||||
return ok && r.v6
|
||||
}
|
||||
|
||||
// DeleteRouteRule deletes a routing rule.
|
||||
// Route rules are keyed by content hash. Check v4 first, try v6 if not found.
|
||||
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if m.hasIPv6() && !m.router.hasRule(rule.ID()) {
|
||||
return m.router6.DeleteRouteRule(rule)
|
||||
}
|
||||
return m.router.DeleteRouteRule(rule)
|
||||
}
|
||||
|
||||
@@ -161,18 +263,63 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.AddNatRule(pair)
|
||||
if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() {
|
||||
if !m.hasIPv6() {
|
||||
return fmt.Errorf("IPv6 not initialized, cannot add NAT rule")
|
||||
}
|
||||
return m.router6.AddNatRule(pair)
|
||||
}
|
||||
|
||||
if err := m.router.AddNatRule(pair); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Dynamic routes need NAT in both tables
|
||||
if m.hasIPv6() && pair.Destination.IsSet() {
|
||||
v6Pair := pair
|
||||
v6Pair.Source = firewall.Network{Prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
|
||||
if err := m.router6.AddNatRule(v6Pair); err != nil {
|
||||
return fmt.Errorf("add v6 NAT rule: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.RemoveNatRule(pair)
|
||||
if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() {
|
||||
if !m.hasIPv6() {
|
||||
return nil
|
||||
}
|
||||
return m.router6.RemoveNatRule(pair)
|
||||
}
|
||||
|
||||
if err := m.router.RemoveNatRule(pair); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if m.hasIPv6() && pair.Destination.IsSet() {
|
||||
v6Pair := pair
|
||||
v6Pair.Source = firewall.Network{Prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
|
||||
if err := m.router6.RemoveNatRule(v6Pair); err != nil {
|
||||
return fmt.Errorf("remove v6 NAT rule: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||
return firewall.SetLegacyManagement(m.router, isLegacy)
|
||||
if err := firewall.SetLegacyManagement(m.router, isLegacy); err != nil {
|
||||
return err
|
||||
}
|
||||
if m.hasIPv6() {
|
||||
return firewall.SetLegacyManagement(m.router6, isLegacy)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reset firewall to the default state
|
||||
@@ -186,6 +333,15 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
merr = multierror.Append(merr, fmt.Errorf("cleanup notrack chain: %w", err))
|
||||
}
|
||||
|
||||
if m.hasIPv6() {
|
||||
if err := m.aclMgr6.Reset(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset v6 acl manager: %w", err))
|
||||
}
|
||||
if err := m.router6.Reset(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset v6 router: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.aclMgr.Reset(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err))
|
||||
}
|
||||
@@ -209,19 +365,16 @@ func (m *Manager) AllowNetbird() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := m.AddPeerFiltering(
|
||||
nil,
|
||||
net.IP{0, 0, 0, 0},
|
||||
firewall.ProtocolALL,
|
||||
nil,
|
||||
nil,
|
||||
firewall.ActionAccept,
|
||||
"",
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("allow netbird interface traffic: %w", err)
|
||||
var merr *multierror.Error
|
||||
if _, err := m.aclMgr.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolALL, nil, nil, firewall.ActionAccept, ""); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("allow netbird interface traffic: %w", err))
|
||||
}
|
||||
return nil
|
||||
if m.hasIPv6() {
|
||||
if _, err := m.aclMgr6.AddPeerFiltering(nil, net.IPv6zero, firewall.ProtocolALL, nil, nil, firewall.ActionAccept, ""); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("allow v6 netbird interface traffic: %w", err))
|
||||
}
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// Flush doesn't need to be implemented for this manager
|
||||
@@ -251,6 +404,9 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error)
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if m.hasIPv6() && rule.TranslatedAddress.Is6() {
|
||||
return m.router6.AddDNATRule(rule)
|
||||
}
|
||||
return m.router.AddDNATRule(rule)
|
||||
}
|
||||
|
||||
@@ -259,6 +415,9 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if m.hasIPv6() && !m.router.hasRule(rule.ID()+dnatSuffix) {
|
||||
return m.router6.DeleteDNATRule(rule)
|
||||
}
|
||||
return m.router.DeleteDNATRule(rule)
|
||||
}
|
||||
|
||||
@@ -267,7 +426,26 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.UpdateSet(set, prefixes)
|
||||
var v4Prefixes, v6Prefixes []netip.Prefix
|
||||
for _, p := range prefixes {
|
||||
if p.Addr().Is6() {
|
||||
v6Prefixes = append(v6Prefixes, p)
|
||||
} else {
|
||||
v4Prefixes = append(v4Prefixes, p)
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.router.UpdateSet(set, v4Prefixes); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if m.hasIPv6() && len(v6Prefixes) > 0 {
|
||||
if err := m.router6.UpdateSet(set, v6Prefixes); err != nil {
|
||||
return fmt.Errorf("update v6 set: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
||||
@@ -275,6 +453,9 @@ func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protoco
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if m.hasIPv6() && localAddr.Is6() {
|
||||
return m.router6.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
@@ -283,6 +464,9 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if m.hasIPv6() && localAddr.Is6() {
|
||||
return m.router6.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
|
||||
@@ -54,8 +54,10 @@ const (
|
||||
snatSuffix = "_snat"
|
||||
fwdSuffix = "_fwd"
|
||||
|
||||
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
|
||||
ipTCPHeaderMinSize = 40
|
||||
// ipv4TCPHeaderSize is the minimum IPv4 (20) + TCP (20) header size for MSS calculation.
|
||||
ipv4TCPHeaderSize = 40
|
||||
// ipv6TCPHeaderSize is the minimum IPv6 (40) + TCP (20) header size for MSS calculation.
|
||||
ipv6TCPHeaderSize = 60
|
||||
)
|
||||
|
||||
type ruleInfo struct {
|
||||
@@ -86,6 +88,7 @@ type router struct {
|
||||
wgIface iFaceMapper
|
||||
legacyManagement bool
|
||||
mtu uint16
|
||||
v6 bool
|
||||
|
||||
stateManager *statemanager.Manager
|
||||
ipFwdState *ipfwdstate.IPForwardingState
|
||||
@@ -97,6 +100,7 @@ func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint1
|
||||
rules: make(map[string][]string),
|
||||
wgIface: wgIface,
|
||||
mtu: mtu,
|
||||
v6: iptablesClient.Proto() == iptables.ProtocolIPv6,
|
||||
ipFwdState: ipfwdstate.NewIPForwardingState(),
|
||||
}
|
||||
|
||||
@@ -186,6 +190,11 @@ func (r *router) AddRouteFiltering(
|
||||
return ruleKey, nil
|
||||
}
|
||||
|
||||
func (r *router) hasRule(id string) bool {
|
||||
_, ok := r.rules[id]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
||||
ruleKey := rule.ID()
|
||||
|
||||
@@ -434,6 +443,12 @@ func (r *router) createContainers() error {
|
||||
{chainRTRDR, tableNat},
|
||||
{chainRTMSSCLAMP, tableMangle},
|
||||
} {
|
||||
// Fallback: clear chains that survived an unclean shutdown.
|
||||
if ok, _ := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain); ok {
|
||||
if err := r.iptablesClient.ClearAndDeleteChain(chainInfo.table, chainInfo.chain); err != nil {
|
||||
log.Warnf("clear stale chain %s in %s: %v", chainInfo.chain, chainInfo.table, err)
|
||||
}
|
||||
}
|
||||
if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil {
|
||||
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
|
||||
}
|
||||
@@ -540,9 +555,12 @@ func (r *router) addPostroutingRules() error {
|
||||
}
|
||||
|
||||
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
|
||||
// TODO: Add IPv6 support
|
||||
func (r *router) addMSSClampingRules() error {
|
||||
mss := r.mtu - ipTCPHeaderMinSize
|
||||
overhead := uint16(ipv4TCPHeaderSize)
|
||||
if r.v6 {
|
||||
overhead = ipv6TCPHeaderSize
|
||||
}
|
||||
mss := r.mtu - overhead
|
||||
|
||||
// Add jump rule from FORWARD chain in mangle table to our custom chain
|
||||
jumpRule := []string{
|
||||
@@ -727,8 +745,13 @@ func (r *router) updateState() {
|
||||
currentState.Lock()
|
||||
defer currentState.Unlock()
|
||||
|
||||
currentState.RouteRules = r.rules
|
||||
currentState.RouteIPsetCounter = r.ipsetCounter
|
||||
if r.v6 {
|
||||
currentState.RouteRules6 = r.rules
|
||||
currentState.RouteIPsetCounter6 = r.ipsetCounter
|
||||
} else {
|
||||
currentState.RouteRules = r.rules
|
||||
currentState.RouteIPsetCounter = r.ipsetCounter
|
||||
}
|
||||
|
||||
if err := r.stateManager.UpdateState(currentState); err != nil {
|
||||
log.Errorf("failed to update state: %v", err)
|
||||
@@ -856,7 +879,7 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
|
||||
}
|
||||
|
||||
if fwdRule, exists := r.rules[ruleKey+fwdSuffix]; exists {
|
||||
if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, fwdRule...); err != nil {
|
||||
if err := r.iptablesClient.Delete(tableFilter, chainRTFWDOUT, fwdRule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete forward rule: %w", err))
|
||||
}
|
||||
delete(r.rules, ruleKey+fwdSuffix)
|
||||
@@ -883,7 +906,7 @@ func (r *router) genRouteRuleSpec(params routeFilteringRuleParams, sources []net
|
||||
rule = append(rule, destExp...)
|
||||
|
||||
if params.Proto != firewall.ProtocolALL {
|
||||
rule = append(rule, "-p", strings.ToLower(string(params.Proto)))
|
||||
rule = append(rule, "-p", strings.ToLower(protoForFamily(params.Proto, r.v6)))
|
||||
rule = append(rule, applyPort("--sport", params.SPort)...)
|
||||
rule = append(rule, applyPort("--dport", params.DPort)...)
|
||||
}
|
||||
@@ -900,11 +923,12 @@ func (r *router) applyNetwork(flag string, network firewall.Network, prefixes []
|
||||
}
|
||||
|
||||
if network.IsSet() {
|
||||
if _, err := r.ipsetCounter.Increment(network.Set.HashedName(), prefixes); err != nil {
|
||||
name := r.ipsetName(network.Set.HashedName())
|
||||
if _, err := r.ipsetCounter.Increment(name, prefixes); err != nil {
|
||||
return nil, fmt.Errorf("create or get ipset: %w", err)
|
||||
}
|
||||
|
||||
return []string{"-m", "set", matchSet, network.Set.HashedName(), direction}, nil
|
||||
return []string{"-m", "set", matchSet, name, direction}, nil
|
||||
}
|
||||
if network.IsPrefix() {
|
||||
return []string{flag, network.Prefix.String()}, nil
|
||||
@@ -915,19 +939,15 @@ func (r *router) applyNetwork(flag string, network firewall.Network, prefixes []
|
||||
}
|
||||
|
||||
func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
name := r.ipsetName(set.HashedName())
|
||||
var merr *multierror.Error
|
||||
for _, prefix := range prefixes {
|
||||
// TODO: Implement IPv6 support
|
||||
if prefix.Addr().Is6() {
|
||||
log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
|
||||
continue
|
||||
}
|
||||
if err := r.addPrefixToIPSet(set.HashedName(), prefix); err != nil {
|
||||
if err := r.addPrefixToIPSet(name, prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add prefix to ipset: %w", err))
|
||||
}
|
||||
}
|
||||
if merr == nil {
|
||||
log.Debugf("updated set %s with prefixes %v", set.HashedName(), prefixes)
|
||||
log.Debugf("updated set %s with prefixes %v", name, prefixes)
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
@@ -943,7 +963,7 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol
|
||||
|
||||
dnatRule := []string{
|
||||
"-i", r.wgIface.Name(),
|
||||
"-p", strings.ToLower(string(protocol)),
|
||||
"-p", strings.ToLower(protoForFamily(protocol, r.v6)),
|
||||
"--dport", strconv.Itoa(int(sourcePort)),
|
||||
"-d", localAddr.String(),
|
||||
"-m", "addrtype", "--dst-type", "LOCAL",
|
||||
@@ -1076,10 +1096,22 @@ func applyPort(flag string, port *firewall.Port) []string {
|
||||
return []string{flag, strconv.Itoa(int(port.Values[0]))}
|
||||
}
|
||||
|
||||
// ipsetName returns the ipset name, suffixed with "-v6" for the v6 router
|
||||
// to avoid collisions since ipsets are global in the kernel.
|
||||
func (r *router) ipsetName(name string) string {
|
||||
if r.v6 {
|
||||
return name + "-v6"
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
func (r *router) createIPSet(name string) error {
|
||||
opts := ipset.CreateOptions{
|
||||
Replace: true,
|
||||
}
|
||||
if r.v6 {
|
||||
opts.Family = ipset.FamilyIPV6
|
||||
}
|
||||
|
||||
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
|
||||
return fmt.Errorf("create ipset %s: %w", name, err)
|
||||
|
||||
@@ -9,6 +9,7 @@ type Rule struct {
|
||||
mangleSpecs []string
|
||||
ip string
|
||||
chain string
|
||||
v6 bool
|
||||
}
|
||||
|
||||
// GetRuleID returns the rule id
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
@@ -37,6 +39,12 @@ type ShutdownState struct {
|
||||
|
||||
ACLEntries aclEntries `json:"acl_entries,omitempty"`
|
||||
ACLIPsetStore *ipsetStore `json:"acl_ipset_store,omitempty"`
|
||||
|
||||
// IPv6 counterparts
|
||||
RouteRules6 routeRules `json:"route_rules_v6,omitempty"`
|
||||
RouteIPsetCounter6 *ipsetCounter `json:"route_ipset_counter_v6,omitempty"`
|
||||
ACLEntries6 aclEntries `json:"acl_entries_v6,omitempty"`
|
||||
ACLIPsetStore6 *ipsetStore `json:"acl_ipset_store_v6,omitempty"`
|
||||
}
|
||||
|
||||
func (s *ShutdownState) Name() string {
|
||||
@@ -67,6 +75,28 @@ func (s *ShutdownState) Cleanup() error {
|
||||
ipt.aclMgr.ipsetStore = s.ACLIPsetStore
|
||||
}
|
||||
|
||||
// Clean up v6 state even if the current run has no IPv6.
|
||||
// The previous run may have left ip6tables rules behind.
|
||||
if !ipt.hasIPv6() {
|
||||
if err := ipt.createIPv6Components(s.InterfaceState, mtu); err != nil {
|
||||
log.Warnf("failed to create v6 components for cleanup: %v", err)
|
||||
}
|
||||
}
|
||||
if ipt.hasIPv6() {
|
||||
if s.RouteRules6 != nil {
|
||||
ipt.router6.rules = s.RouteRules6
|
||||
}
|
||||
if s.RouteIPsetCounter6 != nil {
|
||||
ipt.router6.ipsetCounter.LoadData(s.RouteIPsetCounter6)
|
||||
}
|
||||
if s.ACLEntries6 != nil {
|
||||
ipt.aclMgr6.entries = s.ACLEntries6
|
||||
}
|
||||
if s.ACLIPsetStore6 != nil {
|
||||
ipt.aclMgr6.ipsetStore = s.ACLIPsetStore6
|
||||
}
|
||||
}
|
||||
|
||||
if err := ipt.Close(nil); err != nil {
|
||||
return fmt.Errorf("reset iptables manager: %w", err)
|
||||
}
|
||||
|
||||
@@ -33,15 +33,12 @@ const (
|
||||
|
||||
const flushError = "flush: %w"
|
||||
|
||||
var (
|
||||
anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
|
||||
)
|
||||
|
||||
type AclManager struct {
|
||||
rConn *nftables.Conn
|
||||
sConn *nftables.Conn
|
||||
wgIface iFaceMapper
|
||||
routingFwChainName string
|
||||
af addrFamily
|
||||
|
||||
workTable *nftables.Table
|
||||
chainInputRules *nftables.Chain
|
||||
@@ -67,6 +64,7 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainNam
|
||||
wgIface: wgIface,
|
||||
workTable: table,
|
||||
routingFwChainName: routingFwChainName,
|
||||
af: familyForAddr(table.Family == nftables.TableFamilyIPv4),
|
||||
|
||||
ipsetStore: newIpsetStore(),
|
||||
rules: make(map[string]*Rule),
|
||||
@@ -145,7 +143,7 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
|
||||
}
|
||||
|
||||
if _, ok := ips[r.ip.String()]; ok {
|
||||
err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: r.ip.To4()}})
|
||||
err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: ipToBytes(r.ip, m.af)}})
|
||||
if err != nil {
|
||||
log.Errorf("delete elements for set %q: %v", r.nftSet.Name, err)
|
||||
}
|
||||
@@ -254,11 +252,11 @@ func (m *AclManager) addIOFiltering(
|
||||
expressions = append(expressions, &expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: uint32(9),
|
||||
Offset: m.af.protoOffset,
|
||||
Len: uint32(1),
|
||||
})
|
||||
|
||||
protoData, err := protoToInt(proto)
|
||||
protoData, err := m.af.protoNum(proto)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert protocol to number: %v", err)
|
||||
}
|
||||
@@ -270,19 +268,16 @@ func (m *AclManager) addIOFiltering(
|
||||
})
|
||||
}
|
||||
|
||||
rawIP := ip.To4()
|
||||
rawIP := ipToBytes(ip, m.af)
|
||||
// check if rawIP contains zeroed IPv4 0.0.0.0 value
|
||||
// in that case not add IP match expression into the rule definition
|
||||
if !bytes.HasPrefix(anyIP, rawIP) {
|
||||
// source address position
|
||||
addrOffset := uint32(12)
|
||||
|
||||
if slices.ContainsFunc(rawIP, func(v byte) bool { return v != 0 }) {
|
||||
expressions = append(expressions,
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: addrOffset,
|
||||
Len: 4,
|
||||
Offset: m.af.srcAddrOffset,
|
||||
Len: m.af.addrLen,
|
||||
},
|
||||
)
|
||||
// add individual IP for match if no ipset defined
|
||||
@@ -587,7 +582,7 @@ func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr
|
||||
|
||||
func (m *AclManager) addIpToSet(ipsetName string, ip net.IP) (*nftables.Set, error) {
|
||||
ipset, err := m.rConn.GetSetByName(m.workTable, ipsetName)
|
||||
rawIP := ip.To4()
|
||||
rawIP := ipToBytes(ip, m.af)
|
||||
if err != nil {
|
||||
if ipset, err = m.createSet(m.workTable, ipsetName); err != nil {
|
||||
return nil, fmt.Errorf("get set name: %v", err)
|
||||
@@ -619,7 +614,7 @@ func (m *AclManager) createSet(table *nftables.Table, name string) (*nftables.Se
|
||||
Name: name,
|
||||
Table: table,
|
||||
Dynamic: true,
|
||||
KeyType: nftables.TypeIPAddr,
|
||||
KeyType: m.af.setKeyType,
|
||||
}
|
||||
|
||||
if err := m.rConn.AddSet(ipset, nil); err != nil {
|
||||
@@ -707,15 +702,12 @@ func ifname(n string) []byte {
|
||||
return b
|
||||
}
|
||||
|
||||
func protoToInt(protocol firewall.Protocol) (uint8, error) {
|
||||
switch protocol {
|
||||
case firewall.ProtocolTCP:
|
||||
return unix.IPPROTO_TCP, nil
|
||||
case firewall.ProtocolUDP:
|
||||
return unix.IPPROTO_UDP, nil
|
||||
case firewall.ProtocolICMP:
|
||||
return unix.IPPROTO_ICMP, nil
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("unsupported protocol: %s", protocol)
|
||||
// ipToBytes converts net.IP to the correct byte length for the address family.
|
||||
func ipToBytes(ip net.IP, af addrFamily) []byte {
|
||||
if af.addrLen == 4 {
|
||||
return ip.To4()
|
||||
}
|
||||
return ip.To16()
|
||||
}
|
||||
|
||||
|
||||
81
client/firewall/nftables/addr_family_linux.go
Normal file
81
client/firewall/nftables/addr_family_linux.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
var (
|
||||
// afIPv4 defines IPv4 header layout and nftables types.
|
||||
afIPv4 = addrFamily{
|
||||
protoOffset: 9,
|
||||
srcAddrOffset: 12,
|
||||
dstAddrOffset: 16,
|
||||
addrLen: net.IPv4len,
|
||||
totalBits: 8 * net.IPv4len,
|
||||
setKeyType: nftables.TypeIPAddr,
|
||||
tableFamily: nftables.TableFamilyIPv4,
|
||||
icmpProto: unix.IPPROTO_ICMP,
|
||||
}
|
||||
// afIPv6 defines IPv6 header layout and nftables types.
|
||||
afIPv6 = addrFamily{
|
||||
protoOffset: 6,
|
||||
srcAddrOffset: 8,
|
||||
dstAddrOffset: 24,
|
||||
addrLen: net.IPv6len,
|
||||
totalBits: 8 * net.IPv6len,
|
||||
setKeyType: nftables.TypeIP6Addr,
|
||||
tableFamily: nftables.TableFamilyIPv6,
|
||||
icmpProto: unix.IPPROTO_ICMPV6,
|
||||
}
|
||||
)
|
||||
|
||||
// addrFamily holds protocol-specific constants for nftables expression building.
|
||||
type addrFamily struct {
|
||||
// protoOffset is the IP header offset for the protocol/next-header field (9 for v4, 6 for v6)
|
||||
protoOffset uint32
|
||||
// srcAddrOffset is the IP header offset for the source address (12 for v4, 8 for v6)
|
||||
srcAddrOffset uint32
|
||||
// dstAddrOffset is the IP header offset for the destination address (16 for v4, 24 for v6)
|
||||
dstAddrOffset uint32
|
||||
// addrLen is the byte length of addresses (4 for v4, 16 for v6)
|
||||
addrLen uint32
|
||||
// totalBits is the address size in bits (32 for v4, 128 for v6)
|
||||
totalBits int
|
||||
// setKeyType is the nftables set data type for addresses
|
||||
setKeyType nftables.SetDatatype
|
||||
// tableFamily is the nftables table family
|
||||
tableFamily nftables.TableFamily
|
||||
// icmpProto is the ICMP protocol number for this family (1 for v4, 58 for v6)
|
||||
icmpProto uint8
|
||||
}
|
||||
|
||||
// familyForAddr returns the address family for the given IP.
|
||||
func familyForAddr(is4 bool) addrFamily {
|
||||
if is4 {
|
||||
return afIPv4
|
||||
}
|
||||
return afIPv6
|
||||
}
|
||||
|
||||
// protoNum converts a firewall protocol to the IP protocol number,
|
||||
// using the correct ICMP variant for the address family.
|
||||
func (af addrFamily) protoNum(protocol firewall.Protocol) (uint8, error) {
|
||||
switch protocol {
|
||||
case firewall.ProtocolTCP:
|
||||
return unix.IPPROTO_TCP, nil
|
||||
case firewall.ProtocolUDP:
|
||||
return unix.IPPROTO_UDP, nil
|
||||
case firewall.ProtocolICMP:
|
||||
return af.icmpProto, nil
|
||||
case firewall.ProtocolALL:
|
||||
return 0, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unsupported protocol: %s", protocol)
|
||||
}
|
||||
}
|
||||
@@ -11,9 +11,11 @@ import (
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/google/nftables/expr"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
@@ -49,8 +51,13 @@ type Manager struct {
|
||||
rConn *nftables.Conn
|
||||
wgIface iFaceMapper
|
||||
|
||||
router *router
|
||||
aclManager *AclManager
|
||||
router *router
|
||||
aclManager *AclManager
|
||||
|
||||
// IPv6 counterparts, nil when no v6 overlay
|
||||
router6 *router
|
||||
aclManager6 *AclManager
|
||||
|
||||
notrackOutputChain *nftables.Chain
|
||||
notrackPreroutingChain *nftables.Chain
|
||||
}
|
||||
@@ -62,7 +69,8 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
|
||||
wgIface: wgIface,
|
||||
}
|
||||
|
||||
workTable := &nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4}
|
||||
tableName := getTableName()
|
||||
workTable := &nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4}
|
||||
|
||||
var err error
|
||||
m.router, err = newRouter(workTable, wgIface, mtu)
|
||||
@@ -75,11 +83,70 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
|
||||
return nil, fmt.Errorf("create acl manager: %w", err)
|
||||
}
|
||||
|
||||
if wgIface.Address().HasIPv6() {
|
||||
if err := m.createIPv6Components(tableName, wgIface, mtu); err != nil {
|
||||
return nil, fmt.Errorf("create IPv6 firewall: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *Manager) createIPv6Components(tableName string, wgIface iFaceMapper, mtu uint16) error {
|
||||
workTable6 := &nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv6}
|
||||
|
||||
var err error
|
||||
m.router6, err = newRouter(workTable6, wgIface, mtu)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create v6 router: %w", err)
|
||||
}
|
||||
|
||||
// Share the same IP forwarding state with the v4 router, since
|
||||
// EnableIPForwarding controls both v4 and v6 sysctls.
|
||||
m.router6.ipFwdState = m.router.ipFwdState
|
||||
|
||||
m.aclManager6, err = newAclManager(workTable6, wgIface, chainNameRoutingFw)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create v6 acl manager: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// hasIPv6 reports whether the manager has IPv6 components initialized.
|
||||
func (m *Manager) hasIPv6() bool {
|
||||
return m.router6 != nil
|
||||
}
|
||||
|
||||
func (m *Manager) initIPv6() error {
|
||||
workTable6, err := m.createWorkTableFamily(nftables.TableFamilyIPv6)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create v6 work table: %w", err)
|
||||
}
|
||||
|
||||
if err := m.router6.init(workTable6); err != nil {
|
||||
return fmt.Errorf("v6 router init: %w", err)
|
||||
}
|
||||
|
||||
if err := m.aclManager6.init(workTable6); err != nil {
|
||||
return fmt.Errorf("v6 acl manager init: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Init nftables firewall manager
|
||||
func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
if err := m.initFirewall(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.persistState(stateManager)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) initFirewall() error {
|
||||
workTable, err := m.createWorkTable()
|
||||
if err != nil {
|
||||
return fmt.Errorf("create work table: %w", err)
|
||||
@@ -90,20 +157,32 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
}
|
||||
|
||||
if err := m.aclManager.init(workTable); err != nil {
|
||||
// TODO: cleanup router
|
||||
m.rollbackInit()
|
||||
return fmt.Errorf("acl manager init: %w", err)
|
||||
}
|
||||
|
||||
if m.hasIPv6() {
|
||||
if err := m.initIPv6(); err != nil {
|
||||
// Peer has a v6 address: v6 firewall MUST work or we risk fail-open.
|
||||
m.rollbackInit()
|
||||
return fmt.Errorf("init IPv6 firewall (required because peer has IPv6 address): %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.initNoTrackChains(workTable); err != nil {
|
||||
log.Warnf("raw priority chains not available, notrack rules will be disabled: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// persistState saves the current interface state for potential recreation on restart.
|
||||
// Unlike iptables, which requires tracking individual rules, nftables maintains
|
||||
// a known state (our netbird table plus a few static rules). This allows for easy
|
||||
// cleanup using Close() without needing to store specific rules.
|
||||
func (m *Manager) persistState(stateManager *statemanager.Manager) {
|
||||
stateManager.RegisterState(&ShutdownState{})
|
||||
|
||||
// We only need to record minimal interface state for potential recreation.
|
||||
// Unlike iptables, which requires tracking individual rules, nftables maintains
|
||||
// a known state (our netbird table plus a few static rules). This allows for easy
|
||||
// cleanup using Close() without needing to store specific rules.
|
||||
if err := stateManager.UpdateState(&ShutdownState{
|
||||
InterfaceState: &InterfaceState{
|
||||
NameStr: m.wgIface.Name(),
|
||||
@@ -115,14 +194,29 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
log.Errorf("failed to update state: %v", err)
|
||||
}
|
||||
|
||||
// persist early
|
||||
go func() {
|
||||
if err := stateManager.PersistState(context.Background()); err != nil {
|
||||
log.Errorf("failed to persist state: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return nil
|
||||
// rollbackInit performs best-effort cleanup of already-initialized state when Init fails partway through.
|
||||
func (m *Manager) rollbackInit() {
|
||||
if err := m.router.Reset(); err != nil {
|
||||
log.Warnf("rollback router: %v", err)
|
||||
}
|
||||
if m.hasIPv6() {
|
||||
if err := m.router6.Reset(); err != nil {
|
||||
log.Warnf("rollback v6 router: %v", err)
|
||||
}
|
||||
}
|
||||
if err := m.cleanupNetbirdTables(); err != nil {
|
||||
log.Warnf("cleanup tables: %v", err)
|
||||
}
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
log.Warnf("flush: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// AddPeerFiltering rule to the firewall
|
||||
@@ -141,12 +235,14 @@ func (m *Manager) AddPeerFiltering(
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
rawIP := ip.To4()
|
||||
if rawIP == nil {
|
||||
return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
|
||||
if ip.To4() != nil {
|
||||
return m.aclManager.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||
}
|
||||
|
||||
return m.aclManager.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||
if !m.hasIPv6() {
|
||||
return nil, fmt.Errorf("IPv6 not initialized, cannot add rule for %s", ip)
|
||||
}
|
||||
return m.aclManager6.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||
}
|
||||
|
||||
func (m *Manager) AddRouteFiltering(
|
||||
@@ -160,8 +256,11 @@ func (m *Manager) AddRouteFiltering(
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if destination.IsPrefix() && !destination.Prefix.Addr().Is4() {
|
||||
return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String())
|
||||
if isIPv6RouteRule(sources, destination) {
|
||||
if !m.hasIPv6() {
|
||||
return nil, fmt.Errorf("IPv6 not initialized, cannot add route rule")
|
||||
}
|
||||
return m.router6.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||
}
|
||||
|
||||
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||
@@ -172,14 +271,38 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if m.hasIPv6() && isIPv6Rule(rule) {
|
||||
return m.aclManager6.DeletePeerRule(rule)
|
||||
}
|
||||
return m.aclManager.DeletePeerRule(rule)
|
||||
}
|
||||
|
||||
// DeleteRouteRule deletes a routing rule
|
||||
func isIPv6Rule(rule firewall.Rule) bool {
|
||||
r, ok := rule.(*Rule)
|
||||
return ok && r.nftRule != nil && r.nftRule.Table != nil && r.nftRule.Table.Family == nftables.TableFamilyIPv6
|
||||
}
|
||||
|
||||
// isIPv6RouteRule determines whether a route rule belongs to the v6 table.
|
||||
// For static routes, the destination prefix determines the family. For dynamic
|
||||
// routes (DomainSet), the sources determine the family since management
|
||||
// duplicates dynamic rules per family.
|
||||
func isIPv6RouteRule(sources []netip.Prefix, destination firewall.Network) bool {
|
||||
if destination.IsPrefix() {
|
||||
return destination.Prefix.Addr().Is6()
|
||||
}
|
||||
return len(sources) > 0 && sources[0].Addr().Is6()
|
||||
}
|
||||
|
||||
// DeleteRouteRule deletes a routing rule.
|
||||
// Route rules are keyed by content hash, so the rule exists in exactly one
|
||||
// router. We check v4 first; if the key isn't there, try v6.
|
||||
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if m.hasIPv6() && !m.router.hasRule(rule.ID()) {
|
||||
return m.router6.DeleteRouteRule(rule)
|
||||
}
|
||||
return m.router.DeleteRouteRule(rule)
|
||||
}
|
||||
|
||||
@@ -195,17 +318,63 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.AddNatRule(pair)
|
||||
if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() {
|
||||
if !m.hasIPv6() {
|
||||
return fmt.Errorf("IPv6 not initialized, cannot add NAT rule")
|
||||
}
|
||||
return m.router6.AddNatRule(pair)
|
||||
}
|
||||
|
||||
if err := m.router.AddNatRule(pair); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Dynamic routes (DomainSet) need NAT in both tables since resolved IPs
|
||||
// can be either v4 or v6.
|
||||
if m.hasIPv6() && pair.Destination.IsSet() {
|
||||
v6Pair := pair
|
||||
v6Pair.Source = firewall.Network{Prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
|
||||
if err := m.router6.AddNatRule(v6Pair); err != nil {
|
||||
return fmt.Errorf("add v6 NAT rule: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.RemoveNatRule(pair)
|
||||
if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() {
|
||||
if !m.hasIPv6() {
|
||||
return nil
|
||||
}
|
||||
return m.router6.RemoveNatRule(pair)
|
||||
}
|
||||
|
||||
if err := m.router.RemoveNatRule(pair); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if m.hasIPv6() && pair.Destination.IsSet() {
|
||||
v6Pair := pair
|
||||
v6Pair.Source = firewall.Network{Prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
|
||||
if err := m.router6.RemoveNatRule(v6Pair); err != nil {
|
||||
return fmt.Errorf("remove v6 NAT rule: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AllowNetbird allows netbird interface traffic
|
||||
// AllowNetbird allows netbird interface traffic.
|
||||
// TODO: In USP mode this only adds ACCEPT to the netbird table's own chains,
|
||||
// which doesn't override DROP rules in external tables (e.g. firewalld).
|
||||
// Should add passthrough rules to external chains (like the native mode router's
|
||||
// addExternalChainsRules does) for both the netbird table family and inet tables.
|
||||
// The netbird table itself is fine (routing chains already exist there), but
|
||||
// non-netbird tables with INPUT/FORWARD hooks can still DROP our WG traffic.
|
||||
func (m *Manager) AllowNetbird() error {
|
||||
if !m.wgIface.IsUserspaceBind() {
|
||||
return nil
|
||||
@@ -217,6 +386,11 @@ func (m *Manager) AllowNetbird() error {
|
||||
if err := m.aclManager.createDefaultAllowRules(); err != nil {
|
||||
return fmt.Errorf("create default allow rules: %w", err)
|
||||
}
|
||||
if m.hasIPv6() {
|
||||
if err := m.aclManager6.createDefaultAllowRules(); err != nil {
|
||||
return fmt.Errorf("create v6 default allow rules: %w", err)
|
||||
}
|
||||
}
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush allow input netbird rules: %w", err)
|
||||
}
|
||||
@@ -226,7 +400,13 @@ func (m *Manager) AllowNetbird() error {
|
||||
|
||||
// SetLegacyManagement sets the route manager to use legacy management
|
||||
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||
return firewall.SetLegacyManagement(m.router, isLegacy)
|
||||
if err := firewall.SetLegacyManagement(m.router, isLegacy); err != nil {
|
||||
return err
|
||||
}
|
||||
if m.hasIPv6() {
|
||||
return firewall.SetLegacyManagement(m.router6, isLegacy)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the firewall manager
|
||||
@@ -234,23 +414,31 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := m.router.Reset(); err != nil {
|
||||
return fmt.Errorf("reset router: %v", err)
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset router: %v", err))
|
||||
}
|
||||
|
||||
if m.hasIPv6() {
|
||||
if err := m.router6.Reset(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset v6 router: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.cleanupNetbirdTables(); err != nil {
|
||||
return fmt.Errorf("cleanup netbird tables: %v", err)
|
||||
merr = multierror.Append(merr, fmt.Errorf("cleanup netbird tables: %v", err))
|
||||
}
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return fmt.Errorf(flushError, err)
|
||||
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
|
||||
}
|
||||
|
||||
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
|
||||
return fmt.Errorf("delete state: %v", err)
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete state: %v", err))
|
||||
}
|
||||
|
||||
return nil
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (m *Manager) cleanupNetbirdTables() error {
|
||||
@@ -299,6 +487,12 @@ func (m *Manager) Flush() error {
|
||||
return err
|
||||
}
|
||||
|
||||
if m.hasIPv6() {
|
||||
if err := m.aclManager6.Flush(); err != nil {
|
||||
return fmt.Errorf("flush v6 acl: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.refreshNoTrackChains(); err != nil {
|
||||
log.Errorf("failed to refresh notrack chains: %v", err)
|
||||
}
|
||||
@@ -311,6 +505,9 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error)
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if m.hasIPv6() && rule.TranslatedAddress.Is6() {
|
||||
return m.router6.AddDNATRule(rule)
|
||||
}
|
||||
return m.router.AddDNATRule(rule)
|
||||
}
|
||||
|
||||
@@ -319,6 +516,9 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if m.hasIPv6() && !m.router.hasDNATRule(rule.ID()) {
|
||||
return m.router6.DeleteDNATRule(rule)
|
||||
}
|
||||
return m.router.DeleteDNATRule(rule)
|
||||
}
|
||||
|
||||
@@ -327,7 +527,26 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.UpdateSet(set, prefixes)
|
||||
var v4Prefixes, v6Prefixes []netip.Prefix
|
||||
for _, p := range prefixes {
|
||||
if p.Addr().Is6() {
|
||||
v6Prefixes = append(v6Prefixes, p)
|
||||
} else {
|
||||
v4Prefixes = append(v4Prefixes, p)
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.router.UpdateSet(set, v4Prefixes); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if m.hasIPv6() && len(v6Prefixes) > 0 {
|
||||
if err := m.router6.UpdateSet(set, v6Prefixes); err != nil {
|
||||
return fmt.Errorf("update v6 set: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
||||
@@ -335,6 +554,9 @@ func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protoco
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if m.hasIPv6() && localAddr.Is6() {
|
||||
return m.router6.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
@@ -343,6 +565,9 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if m.hasIPv6() && localAddr.Is6() {
|
||||
return m.router6.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
@@ -533,7 +758,11 @@ func (m *Manager) refreshNoTrackChains() error {
|
||||
}
|
||||
|
||||
func (m *Manager) createWorkTable() (*nftables.Table, error) {
|
||||
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||
return m.createWorkTableFamily(nftables.TableFamilyIPv4)
|
||||
}
|
||||
|
||||
func (m *Manager) createWorkTableFamily(family nftables.TableFamily) (*nftables.Table, error) {
|
||||
tables, err := m.rConn.ListTablesOfFamily(family)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list of tables: %w", err)
|
||||
}
|
||||
@@ -545,7 +774,7 @@ func (m *Manager) createWorkTable() (*nftables.Table, error) {
|
||||
}
|
||||
}
|
||||
|
||||
table := m.rConn.AddTable(&nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4})
|
||||
table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: family})
|
||||
err = m.rConn.Flush()
|
||||
return table, err
|
||||
}
|
||||
|
||||
@@ -385,10 +385,134 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
||||
err = manager.AddNatRule(pair)
|
||||
require.NoError(t, err, "failed to add NAT rule")
|
||||
|
||||
dnatRule, err := manager.AddDNATRule(fw.ForwardRule{
|
||||
Protocol: fw.ProtocolTCP,
|
||||
DestinationPort: fw.Port{Values: []uint16{8080}},
|
||||
TranslatedAddress: netip.MustParseAddr("100.96.0.2"),
|
||||
TranslatedPort: fw.Port{Values: []uint16{80}},
|
||||
})
|
||||
require.NoError(t, err, "failed to add DNAT rule")
|
||||
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.DeleteDNATRule(dnatRule), "failed to delete DNAT rule")
|
||||
})
|
||||
|
||||
stdout, stderr = runIptablesSave(t)
|
||||
verifyIptablesOutput(t, stdout, stderr)
|
||||
}
|
||||
|
||||
func TestNftablesManagerIPv6CompatibilityWithIp6tables(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this system")
|
||||
}
|
||||
|
||||
for _, bin := range []string{"ip6tables", "ip6tables-save", "iptables-save"} {
|
||||
if _, err := exec.LookPath(bin); err != nil {
|
||||
t.Skipf("%s not available on this system: %v", bin, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Seed ip6 tables in the nft backend. Docker may not create them.
|
||||
seedIp6tables(t)
|
||||
|
||||
ifaceMockV6 := &iFaceMock{
|
||||
NameFunc: func() string { return "wt-test" },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: netip.MustParseAddr("100.96.0.1"),
|
||||
Network: netip.MustParsePrefix("100.96.0.0/16"),
|
||||
IPv6: netip.MustParseAddr("fd00::1"),
|
||||
IPv6Net: netip.MustParsePrefix("fd00::/64"),
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMockV6, iface.DefaultMTU)
|
||||
require.NoError(t, err, "create manager")
|
||||
require.NoError(t, manager.Init(nil))
|
||||
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil), "close manager")
|
||||
|
||||
stdout, stderr := runIp6tablesSave(t)
|
||||
verifyIp6tablesOutput(t, stdout, stderr)
|
||||
})
|
||||
|
||||
ip := netip.MustParseAddr("fd00::2")
|
||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||
require.NoError(t, err, "add v6 peer filtering rule")
|
||||
|
||||
_, err = manager.AddRouteFiltering(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("fd00:1::/64")},
|
||||
fw.Network{Prefix: netip.MustParsePrefix("2001:db8::/48")},
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []uint16{443}},
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err, "add v6 route filtering rule")
|
||||
|
||||
err = manager.AddNatRule(fw.RouterPair{
|
||||
Source: fw.Network{Prefix: netip.MustParsePrefix("fd00::/64")},
|
||||
Destination: fw.Network{Prefix: netip.MustParsePrefix("2001:db8::/48")},
|
||||
Masquerade: true,
|
||||
})
|
||||
require.NoError(t, err, "add v6 NAT rule")
|
||||
|
||||
dnatRule, err := manager.AddDNATRule(fw.ForwardRule{
|
||||
Protocol: fw.ProtocolTCP,
|
||||
DestinationPort: fw.Port{Values: []uint16{8080}},
|
||||
TranslatedAddress: netip.MustParseAddr("fd00::2"),
|
||||
TranslatedPort: fw.Port{Values: []uint16{80}},
|
||||
})
|
||||
require.NoError(t, err, "add v6 DNAT rule")
|
||||
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.DeleteDNATRule(dnatRule), "delete v6 DNAT rule")
|
||||
})
|
||||
|
||||
stdout, stderr := runIptablesSave(t)
|
||||
verifyIptablesOutput(t, stdout, stderr)
|
||||
|
||||
stdout, stderr = runIp6tablesSave(t)
|
||||
verifyIp6tablesOutput(t, stdout, stderr)
|
||||
}
|
||||
|
||||
func seedIp6tables(t *testing.T) {
|
||||
t.Helper()
|
||||
for _, tc := range []struct{ table, chain string }{
|
||||
{"filter", "FORWARD"},
|
||||
{"nat", "POSTROUTING"},
|
||||
{"mangle", "FORWARD"},
|
||||
} {
|
||||
add := exec.Command("ip6tables", "-t", tc.table, "-A", tc.chain, "-j", "ACCEPT")
|
||||
require.NoError(t, add.Run(), "seed ip6tables -t %s", tc.table)
|
||||
del := exec.Command("ip6tables", "-t", tc.table, "-D", tc.chain, "-j", "ACCEPT")
|
||||
require.NoError(t, del.Run(), "unseed ip6tables -t %s", tc.table)
|
||||
}
|
||||
}
|
||||
|
||||
func runIp6tablesSave(t *testing.T) (string, string) {
|
||||
t.Helper()
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd := exec.Command("ip6tables-save")
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
require.NoError(t, cmd.Run(), "ip6tables-save failed")
|
||||
return stdout.String(), stderr.String()
|
||||
}
|
||||
|
||||
func verifyIp6tablesOutput(t *testing.T, stdout, stderr string) {
|
||||
t.Helper()
|
||||
require.NotContains(t, stdout, "Table `nat' is incompatible",
|
||||
"ip6tables-save: nat table incompatible. Full output: %s", stdout)
|
||||
require.NotContains(t, stdout, "Table `mangle' is incompatible",
|
||||
"ip6tables-save: mangle table incompatible. Full output: %s", stdout)
|
||||
require.NotContains(t, stdout, "Table `filter' is incompatible",
|
||||
"ip6tables-save: filter table incompatible. Full output: %s", stdout)
|
||||
}
|
||||
|
||||
func TestNftablesManagerCompatibilityWithIptablesFor6kPrefixes(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this system")
|
||||
|
||||
@@ -47,8 +47,10 @@ const (
|
||||
dnatSuffix = "_dnat"
|
||||
snatSuffix = "_snat"
|
||||
|
||||
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
|
||||
ipTCPHeaderMinSize = 40
|
||||
// ipv4TCPHeaderSize is the minimum IPv4 (20) + TCP (20) header size for MSS calculation.
|
||||
ipv4TCPHeaderSize = 40
|
||||
// ipv6TCPHeaderSize is the minimum IPv6 (40) + TCP (20) header size for MSS calculation.
|
||||
ipv6TCPHeaderSize = 60
|
||||
|
||||
// maxPrefixesSet 1638 prefixes start to fail, taking some margin
|
||||
maxPrefixesSet = 1500
|
||||
@@ -73,6 +75,7 @@ type router struct {
|
||||
rules map[string]*nftables.Rule
|
||||
ipsetCounter *refcounter.Counter[string, setInput, *nftables.Set]
|
||||
|
||||
af addrFamily
|
||||
wgIface iFaceMapper
|
||||
ipFwdState *ipfwdstate.IPForwardingState
|
||||
legacyManagement bool
|
||||
@@ -85,6 +88,7 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*rou
|
||||
workTable: workTable,
|
||||
chains: make(map[string]*nftables.Chain),
|
||||
rules: make(map[string]*nftables.Rule),
|
||||
af: familyForAddr(workTable.Family == nftables.TableFamilyIPv4),
|
||||
wgIface: wgIface,
|
||||
ipFwdState: ipfwdstate.NewIPForwardingState(),
|
||||
mtu: mtu,
|
||||
@@ -143,7 +147,7 @@ func (r *router) Reset() error {
|
||||
func (r *router) removeNatPreroutingRules() error {
|
||||
table := &nftables.Table{
|
||||
Name: tableNat,
|
||||
Family: nftables.TableFamilyIPv4,
|
||||
Family: r.af.tableFamily,
|
||||
}
|
||||
chain := &nftables.Chain{
|
||||
Name: chainNameNatPrerouting,
|
||||
@@ -176,7 +180,7 @@ func (r *router) removeNatPreroutingRules() error {
|
||||
}
|
||||
|
||||
func (r *router) loadFilterTable() (*nftables.Table, error) {
|
||||
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||
tables, err := r.conn.ListTablesOfFamily(r.af.tableFamily)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list tables: %w", err)
|
||||
}
|
||||
@@ -408,7 +412,7 @@ func (r *router) AddRouteFiltering(
|
||||
|
||||
// Handle protocol
|
||||
if proto != firewall.ProtocolALL {
|
||||
protoNum, err := protoToInt(proto)
|
||||
protoNum, err := r.af.protoNum(proto)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert protocol to number: %w", err)
|
||||
}
|
||||
@@ -468,7 +472,24 @@ func (r *router) getIpSet(set firewall.Set, prefixes []netip.Prefix, isSource bo
|
||||
return nil, fmt.Errorf("create or get ipset: %w", err)
|
||||
}
|
||||
|
||||
return getIpSetExprs(ref, isSource)
|
||||
return r.getIpSetExprs(ref, isSource)
|
||||
}
|
||||
|
||||
func (r *router) iptablesProto() iptables.Protocol {
|
||||
if r.af.tableFamily == nftables.TableFamilyIPv6 {
|
||||
return iptables.ProtocolIPv6
|
||||
}
|
||||
return iptables.ProtocolIPv4
|
||||
}
|
||||
|
||||
func (r *router) hasRule(id string) bool {
|
||||
_, ok := r.rules[id]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (r *router) hasDNATRule(id string) bool {
|
||||
_, ok := r.rules[id+dnatSuffix]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
||||
@@ -517,10 +538,10 @@ func (r *router) createIpSet(setName string, input setInput) (*nftables.Set, err
|
||||
Table: r.workTable,
|
||||
// required for prefixes
|
||||
Interval: true,
|
||||
KeyType: nftables.TypeIPAddr,
|
||||
KeyType: r.af.setKeyType,
|
||||
}
|
||||
|
||||
elements := convertPrefixesToSet(prefixes)
|
||||
elements := r.convertPrefixesToSet(prefixes)
|
||||
nElements := len(elements)
|
||||
|
||||
maxElements := maxPrefixesSet * 2
|
||||
@@ -553,23 +574,17 @@ func (r *router) createIpSet(setName string, input setInput) (*nftables.Set, err
|
||||
return nfset, nil
|
||||
}
|
||||
|
||||
func convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement {
|
||||
func (r *router) convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement {
|
||||
var elements []nftables.SetElement
|
||||
for _, prefix := range prefixes {
|
||||
// TODO: Implement IPv6 support
|
||||
if prefix.Addr().Is6() {
|
||||
log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
|
||||
continue
|
||||
}
|
||||
|
||||
// nftables needs half-open intervals [firstIP, lastIP) for prefixes
|
||||
// e.g. 10.0.0.0/24 becomes [10.0.0.0, 10.0.1.0), 10.1.1.1/32 becomes [10.1.1.1, 10.1.1.2) etc
|
||||
firstIP := prefix.Addr()
|
||||
lastIP := calculateLastIP(prefix).Next()
|
||||
|
||||
elements = append(elements,
|
||||
// the nft tool also adds a line like this, see https://github.com/google/nftables/issues/247
|
||||
// nftables.SetElement{Key: []byte{0, 0, 0, 0}, IntervalEnd: true},
|
||||
// the nft tool also adds a zero-address IntervalEnd element, see https://github.com/google/nftables/issues/247
|
||||
// nftables.SetElement{Key: make([]byte, r.af.addrLen), IntervalEnd: true},
|
||||
nftables.SetElement{Key: firstIP.AsSlice()},
|
||||
nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true},
|
||||
)
|
||||
@@ -579,10 +594,20 @@ func convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement {
|
||||
|
||||
// calculateLastIP determines the last IP in a given prefix.
|
||||
func calculateLastIP(prefix netip.Prefix) netip.Addr {
|
||||
hostMask := ^uint32(0) >> prefix.Masked().Bits()
|
||||
lastIP := uint32FromNetipAddr(prefix.Addr()) | hostMask
|
||||
masked := prefix.Masked()
|
||||
if masked.Addr().Is4() {
|
||||
hostMask := ^uint32(0) >> masked.Bits()
|
||||
lastIP := uint32FromNetipAddr(masked.Addr()) | hostMask
|
||||
return netip.AddrFrom4(uint32ToBytes(lastIP))
|
||||
}
|
||||
|
||||
return netip.AddrFrom4(uint32ToBytes(lastIP))
|
||||
// IPv6: set host bits to all 1s
|
||||
b := masked.Addr().As16()
|
||||
bits := masked.Bits()
|
||||
for i := bits; i < 128; i++ {
|
||||
b[i/8] |= 1 << (7 - i%8)
|
||||
}
|
||||
return netip.AddrFrom16(b)
|
||||
}
|
||||
|
||||
// Utility function to convert netip.Addr to uint32.
|
||||
@@ -834,9 +859,12 @@ func (r *router) addPostroutingRules() {
|
||||
}
|
||||
|
||||
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
|
||||
// TODO: Add IPv6 support
|
||||
func (r *router) addMSSClampingRules() error {
|
||||
mss := r.mtu - ipTCPHeaderMinSize
|
||||
overhead := uint16(ipv4TCPHeaderSize)
|
||||
if r.af.tableFamily == nftables.TableFamilyIPv6 {
|
||||
overhead = ipv6TCPHeaderSize
|
||||
}
|
||||
mss := r.mtu - overhead
|
||||
|
||||
exprsOut := []expr.Any{
|
||||
&expr.Meta{
|
||||
@@ -1043,17 +1071,22 @@ func (r *router) acceptFilterTableRules() error {
|
||||
log.Debugf("Used %s to add accept forward and input rules", fw)
|
||||
}()
|
||||
|
||||
// Try iptables first and fallback to nftables if iptables is not available
|
||||
ipt, err := iptables.New()
|
||||
// Try iptables first and fallback to nftables if iptables is not available.
|
||||
// Use the correct protocol (iptables vs ip6tables) for the address family.
|
||||
ipt, err := iptables.NewWithProtocol(r.iptablesProto())
|
||||
if err != nil {
|
||||
// iptables is not available but the filter table exists
|
||||
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
|
||||
|
||||
fw = "nftables"
|
||||
return r.acceptFilterRulesNftables(r.filterTable)
|
||||
}
|
||||
|
||||
return r.acceptFilterRulesIptables(ipt)
|
||||
if err := r.acceptFilterRulesIptables(ipt); err != nil {
|
||||
log.Warnf("iptables failed (table may be incompatible), falling back to nftables: %v", err)
|
||||
fw = "nftables"
|
||||
return r.acceptFilterRulesNftables(r.filterTable)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
||||
@@ -1222,13 +1255,17 @@ func (r *router) removeFilterTableRules() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
ipt, err := iptables.New()
|
||||
ipt, err := iptables.NewWithProtocol(r.iptablesProto())
|
||||
if err != nil {
|
||||
log.Debugf("iptables not available, using nftables to remove filter rules: %v", err)
|
||||
return r.removeAcceptRulesFromTable(r.filterTable)
|
||||
}
|
||||
|
||||
return r.removeAcceptFilterRulesIptables(ipt)
|
||||
if err := r.removeAcceptFilterRulesIptables(ipt); err != nil {
|
||||
log.Debugf("iptables removal failed (table may be incompatible), falling back to nftables: %v", err)
|
||||
return r.removeAcceptRulesFromTable(r.filterTable)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) removeAcceptRulesFromTable(table *nftables.Table) error {
|
||||
@@ -1295,7 +1332,7 @@ func (r *router) removeExternalChainsRules() error {
|
||||
func (r *router) findExternalChains() []*nftables.Chain {
|
||||
var chains []*nftables.Chain
|
||||
|
||||
families := []nftables.TableFamily{nftables.TableFamilyIPv4, nftables.TableFamilyINet}
|
||||
families := []nftables.TableFamily{r.af.tableFamily, nftables.TableFamilyINet}
|
||||
|
||||
for _, family := range families {
|
||||
allChains, err := r.conn.ListChainsOfTableFamily(family)
|
||||
@@ -1319,8 +1356,8 @@ func (r *router) isExternalChain(chain *nftables.Chain) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Skip all iptables-managed tables in the ip family
|
||||
if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) {
|
||||
// Skip iptables/ip6tables-managed tables (adding nft-native rules breaks iptables-save compat)
|
||||
if (chain.Table.Family == nftables.TableFamilyIPv4 || chain.Table.Family == nftables.TableFamilyIPv6) && isIptablesTable(chain.Table.Name) {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1461,7 +1498,7 @@ func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
protoNum, err := protoToInt(rule.Protocol)
|
||||
protoNum, err := r.af.protoNum(rule.Protocol)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert protocol to number: %w", err)
|
||||
}
|
||||
@@ -1524,7 +1561,7 @@ func (r *router) addDnatRedirect(rule firewall.ForwardRule, protoNum uint8, rule
|
||||
dnatExprs = append(dnatExprs,
|
||||
&expr.NAT{
|
||||
Type: expr.NATTypeDestNAT,
|
||||
Family: uint32(nftables.TableFamilyIPv4),
|
||||
Family: uint32(r.af.tableFamily),
|
||||
RegAddrMin: 1,
|
||||
RegProtoMin: regProtoMin,
|
||||
RegProtoMax: regProtoMax,
|
||||
@@ -1620,7 +1657,7 @@ func (r *router) addXTablesRedirect(dnatExprs []expr.Any, ruleKey string, rule f
|
||||
dnatRule := &nftables.Rule{
|
||||
Table: &nftables.Table{
|
||||
Name: tableNat,
|
||||
Family: nftables.TableFamilyIPv4,
|
||||
Family: r.af.tableFamily,
|
||||
},
|
||||
Chain: &nftables.Chain{
|
||||
Name: chainNameNatPrerouting,
|
||||
@@ -1655,8 +1692,8 @@ func (r *router) addDnatMasq(rule firewall.ForwardRule, protoNum uint8, ruleKey
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 16,
|
||||
Len: 4,
|
||||
Offset: r.af.dstAddrOffset,
|
||||
Len: r.af.addrLen,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
@@ -1734,7 +1771,7 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
return fmt.Errorf("get set %s: %w", set.HashedName(), err)
|
||||
}
|
||||
|
||||
elements := convertPrefixesToSet(prefixes)
|
||||
elements := r.convertPrefixesToSet(prefixes)
|
||||
if err := r.conn.SetAddElements(nfset, elements); err != nil {
|
||||
return fmt.Errorf("add elements to set %s: %w", set.HashedName(), err)
|
||||
}
|
||||
@@ -1756,7 +1793,7 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol
|
||||
return nil
|
||||
}
|
||||
|
||||
protoNum, err := protoToInt(protocol)
|
||||
protoNum, err := r.af.protoNum(protocol)
|
||||
if err != nil {
|
||||
return fmt.Errorf("convert protocol to number: %w", err)
|
||||
}
|
||||
@@ -1787,7 +1824,11 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol
|
||||
},
|
||||
}
|
||||
|
||||
exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...)
|
||||
bits := 32
|
||||
if localAddr.Is6() {
|
||||
bits = 128
|
||||
}
|
||||
exprs = append(exprs, r.applyPrefix(netip.PrefixFrom(localAddr, bits), false)...)
|
||||
|
||||
exprs = append(exprs,
|
||||
&expr.Immediate{
|
||||
@@ -1800,7 +1841,7 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol
|
||||
},
|
||||
&expr.NAT{
|
||||
Type: expr.NATTypeDestNAT,
|
||||
Family: uint32(nftables.TableFamilyIPv4),
|
||||
Family: uint32(r.af.tableFamily),
|
||||
RegAddrMin: 1,
|
||||
RegProtoMin: 2,
|
||||
RegProtoMax: 0,
|
||||
@@ -1887,7 +1928,7 @@ func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol,
|
||||
return err
|
||||
}
|
||||
|
||||
protoNum, err := protoToInt(protocol)
|
||||
protoNum, err := r.af.protoNum(protocol)
|
||||
if err != nil {
|
||||
return fmt.Errorf("convert protocol to number: %w", err)
|
||||
}
|
||||
@@ -1912,7 +1953,11 @@ func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol,
|
||||
},
|
||||
}
|
||||
|
||||
exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...)
|
||||
bits := 32
|
||||
if localAddr.Is6() {
|
||||
bits = 128
|
||||
}
|
||||
exprs = append(exprs, r.applyPrefix(netip.PrefixFrom(localAddr, bits), false)...)
|
||||
|
||||
exprs = append(exprs,
|
||||
&expr.Immediate{
|
||||
@@ -1925,7 +1970,7 @@ func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol,
|
||||
},
|
||||
&expr.NAT{
|
||||
Type: expr.NATTypeDestNAT,
|
||||
Family: uint32(nftables.TableFamilyIPv4),
|
||||
Family: uint32(r.af.tableFamily),
|
||||
RegAddrMin: 1,
|
||||
RegProtoMin: 2,
|
||||
},
|
||||
@@ -1993,45 +2038,44 @@ func (r *router) applyNetwork(
|
||||
}
|
||||
|
||||
if network.IsPrefix() {
|
||||
return applyPrefix(network.Prefix, isSource), nil
|
||||
return r.applyPrefix(network.Prefix, isSource), nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// applyPrefix generates nftables expressions for a CIDR prefix
|
||||
func applyPrefix(prefix netip.Prefix, isSource bool) []expr.Any {
|
||||
// dst offset
|
||||
offset := uint32(16)
|
||||
func (r *router) applyPrefix(prefix netip.Prefix, isSource bool) []expr.Any {
|
||||
// dst offset by default
|
||||
offset := r.af.dstAddrOffset
|
||||
if isSource {
|
||||
// src offset
|
||||
offset = 12
|
||||
offset = r.af.srcAddrOffset
|
||||
}
|
||||
|
||||
ones := prefix.Bits()
|
||||
// 0.0.0.0/0 doesn't need extra expressions
|
||||
// unspecified address (/0) doesn't need extra expressions
|
||||
if ones == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
mask := net.CIDRMask(ones, 32)
|
||||
mask := net.CIDRMask(ones, r.af.totalBits)
|
||||
xor := make([]byte, r.af.addrLen)
|
||||
|
||||
return []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: offset,
|
||||
Len: 4,
|
||||
Len: r.af.addrLen,
|
||||
},
|
||||
// netmask
|
||||
&expr.Bitwise{
|
||||
DestRegister: 1,
|
||||
SourceRegister: 1,
|
||||
Len: 4,
|
||||
Len: r.af.addrLen,
|
||||
Mask: mask,
|
||||
Xor: []byte{0, 0, 0, 0},
|
||||
Xor: xor,
|
||||
},
|
||||
// net address
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
@@ -2114,13 +2158,12 @@ func getCtNewExprs() []expr.Any {
|
||||
}
|
||||
}
|
||||
|
||||
func getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any, error) {
|
||||
|
||||
// dst offset
|
||||
offset := uint32(16)
|
||||
func (r *router) getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any, error) {
|
||||
// dst offset by default
|
||||
offset := r.af.dstAddrOffset
|
||||
if isSource {
|
||||
// src offset
|
||||
offset = 12
|
||||
offset = r.af.srcAddrOffset
|
||||
}
|
||||
|
||||
return []expr.Any{
|
||||
@@ -2128,7 +2171,7 @@ func getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: offset,
|
||||
Len: 4,
|
||||
Len: r.af.addrLen,
|
||||
},
|
||||
&expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
|
||||
@@ -90,8 +90,9 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
||||
}
|
||||
|
||||
// Build CIDR matching expressions
|
||||
sourceExp := applyPrefix(testCase.InputPair.Source.Prefix, true)
|
||||
destExp := applyPrefix(testCase.InputPair.Destination.Prefix, false)
|
||||
testRouter := &router{af: afIPv4}
|
||||
sourceExp := testRouter.applyPrefix(testCase.InputPair.Source.Prefix, true)
|
||||
destExp := testRouter.applyPrefix(testCase.InputPair.Destination.Prefix, false)
|
||||
|
||||
// Combine all expressions in the correct order
|
||||
// nolint:gocritic
|
||||
@@ -508,6 +509,136 @@ func TestNftablesCreateIpSet(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNftablesCreateIpSet_IPv6(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this system")
|
||||
}
|
||||
|
||||
workTable, err := createWorkTableIPv6()
|
||||
require.NoError(t, err, "Failed to create v6 work table")
|
||||
defer deleteWorkTableIPv6()
|
||||
|
||||
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "Failed to create router")
|
||||
require.NoError(t, r.init(workTable))
|
||||
defer func() {
|
||||
require.NoError(t, r.Reset(), "Failed to reset router")
|
||||
}()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sources []netip.Prefix
|
||||
expected []netip.Prefix
|
||||
}{
|
||||
{
|
||||
name: "Single IPv6",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("2001:db8::1/128")},
|
||||
},
|
||||
{
|
||||
name: "Multiple IPv6 Subnets",
|
||||
sources: []netip.Prefix{
|
||||
netip.MustParsePrefix("fd00::/64"),
|
||||
netip.MustParsePrefix("2001:db8::/48"),
|
||||
netip.MustParsePrefix("fe80::/10"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Overlapping IPv6",
|
||||
sources: []netip.Prefix{
|
||||
netip.MustParsePrefix("fd00::/48"),
|
||||
netip.MustParsePrefix("fd00::/64"),
|
||||
netip.MustParsePrefix("fd00::1/128"),
|
||||
},
|
||||
expected: []netip.Prefix{
|
||||
netip.MustParsePrefix("fd00::/48"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Mixed prefix lengths",
|
||||
sources: []netip.Prefix{
|
||||
netip.MustParsePrefix("2001:db8:1::/48"),
|
||||
netip.MustParsePrefix("2001:db8:2::1/128"),
|
||||
netip.MustParsePrefix("fd00:abcd::/32"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
setName := firewall.NewPrefixSet(tt.sources).HashedName()
|
||||
set, err := r.createIpSet(setName, setInput{prefixes: tt.sources})
|
||||
require.NoError(t, err, "Failed to create IPv6 set")
|
||||
require.NotNil(t, set)
|
||||
|
||||
assert.Equal(t, setName, set.Name)
|
||||
assert.True(t, set.Interval)
|
||||
assert.Equal(t, nftables.TypeIP6Addr, set.KeyType)
|
||||
|
||||
fetchedSet, err := r.conn.GetSetByName(r.workTable, setName)
|
||||
require.NoError(t, err, "Failed to fetch created set")
|
||||
|
||||
elements, err := r.conn.GetSetElements(fetchedSet)
|
||||
require.NoError(t, err, "Failed to get set elements")
|
||||
|
||||
uniquePrefixes := make(map[string]bool)
|
||||
for _, elem := range elements {
|
||||
if !elem.IntervalEnd && len(elem.Key) == 16 {
|
||||
ip := netip.AddrFrom16([16]byte(elem.Key))
|
||||
uniquePrefixes[ip.String()] = true
|
||||
}
|
||||
}
|
||||
|
||||
expectedCount := len(tt.expected)
|
||||
if expectedCount == 0 {
|
||||
expectedCount = len(tt.sources)
|
||||
}
|
||||
assert.Equal(t, expectedCount, len(uniquePrefixes), "unique prefix count mismatch")
|
||||
|
||||
r.conn.DelSet(set)
|
||||
require.NoError(t, r.conn.Flush())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func createWorkTableIPv6() (*nftables.Table, error) {
|
||||
sConn, err := nftables.New(nftables.AsLasting())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tables, err := sConn.ListTablesOfFamily(nftables.TableFamilyIPv6)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, t := range tables {
|
||||
if t.Name == tableNameNetbird {
|
||||
sConn.DelTable(t)
|
||||
}
|
||||
}
|
||||
|
||||
table := sConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv6})
|
||||
err = sConn.Flush()
|
||||
return table, err
|
||||
}
|
||||
|
||||
func deleteWorkTableIPv6() {
|
||||
sConn, err := nftables.New(nftables.AsLasting())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
tables, err := sConn.ListTablesOfFamily(nftables.TableFamilyIPv6)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for _, t := range tables {
|
||||
if t.Name == tableNameNetbird {
|
||||
sConn.DelTable(t)
|
||||
_ = sConn.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func verifyRule(t *testing.T, rule *nftables.Rule, sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, expectSet bool) {
|
||||
t.Helper()
|
||||
|
||||
@@ -627,7 +758,7 @@ func containsPort(exprs []expr.Any, port *firewall.Port, isSource bool) bool {
|
||||
|
||||
func containsProtocol(exprs []expr.Any, proto firewall.Protocol) bool {
|
||||
var metaFound, cmpFound bool
|
||||
expectedProto, _ := protoToInt(proto)
|
||||
expectedProto, _ := afIPv4.protoNum(proto)
|
||||
for _, e := range exprs {
|
||||
switch ex := e.(type) {
|
||||
case *expr.Meta:
|
||||
@@ -854,3 +985,55 @@ func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
|
||||
}
|
||||
assert.Equal(t, 1, found, "NAT rule should exist in kernel")
|
||||
}
|
||||
|
||||
func TestCalculateLastIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
prefix string
|
||||
want string
|
||||
}{
|
||||
{"10.0.0.0/24", "10.0.0.255"},
|
||||
{"10.0.0.0/32", "10.0.0.0"},
|
||||
{"0.0.0.0/0", "255.255.255.255"},
|
||||
{"192.168.1.0/28", "192.168.1.15"},
|
||||
{"fd00::/64", "fd00::ffff:ffff:ffff:ffff"},
|
||||
{"fd00::/128", "fd00::"},
|
||||
{"2001:db8::/48", "2001:db8:0:ffff:ffff:ffff:ffff:ffff"},
|
||||
{"::/0", "ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.prefix, func(t *testing.T) {
|
||||
prefix := netip.MustParsePrefix(tt.prefix)
|
||||
got := calculateLastIP(prefix)
|
||||
assert.Equal(t, tt.want, got.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertPrefixesToSet_IPv6(t *testing.T) {
|
||||
r := &router{af: afIPv6}
|
||||
prefixes := []netip.Prefix{
|
||||
netip.MustParsePrefix("fd00::/64"),
|
||||
netip.MustParsePrefix("2001:db8::1/128"),
|
||||
}
|
||||
|
||||
elements := r.convertPrefixesToSet(prefixes)
|
||||
|
||||
// Each prefix produces 2 elements (start + end)
|
||||
require.Len(t, elements, 4)
|
||||
|
||||
// fd00::/64 start
|
||||
assert.Equal(t, netip.MustParseAddr("fd00::").As16(), [16]byte(elements[0].Key))
|
||||
assert.False(t, elements[0].IntervalEnd)
|
||||
|
||||
// fd00::/64 end (fd00:0:0:1::, one past the last)
|
||||
assert.Equal(t, netip.MustParseAddr("fd00:0:0:1::").As16(), [16]byte(elements[1].Key))
|
||||
assert.True(t, elements[1].IntervalEnd)
|
||||
|
||||
// 2001:db8::1/128 start
|
||||
assert.Equal(t, netip.MustParseAddr("2001:db8::1").As16(), [16]byte(elements[2].Key))
|
||||
assert.False(t, elements[2].IntervalEnd)
|
||||
|
||||
// 2001:db8::1/128 end (2001:db8::2)
|
||||
assert.Equal(t, netip.MustParseAddr("2001:db8::2").As16(), [16]byte(elements[3].Key))
|
||||
assert.True(t, elements[3].IntervalEnd)
|
||||
}
|
||||
|
||||
@@ -5,8 +5,10 @@ import (
|
||||
"os/exec"
|
||||
"syscall"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
@@ -29,15 +31,20 @@ func (m *Manager) Close(*statemanager.Manager) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !isFirewallRuleActive(firewallRuleName) {
|
||||
return nil
|
||||
var merr *multierror.Error
|
||||
if isFirewallRuleActive(firewallRuleName) {
|
||||
if err := manageFirewallRule(firewallRuleName, deleteRule); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove windows firewall rule: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
if err := manageFirewallRule(firewallRuleName, deleteRule); err != nil {
|
||||
return fmt.Errorf("couldn't remove windows firewall: %w", err)
|
||||
if isFirewallRuleActive(firewallRuleName + "-v6") {
|
||||
if err := manageFirewallRule(firewallRuleName+"-v6", deleteRule); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove windows v6 firewall rule: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// AllowNetbird allows netbird interface traffic
|
||||
@@ -46,17 +53,33 @@ func (m *Manager) AllowNetbird() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if isFirewallRuleActive(firewallRuleName) {
|
||||
return nil
|
||||
if !isFirewallRuleActive(firewallRuleName) {
|
||||
if err := manageFirewallRule(firewallRuleName,
|
||||
addRule,
|
||||
"dir=in",
|
||||
"enable=yes",
|
||||
"action=allow",
|
||||
"profile=any",
|
||||
"localip="+m.wgIface.Address().IP.String(),
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return manageFirewallRule(firewallRuleName,
|
||||
addRule,
|
||||
"dir=in",
|
||||
"enable=yes",
|
||||
"action=allow",
|
||||
"profile=any",
|
||||
"localip="+m.wgIface.Address().IP.String(),
|
||||
)
|
||||
|
||||
if v6 := m.wgIface.Address().IPv6; v6.IsValid() && !isFirewallRuleActive(firewallRuleName+"-v6") {
|
||||
if err := manageFirewallRule(firewallRuleName+"-v6",
|
||||
addRule,
|
||||
"dir=in",
|
||||
"enable=yes",
|
||||
"action=allow",
|
||||
"profile=any",
|
||||
"localip="+v6.String(),
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func manageFirewallRule(ruleName string, action action, extraArgs ...string) error {
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
@@ -64,5 +65,7 @@ type ConnKey struct {
|
||||
}
|
||||
|
||||
func (c ConnKey) String() string {
|
||||
return fmt.Sprintf("%s:%d → %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
|
||||
return net.JoinHostPort(c.SrcIP.Unmap().String(), strconv.Itoa(int(c.SrcPort))) +
|
||||
" → " +
|
||||
net.JoinHostPort(c.DstIP.Unmap().String(), strconv.Itoa(int(c.DstPort)))
|
||||
}
|
||||
|
||||
@@ -21,9 +21,10 @@ const (
|
||||
// ICMPCleanupInterval is how often we check for stale ICMP connections
|
||||
ICMPCleanupInterval = 15 * time.Second
|
||||
|
||||
// MaxICMPPayloadLength is the maximum length of ICMP payload we consider for original packet info,
|
||||
// which includes the IP header (20 bytes) and transport header (8 bytes)
|
||||
MaxICMPPayloadLength = 28
|
||||
// MaxICMPPayloadLength is the maximum length of ICMP payload we consider for original packet info.
|
||||
// IPv4: 20-byte header + 8-byte transport = 28 bytes.
|
||||
// IPv6: 40-byte header + 8-byte transport = 48 bytes.
|
||||
MaxICMPPayloadLength = 48
|
||||
)
|
||||
|
||||
// ICMPConnKey uniquely identifies an ICMP connection
|
||||
@@ -74,32 +75,64 @@ func (info ICMPInfo) String() string {
|
||||
return info.TypeCode.String()
|
||||
}
|
||||
|
||||
// isErrorMessage returns true if this ICMP type carries original packet info
|
||||
// isErrorMessage returns true if this ICMP type carries original packet info.
|
||||
// Covers both ICMPv4 and ICMPv6 error types. Without a family field we match
|
||||
// both sets; type 3 overlaps (v4 DestUnreachable / v6 TimeExceeded) so it's
|
||||
// kept as a literal.
|
||||
func (info ICMPInfo) isErrorMessage() bool {
|
||||
typ := info.TypeCode.Type()
|
||||
return typ == 3 || // Destination Unreachable
|
||||
typ == 5 || // Redirect
|
||||
typ == 11 || // Time Exceeded
|
||||
typ == 12 // Parameter Problem
|
||||
// ICMPv4 error types
|
||||
if typ == layers.ICMPv4TypeDestinationUnreachable ||
|
||||
typ == layers.ICMPv4TypeRedirect ||
|
||||
typ == layers.ICMPv4TypeTimeExceeded ||
|
||||
typ == layers.ICMPv4TypeParameterProblem {
|
||||
return true
|
||||
}
|
||||
// ICMPv6 error types (type 3 already matched above as v4 DestUnreachable)
|
||||
if typ == layers.ICMPv6TypeDestinationUnreachable ||
|
||||
typ == layers.ICMPv6TypePacketTooBig ||
|
||||
typ == layers.ICMPv6TypeParameterProblem {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// parseOriginalPacket extracts info about the original packet from ICMP payload
|
||||
func (info ICMPInfo) parseOriginalPacket() string {
|
||||
if info.PayloadLen < MaxICMPPayloadLength {
|
||||
if info.PayloadLen == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// TODO: handle IPv6
|
||||
if version := (info.PayloadData[0] >> 4) & 0xF; version != 4 {
|
||||
version := (info.PayloadData[0] >> 4) & 0xF
|
||||
|
||||
var protocol uint8
|
||||
var srcIP, dstIP net.IP
|
||||
var transportData []byte
|
||||
|
||||
switch version {
|
||||
case 4:
|
||||
// 20-byte IPv4 header + 8-byte transport minimum
|
||||
if info.PayloadLen < 28 {
|
||||
return ""
|
||||
}
|
||||
protocol = info.PayloadData[9]
|
||||
srcIP = net.IP(info.PayloadData[12:16])
|
||||
dstIP = net.IP(info.PayloadData[16:20])
|
||||
transportData = info.PayloadData[20:]
|
||||
case 6:
|
||||
// 40-byte IPv6 header + 8-byte transport minimum
|
||||
if info.PayloadLen < 48 {
|
||||
return ""
|
||||
}
|
||||
// Next Header field in IPv6 header
|
||||
protocol = info.PayloadData[6]
|
||||
srcIP = net.IP(info.PayloadData[8:24])
|
||||
dstIP = net.IP(info.PayloadData[24:40])
|
||||
transportData = info.PayloadData[40:]
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
|
||||
protocol := info.PayloadData[9]
|
||||
srcIP := net.IP(info.PayloadData[12:16])
|
||||
dstIP := net.IP(info.PayloadData[16:20])
|
||||
|
||||
transportData := info.PayloadData[20:]
|
||||
|
||||
switch nftypes.Protocol(protocol) {
|
||||
case nftypes.TCP:
|
||||
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
|
||||
@@ -247,9 +280,10 @@ func (t *ICMPTracker) track(
|
||||
t.sendEvent(nftypes.TypeStart, conn, ruleId)
|
||||
}
|
||||
|
||||
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
|
||||
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request.
|
||||
// Accepts both ICMPv4 (type 0) and ICMPv6 (type 129) echo replies.
|
||||
func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, icmpType uint8, size int) bool {
|
||||
if icmpType != uint8(layers.ICMPv4TypeEchoReply) {
|
||||
if icmpType != uint8(layers.ICMPv4TypeEchoReply) && icmpType != uint8(layers.ICMPv6TypeEchoReply) {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -301,6 +335,13 @@ func (t *ICMPTracker) cleanup() {
|
||||
}
|
||||
}
|
||||
|
||||
func icmpProtocolForAddr(ip netip.Addr) nftypes.Protocol {
|
||||
if ip.Is6() {
|
||||
return nftypes.ICMPv6
|
||||
}
|
||||
return nftypes.ICMP
|
||||
}
|
||||
|
||||
// Close stops the cleanup routine and releases resources
|
||||
func (t *ICMPTracker) Close() {
|
||||
t.tickerCancel()
|
||||
@@ -316,7 +357,7 @@ func (t *ICMPTracker) sendEvent(typ nftypes.Type, conn *ICMPConnTrack, ruleID []
|
||||
Type: typ,
|
||||
RuleID: ruleID,
|
||||
Direction: conn.Direction,
|
||||
Protocol: nftypes.ICMP, // TODO: adjust for IPv6/icmpv6
|
||||
Protocol: icmpProtocolForAddr(conn.SourceIP),
|
||||
SourceIP: conn.SourceIP,
|
||||
DestIP: conn.DestIP,
|
||||
ICMPType: conn.ICMPType,
|
||||
@@ -334,7 +375,7 @@ func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, srcIP netip.Ad
|
||||
Type: nftypes.TypeStart,
|
||||
RuleID: ruleID,
|
||||
Direction: direction,
|
||||
Protocol: nftypes.ICMP,
|
||||
Protocol: icmpProtocolForAddr(srcIP),
|
||||
SourceIP: srcIP,
|
||||
DestIP: dstIP,
|
||||
ICMPType: typ,
|
||||
|
||||
@@ -35,8 +35,10 @@ import (
|
||||
const (
|
||||
layerTypeAll = 255
|
||||
|
||||
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
|
||||
ipTCPHeaderMinSize = 40
|
||||
// ipv4TCPHeaderMinSize represents minimum IPv4 (20) + TCP (20) header size for MSS calculation
|
||||
ipv4TCPHeaderMinSize = 40
|
||||
// ipv6TCPHeaderMinSize represents minimum IPv6 (40) + TCP (20) header size for MSS calculation
|
||||
ipv6TCPHeaderMinSize = 60
|
||||
)
|
||||
|
||||
// serviceKey represents a protocol/port combination for netstack service registry
|
||||
@@ -137,9 +139,10 @@ type Manager struct {
|
||||
netstackServices map[serviceKey]struct{}
|
||||
netstackServiceMutex sync.RWMutex
|
||||
|
||||
mtu uint16
|
||||
mssClampValue uint16
|
||||
mssClampEnabled bool
|
||||
mtu uint16
|
||||
mssClampValueIPv4 uint16
|
||||
mssClampValueIPv6 uint16
|
||||
mssClampEnabled bool
|
||||
|
||||
// Only one hook per protocol is supported. Outbound direction only.
|
||||
udpHookOut atomic.Pointer[packetHook]
|
||||
@@ -163,11 +166,28 @@ type decoder struct {
|
||||
icmp4 layers.ICMPv4
|
||||
icmp6 layers.ICMPv6
|
||||
decoded []gopacket.LayerType
|
||||
parser *gopacket.DecodingLayerParser
|
||||
parser4 *gopacket.DecodingLayerParser
|
||||
parser6 *gopacket.DecodingLayerParser
|
||||
|
||||
dnatOrigPort uint16
|
||||
}
|
||||
|
||||
// decodePacket decodes packet data using the appropriate parser based on IP version.
|
||||
func (d *decoder) decodePacket(data []byte) error {
|
||||
if len(data) == 0 {
|
||||
return errors.New("empty packet")
|
||||
}
|
||||
version := data[0] >> 4
|
||||
switch version {
|
||||
case 4:
|
||||
return d.parser4.DecodeLayers(data, &d.decoded)
|
||||
case 6:
|
||||
return d.parser6.DecodeLayers(data, &d.decoded)
|
||||
default:
|
||||
return fmt.Errorf("unknown IP version %d", version)
|
||||
}
|
||||
}
|
||||
|
||||
// Create userspace firewall manager constructor
|
||||
func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
|
||||
return create(iface, nil, disableServerRoutes, flowLogger, mtu)
|
||||
@@ -225,11 +245,17 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
||||
d := &decoder{
|
||||
decoded: []gopacket.LayerType{},
|
||||
}
|
||||
d.parser = gopacket.NewDecodingLayerParser(
|
||||
d.parser4 = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv4,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser.IgnoreUnsupported = true
|
||||
d.parser4.IgnoreUnsupported = true
|
||||
|
||||
d.parser6 = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv6,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser6.IgnoreUnsupported = true
|
||||
return d
|
||||
},
|
||||
},
|
||||
@@ -255,7 +281,8 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
||||
|
||||
if !disableMSSClamping {
|
||||
m.mssClampEnabled = true
|
||||
m.mssClampValue = mtu - ipTCPHeaderMinSize
|
||||
m.mssClampValueIPv4 = mtu - ipv4TCPHeaderMinSize
|
||||
m.mssClampValueIPv6 = mtu - ipv6TCPHeaderMinSize
|
||||
}
|
||||
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
|
||||
return nil, fmt.Errorf("update local IPs: %w", err)
|
||||
@@ -282,9 +309,14 @@ func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) (firewall.Rule, e
|
||||
wgPrefix := iface.Address().Network
|
||||
log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
|
||||
|
||||
sources := []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
|
||||
if v6 := iface.Address().IPv6Net; v6.IsValid() {
|
||||
sources = append(sources, netip.PrefixFrom(netip.IPv6Unspecified(), 0))
|
||||
}
|
||||
|
||||
rule, err := m.addRouteFiltering(
|
||||
nil,
|
||||
[]netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)},
|
||||
sources,
|
||||
firewall.Network{Prefix: wgPrefix},
|
||||
firewall.ProtocolALL,
|
||||
nil,
|
||||
@@ -292,7 +324,22 @@ func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) (firewall.Rule, e
|
||||
firewall.ActionDrop,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("block wg nte : %w", err)
|
||||
return nil, fmt.Errorf("block wg v4 net: %w", err)
|
||||
}
|
||||
|
||||
if v6Net := iface.Address().IPv6Net; v6Net.IsValid() {
|
||||
log.Debugf("blocking invalid routed traffic for %s", v6Net)
|
||||
if _, err := m.addRouteFiltering(
|
||||
nil,
|
||||
sources,
|
||||
firewall.Network{Prefix: v6Net},
|
||||
firewall.ProtocolALL,
|
||||
nil,
|
||||
nil,
|
||||
firewall.ActionDrop,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("block wg v6 net: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Block networks that we're a client of
|
||||
@@ -509,7 +556,7 @@ func (m *Manager) addRouteFiltering(
|
||||
mgmtId: id,
|
||||
sources: sources,
|
||||
dstSet: destination.Set,
|
||||
protoLayer: protoToLayer(proto, layers.LayerTypeIPv4),
|
||||
protoLayer: protoToLayer(proto, ipLayerFromPrefix(destination.Prefix)),
|
||||
srcPort: sPort,
|
||||
dstPort: dPort,
|
||||
action: action,
|
||||
@@ -663,11 +710,7 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
}
|
||||
|
||||
destinations := matches[0].destinations
|
||||
for _, prefix := range prefixes {
|
||||
if prefix.Addr().Is4() {
|
||||
destinations = append(destinations, prefix)
|
||||
}
|
||||
}
|
||||
destinations = append(destinations, prefixes...)
|
||||
|
||||
slices.SortFunc(destinations, func(a, b netip.Prefix) int {
|
||||
cmp := a.Addr().Compare(b.Addr())
|
||||
@@ -706,7 +749,7 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
|
||||
d := m.decoders.Get().(*decoder)
|
||||
defer m.decoders.Put(d)
|
||||
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
if err := d.decodePacket(packetData); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -790,12 +833,28 @@ func (m *Manager) clampTCPMSS(packetData []byte, d *decoder) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
var mssClampValue uint16
|
||||
var ipHeaderSize int
|
||||
switch d.decoded[0] {
|
||||
case layers.LayerTypeIPv4:
|
||||
mssClampValue = m.mssClampValueIPv4
|
||||
ipHeaderSize = int(d.ip4.IHL) * 4
|
||||
if ipHeaderSize < 20 {
|
||||
return false
|
||||
}
|
||||
case layers.LayerTypeIPv6:
|
||||
mssClampValue = m.mssClampValueIPv6
|
||||
ipHeaderSize = 40
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
mssOptionIndex := -1
|
||||
var currentMSS uint16
|
||||
for i, opt := range d.tcp.Options {
|
||||
if opt.OptionType == layers.TCPOptionKindMSS && len(opt.OptionData) == 2 {
|
||||
currentMSS = binary.BigEndian.Uint16(opt.OptionData)
|
||||
if currentMSS > m.mssClampValue {
|
||||
if currentMSS > mssClampValue {
|
||||
mssOptionIndex = i
|
||||
break
|
||||
}
|
||||
@@ -806,20 +865,15 @@ func (m *Manager) clampTCPMSS(packetData []byte, d *decoder) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
ipHeaderSize := int(d.ip4.IHL) * 4
|
||||
if ipHeaderSize < 20 {
|
||||
if !m.updateMSSOption(packetData, d, mssOptionIndex, mssClampValue, ipHeaderSize) {
|
||||
return false
|
||||
}
|
||||
|
||||
if !m.updateMSSOption(packetData, d, mssOptionIndex, ipHeaderSize) {
|
||||
return false
|
||||
}
|
||||
|
||||
m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, m.mssClampValue)
|
||||
m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, mssClampValue)
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) updateMSSOption(packetData []byte, d *decoder, mssOptionIndex, ipHeaderSize int) bool {
|
||||
func (m *Manager) updateMSSOption(packetData []byte, d *decoder, mssOptionIndex int, mssClampValue uint16, ipHeaderSize int) bool {
|
||||
tcpHeaderStart := ipHeaderSize
|
||||
tcpOptionsStart := tcpHeaderStart + 20
|
||||
|
||||
@@ -834,7 +888,7 @@ func (m *Manager) updateMSSOption(packetData []byte, d *decoder, mssOptionIndex,
|
||||
}
|
||||
|
||||
mssValueOffset := optOffset + 2
|
||||
binary.BigEndian.PutUint16(packetData[mssValueOffset:mssValueOffset+2], m.mssClampValue)
|
||||
binary.BigEndian.PutUint16(packetData[mssValueOffset:mssValueOffset+2], mssClampValue)
|
||||
|
||||
m.recalculateTCPChecksum(packetData, d, tcpHeaderStart)
|
||||
return true
|
||||
@@ -844,18 +898,32 @@ func (m *Manager) recalculateTCPChecksum(packetData []byte, d *decoder, tcpHeade
|
||||
tcpLayer := packetData[tcpHeaderStart:]
|
||||
tcpLength := len(packetData) - tcpHeaderStart
|
||||
|
||||
// Zero out existing checksum
|
||||
tcpLayer[16] = 0
|
||||
tcpLayer[17] = 0
|
||||
|
||||
// Build pseudo-header checksum based on IP version
|
||||
var pseudoSum uint32
|
||||
pseudoSum += uint32(d.ip4.SrcIP[0])<<8 | uint32(d.ip4.SrcIP[1])
|
||||
pseudoSum += uint32(d.ip4.SrcIP[2])<<8 | uint32(d.ip4.SrcIP[3])
|
||||
pseudoSum += uint32(d.ip4.DstIP[0])<<8 | uint32(d.ip4.DstIP[1])
|
||||
pseudoSum += uint32(d.ip4.DstIP[2])<<8 | uint32(d.ip4.DstIP[3])
|
||||
pseudoSum += uint32(d.ip4.Protocol)
|
||||
pseudoSum += uint32(tcpLength)
|
||||
switch d.decoded[0] {
|
||||
case layers.LayerTypeIPv4:
|
||||
pseudoSum += uint32(d.ip4.SrcIP[0])<<8 | uint32(d.ip4.SrcIP[1])
|
||||
pseudoSum += uint32(d.ip4.SrcIP[2])<<8 | uint32(d.ip4.SrcIP[3])
|
||||
pseudoSum += uint32(d.ip4.DstIP[0])<<8 | uint32(d.ip4.DstIP[1])
|
||||
pseudoSum += uint32(d.ip4.DstIP[2])<<8 | uint32(d.ip4.DstIP[3])
|
||||
pseudoSum += uint32(d.ip4.Protocol)
|
||||
pseudoSum += uint32(tcpLength)
|
||||
case layers.LayerTypeIPv6:
|
||||
for i := 0; i < 16; i += 2 {
|
||||
pseudoSum += uint32(d.ip6.SrcIP[i])<<8 | uint32(d.ip6.SrcIP[i+1])
|
||||
}
|
||||
for i := 0; i < 16; i += 2 {
|
||||
pseudoSum += uint32(d.ip6.DstIP[i])<<8 | uint32(d.ip6.DstIP[i+1])
|
||||
}
|
||||
pseudoSum += uint32(tcpLength)
|
||||
pseudoSum += uint32(layers.IPProtocolTCP)
|
||||
}
|
||||
|
||||
var sum = pseudoSum
|
||||
sum := pseudoSum
|
||||
for i := 0; i < tcpLength-1; i += 2 {
|
||||
sum += uint32(tcpLayer[i])<<8 | uint32(tcpLayer[i+1])
|
||||
}
|
||||
@@ -893,6 +961,9 @@ func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, packetData
|
||||
}
|
||||
case layers.LayerTypeICMPv4:
|
||||
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size)
|
||||
case layers.LayerTypeICMPv6:
|
||||
id, tc := icmpv6EchoFields(d)
|
||||
m.icmpTracker.TrackOutbound(srcIP, dstIP, id, tc, d.icmp6.Payload, size)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -906,6 +977,9 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
|
||||
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size, d.dnatOrigPort)
|
||||
case layers.LayerTypeICMPv4:
|
||||
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size)
|
||||
case layers.LayerTypeICMPv6:
|
||||
id, tc := icmpv6EchoFields(d)
|
||||
m.icmpTracker.TrackInbound(srcIP, dstIP, id, tc, ruleID, d.icmp6.Payload, size)
|
||||
}
|
||||
|
||||
d.dnatOrigPort = 0
|
||||
@@ -948,15 +1022,19 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
|
||||
|
||||
// TODO: pass fragments of routed packets to forwarder
|
||||
if fragment {
|
||||
m.logger.Trace4("packet is a fragment: src=%v dst=%v id=%v flags=%v",
|
||||
srcIP, dstIP, d.ip4.Id, d.ip4.Flags)
|
||||
if d.decoded[0] == layers.LayerTypeIPv4 {
|
||||
m.logger.Trace4("packet is a fragment: src=%v dst=%v id=%v flags=%v",
|
||||
srcIP, dstIP, d.ip4.Id, d.ip4.Flags)
|
||||
} else {
|
||||
m.logger.Trace2("packet is an IPv6 fragment: src=%v dst=%v", srcIP, dstIP)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// TODO: optimize port DNAT by caching matched rules in conntrack
|
||||
if translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP); translated {
|
||||
// Re-decode after port DNAT translation to update port information
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
if err := d.decodePacket(packetData); err != nil {
|
||||
m.logger.Error1("failed to re-decode packet after port DNAT: %v", err)
|
||||
return true
|
||||
}
|
||||
@@ -965,7 +1043,7 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
|
||||
|
||||
if translated := m.translateInboundReverse(packetData, d); translated {
|
||||
// Re-decode after translation to get original addresses
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
if err := d.decodePacket(packetData); err != nil {
|
||||
m.logger.Error1("failed to re-decode packet after reverse DNAT: %v", err)
|
||||
return true
|
||||
}
|
||||
@@ -1097,6 +1175,48 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
|
||||
return true
|
||||
}
|
||||
|
||||
// icmpv6EchoFields extracts the echo identifier from an ICMPv6 packet and maps
|
||||
// the ICMPv6 type code to an ICMPv4TypeCode so the ICMP conntrack can handle
|
||||
// both families uniformly. The echo ID is in the first two payload bytes.
|
||||
func icmpv6EchoFields(d *decoder) (id uint16, tc layers.ICMPv4TypeCode) {
|
||||
if len(d.icmp6.Payload) >= 2 {
|
||||
id = uint16(d.icmp6.Payload[0])<<8 | uint16(d.icmp6.Payload[1])
|
||||
}
|
||||
// Map ICMPv6 echo types to ICMPv4 equivalents for unified tracking.
|
||||
switch d.icmp6.TypeCode.Type() {
|
||||
case layers.ICMPv6TypeEchoRequest:
|
||||
tc = layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0)
|
||||
case layers.ICMPv6TypeEchoReply:
|
||||
tc = layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoReply, 0)
|
||||
default:
|
||||
tc = layers.CreateICMPv4TypeCode(d.icmp6.TypeCode.Type(), d.icmp6.TypeCode.Code())
|
||||
}
|
||||
return id, tc
|
||||
}
|
||||
|
||||
// protoLayerMatches checks if a packet's protocol layer matches a rule's expected
|
||||
// protocol layer. ICMPv4 and ICMPv6 are treated as equivalent when matching
|
||||
// ICMP rules since management sends a single ICMP rule for both families.
|
||||
func protoLayerMatches(ruleLayer, packetLayer gopacket.LayerType) bool {
|
||||
if ruleLayer == packetLayer {
|
||||
return true
|
||||
}
|
||||
if ruleLayer == layers.LayerTypeICMPv4 && packetLayer == layers.LayerTypeICMPv6 {
|
||||
return true
|
||||
}
|
||||
if ruleLayer == layers.LayerTypeICMPv6 && packetLayer == layers.LayerTypeICMPv4 {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func ipLayerFromPrefix(p netip.Prefix) gopacket.LayerType {
|
||||
if p.Addr().Is6() {
|
||||
return layers.LayerTypeIPv6
|
||||
}
|
||||
return layers.LayerTypeIPv4
|
||||
}
|
||||
|
||||
func protoToLayer(proto firewall.Protocol, ipLayer gopacket.LayerType) gopacket.LayerType {
|
||||
switch proto {
|
||||
case firewall.ProtocolTCP:
|
||||
@@ -1120,8 +1240,10 @@ func getProtocolFromPacket(d *decoder) nftypes.Protocol {
|
||||
return nftypes.TCP
|
||||
case layers.LayerTypeUDP:
|
||||
return nftypes.UDP
|
||||
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
||||
case layers.LayerTypeICMPv4:
|
||||
return nftypes.ICMP
|
||||
case layers.LayerTypeICMPv6:
|
||||
return nftypes.ICMPv6
|
||||
default:
|
||||
return nftypes.ProtocolUnknown
|
||||
}
|
||||
@@ -1142,7 +1264,7 @@ func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) {
|
||||
// It returns true, false if the packet is valid and not a fragment.
|
||||
// It returns true, true if the packet is a fragment and valid.
|
||||
func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) {
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
if err := d.decodePacket(packetData); err != nil {
|
||||
m.logger.Trace1("couldn't decode packet, err: %s", err)
|
||||
return false, false
|
||||
}
|
||||
@@ -1155,10 +1277,18 @@ func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) {
|
||||
}
|
||||
|
||||
// Fragments are also valid
|
||||
if l == 1 && d.decoded[0] == layers.LayerTypeIPv4 {
|
||||
ip4 := d.ip4
|
||||
if ip4.Flags&layers.IPv4MoreFragments != 0 || ip4.FragOffset != 0 {
|
||||
return true, true
|
||||
if l == 1 {
|
||||
switch d.decoded[0] {
|
||||
case layers.LayerTypeIPv4:
|
||||
if d.ip4.Flags&layers.IPv4MoreFragments != 0 || d.ip4.FragOffset != 0 {
|
||||
return true, true
|
||||
}
|
||||
case layers.LayerTypeIPv6:
|
||||
// IPv6 uses Fragment extension header (NextHeader=44). If gopacket
|
||||
// only decoded the IPv6 layer, the transport is in a fragment.
|
||||
if d.ip6.NextHeader == layers.IPProtocolIPv6Fragment {
|
||||
return true, true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1196,21 +1326,34 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr,
|
||||
size,
|
||||
)
|
||||
|
||||
// TODO: ICMPv6
|
||||
case layers.LayerTypeICMPv6:
|
||||
id, _ := icmpv6EchoFields(d)
|
||||
return m.icmpTracker.IsValidInbound(
|
||||
srcIP,
|
||||
dstIP,
|
||||
id,
|
||||
d.icmp6.TypeCode.Type(),
|
||||
size,
|
||||
)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isSpecialICMP returns true if the packet is a special ICMP packet that should be allowed
|
||||
// isSpecialICMP returns true if the packet is a special ICMP error packet that should be allowed.
|
||||
func (m *Manager) isSpecialICMP(d *decoder) bool {
|
||||
if d.decoded[1] != layers.LayerTypeICMPv4 {
|
||||
return false
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeICMPv4:
|
||||
icmpType := d.icmp4.TypeCode.Type()
|
||||
return icmpType == layers.ICMPv4TypeDestinationUnreachable ||
|
||||
icmpType == layers.ICMPv4TypeTimeExceeded
|
||||
case layers.LayerTypeICMPv6:
|
||||
icmpType := d.icmp6.TypeCode.Type()
|
||||
return icmpType == layers.ICMPv6TypeDestinationUnreachable ||
|
||||
icmpType == layers.ICMPv6TypePacketTooBig ||
|
||||
icmpType == layers.ICMPv6TypeTimeExceeded
|
||||
}
|
||||
|
||||
icmpType := d.icmp4.TypeCode.Type()
|
||||
return icmpType == layers.ICMPv4TypeDestinationUnreachable ||
|
||||
icmpType == layers.ICMPv4TypeTimeExceeded
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) peerACLsBlock(srcIP netip.Addr, d *decoder, packetData []byte) ([]byte, bool) {
|
||||
@@ -1267,7 +1410,7 @@ func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d
|
||||
return rule.mgmtId, rule.drop, true
|
||||
}
|
||||
|
||||
if payloadLayer != rule.protoLayer {
|
||||
if !protoLayerMatches(rule.protoLayer, payloadLayer) {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -1302,8 +1445,7 @@ func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, protoLayer gopacket.Lay
|
||||
}
|
||||
|
||||
func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) bool {
|
||||
// TODO: handle ipv6 vs ipv4 icmp rules
|
||||
if rule.protoLayer != layerTypeAll && rule.protoLayer != protoLayer {
|
||||
if rule.protoLayer != layerTypeAll && !protoLayerMatches(rule.protoLayer, protoLayer) {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1473,7 +1615,8 @@ func (m *Manager) shouldForward(d *decoder, dstIP netip.Addr) bool {
|
||||
}
|
||||
|
||||
// traffic to our other local interfaces (not NetBird IP) - always forward
|
||||
if dstIP != m.wgIface.Address().IP {
|
||||
addr := m.wgIface.Address()
|
||||
if dstIP != addr.IP && (!addr.IPv6.IsValid() || dstIP != addr.IPv6) {
|
||||
return true
|
||||
}
|
||||
|
||||
|
||||
@@ -1023,7 +1023,8 @@ func BenchmarkMSSClamping(b *testing.B) {
|
||||
}()
|
||||
|
||||
manager.mssClampEnabled = true
|
||||
manager.mssClampValue = 1240
|
||||
manager.mssClampValueIPv4 = 1240
|
||||
manager.mssClampValueIPv6 = 1220
|
||||
|
||||
srcIP := net.ParseIP("100.64.0.2")
|
||||
dstIP := net.ParseIP("8.8.8.8")
|
||||
@@ -1088,7 +1089,8 @@ func BenchmarkMSSClampingOverhead(b *testing.B) {
|
||||
|
||||
manager.mssClampEnabled = sc.enabled
|
||||
if sc.enabled {
|
||||
manager.mssClampValue = 1240
|
||||
manager.mssClampValueIPv4 = 1240
|
||||
manager.mssClampValueIPv6 = 1220
|
||||
}
|
||||
|
||||
srcIP := net.ParseIP("100.64.0.2")
|
||||
@@ -1141,7 +1143,8 @@ func BenchmarkMSSClampingMemory(b *testing.B) {
|
||||
}()
|
||||
|
||||
manager.mssClampEnabled = true
|
||||
manager.mssClampValue = 1240
|
||||
manager.mssClampValueIPv4 = 1240
|
||||
manager.mssClampValueIPv6 = 1220
|
||||
|
||||
srcIP := net.ParseIP("100.64.0.2")
|
||||
dstIP := net.ParseIP("8.8.8.8")
|
||||
|
||||
@@ -539,53 +539,236 @@ func TestPeerACLFiltering(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeerACLFilteringIPv6(t *testing.T) {
|
||||
localIP := netip.MustParseAddr("100.10.0.100")
|
||||
localIPv6 := netip.MustParseAddr("fd00::100")
|
||||
wgNet := netip.MustParsePrefix("100.10.0.0/16")
|
||||
wgNetV6 := netip.MustParsePrefix("fd00::/64")
|
||||
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: localIP,
|
||||
Network: wgNet,
|
||||
IPv6: localIPv6,
|
||||
IPv6Net: wgNetV6,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
|
||||
|
||||
err = manager.UpdateLocalIPs()
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
srcIP string
|
||||
dstIP string
|
||||
proto fw.Protocol
|
||||
srcPort uint16
|
||||
dstPort uint16
|
||||
ruleIP string
|
||||
ruleProto fw.Protocol
|
||||
ruleDstPort *fw.Port
|
||||
ruleAction fw.Action
|
||||
shouldBeBlocked bool
|
||||
}{
|
||||
{
|
||||
name: "IPv6: allow TCP from peer",
|
||||
srcIP: "fd00::1",
|
||||
dstIP: "fd00::100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 443,
|
||||
ruleIP: "fd00::1",
|
||||
ruleProto: fw.ProtocolTCP,
|
||||
ruleDstPort: &fw.Port{Values: []uint16{443}},
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6: allow UDP from peer",
|
||||
srcIP: "fd00::1",
|
||||
dstIP: "fd00::100",
|
||||
proto: fw.ProtocolUDP,
|
||||
srcPort: 12345,
|
||||
dstPort: 53,
|
||||
ruleIP: "fd00::1",
|
||||
ruleProto: fw.ProtocolUDP,
|
||||
ruleDstPort: &fw.Port{Values: []uint16{53}},
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6: allow ICMPv6 from peer",
|
||||
srcIP: "fd00::1",
|
||||
dstIP: "fd00::100",
|
||||
proto: fw.ProtocolICMP,
|
||||
ruleIP: "fd00::1",
|
||||
ruleProto: fw.ProtocolICMP,
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6: block TCP without rule",
|
||||
srcIP: "fd00::2",
|
||||
dstIP: "fd00::100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 443,
|
||||
ruleIP: "fd00::1",
|
||||
ruleProto: fw.ProtocolTCP,
|
||||
ruleDstPort: &fw.Port{Values: []uint16{443}},
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "IPv6: drop rule",
|
||||
srcIP: "fd00::1",
|
||||
dstIP: "fd00::100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 22,
|
||||
ruleIP: "fd00::1",
|
||||
ruleProto: fw.ProtocolTCP,
|
||||
ruleDstPort: &fw.Port{Values: []uint16{22}},
|
||||
ruleAction: fw.ActionDrop,
|
||||
shouldBeBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "IPv6: allow all protocols",
|
||||
srcIP: "fd00::1",
|
||||
dstIP: "fd00::100",
|
||||
proto: fw.ProtocolUDP,
|
||||
srcPort: 12345,
|
||||
dstPort: 9999,
|
||||
ruleIP: "fd00::1",
|
||||
ruleProto: fw.ProtocolALL,
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6: v4 wildcard ICMP rule matches ICMPv6 via protoLayerMatches",
|
||||
srcIP: "fd00::1",
|
||||
dstIP: "fd00::100",
|
||||
proto: fw.ProtocolICMP,
|
||||
ruleIP: "0.0.0.0",
|
||||
ruleProto: fw.ProtocolICMP,
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: false,
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("IPv6 implicit DROP (no rules)", func(t *testing.T) {
|
||||
packet := createTestPacket(t, "fd00::1", "fd00::100", fw.ProtocolTCP, 12345, 443)
|
||||
isDropped := manager.FilterInbound(packet, 0)
|
||||
require.True(t, isDropped, "IPv6 packet should be dropped when no rules exist")
|
||||
})
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if tc.ruleAction == fw.ActionDrop {
|
||||
rules, err := manager.AddPeerFiltering(nil, net.ParseIP(tc.ruleIP), fw.ProtocolALL, nil, nil, fw.ActionAccept, "")
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
for _, rule := range rules {
|
||||
require.NoError(t, manager.DeletePeerRule(rule))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
rules, err := manager.AddPeerFiltering(nil, net.ParseIP(tc.ruleIP), tc.ruleProto, nil, tc.ruleDstPort, tc.ruleAction, "")
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, rules)
|
||||
t.Cleanup(func() {
|
||||
for _, rule := range rules {
|
||||
require.NoError(t, manager.DeletePeerRule(rule))
|
||||
}
|
||||
})
|
||||
|
||||
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
||||
isDropped := manager.FilterInbound(packet, 0)
|
||||
require.Equal(t, tc.shouldBeBlocked, isDropped, "packet filter result mismatch")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func createTestPacket(t *testing.T, srcIP, dstIP string, proto fw.Protocol, srcPort, dstPort uint16) []byte {
|
||||
t.Helper()
|
||||
|
||||
src := net.ParseIP(srcIP)
|
||||
dst := net.ParseIP(dstIP)
|
||||
|
||||
buf := gopacket.NewSerializeBuffer()
|
||||
opts := gopacket.SerializeOptions{
|
||||
ComputeChecksums: true,
|
||||
FixLengths: true,
|
||||
}
|
||||
|
||||
ipLayer := &layers.IPv4{
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
SrcIP: net.ParseIP(srcIP),
|
||||
DstIP: net.ParseIP(dstIP),
|
||||
}
|
||||
// Detect address family
|
||||
isV6 := src.To4() == nil
|
||||
|
||||
var err error
|
||||
switch proto {
|
||||
case fw.ProtocolTCP:
|
||||
ipLayer.Protocol = layers.IPProtocolTCP
|
||||
tcp := &layers.TCP{
|
||||
SrcPort: layers.TCPPort(srcPort),
|
||||
DstPort: layers.TCPPort(dstPort),
|
||||
}
|
||||
err = tcp.SetNetworkLayerForChecksum(ipLayer)
|
||||
require.NoError(t, err)
|
||||
err = gopacket.SerializeLayers(buf, opts, ipLayer, tcp)
|
||||
|
||||
case fw.ProtocolUDP:
|
||||
ipLayer.Protocol = layers.IPProtocolUDP
|
||||
udp := &layers.UDP{
|
||||
SrcPort: layers.UDPPort(srcPort),
|
||||
DstPort: layers.UDPPort(dstPort),
|
||||
if isV6 {
|
||||
ip6 := &layers.IPv6{
|
||||
Version: 6,
|
||||
HopLimit: 64,
|
||||
SrcIP: src,
|
||||
DstIP: dst,
|
||||
}
|
||||
err = udp.SetNetworkLayerForChecksum(ipLayer)
|
||||
require.NoError(t, err)
|
||||
err = gopacket.SerializeLayers(buf, opts, ipLayer, udp)
|
||||
|
||||
case fw.ProtocolICMP:
|
||||
ipLayer.Protocol = layers.IPProtocolICMPv4
|
||||
icmp := &layers.ICMPv4{
|
||||
TypeCode: layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0),
|
||||
switch proto {
|
||||
case fw.ProtocolTCP:
|
||||
ip6.NextHeader = layers.IPProtocolTCP
|
||||
tcp := &layers.TCP{SrcPort: layers.TCPPort(srcPort), DstPort: layers.TCPPort(dstPort)}
|
||||
_ = tcp.SetNetworkLayerForChecksum(ip6)
|
||||
err = gopacket.SerializeLayers(buf, opts, ip6, tcp)
|
||||
case fw.ProtocolUDP:
|
||||
ip6.NextHeader = layers.IPProtocolUDP
|
||||
udp := &layers.UDP{SrcPort: layers.UDPPort(srcPort), DstPort: layers.UDPPort(dstPort)}
|
||||
_ = udp.SetNetworkLayerForChecksum(ip6)
|
||||
err = gopacket.SerializeLayers(buf, opts, ip6, udp)
|
||||
case fw.ProtocolICMP:
|
||||
ip6.NextHeader = layers.IPProtocolICMPv6
|
||||
icmp := &layers.ICMPv6{
|
||||
TypeCode: layers.CreateICMPv6TypeCode(layers.ICMPv6TypeEchoRequest, 0),
|
||||
}
|
||||
_ = icmp.SetNetworkLayerForChecksum(ip6)
|
||||
err = gopacket.SerializeLayers(buf, opts, ip6, icmp)
|
||||
default:
|
||||
err = gopacket.SerializeLayers(buf, opts, ip6)
|
||||
}
|
||||
} else {
|
||||
ip4 := &layers.IPv4{
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
SrcIP: src,
|
||||
DstIP: dst,
|
||||
}
|
||||
err = gopacket.SerializeLayers(buf, opts, ipLayer, icmp)
|
||||
|
||||
default:
|
||||
err = gopacket.SerializeLayers(buf, opts, ipLayer)
|
||||
switch proto {
|
||||
case fw.ProtocolTCP:
|
||||
ip4.Protocol = layers.IPProtocolTCP
|
||||
tcp := &layers.TCP{SrcPort: layers.TCPPort(srcPort), DstPort: layers.TCPPort(dstPort)}
|
||||
_ = tcp.SetNetworkLayerForChecksum(ip4)
|
||||
err = gopacket.SerializeLayers(buf, opts, ip4, tcp)
|
||||
case fw.ProtocolUDP:
|
||||
ip4.Protocol = layers.IPProtocolUDP
|
||||
udp := &layers.UDP{SrcPort: layers.UDPPort(srcPort), DstPort: layers.UDPPort(dstPort)}
|
||||
_ = udp.SetNetworkLayerForChecksum(ip4)
|
||||
err = gopacket.SerializeLayers(buf, opts, ip4, udp)
|
||||
case fw.ProtocolICMP:
|
||||
ip4.Protocol = layers.IPProtocolICMPv4
|
||||
icmp := &layers.ICMPv4{TypeCode: layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0)}
|
||||
err = gopacket.SerializeLayers(buf, opts, ip4, icmp)
|
||||
default:
|
||||
err = gopacket.SerializeLayers(buf, opts, ip4)
|
||||
}
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
@@ -1498,3 +1681,103 @@ func TestRouteACLSet(t *testing.T) {
|
||||
_, isAllowed = manager.routeACLsPass(srcIP, dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
|
||||
require.True(t, isAllowed, "After set update, traffic to the added network should be allowed")
|
||||
}
|
||||
|
||||
// TestRouteACLFilteringIPv6 tests IPv6 route ACL matching directly via routeACLsPass.
|
||||
// Note: full FilterInbound for routed IPv6 traffic drops at the forwarder stage (IPv4-only)
|
||||
// but the ACL decision itself is correct.
|
||||
func TestRouteACLFilteringIPv6(t *testing.T) {
|
||||
manager := setupRoutedManager(t, "10.10.0.100/16")
|
||||
|
||||
v6Dst := netip.MustParsePrefix("fd00:dead:beef::/48")
|
||||
_, err := manager.AddRouteFiltering(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("fd00::/16")},
|
||||
fw.Network{Prefix: v6Dst},
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []uint16{80}},
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = manager.AddRouteFiltering(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("fd00::/16")},
|
||||
fw.Network{Prefix: netip.MustParsePrefix("fd00:dead:beef:1::/64")},
|
||||
fw.ProtocolALL,
|
||||
nil,
|
||||
nil,
|
||||
fw.ActionDrop,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
srcIP netip.Addr
|
||||
dstIP netip.Addr
|
||||
proto gopacket.LayerType
|
||||
srcPort uint16
|
||||
dstPort uint16
|
||||
allowed bool
|
||||
}{
|
||||
{
|
||||
name: "IPv6 TCP to allowed dest",
|
||||
srcIP: netip.MustParseAddr("fd00::1"),
|
||||
dstIP: netip.MustParseAddr("fd00:dead:beef::80"),
|
||||
proto: layers.LayerTypeTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 80,
|
||||
allowed: true,
|
||||
},
|
||||
{
|
||||
name: "IPv6 TCP wrong port",
|
||||
srcIP: netip.MustParseAddr("fd00::1"),
|
||||
dstIP: netip.MustParseAddr("fd00:dead:beef::80"),
|
||||
proto: layers.LayerTypeTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 443,
|
||||
allowed: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6 UDP not matched by TCP rule",
|
||||
srcIP: netip.MustParseAddr("fd00::1"),
|
||||
dstIP: netip.MustParseAddr("fd00:dead:beef::80"),
|
||||
proto: layers.LayerTypeUDP,
|
||||
srcPort: 12345,
|
||||
dstPort: 80,
|
||||
allowed: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6 ICMPv6 matches ICMP rule via protoLayerMatches",
|
||||
srcIP: netip.MustParseAddr("fd00::1"),
|
||||
dstIP: netip.MustParseAddr("fd00:dead:beef::80"),
|
||||
proto: layers.LayerTypeICMPv6,
|
||||
allowed: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6 to denied subnet",
|
||||
srcIP: netip.MustParseAddr("fd00::1"),
|
||||
dstIP: netip.MustParseAddr("fd00:dead:beef:1::1"),
|
||||
proto: layers.LayerTypeTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 80,
|
||||
allowed: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6 source outside allowed range",
|
||||
srcIP: netip.MustParseAddr("fe80::1"),
|
||||
dstIP: netip.MustParseAddr("fd00:dead:beef::80"),
|
||||
proto: layers.LayerTypeTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 80,
|
||||
allowed: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, pass := manager.routeACLsPass(tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
||||
require.Equal(t, tc.allowed, pass, "route ACL result mismatch")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -527,11 +527,16 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
||||
d := &decoder{
|
||||
decoded: []gopacket.LayerType{},
|
||||
}
|
||||
d.parser = gopacket.NewDecodingLayerParser(
|
||||
d.parser4 = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv4,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser.IgnoreUnsupported = true
|
||||
d.parser4.IgnoreUnsupported = true
|
||||
d.parser6 = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv6,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser6.IgnoreUnsupported = true
|
||||
return d
|
||||
},
|
||||
}
|
||||
@@ -630,11 +635,16 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
d := &decoder{
|
||||
decoded: []gopacket.LayerType{},
|
||||
}
|
||||
d.parser = gopacket.NewDecodingLayerParser(
|
||||
d.parser4 = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv4,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser.IgnoreUnsupported = true
|
||||
d.parser4.IgnoreUnsupported = true
|
||||
d.parser6 = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv6,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser6.IgnoreUnsupported = true
|
||||
return d
|
||||
},
|
||||
}
|
||||
@@ -1040,8 +1050,8 @@ func TestMSSClamping(t *testing.T) {
|
||||
}()
|
||||
|
||||
require.True(t, manager.mssClampEnabled, "MSS clamping should be enabled by default")
|
||||
expectedMSSValue := uint16(1280 - ipTCPHeaderMinSize)
|
||||
require.Equal(t, expectedMSSValue, manager.mssClampValue, "MSS clamp value should be MTU - 40")
|
||||
require.Equal(t, uint16(1280-ipv4TCPHeaderMinSize), manager.mssClampValueIPv4, "IPv4 MSS clamp value should be MTU - 40")
|
||||
require.Equal(t, uint16(1280-ipv6TCPHeaderMinSize), manager.mssClampValueIPv6, "IPv6 MSS clamp value should be MTU - 60")
|
||||
|
||||
err = manager.UpdateLocalIPs()
|
||||
require.NoError(t, err)
|
||||
@@ -1059,7 +1069,7 @@ func TestMSSClamping(t *testing.T) {
|
||||
require.Len(t, d.tcp.Options, 1, "Should have MSS option")
|
||||
require.Equal(t, uint8(layers.TCPOptionKindMSS), uint8(d.tcp.Options[0].OptionType))
|
||||
actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData)
|
||||
require.Equal(t, expectedMSSValue, actualMSS, "MSS should be clamped to MTU - 40")
|
||||
require.Equal(t, manager.mssClampValueIPv4, actualMSS, "MSS should be clamped to MTU - 40")
|
||||
})
|
||||
|
||||
t.Run("SYN packet with low MSS unchanged", func(t *testing.T) {
|
||||
@@ -1083,7 +1093,7 @@ func TestMSSClamping(t *testing.T) {
|
||||
d := parsePacket(t, packet)
|
||||
require.Len(t, d.tcp.Options, 1, "Should have MSS option")
|
||||
actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData)
|
||||
require.Equal(t, expectedMSSValue, actualMSS, "MSS in SYN-ACK should be clamped")
|
||||
require.Equal(t, manager.mssClampValueIPv4, actualMSS, "MSS in SYN-ACK should be clamped")
|
||||
})
|
||||
|
||||
t.Run("Non-SYN packet unchanged", func(t *testing.T) {
|
||||
@@ -1255,13 +1265,18 @@ func TestShouldForward(t *testing.T) {
|
||||
d := &decoder{
|
||||
decoded: []gopacket.LayerType{},
|
||||
}
|
||||
d.parser = gopacket.NewDecodingLayerParser(
|
||||
d.parser4 = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv4,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser.IgnoreUnsupported = true
|
||||
d.parser4.IgnoreUnsupported = true
|
||||
d.parser6 = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv6,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser6.IgnoreUnsupported = true
|
||||
|
||||
err = d.parser.DecodeLayers(buf.Bytes(), &d.decoded)
|
||||
err = d.decodePacket(buf.Bytes())
|
||||
require.NoError(t, err)
|
||||
|
||||
return d
|
||||
@@ -1321,6 +1336,44 @@ func TestShouldForward(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
// Add IPv6 to the interface and test dual-stack cases
|
||||
wgIPv6 := netip.MustParseAddr("fd00::1")
|
||||
otherIPv6 := netip.MustParseAddr("fd00::2")
|
||||
ifaceMock.AddressFunc = func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: wgIP,
|
||||
Network: netip.PrefixFrom(wgIP, 24),
|
||||
IPv6: wgIPv6,
|
||||
IPv6Net: netip.PrefixFrom(wgIPv6, 64),
|
||||
}
|
||||
}
|
||||
|
||||
// Re-create manager to pick up the new address with IPv6
|
||||
require.NoError(t, manager.Close(nil))
|
||||
manager, err = Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
|
||||
v6Cases := []struct {
|
||||
name string
|
||||
dstIP netip.Addr
|
||||
expected bool
|
||||
description string
|
||||
}{
|
||||
{"v6 traffic to other address", otherIPv6, true, "should forward v6 traffic not destined to our v6 address"},
|
||||
{"v6 traffic to our v6 IP", wgIPv6, false, "should not forward traffic destined to our v6 address"},
|
||||
{"v4 traffic to other with v6 configured", otherIP, true, "should forward v4 traffic when v6 configured"},
|
||||
{"v4 traffic to our v4 IP with v6 configured", wgIP, false, "should not forward traffic to our v4 address"},
|
||||
}
|
||||
for _, tt := range v6Cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
manager.localForwarding = true
|
||||
manager.netstack = false
|
||||
decoder := createTCPDecoder(8080)
|
||||
result := manager.shouldForward(decoder, tt.dstIP)
|
||||
require.Equal(t, tt.expected, result, tt.description)
|
||||
})
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Configure manager
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
package forwarder
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
@@ -47,17 +48,23 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress {
|
||||
func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) {
|
||||
var written int
|
||||
for _, pkt := range pkts.AsSlice() {
|
||||
netHeader := header.IPv4(pkt.NetworkHeader().View().AsSlice())
|
||||
|
||||
data := stack.PayloadSince(pkt.NetworkHeader())
|
||||
if data == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Send the packet through WireGuard
|
||||
address := netHeader.DestinationAddress()
|
||||
err := e.device.CreateOutboundPacket(data.AsSlice(), address.AsSlice())
|
||||
if err != nil {
|
||||
raw := pkt.NetworkHeader().View().AsSlice()
|
||||
if len(raw) == 0 {
|
||||
continue
|
||||
}
|
||||
var address tcpip.Address
|
||||
if raw[0]>>4 == 6 {
|
||||
address = header.IPv6(raw).DestinationAddress()
|
||||
} else {
|
||||
address = header.IPv4(raw).DestinationAddress()
|
||||
}
|
||||
|
||||
if err := e.device.CreateOutboundPacket(data.AsSlice(), address.AsSlice()); err != nil {
|
||||
e.logger.Error1("CreateOutboundPacket: %v", err)
|
||||
continue
|
||||
}
|
||||
@@ -103,5 +110,7 @@ type epID stack.TransportEndpointID
|
||||
|
||||
func (i epID) String() string {
|
||||
// src and remote is swapped
|
||||
return fmt.Sprintf("%s:%d → %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort)
|
||||
return net.JoinHostPort(i.RemoteAddress.String(), strconv.Itoa(int(i.RemotePort))) +
|
||||
" → " +
|
||||
net.JoinHostPort(i.LocalAddress.String(), strconv.Itoa(int(i.LocalPort)))
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||
@@ -36,25 +37,31 @@ type Forwarder struct {
|
||||
logger *nblog.Logger
|
||||
flowLogger nftypes.FlowLogger
|
||||
// ruleIdMap is used to store the rule ID for a given connection
|
||||
ruleIdMap sync.Map
|
||||
stack *stack.Stack
|
||||
endpoint *endpoint
|
||||
udpForwarder *udpForwarder
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
ip tcpip.Address
|
||||
netstack bool
|
||||
hasRawICMPAccess bool
|
||||
pingSemaphore chan struct{}
|
||||
ruleIdMap sync.Map
|
||||
stack *stack.Stack
|
||||
endpoint *endpoint
|
||||
udpForwarder *udpForwarder
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
ip tcpip.Address
|
||||
ipv6 tcpip.Address
|
||||
netstack bool
|
||||
hasRawICMPAccess bool
|
||||
hasRawICMPv6Access bool
|
||||
pingSemaphore chan struct{}
|
||||
}
|
||||
|
||||
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) {
|
||||
s := stack.New(stack.Options{
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{
|
||||
ipv4.NewProtocol,
|
||||
ipv6.NewProtocol,
|
||||
},
|
||||
TransportProtocols: []stack.TransportProtocolFactory{
|
||||
tcp.NewProtocol,
|
||||
udp.NewProtocol,
|
||||
icmp.NewProtocol4,
|
||||
icmp.NewProtocol6,
|
||||
},
|
||||
HandleLocal: false,
|
||||
})
|
||||
@@ -73,7 +80,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
protoAddr := tcpip.ProtocolAddress{
|
||||
Protocol: ipv4.ProtocolNumber,
|
||||
AddressWithPrefix: tcpip.AddressWithPrefix{
|
||||
Address: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
|
||||
Address: tcpip.AddrFrom4(iface.Address().IP.As4()),
|
||||
PrefixLen: iface.Address().Network.Bits(),
|
||||
},
|
||||
}
|
||||
@@ -82,6 +89,19 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
return nil, fmt.Errorf("failed to add protocol address: %s", err)
|
||||
}
|
||||
|
||||
if v6 := iface.Address().IPv6; v6.IsValid() {
|
||||
v6Addr := tcpip.ProtocolAddress{
|
||||
Protocol: ipv6.ProtocolNumber,
|
||||
AddressWithPrefix: tcpip.AddressWithPrefix{
|
||||
Address: tcpip.AddrFrom16(v6.As16()),
|
||||
PrefixLen: iface.Address().IPv6Net.Bits(),
|
||||
},
|
||||
}
|
||||
if err := s.AddProtocolAddress(nicID, v6Addr, stack.AddressProperties{}); err != nil {
|
||||
return nil, fmt.Errorf("add IPv6 protocol address: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
defaultSubnet, err := tcpip.NewSubnet(
|
||||
tcpip.AddrFrom4([4]byte{0, 0, 0, 0}),
|
||||
tcpip.MaskFromBytes([]byte{0, 0, 0, 0}),
|
||||
@@ -90,6 +110,14 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
return nil, fmt.Errorf("creating default subnet: %w", err)
|
||||
}
|
||||
|
||||
defaultSubnetV6, err := tcpip.NewSubnet(
|
||||
tcpip.AddrFrom16([16]byte{}),
|
||||
tcpip.MaskFromBytes(make([]byte, 16)),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating default v6 subnet: %w", err)
|
||||
}
|
||||
|
||||
if err := s.SetPromiscuousMode(nicID, true); err != nil {
|
||||
return nil, fmt.Errorf("set promiscuous mode: %s", err)
|
||||
}
|
||||
@@ -98,10 +126,8 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
}
|
||||
|
||||
s.SetRouteTable([]tcpip.Route{
|
||||
{
|
||||
Destination: defaultSubnet,
|
||||
NIC: nicID,
|
||||
},
|
||||
{Destination: defaultSubnet, NIC: nicID},
|
||||
{Destination: defaultSubnetV6, NIC: nicID},
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
@@ -114,7 +140,8 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
netstack: netstack,
|
||||
ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
|
||||
ip: tcpip.AddrFrom4(iface.Address().IP.As4()),
|
||||
ipv6: addrFromNetipAddr(iface.Address().IPv6),
|
||||
pingSemaphore: make(chan struct{}, 3),
|
||||
}
|
||||
|
||||
@@ -131,7 +158,10 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
udpForwarder := udp.NewForwarder(s, f.handleUDP)
|
||||
s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
|
||||
|
||||
s.SetTransportProtocolHandler(icmp.ProtocolNumber4, f.handleICMP)
|
||||
// ICMP is handled directly in InjectIncomingPacket, bypassing gVisor's
|
||||
// network layer. This avoids duplicate echo replies (v4) and the v6
|
||||
// auto-reply bug where gVisor responds at the network layer before
|
||||
// our transport handler fires.
|
||||
|
||||
f.checkICMPCapability()
|
||||
|
||||
@@ -140,8 +170,30 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
}
|
||||
|
||||
func (f *Forwarder) InjectIncomingPacket(payload []byte) error {
|
||||
if len(payload) < header.IPv4MinimumSize {
|
||||
return fmt.Errorf("packet too small: %d bytes", len(payload))
|
||||
if len(payload) == 0 {
|
||||
return fmt.Errorf("empty packet")
|
||||
}
|
||||
|
||||
var protoNum tcpip.NetworkProtocolNumber
|
||||
switch payload[0] >> 4 {
|
||||
case 4:
|
||||
if len(payload) < header.IPv4MinimumSize {
|
||||
return fmt.Errorf("IPv4 packet too small: %d bytes", len(payload))
|
||||
}
|
||||
if f.handleICMPDirect(payload) {
|
||||
return nil
|
||||
}
|
||||
protoNum = ipv4.ProtocolNumber
|
||||
case 6:
|
||||
if len(payload) < header.IPv6MinimumSize {
|
||||
return fmt.Errorf("IPv6 packet too small: %d bytes", len(payload))
|
||||
}
|
||||
if f.handleICMPDirect(payload) {
|
||||
return nil
|
||||
}
|
||||
protoNum = ipv6.ProtocolNumber
|
||||
default:
|
||||
return fmt.Errorf("unknown IP version: %d", payload[0]>>4)
|
||||
}
|
||||
|
||||
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
@@ -150,11 +202,95 @@ func (f *Forwarder) InjectIncomingPacket(payload []byte) error {
|
||||
defer pkt.DecRef()
|
||||
|
||||
if f.endpoint.dispatcher != nil {
|
||||
f.endpoint.dispatcher.DeliverNetworkPacket(ipv4.ProtocolNumber, pkt)
|
||||
f.endpoint.dispatcher.DeliverNetworkPacket(protoNum, pkt)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleICMPDirect intercepts ICMP packets from raw IP payloads before they
|
||||
// enter gVisor. It synthesizes the TransportEndpointID and PacketBuffer that
|
||||
// the existing handlers expect, then dispatches to handleICMP/handleICMPv6.
|
||||
// This bypasses gVisor's network layer which causes duplicate v4 echo replies
|
||||
// and auto-replies to all v6 echo requests in promiscuous mode.
|
||||
//
|
||||
// Unlike gVisor's network layer, this does not validate ICMP checksums or
|
||||
// reassemble IP fragments. Fragmented ICMP packets fall through to gVisor.
|
||||
func parseICMPv4(payload []byte) (ipHdrLen int, src, dst tcpip.Address, ok bool) {
|
||||
ip := header.IPv4(payload)
|
||||
if ip.Protocol() != uint8(header.ICMPv4ProtocolNumber) {
|
||||
return 0, src, dst, false
|
||||
}
|
||||
if ip.FragmentOffset() != 0 || ip.Flags()&header.IPv4FlagMoreFragments != 0 {
|
||||
return 0, src, dst, false
|
||||
}
|
||||
ipHdrLen = int(ip.HeaderLength())
|
||||
if len(payload)-ipHdrLen < header.ICMPv4MinimumSize {
|
||||
return 0, src, dst, false
|
||||
}
|
||||
return ipHdrLen, ip.SourceAddress(), ip.DestinationAddress(), true
|
||||
}
|
||||
|
||||
func parseICMPv6(payload []byte) (ipHdrLen int, src, dst tcpip.Address, ok bool) {
|
||||
ip := header.IPv6(payload)
|
||||
if ip.NextHeader() != uint8(header.ICMPv6ProtocolNumber) {
|
||||
return 0, src, dst, false
|
||||
}
|
||||
ipHdrLen = header.IPv6MinimumSize
|
||||
if len(payload)-ipHdrLen < header.ICMPv6MinimumSize {
|
||||
return 0, src, dst, false
|
||||
}
|
||||
return ipHdrLen, ip.SourceAddress(), ip.DestinationAddress(), true
|
||||
}
|
||||
|
||||
func (f *Forwarder) handleICMPDirect(payload []byte) bool {
|
||||
var (
|
||||
ipHdrLen int
|
||||
srcAddr tcpip.Address
|
||||
dstAddr tcpip.Address
|
||||
ok bool
|
||||
)
|
||||
switch payload[0] >> 4 {
|
||||
case 4:
|
||||
ipHdrLen, srcAddr, dstAddr, ok = parseICMPv4(payload)
|
||||
case 6:
|
||||
ipHdrLen, srcAddr, dstAddr, ok = parseICMPv6(payload)
|
||||
}
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// Let gVisor handle ICMP destined for our own addresses natively.
|
||||
// Its network-layer auto-reply is correct and efficient for local traffic.
|
||||
if f.ip.Equal(dstAddr) || f.ipv6.Equal(dstAddr) {
|
||||
return false
|
||||
}
|
||||
|
||||
id := stack.TransportEndpointID{
|
||||
LocalAddress: dstAddr,
|
||||
RemoteAddress: srcAddr,
|
||||
}
|
||||
|
||||
// Build a PacketBuffer with headers consumed the same way gVisor would.
|
||||
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
Payload: buffer.MakeWithData(payload),
|
||||
})
|
||||
defer pkt.DecRef()
|
||||
|
||||
if _, ok := pkt.NetworkHeader().Consume(ipHdrLen); !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
icmpPayload := payload[ipHdrLen:]
|
||||
if _, ok := pkt.TransportHeader().Consume(len(icmpPayload)); !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
if payload[0]>>4 == 6 {
|
||||
return f.handleICMPv6(id, pkt)
|
||||
}
|
||||
return f.handleICMP(id, pkt)
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the forwarder
|
||||
func (f *Forwarder) Stop() {
|
||||
f.cancel()
|
||||
@@ -167,11 +303,14 @@ func (f *Forwarder) Stop() {
|
||||
f.stack.Wait()
|
||||
}
|
||||
|
||||
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
|
||||
func (f *Forwarder) determineDialAddr(addr tcpip.Address) netip.Addr {
|
||||
if f.netstack && f.ip.Equal(addr) {
|
||||
return net.IPv4(127, 0, 0, 1)
|
||||
return netip.AddrFrom4([4]byte{127, 0, 0, 1})
|
||||
}
|
||||
return addr.AsSlice()
|
||||
if f.netstack && f.ipv6.Equal(addr) {
|
||||
return netip.IPv6Loopback()
|
||||
}
|
||||
return addrToNetipAddr(addr)
|
||||
}
|
||||
|
||||
func (f *Forwarder) RegisterRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, ruleID []byte) {
|
||||
@@ -205,23 +344,50 @@ func buildKey(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) conntrack.ConnKe
|
||||
}
|
||||
}
|
||||
|
||||
// addrFromNetipAddr converts a netip.Addr to a gvisor tcpip.Address without allocating.
|
||||
func addrFromNetipAddr(addr netip.Addr) tcpip.Address {
|
||||
if !addr.IsValid() {
|
||||
return tcpip.Address{}
|
||||
}
|
||||
if addr.Is4() {
|
||||
return tcpip.AddrFrom4(addr.As4())
|
||||
}
|
||||
return tcpip.AddrFrom16(addr.As16())
|
||||
}
|
||||
|
||||
// addrToNetipAddr converts a gvisor tcpip.Address to netip.Addr without allocating.
|
||||
func addrToNetipAddr(addr tcpip.Address) netip.Addr {
|
||||
switch addr.Len() {
|
||||
case 4:
|
||||
return netip.AddrFrom4(addr.As4())
|
||||
case 16:
|
||||
return netip.AddrFrom16(addr.As16())
|
||||
default:
|
||||
return netip.Addr{}
|
||||
}
|
||||
}
|
||||
|
||||
// checkICMPCapability tests whether we have raw ICMP socket access at startup.
|
||||
func (f *Forwarder) checkICMPCapability() {
|
||||
f.hasRawICMPAccess = probeRawICMP("ip4:icmp", "0.0.0.0", f.logger)
|
||||
f.hasRawICMPv6Access = probeRawICMP("ip6:ipv6-icmp", "::", f.logger)
|
||||
}
|
||||
|
||||
func probeRawICMP(network, addr string, logger *nblog.Logger) bool {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
lc := net.ListenConfig{}
|
||||
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
|
||||
conn, err := lc.ListenPacket(ctx, network, addr)
|
||||
if err != nil {
|
||||
f.hasRawICMPAccess = false
|
||||
f.logger.Debug("forwarder: No raw ICMP socket access, will use ping binary fallback")
|
||||
return
|
||||
logger.Debug1("forwarder: no raw %s socket access, will use ping binary fallback", network)
|
||||
return false
|
||||
}
|
||||
|
||||
if err := conn.Close(); err != nil {
|
||||
f.logger.Debug1("forwarder: Failed to close ICMP capability test socket: %v", err)
|
||||
logger.Debug2("forwarder: failed to close %s capability test socket: %v", network, err)
|
||||
}
|
||||
|
||||
f.hasRawICMPAccess = true
|
||||
f.logger.Debug("forwarder: Raw ICMP socket access available")
|
||||
logger.Debug1("forwarder: raw %s socket access available", network)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -35,7 +35,7 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt *stack.PacketBu
|
||||
}
|
||||
|
||||
icmpData := stack.PayloadSince(pkt.TransportHeader()).AsSlice()
|
||||
conn, err := f.forwardICMPPacket(id, icmpData, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), 100*time.Millisecond)
|
||||
conn, err := f.forwardICMPPacket(id, icmpData, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), false, 100*time.Millisecond)
|
||||
if err != nil {
|
||||
f.logger.Error2("forwarder: Failed to forward ICMP packet for %v: %v", epID(id), err)
|
||||
return true
|
||||
@@ -58,7 +58,7 @@ func (f *Forwarder) handleICMPEcho(flowID uuid.UUID, id stack.TransportEndpointI
|
||||
defer func() { <-f.pingSemaphore }()
|
||||
|
||||
if f.hasRawICMPAccess {
|
||||
f.handleICMPViaSocket(flowID, id, icmpType, icmpCode, icmpData, rxBytes)
|
||||
f.handleICMPViaSocket(flowID, id, icmpType, icmpCode, icmpData, rxBytes, false)
|
||||
} else {
|
||||
f.handleICMPViaPing(flowID, id, icmpType, icmpCode, icmpData, rxBytes)
|
||||
}
|
||||
@@ -72,18 +72,23 @@ func (f *Forwarder) handleICMPEcho(flowID uuid.UUID, id stack.TransportEndpointI
|
||||
|
||||
// forwardICMPPacket creates a raw ICMP socket and sends the packet, returning the connection.
|
||||
// The caller is responsible for closing the returned connection.
|
||||
func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []byte, icmpType, icmpCode uint8, timeout time.Duration) (net.PacketConn, error) {
|
||||
func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []byte, icmpType, icmpCode uint8, v6 bool, timeout time.Duration) (net.PacketConn, error) {
|
||||
ctx, cancel := context.WithTimeout(f.ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
network, listenAddr := "ip4:icmp", "0.0.0.0"
|
||||
if v6 {
|
||||
network, listenAddr = "ip6:ipv6-icmp", "::"
|
||||
}
|
||||
|
||||
lc := net.ListenConfig{}
|
||||
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
|
||||
conn, err := lc.ListenPacket(ctx, network, listenAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create ICMP socket: %w", err)
|
||||
}
|
||||
|
||||
dstIP := f.determineDialAddr(id.LocalAddress)
|
||||
dst := &net.IPAddr{IP: dstIP}
|
||||
dst := &net.IPAddr{IP: dstIP.AsSlice()}
|
||||
|
||||
if _, err = conn.WriteTo(payload, dst); err != nil {
|
||||
if closeErr := conn.Close(); closeErr != nil {
|
||||
@@ -98,11 +103,11 @@ func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []by
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// handleICMPViaSocket handles ICMP echo requests using raw sockets.
|
||||
func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int) {
|
||||
// handleICMPViaSocket handles ICMP echo requests using raw sockets for both v4 and v6.
|
||||
func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int, v6 bool) {
|
||||
sendTime := time.Now()
|
||||
|
||||
conn, err := f.forwardICMPPacket(id, icmpData, icmpType, icmpCode, 5*time.Second)
|
||||
conn, err := f.forwardICMPPacket(id, icmpData, icmpType, icmpCode, v6, 5*time.Second)
|
||||
if err != nil {
|
||||
f.logger.Error2("forwarder: Failed to send ICMP packet for %v: %v", epID(id), err)
|
||||
return
|
||||
@@ -113,16 +118,20 @@ func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndp
|
||||
}
|
||||
}()
|
||||
|
||||
txBytes := f.handleEchoResponse(conn, id)
|
||||
txBytes := f.handleEchoResponse(conn, id, v6)
|
||||
rtt := time.Since(sendTime).Round(10 * time.Microsecond)
|
||||
|
||||
f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, raw socket)",
|
||||
epID(id), icmpType, icmpCode, rtt)
|
||||
proto := "ICMP"
|
||||
if v6 {
|
||||
proto = "ICMPv6"
|
||||
}
|
||||
f.logger.Trace5("forwarder: Forwarded %s echo reply %v type %v code %v (rtt=%v, raw socket)",
|
||||
proto, epID(id), icmpType, icmpCode, rtt)
|
||||
|
||||
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
|
||||
}
|
||||
|
||||
func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEndpointID) int {
|
||||
func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEndpointID, v6 bool) int {
|
||||
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
f.logger.Error1("forwarder: Failed to set read deadline for ICMP response: %v", err)
|
||||
return 0
|
||||
@@ -137,6 +146,19 @@ func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEn
|
||||
return 0
|
||||
}
|
||||
|
||||
if v6 {
|
||||
// Recompute checksum: the raw socket response has a checksum computed
|
||||
// over the real endpoint addresses, but we inject with overlay addresses.
|
||||
icmpHdr := header.ICMPv6(response[:n])
|
||||
icmpHdr.SetChecksum(0)
|
||||
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
|
||||
Header: icmpHdr,
|
||||
Src: id.LocalAddress,
|
||||
Dst: id.RemoteAddress,
|
||||
}))
|
||||
return f.injectICMPv6Reply(id, response[:n])
|
||||
}
|
||||
|
||||
return f.injectICMPReply(id, response[:n])
|
||||
}
|
||||
|
||||
@@ -150,19 +172,23 @@ func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.T
|
||||
txPackets = 1
|
||||
}
|
||||
|
||||
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
|
||||
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
|
||||
srcIp := addrToNetipAddr(id.RemoteAddress)
|
||||
dstIp := addrToNetipAddr(id.LocalAddress)
|
||||
|
||||
proto := nftypes.ICMP
|
||||
if srcIp.Is6() {
|
||||
proto = nftypes.ICMPv6
|
||||
}
|
||||
|
||||
fields := nftypes.EventFields{
|
||||
FlowID: flowID,
|
||||
Type: typ,
|
||||
Direction: nftypes.Ingress,
|
||||
Protocol: nftypes.ICMP,
|
||||
// TODO: handle ipv6
|
||||
SourceIP: srcIp,
|
||||
DestIP: dstIp,
|
||||
ICMPType: icmpType,
|
||||
ICMPCode: icmpCode,
|
||||
Protocol: proto,
|
||||
SourceIP: srcIp,
|
||||
DestIP: dstIp,
|
||||
ICMPType: icmpType,
|
||||
ICMPCode: icmpCode,
|
||||
|
||||
RxBytes: rxBytes,
|
||||
TxBytes: txBytes,
|
||||
@@ -209,26 +235,164 @@ func (f *Forwarder) handleICMPViaPing(flowID uuid.UUID, id stack.TransportEndpoi
|
||||
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
|
||||
}
|
||||
|
||||
// handleICMPv6 handles ICMPv6 packets from the network stack.
|
||||
func (f *Forwarder) handleICMPv6(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
|
||||
icmpHdr := header.ICMPv6(pkt.TransportHeader().View().AsSlice())
|
||||
|
||||
flowID := uuid.New()
|
||||
f.sendICMPEvent(nftypes.TypeStart, flowID, id, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), 0, 0)
|
||||
|
||||
if icmpHdr.Type() == header.ICMPv6EchoRequest {
|
||||
return f.handleICMPv6Echo(flowID, id, pkt, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()))
|
||||
}
|
||||
|
||||
// For non-echo types (Destination Unreachable, Packet Too Big, etc), forward without waiting
|
||||
if !f.hasRawICMPv6Access {
|
||||
f.logger.Debug2("forwarder: Cannot handle ICMPv6 type %v without raw socket access for %v", icmpHdr.Type(), epID(id))
|
||||
return false
|
||||
}
|
||||
|
||||
icmpData := stack.PayloadSince(pkt.TransportHeader()).AsSlice()
|
||||
conn, err := f.forwardICMPPacket(id, icmpData, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), true, 100*time.Millisecond)
|
||||
if err != nil {
|
||||
f.logger.Error2("forwarder: Failed to forward ICMPv6 packet for %v: %v", epID(id), err)
|
||||
return true
|
||||
}
|
||||
if err := conn.Close(); err != nil {
|
||||
f.logger.Debug1("forwarder: Failed to close ICMPv6 socket: %v", err)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// handleICMPv6Echo handles ICMPv6 echo requests via raw socket or ping binary fallback.
|
||||
func (f *Forwarder) handleICMPv6Echo(flowID uuid.UUID, id stack.TransportEndpointID, pkt *stack.PacketBuffer, icmpType, icmpCode uint8) bool {
|
||||
select {
|
||||
case f.pingSemaphore <- struct{}{}:
|
||||
icmpData := stack.PayloadSince(pkt.TransportHeader()).ToSlice()
|
||||
rxBytes := pkt.Size()
|
||||
|
||||
go func() {
|
||||
defer func() { <-f.pingSemaphore }()
|
||||
|
||||
if f.hasRawICMPv6Access {
|
||||
f.handleICMPViaSocket(flowID, id, icmpType, icmpCode, icmpData, rxBytes, true)
|
||||
} else {
|
||||
f.handleICMPv6ViaPing(flowID, id, icmpType, icmpCode, icmpData, rxBytes)
|
||||
}
|
||||
}()
|
||||
default:
|
||||
f.logger.Debug3("forwarder: ICMPv6 rate limit exceeded for %v type %v code %v", epID(id), icmpType, icmpCode)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// handleICMPv6ViaPing uses the system ping6 binary for ICMPv6 echo.
|
||||
func (f *Forwarder) handleICMPv6ViaPing(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int) {
|
||||
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
dstIP := f.determineDialAddr(id.LocalAddress)
|
||||
cmd := buildPingCommand(ctx, dstIP, 5*time.Second)
|
||||
|
||||
pingStart := time.Now()
|
||||
if err := cmd.Run(); err != nil {
|
||||
f.logger.Warn4("forwarder: Ping6 failed for %v type %v code %v: %v", epID(id), icmpType, icmpCode, err)
|
||||
return
|
||||
}
|
||||
rtt := time.Since(pingStart).Round(10 * time.Microsecond)
|
||||
|
||||
f.logger.Trace3("forwarder: Forwarded ICMPv6 echo request %v type %v code %v",
|
||||
epID(id), icmpType, icmpCode)
|
||||
|
||||
txBytes := f.synthesizeICMPv6EchoReply(id, icmpData)
|
||||
|
||||
f.logger.Trace4("forwarder: Forwarded ICMPv6 echo reply %v type %v code %v (rtt=%v, ping binary)",
|
||||
epID(id), icmpType, icmpCode, rtt)
|
||||
|
||||
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
|
||||
}
|
||||
|
||||
// synthesizeICMPv6EchoReply creates an ICMPv6 echo reply and injects it back.
|
||||
func (f *Forwarder) synthesizeICMPv6EchoReply(id stack.TransportEndpointID, icmpData []byte) int {
|
||||
replyICMP := make([]byte, len(icmpData))
|
||||
copy(replyICMP, icmpData)
|
||||
|
||||
replyHdr := header.ICMPv6(replyICMP)
|
||||
replyHdr.SetType(header.ICMPv6EchoReply)
|
||||
replyHdr.SetChecksum(0)
|
||||
// ICMPv6Checksum computes the pseudo-header internally from Src/Dst.
|
||||
// Header contains the full ICMP message, so PayloadCsum/PayloadLen are zero.
|
||||
replyHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
|
||||
Header: replyHdr,
|
||||
Src: id.LocalAddress,
|
||||
Dst: id.RemoteAddress,
|
||||
}))
|
||||
|
||||
return f.injectICMPv6Reply(id, replyICMP)
|
||||
}
|
||||
|
||||
// injectICMPv6Reply wraps an ICMPv6 payload in an IPv6 header and sends to the peer.
|
||||
func (f *Forwarder) injectICMPv6Reply(id stack.TransportEndpointID, icmpPayload []byte) int {
|
||||
ipHdr := make([]byte, header.IPv6MinimumSize)
|
||||
ip := header.IPv6(ipHdr)
|
||||
ip.Encode(&header.IPv6Fields{
|
||||
PayloadLength: uint16(len(icmpPayload)),
|
||||
TransportProtocol: header.ICMPv6ProtocolNumber,
|
||||
HopLimit: 64,
|
||||
SrcAddr: id.LocalAddress,
|
||||
DstAddr: id.RemoteAddress,
|
||||
})
|
||||
|
||||
fullPacket := make([]byte, 0, len(ipHdr)+len(icmpPayload))
|
||||
fullPacket = append(fullPacket, ipHdr...)
|
||||
fullPacket = append(fullPacket, icmpPayload...)
|
||||
|
||||
if err := f.endpoint.device.CreateOutboundPacket(fullPacket, id.RemoteAddress.AsSlice()); err != nil {
|
||||
f.logger.Error1("forwarder: Failed to send ICMPv6 reply to peer: %v", err)
|
||||
return 0
|
||||
}
|
||||
|
||||
return len(fullPacket)
|
||||
}
|
||||
|
||||
const (
|
||||
pingBin = "ping"
|
||||
ping6Bin = "ping6"
|
||||
)
|
||||
|
||||
// buildPingCommand creates a platform-specific ping command.
|
||||
func buildPingCommand(ctx context.Context, target net.IP, timeout time.Duration) *exec.Cmd {
|
||||
// Most platforms auto-detect IPv6 from raw addresses. macOS/iOS/OpenBSD require ping6.
|
||||
func buildPingCommand(ctx context.Context, target netip.Addr, timeout time.Duration) *exec.Cmd {
|
||||
timeoutSec := int(timeout.Seconds())
|
||||
if timeoutSec < 1 {
|
||||
timeoutSec = 1
|
||||
}
|
||||
|
||||
isV6 := target.Is6()
|
||||
timeoutStr := fmt.Sprintf("%d", timeoutSec)
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "linux", "android":
|
||||
return exec.CommandContext(ctx, "ping", "-c", "1", "-W", fmt.Sprintf("%d", timeoutSec), "-q", target.String())
|
||||
return exec.CommandContext(ctx, pingBin, "-c", "1", "-W", timeoutStr, "-q", target.String())
|
||||
case "darwin", "ios":
|
||||
return exec.CommandContext(ctx, "ping", "-c", "1", "-t", fmt.Sprintf("%d", timeoutSec), "-q", target.String())
|
||||
bin := pingBin
|
||||
if isV6 {
|
||||
bin = ping6Bin
|
||||
}
|
||||
return exec.CommandContext(ctx, bin, "-c", "1", "-t", timeoutStr, "-q", target.String())
|
||||
case "freebsd":
|
||||
return exec.CommandContext(ctx, "ping", "-c", "1", "-t", fmt.Sprintf("%d", timeoutSec), target.String())
|
||||
return exec.CommandContext(ctx, pingBin, "-c", "1", "-t", timeoutStr, target.String())
|
||||
case "openbsd", "netbsd":
|
||||
return exec.CommandContext(ctx, "ping", "-c", "1", "-w", fmt.Sprintf("%d", timeoutSec), target.String())
|
||||
bin := pingBin
|
||||
if isV6 {
|
||||
bin = ping6Bin
|
||||
}
|
||||
return exec.CommandContext(ctx, bin, "-c", "1", "-w", timeoutStr, target.String())
|
||||
case "windows":
|
||||
return exec.CommandContext(ctx, "ping", "-n", "1", "-w", fmt.Sprintf("%d", timeoutSec*1000), target.String())
|
||||
return exec.CommandContext(ctx, pingBin, "-n", "1", "-w", fmt.Sprintf("%d", timeoutSec*1000), target.String())
|
||||
default:
|
||||
return exec.CommandContext(ctx, "ping", "-c", "1", target.String())
|
||||
return exec.CommandContext(ctx, pingBin, "-c", "1", target.String())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,10 +2,9 @@ package forwarder
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -33,7 +32,7 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
||||
}
|
||||
}()
|
||||
|
||||
dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
||||
dialAddr := net.JoinHostPort(f.determineDialAddr(id.LocalAddress).String(), strconv.Itoa(int(id.LocalPort)))
|
||||
|
||||
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
|
||||
if err != nil {
|
||||
@@ -133,15 +132,14 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
|
||||
}
|
||||
|
||||
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) {
|
||||
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
|
||||
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
|
||||
srcIp := addrToNetipAddr(id.RemoteAddress)
|
||||
dstIp := addrToNetipAddr(id.LocalAddress)
|
||||
|
||||
fields := nftypes.EventFields{
|
||||
FlowID: flowID,
|
||||
Type: typ,
|
||||
Direction: nftypes.Ingress,
|
||||
Protocol: nftypes.TCP,
|
||||
// TODO: handle ipv6
|
||||
FlowID: flowID,
|
||||
Type: typ,
|
||||
Direction: nftypes.Ingress,
|
||||
Protocol: nftypes.TCP,
|
||||
SourceIP: srcIp,
|
||||
DestIP: dstIp,
|
||||
SourcePort: id.RemotePort,
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -158,7 +158,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
|
||||
}
|
||||
}()
|
||||
|
||||
dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
||||
dstAddr := net.JoinHostPort(f.determineDialAddr(id.LocalAddress).String(), strconv.Itoa(int(id.LocalPort)))
|
||||
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
|
||||
if err != nil {
|
||||
f.logger.Debug2("forwarder: UDP dial error for %v: %v", epID(id), err)
|
||||
@@ -276,15 +276,14 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
|
||||
|
||||
// sendUDPEvent stores flow events for UDP connections
|
||||
func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) {
|
||||
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
|
||||
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
|
||||
srcIp := addrToNetipAddr(id.RemoteAddress)
|
||||
dstIp := addrToNetipAddr(id.LocalAddress)
|
||||
|
||||
fields := nftypes.EventFields{
|
||||
FlowID: flowID,
|
||||
Type: typ,
|
||||
Direction: nftypes.Ingress,
|
||||
Protocol: nftypes.UDP,
|
||||
// TODO: handle ipv6
|
||||
FlowID: flowID,
|
||||
Type: typ,
|
||||
Direction: nftypes.Ingress,
|
||||
Protocol: nftypes.UDP,
|
||||
SourceIP: srcIp,
|
||||
DestIP: dstIp,
|
||||
SourcePort: id.RemotePort,
|
||||
|
||||
@@ -4,89 +4,32 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||
)
|
||||
|
||||
type localIPManager struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
// fixed-size high array for upper byte of a IPv4 address
|
||||
ipv4Bitmap [256]*ipv4LowBitmap
|
||||
// localIPSnapshot is an immutable snapshot of local IP addresses, swapped
|
||||
// atomically so reads are lock-free.
|
||||
type localIPSnapshot struct {
|
||||
ips map[netip.Addr]struct{}
|
||||
}
|
||||
|
||||
// ipv4LowBitmap is a map for the low 16 bits of a IPv4 address
|
||||
type ipv4LowBitmap struct {
|
||||
bitmap [8192]uint32
|
||||
type localIPManager struct {
|
||||
snapshot atomic.Pointer[localIPSnapshot]
|
||||
}
|
||||
|
||||
func newLocalIPManager() *localIPManager {
|
||||
return &localIPManager{}
|
||||
m := &localIPManager{}
|
||||
m.snapshot.Store(&localIPSnapshot{
|
||||
ips: make(map[netip.Addr]struct{}),
|
||||
})
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *localIPManager) setBitmapBit(ip net.IP) {
|
||||
ipv4 := ip.To4()
|
||||
if ipv4 == nil {
|
||||
return
|
||||
}
|
||||
high := uint16(ipv4[0])
|
||||
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
|
||||
|
||||
index := low / 32
|
||||
bit := low % 32
|
||||
|
||||
if m.ipv4Bitmap[high] == nil {
|
||||
m.ipv4Bitmap[high] = &ipv4LowBitmap{}
|
||||
}
|
||||
|
||||
m.ipv4Bitmap[high].bitmap[index] |= 1 << bit
|
||||
}
|
||||
|
||||
func (m *localIPManager) setBitInBitmap(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
|
||||
if !ip.Is4() {
|
||||
return
|
||||
}
|
||||
ipv4 := ip.AsSlice()
|
||||
|
||||
high := uint16(ipv4[0])
|
||||
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
|
||||
|
||||
if bitmap[high] == nil {
|
||||
bitmap[high] = &ipv4LowBitmap{}
|
||||
}
|
||||
|
||||
index := low / 32
|
||||
bit := low % 32
|
||||
bitmap[high].bitmap[index] |= 1 << bit
|
||||
|
||||
if _, exists := ipv4Set[ip]; !exists {
|
||||
ipv4Set[ip] = struct{}{}
|
||||
*ipv4Addresses = append(*ipv4Addresses, ip)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *localIPManager) checkBitmapBit(ip []byte) bool {
|
||||
high := uint16(ip[0])
|
||||
low := (uint16(ip[1]) << 8) | (uint16(ip[2]) << 4) | uint16(ip[3])
|
||||
|
||||
if m.ipv4Bitmap[high] == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
index := low / 32
|
||||
bit := low % 32
|
||||
return (m.ipv4Bitmap[high].bitmap[index] & (1 << bit)) != 0
|
||||
}
|
||||
|
||||
func (m *localIPManager) processIP(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) error {
|
||||
m.setBitInBitmap(ip, bitmap, ipv4Set, ipv4Addresses)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
|
||||
func processInterface(iface net.Interface, ips map[netip.Addr]struct{}, addresses *[]netip.Addr) {
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
|
||||
@@ -104,18 +47,19 @@ func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv
|
||||
continue
|
||||
}
|
||||
|
||||
addr, ok := netip.AddrFromSlice(ip)
|
||||
parsed, ok := netip.AddrFromSlice(ip)
|
||||
if !ok {
|
||||
log.Warnf("invalid IP address %s in interface %s", ip.String(), iface.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := m.processIP(addr.Unmap(), bitmap, ipv4Set, ipv4Addresses); err != nil {
|
||||
log.Debugf("process IP failed: %v", err)
|
||||
}
|
||||
parsed = parsed.Unmap()
|
||||
ips[parsed] = struct{}{}
|
||||
*addresses = append(*addresses, parsed)
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateLocalIPs rebuilds the local IP snapshot and swaps it in atomically.
|
||||
func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
@@ -123,20 +67,20 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
||||
}
|
||||
}()
|
||||
|
||||
var newIPv4Bitmap [256]*ipv4LowBitmap
|
||||
ipv4Set := make(map[netip.Addr]struct{})
|
||||
var ipv4Addresses []netip.Addr
|
||||
ips := make(map[netip.Addr]struct{})
|
||||
var addresses []netip.Addr
|
||||
|
||||
// 127.0.0.0/8
|
||||
newIPv4Bitmap[127] = &ipv4LowBitmap{}
|
||||
for i := 0; i < 8192; i++ {
|
||||
// #nosec G602 -- bitmap is defined as [8192]uint32, loop range is correct
|
||||
newIPv4Bitmap[127].bitmap[i] = 0xFFFFFFFF
|
||||
}
|
||||
// loopback
|
||||
ips[netip.AddrFrom4([4]byte{127, 0, 0, 1})] = struct{}{}
|
||||
ips[netip.IPv6Loopback()] = struct{}{}
|
||||
|
||||
if iface != nil {
|
||||
if err := m.processIP(iface.Address().IP, &newIPv4Bitmap, ipv4Set, &ipv4Addresses); err != nil {
|
||||
return err
|
||||
ip := iface.Address().IP
|
||||
ips[ip] = struct{}{}
|
||||
addresses = append(addresses, ip)
|
||||
if v6 := iface.Address().IPv6; v6.IsValid() {
|
||||
ips[v6] = struct{}{}
|
||||
addresses = append(addresses, v6)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -147,25 +91,24 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
||||
// TODO: filter out down interfaces (net.FlagUp). Also handle the reverse
|
||||
// case where an interface comes up between refreshes.
|
||||
for _, intf := range interfaces {
|
||||
m.processInterface(intf, &newIPv4Bitmap, ipv4Set, &ipv4Addresses)
|
||||
processInterface(intf, ips, &addresses)
|
||||
}
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
m.ipv4Bitmap = newIPv4Bitmap
|
||||
m.mu.Unlock()
|
||||
m.snapshot.Store(&localIPSnapshot{ips: ips})
|
||||
|
||||
log.Debugf("Local IPv4 addresses: %v", ipv4Addresses)
|
||||
log.Debugf("Local IP addresses: %v", addresses)
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsLocalIP checks if the given IP is a local address. Lock-free on the read path.
|
||||
func (m *localIPManager) IsLocalIP(ip netip.Addr) bool {
|
||||
if !ip.Is4() {
|
||||
return false
|
||||
s := m.snapshot.Load()
|
||||
|
||||
if ip.Is4() && ip.As4()[0] == 127 {
|
||||
return true
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
return m.checkBitmapBit(ip.AsSlice())
|
||||
_, found := s.ips[ip]
|
||||
return found
|
||||
}
|
||||
|
||||
72
client/firewall/uspfilter/localip_bench_test.go
Normal file
72
client/firewall/uspfilter/localip_bench_test.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
func setupManager(b *testing.B) *localIPManager {
|
||||
b.Helper()
|
||||
m := newLocalIPManager()
|
||||
mock := &IFaceMock{
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: netip.MustParseAddr("100.64.0.1"),
|
||||
Network: netip.MustParsePrefix("100.64.0.0/16"),
|
||||
IPv6: netip.MustParseAddr("fd00::1"),
|
||||
IPv6Net: netip.MustParsePrefix("fd00::/64"),
|
||||
}
|
||||
},
|
||||
}
|
||||
if err := m.UpdateLocalIPs(mock); err != nil {
|
||||
b.Fatalf("UpdateLocalIPs: %v", err)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func BenchmarkIsLocalIP_v4_hit(b *testing.B) {
|
||||
m := setupManager(b)
|
||||
ip := netip.MustParseAddr("100.64.0.1")
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
m.IsLocalIP(ip)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkIsLocalIP_v4_miss(b *testing.B) {
|
||||
m := setupManager(b)
|
||||
ip := netip.MustParseAddr("8.8.8.8")
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
m.IsLocalIP(ip)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkIsLocalIP_v6_hit(b *testing.B) {
|
||||
m := setupManager(b)
|
||||
ip := netip.MustParseAddr("fd00::1")
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
m.IsLocalIP(ip)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkIsLocalIP_v6_miss(b *testing.B) {
|
||||
m := setupManager(b)
|
||||
ip := netip.MustParseAddr("2001:db8::1")
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
m.IsLocalIP(ip)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkIsLocalIP_loopback(b *testing.B) {
|
||||
m := setupManager(b)
|
||||
ip := netip.MustParseAddr("127.0.0.1")
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
m.IsLocalIP(ip)
|
||||
}
|
||||
}
|
||||
@@ -72,14 +72,45 @@ func TestLocalIPManager(t *testing.T) {
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6 address",
|
||||
name: "IPv6 address matches",
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: netip.MustParseAddr("fe80::1"),
|
||||
IP: netip.MustParseAddr("100.64.0.1"),
|
||||
Network: netip.MustParsePrefix("100.64.0.0/16"),
|
||||
IPv6: netip.MustParseAddr("fd00::1"),
|
||||
IPv6Net: netip.MustParsePrefix("fd00::/64"),
|
||||
},
|
||||
testIP: netip.MustParseAddr("fd00::1"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "IPv6 address does not match",
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: netip.MustParseAddr("100.64.0.1"),
|
||||
Network: netip.MustParsePrefix("100.64.0.0/16"),
|
||||
IPv6: netip.MustParseAddr("fd00::1"),
|
||||
IPv6Net: netip.MustParsePrefix("fd00::/64"),
|
||||
},
|
||||
testIP: netip.MustParseAddr("fd00::99"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "No aliasing between similar IPs",
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: netip.MustParseAddr("192.168.1.1"),
|
||||
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
testIP: netip.MustParseAddr("fe80::1"),
|
||||
testIP: netip.MustParseAddr("192.168.0.17"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6 loopback",
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: netip.MustParseAddr("100.64.0.1"),
|
||||
Network: netip.MustParsePrefix("100.64.0.0/16"),
|
||||
},
|
||||
testIP: netip.MustParseAddr("::1"),
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -171,90 +202,3 @@ func TestLocalIPManager_AllInterfaces(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// MapImplementation is a version using map[string]struct{}
|
||||
type MapImplementation struct {
|
||||
localIPs map[string]struct{}
|
||||
}
|
||||
|
||||
func BenchmarkIPChecks(b *testing.B) {
|
||||
interfaces := make([]net.IP, 16)
|
||||
for i := range interfaces {
|
||||
interfaces[i] = net.IPv4(10, 0, byte(i>>8), byte(i))
|
||||
}
|
||||
|
||||
// Setup bitmap
|
||||
bitmapManager := newLocalIPManager()
|
||||
for _, ip := range interfaces[:8] { // Add half of IPs
|
||||
bitmapManager.setBitmapBit(ip)
|
||||
}
|
||||
|
||||
// Setup map version
|
||||
mapManager := &MapImplementation{
|
||||
localIPs: make(map[string]struct{}),
|
||||
}
|
||||
for _, ip := range interfaces[:8] {
|
||||
mapManager.localIPs[ip.String()] = struct{}{}
|
||||
}
|
||||
|
||||
b.Run("Bitmap_Hit", func(b *testing.B) {
|
||||
ip := interfaces[4]
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
bitmapManager.checkBitmapBit(ip)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("Bitmap_Miss", func(b *testing.B) {
|
||||
ip := interfaces[12]
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
bitmapManager.checkBitmapBit(ip)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("Map_Hit", func(b *testing.B) {
|
||||
ip := interfaces[4]
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
// nolint:gosimple
|
||||
_ = mapManager.localIPs[ip.String()]
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("Map_Miss", func(b *testing.B) {
|
||||
ip := interfaces[12]
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
// nolint:gosimple
|
||||
_ = mapManager.localIPs[ip.String()]
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkWGPosition(b *testing.B) {
|
||||
wgIP := net.ParseIP("10.10.0.1")
|
||||
|
||||
// Create two managers - one checks WG IP first, other checks it last
|
||||
b.Run("WG_First", func(b *testing.B) {
|
||||
bm := newLocalIPManager()
|
||||
bm.setBitmapBit(wgIP)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
bm.checkBitmapBit(wgIP)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("WG_Last", func(b *testing.B) {
|
||||
bm := newLocalIPManager()
|
||||
// Fill with other IPs first
|
||||
for i := 0; i < 15; i++ {
|
||||
bm.setBitmapBit(net.IPv4(10, 0, byte(i>>8), byte(i)))
|
||||
}
|
||||
bm.setBitmapBit(wgIP) // Add WG IP last
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
bm.checkBitmapBit(wgIP)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -13,8 +13,6 @@ import (
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT")
|
||||
|
||||
var (
|
||||
errInvalidIPHeaderLength = errors.New("invalid IP header length")
|
||||
)
|
||||
@@ -25,10 +23,33 @@ const (
|
||||
destinationPortOffset = 2
|
||||
|
||||
// IP address offsets in IPv4 header
|
||||
sourceIPOffset = 12
|
||||
destinationIPOffset = 16
|
||||
ipv4SrcOffset = 12
|
||||
ipv4DstOffset = 16
|
||||
|
||||
// IP address offsets in IPv6 header
|
||||
ipv6SrcOffset = 8
|
||||
ipv6DstOffset = 24
|
||||
|
||||
// IPv6 fixed header length
|
||||
ipv6HeaderLen = 40
|
||||
)
|
||||
|
||||
// ipHeaderLen returns the IP header length based on the decoded layer type.
|
||||
func ipHeaderLen(d *decoder) (int, error) {
|
||||
switch d.decoded[0] {
|
||||
case layers.LayerTypeIPv4:
|
||||
n := int(d.ip4.IHL) * 4
|
||||
if n < 20 {
|
||||
return 0, errInvalidIPHeaderLength
|
||||
}
|
||||
return n, nil
|
||||
case layers.LayerTypeIPv6:
|
||||
return ipv6HeaderLen, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unknown IP layer: %v", d.decoded[0])
|
||||
}
|
||||
}
|
||||
|
||||
// ipv4Checksum calculates IPv4 header checksum.
|
||||
func ipv4Checksum(header []byte) uint16 {
|
||||
if len(header) < 20 {
|
||||
@@ -234,14 +255,13 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
|
||||
|
||||
_, dstIP := extractPacketIPs(packetData, d)
|
||||
translatedIP, exists := m.getDNATTranslation(dstIP)
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
if err := m.rewritePacketIP(packetData, d, translatedIP, destinationIPOffset); err != nil {
|
||||
if err := m.rewritePacketIP(packetData, d, translatedIP, false); err != nil {
|
||||
m.logger.Error1("failed to rewrite packet destination: %v", err)
|
||||
return false
|
||||
}
|
||||
@@ -256,14 +276,13 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
|
||||
|
||||
srcIP, _ := extractPacketIPs(packetData, d)
|
||||
originalIP, exists := m.findReverseDNATMapping(srcIP)
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
if err := m.rewritePacketIP(packetData, d, originalIP, sourceIPOffset); err != nil {
|
||||
if err := m.rewritePacketIP(packetData, d, originalIP, true); err != nil {
|
||||
m.logger.Error1("failed to rewrite packet source: %v", err)
|
||||
return false
|
||||
}
|
||||
@@ -272,38 +291,96 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// rewritePacketIP replaces an IP address (source or destination) in the packet and updates checksums.
|
||||
func (m *Manager) rewritePacketIP(packetData []byte, d *decoder, newIP netip.Addr, ipOffset int) error {
|
||||
// extractPacketIPs extracts src and dst IP addresses directly from raw packet bytes.
|
||||
func extractPacketIPs(packetData []byte, d *decoder) (src, dst netip.Addr) {
|
||||
switch d.decoded[0] {
|
||||
case layers.LayerTypeIPv4:
|
||||
src = netip.AddrFrom4([4]byte{packetData[ipv4SrcOffset], packetData[ipv4SrcOffset+1], packetData[ipv4SrcOffset+2], packetData[ipv4SrcOffset+3]})
|
||||
dst = netip.AddrFrom4([4]byte{packetData[ipv4DstOffset], packetData[ipv4DstOffset+1], packetData[ipv4DstOffset+2], packetData[ipv4DstOffset+3]})
|
||||
case layers.LayerTypeIPv6:
|
||||
src = netip.AddrFrom16([16]byte(packetData[ipv6SrcOffset : ipv6SrcOffset+16]))
|
||||
dst = netip.AddrFrom16([16]byte(packetData[ipv6DstOffset : ipv6DstOffset+16]))
|
||||
}
|
||||
return src, dst
|
||||
}
|
||||
|
||||
// rewritePacketIP replaces a source (isSource=true) or destination IP address in the packet and updates checksums.
|
||||
func (m *Manager) rewritePacketIP(packetData []byte, d *decoder, newIP netip.Addr, isSource bool) error {
|
||||
hdrLen, err := ipHeaderLen(d)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch d.decoded[0] {
|
||||
case layers.LayerTypeIPv4:
|
||||
return m.rewriteIPv4(packetData, d, newIP, hdrLen, isSource)
|
||||
case layers.LayerTypeIPv6:
|
||||
return m.rewriteIPv6(packetData, d, newIP, hdrLen, isSource)
|
||||
default:
|
||||
return fmt.Errorf("unknown IP layer: %v", d.decoded[0])
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) rewriteIPv4(packetData []byte, d *decoder, newIP netip.Addr, hdrLen int, isSource bool) error {
|
||||
if !newIP.Is4() {
|
||||
return ErrIPv4Only
|
||||
return fmt.Errorf("cannot write IPv6 address into IPv4 packet")
|
||||
}
|
||||
|
||||
offset := ipv4DstOffset
|
||||
if isSource {
|
||||
offset = ipv4SrcOffset
|
||||
}
|
||||
|
||||
var oldIP [4]byte
|
||||
copy(oldIP[:], packetData[ipOffset:ipOffset+4])
|
||||
copy(oldIP[:], packetData[offset:offset+4])
|
||||
newIPBytes := newIP.As4()
|
||||
copy(packetData[offset:offset+4], newIPBytes[:])
|
||||
|
||||
copy(packetData[ipOffset:ipOffset+4], newIPBytes[:])
|
||||
|
||||
ipHeaderLen := int(d.ip4.IHL) * 4
|
||||
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
|
||||
return errInvalidIPHeaderLength
|
||||
}
|
||||
|
||||
// Recalculate IPv4 header checksum
|
||||
binary.BigEndian.PutUint16(packetData[10:12], 0)
|
||||
ipChecksum := ipv4Checksum(packetData[:ipHeaderLen])
|
||||
binary.BigEndian.PutUint16(packetData[10:12], ipChecksum)
|
||||
binary.BigEndian.PutUint16(packetData[10:12], ipv4Checksum(packetData[:hdrLen]))
|
||||
|
||||
// Update transport checksums incrementally
|
||||
if len(d.decoded) > 1 {
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
m.updateTCPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:])
|
||||
m.updateTCPChecksum(packetData, hdrLen, oldIP[:], newIPBytes[:])
|
||||
case layers.LayerTypeUDP:
|
||||
m.updateUDPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:])
|
||||
m.updateUDPChecksum(packetData, hdrLen, oldIP[:], newIPBytes[:])
|
||||
case layers.LayerTypeICMPv4:
|
||||
m.updateICMPChecksum(packetData, ipHeaderLen)
|
||||
m.updateICMPChecksum(packetData, hdrLen)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) rewriteIPv6(packetData []byte, d *decoder, newIP netip.Addr, hdrLen int, isSource bool) error {
|
||||
if !newIP.Is6() {
|
||||
return fmt.Errorf("cannot write IPv4 address into IPv6 packet")
|
||||
}
|
||||
|
||||
offset := ipv6DstOffset
|
||||
if isSource {
|
||||
offset = ipv6SrcOffset
|
||||
}
|
||||
|
||||
var oldIP [16]byte
|
||||
copy(oldIP[:], packetData[offset:offset+16])
|
||||
newIPBytes := newIP.As16()
|
||||
copy(packetData[offset:offset+16], newIPBytes[:])
|
||||
|
||||
// IPv6 has no header checksum, only update transport checksums
|
||||
if len(d.decoded) > 1 {
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
m.updateTCPChecksum(packetData, hdrLen, oldIP[:], newIPBytes[:])
|
||||
case layers.LayerTypeUDP:
|
||||
m.updateUDPChecksum(packetData, hdrLen, oldIP[:], newIPBytes[:])
|
||||
case layers.LayerTypeICMPv6:
|
||||
// ICMPv6 checksum includes pseudo-header with addresses, use incremental update
|
||||
m.updateICMPv6Checksum(packetData, hdrLen, oldIP[:], newIPBytes[:])
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -351,6 +428,20 @@ func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) {
|
||||
binary.BigEndian.PutUint16(icmpData[2:4], checksum)
|
||||
}
|
||||
|
||||
// updateICMPv6Checksum updates ICMPv6 checksum after address change.
|
||||
// ICMPv6 uses a pseudo-header (like TCP/UDP), so incremental update applies.
|
||||
func (m *Manager) updateICMPv6Checksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
|
||||
icmpStart := ipHeaderLen
|
||||
if len(packetData) < icmpStart+4 {
|
||||
return
|
||||
}
|
||||
|
||||
checksumOffset := icmpStart + 2
|
||||
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
|
||||
newChecksum := incrementalUpdate(oldChecksum, oldIP, newIP)
|
||||
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
|
||||
}
|
||||
|
||||
// incrementalUpdate performs incremental checksum update per RFC 1624.
|
||||
func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
|
||||
sum := uint32(^oldChecksum)
|
||||
@@ -532,12 +623,12 @@ func (m *Manager) applyPortRule(packetData []byte, d *decoder, srcIP, dstIP neti
|
||||
|
||||
// rewriteTCPPort rewrites a TCP port (source or destination) and updates checksum.
|
||||
func (m *Manager) rewriteTCPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error {
|
||||
ipHeaderLen := int(d.ip4.IHL) * 4
|
||||
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
|
||||
return errInvalidIPHeaderLength
|
||||
hdrLen, err := ipHeaderLen(d)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tcpStart := ipHeaderLen
|
||||
tcpStart := hdrLen
|
||||
if len(packetData) < tcpStart+4 {
|
||||
return fmt.Errorf("packet too short for TCP header")
|
||||
}
|
||||
@@ -563,12 +654,12 @@ func (m *Manager) rewriteTCPPort(packetData []byte, d *decoder, newPort uint16,
|
||||
|
||||
// rewriteUDPPort rewrites a UDP port (source or destination) and updates checksum.
|
||||
func (m *Manager) rewriteUDPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error {
|
||||
ipHeaderLen := int(d.ip4.IHL) * 4
|
||||
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
|
||||
return errInvalidIPHeaderLength
|
||||
hdrLen, err := ipHeaderLen(d)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
udpStart := ipHeaderLen
|
||||
udpStart := hdrLen
|
||||
if len(packetData) < udpStart+8 {
|
||||
return fmt.Errorf("packet too short for UDP header")
|
||||
}
|
||||
|
||||
@@ -342,12 +342,17 @@ func BenchmarkDNATMemoryAllocations(b *testing.B) {
|
||||
|
||||
// Parse the packet fresh each time to get a clean decoder
|
||||
d := &decoder{decoded: []gopacket.LayerType{}}
|
||||
d.parser = gopacket.NewDecodingLayerParser(
|
||||
d.parser4 = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv4,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser.IgnoreUnsupported = true
|
||||
err = d.parser.DecodeLayers(testPacket, &d.decoded)
|
||||
d.parser4.IgnoreUnsupported = true
|
||||
d.parser6 = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv6,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser6.IgnoreUnsupported = true
|
||||
err = d.decodePacket(testPacket)
|
||||
assert.NoError(b, err)
|
||||
|
||||
manager.translateOutboundDNAT(testPacket, d)
|
||||
@@ -371,12 +376,17 @@ func BenchmarkDirectIPExtraction(b *testing.B) {
|
||||
b.Run("decoder_extraction", func(b *testing.B) {
|
||||
// Create decoder once for comparison
|
||||
d := &decoder{decoded: []gopacket.LayerType{}}
|
||||
d.parser = gopacket.NewDecodingLayerParser(
|
||||
d.parser4 = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv4,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser.IgnoreUnsupported = true
|
||||
err := d.parser.DecodeLayers(packet, &d.decoded)
|
||||
d.parser4.IgnoreUnsupported = true
|
||||
d.parser6 = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv6,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser6.IgnoreUnsupported = true
|
||||
err := d.decodePacket(packet)
|
||||
assert.NoError(b, err)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
|
||||
@@ -86,13 +86,18 @@ func parsePacket(t testing.TB, packetData []byte) *decoder {
|
||||
d := &decoder{
|
||||
decoded: []gopacket.LayerType{},
|
||||
}
|
||||
d.parser = gopacket.NewDecodingLayerParser(
|
||||
d.parser4 = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv4,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser.IgnoreUnsupported = true
|
||||
d.parser4.IgnoreUnsupported = true
|
||||
d.parser6 = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv6,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser6.IgnoreUnsupported = true
|
||||
|
||||
err := d.parser.DecodeLayers(packetData, &d.decoded)
|
||||
err := d.decodePacket(packetData)
|
||||
require.NoError(t, err)
|
||||
return d
|
||||
}
|
||||
|
||||
@@ -112,10 +112,13 @@ func (t *PacketTrace) AddResultWithForwarder(stage PacketStage, message string,
|
||||
}
|
||||
|
||||
func (p *PacketBuilder) Build() ([]byte, error) {
|
||||
ip := p.buildIPLayer()
|
||||
pktLayers := []gopacket.SerializableLayer{ip}
|
||||
ipLayer, err := p.buildIPLayer()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pktLayers := []gopacket.SerializableLayer{ipLayer}
|
||||
|
||||
transportLayer, err := p.buildTransportLayer(ip)
|
||||
transportLayer, err := p.buildTransportLayer(ipLayer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -129,30 +132,43 @@ func (p *PacketBuilder) Build() ([]byte, error) {
|
||||
return serializePacket(pktLayers)
|
||||
}
|
||||
|
||||
func (p *PacketBuilder) buildIPLayer() *layers.IPv4 {
|
||||
func (p *PacketBuilder) buildIPLayer() (gopacket.SerializableLayer, error) {
|
||||
if p.SrcIP.Is4() != p.DstIP.Is4() {
|
||||
return nil, fmt.Errorf("mixed address families: src=%s dst=%s", p.SrcIP, p.DstIP)
|
||||
}
|
||||
proto := getIPProtocolNumber(p.Protocol, p.SrcIP.Is6())
|
||||
if p.SrcIP.Is6() {
|
||||
return &layers.IPv6{
|
||||
Version: 6,
|
||||
HopLimit: 64,
|
||||
NextHeader: proto,
|
||||
SrcIP: p.SrcIP.AsSlice(),
|
||||
DstIP: p.DstIP.AsSlice(),
|
||||
}, nil
|
||||
}
|
||||
return &layers.IPv4{
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocol(getIPProtocolNumber(p.Protocol)),
|
||||
Protocol: proto,
|
||||
SrcIP: p.SrcIP.AsSlice(),
|
||||
DstIP: p.DstIP.AsSlice(),
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *PacketBuilder) buildTransportLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) {
|
||||
func (p *PacketBuilder) buildTransportLayer(ipLayer gopacket.SerializableLayer) ([]gopacket.SerializableLayer, error) {
|
||||
switch p.Protocol {
|
||||
case "tcp":
|
||||
return p.buildTCPLayer(ip)
|
||||
return p.buildTCPLayer(ipLayer)
|
||||
case "udp":
|
||||
return p.buildUDPLayer(ip)
|
||||
return p.buildUDPLayer(ipLayer)
|
||||
case "icmp":
|
||||
return p.buildICMPLayer()
|
||||
return p.buildICMPLayer(ipLayer)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported protocol: %s", p.Protocol)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PacketBuilder) buildTCPLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) {
|
||||
func (p *PacketBuilder) buildTCPLayer(ipLayer gopacket.SerializableLayer) ([]gopacket.SerializableLayer, error) {
|
||||
tcp := &layers.TCP{
|
||||
SrcPort: layers.TCPPort(p.SrcPort),
|
||||
DstPort: layers.TCPPort(p.DstPort),
|
||||
@@ -164,24 +180,44 @@ func (p *PacketBuilder) buildTCPLayer(ip *layers.IPv4) ([]gopacket.SerializableL
|
||||
PSH: p.TCPState != nil && p.TCPState.PSH,
|
||||
URG: p.TCPState != nil && p.TCPState.URG,
|
||||
}
|
||||
if err := tcp.SetNetworkLayerForChecksum(ip); err != nil {
|
||||
return nil, fmt.Errorf("set network layer for TCP checksum: %w", err)
|
||||
if nl, ok := ipLayer.(gopacket.NetworkLayer); ok {
|
||||
if err := tcp.SetNetworkLayerForChecksum(nl); err != nil {
|
||||
return nil, fmt.Errorf("set network layer for TCP checksum: %w", err)
|
||||
}
|
||||
}
|
||||
return []gopacket.SerializableLayer{tcp}, nil
|
||||
}
|
||||
|
||||
func (p *PacketBuilder) buildUDPLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) {
|
||||
func (p *PacketBuilder) buildUDPLayer(ipLayer gopacket.SerializableLayer) ([]gopacket.SerializableLayer, error) {
|
||||
udp := &layers.UDP{
|
||||
SrcPort: layers.UDPPort(p.SrcPort),
|
||||
DstPort: layers.UDPPort(p.DstPort),
|
||||
}
|
||||
if err := udp.SetNetworkLayerForChecksum(ip); err != nil {
|
||||
return nil, fmt.Errorf("set network layer for UDP checksum: %w", err)
|
||||
if nl, ok := ipLayer.(gopacket.NetworkLayer); ok {
|
||||
if err := udp.SetNetworkLayerForChecksum(nl); err != nil {
|
||||
return nil, fmt.Errorf("set network layer for UDP checksum: %w", err)
|
||||
}
|
||||
}
|
||||
return []gopacket.SerializableLayer{udp}, nil
|
||||
}
|
||||
|
||||
func (p *PacketBuilder) buildICMPLayer() ([]gopacket.SerializableLayer, error) {
|
||||
func (p *PacketBuilder) buildICMPLayer(ipLayer gopacket.SerializableLayer) ([]gopacket.SerializableLayer, error) {
|
||||
if p.SrcIP.Is6() || p.DstIP.Is6() {
|
||||
icmp := &layers.ICMPv6{
|
||||
TypeCode: layers.CreateICMPv6TypeCode(p.ICMPType, p.ICMPCode),
|
||||
}
|
||||
if nl, ok := ipLayer.(gopacket.NetworkLayer); ok {
|
||||
_ = icmp.SetNetworkLayerForChecksum(nl)
|
||||
}
|
||||
if p.ICMPType == layers.ICMPv6TypeEchoRequest || p.ICMPType == layers.ICMPv6TypeEchoReply {
|
||||
echo := &layers.ICMPv6Echo{
|
||||
Identifier: 1,
|
||||
SeqNumber: 1,
|
||||
}
|
||||
return []gopacket.SerializableLayer{icmp, echo}, nil
|
||||
}
|
||||
return []gopacket.SerializableLayer{icmp}, nil
|
||||
}
|
||||
icmp := &layers.ICMPv4{
|
||||
TypeCode: layers.CreateICMPv4TypeCode(p.ICMPType, p.ICMPCode),
|
||||
}
|
||||
@@ -204,14 +240,17 @@ func serializePacket(layers []gopacket.SerializableLayer) ([]byte, error) {
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func getIPProtocolNumber(protocol fw.Protocol) int {
|
||||
func getIPProtocolNumber(protocol fw.Protocol, isV6 bool) layers.IPProtocol {
|
||||
switch protocol {
|
||||
case fw.ProtocolTCP:
|
||||
return int(layers.IPProtocolTCP)
|
||||
return layers.IPProtocolTCP
|
||||
case fw.ProtocolUDP:
|
||||
return int(layers.IPProtocolUDP)
|
||||
return layers.IPProtocolUDP
|
||||
case fw.ProtocolICMP:
|
||||
return int(layers.IPProtocolICMPv4)
|
||||
if isV6 {
|
||||
return layers.IPProtocolICMPv6
|
||||
}
|
||||
return layers.IPProtocolICMPv4
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
@@ -234,7 +273,7 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
|
||||
trace := &PacketTrace{Direction: direction}
|
||||
|
||||
// Initial packet decoding
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
if err := d.decodePacket(packetData); err != nil {
|
||||
trace.AddResult(StageReceived, fmt.Sprintf("Failed to decode packet: %v", err), false)
|
||||
return trace
|
||||
}
|
||||
@@ -256,6 +295,8 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
|
||||
trace.DestinationPort = uint16(d.udp.DstPort)
|
||||
case layers.LayerTypeICMPv4:
|
||||
trace.Protocol = "ICMP"
|
||||
case layers.LayerTypeICMPv6:
|
||||
trace.Protocol = "ICMPv6"
|
||||
}
|
||||
|
||||
trace.AddResult(StageReceived, fmt.Sprintf("Received %s packet: %s:%d -> %s:%d",
|
||||
@@ -319,6 +360,13 @@ func (m *Manager) buildConntrackStateMessage(d *decoder) string {
|
||||
flags&conntrack.TCPFin != 0)
|
||||
case layers.LayerTypeICMPv4:
|
||||
msg += fmt.Sprintf(" (ICMP ID=%d, Seq=%d)", d.icmp4.Id, d.icmp4.Seq)
|
||||
case layers.LayerTypeICMPv6:
|
||||
var id, seq uint16
|
||||
if len(d.icmp6.Payload) >= 4 {
|
||||
id = uint16(d.icmp6.Payload[0])<<8 | uint16(d.icmp6.Payload[1])
|
||||
seq = uint16(d.icmp6.Payload[2])<<8 | uint16(d.icmp6.Payload[3])
|
||||
}
|
||||
msg += fmt.Sprintf(" (ICMPv6 ID=%d, Seq=%d)", id, seq)
|
||||
}
|
||||
return msg
|
||||
}
|
||||
@@ -415,7 +463,7 @@ func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTr
|
||||
d := m.decoders.Get().(*decoder)
|
||||
defer m.decoders.Put(d)
|
||||
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
if err := d.decodePacket(packetData); err != nil {
|
||||
trace.AddResult(StageCompleted, "Packet dropped - decode error", false)
|
||||
return trace
|
||||
}
|
||||
@@ -434,7 +482,7 @@ func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTr
|
||||
func (m *Manager) handleInboundDNAT(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP *netip.Addr) bool {
|
||||
portDNATApplied := m.traceInboundPortDNAT(trace, packetData, d)
|
||||
if portDNATApplied {
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
if err := d.decodePacket(packetData); err != nil {
|
||||
trace.AddResult(StageInboundPortDNAT, "Failed to re-decode after port DNAT", false)
|
||||
return true
|
||||
}
|
||||
@@ -444,7 +492,7 @@ func (m *Manager) handleInboundDNAT(trace *PacketTrace, packetData []byte, d *de
|
||||
|
||||
nat1to1Applied := m.traceInbound1to1NAT(trace, packetData, d)
|
||||
if nat1to1Applied {
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
if err := d.decodePacket(packetData); err != nil {
|
||||
trace.AddResult(StageInbound1to1NAT, "Failed to re-decode after 1:1 NAT", false)
|
||||
return true
|
||||
}
|
||||
@@ -509,7 +557,7 @@ func (m *Manager) traceInbound1to1NAT(trace *PacketTrace, packetData []byte, d *
|
||||
return false
|
||||
}
|
||||
|
||||
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
|
||||
srcIP, _ := extractPacketIPs(packetData, d)
|
||||
|
||||
translated := m.translateInboundReverse(packetData, d)
|
||||
if translated {
|
||||
@@ -539,7 +587,7 @@ func (m *Manager) traceOutbound1to1NAT(trace *PacketTrace, packetData []byte, d
|
||||
return false
|
||||
}
|
||||
|
||||
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
|
||||
_, dstIP := extractPacketIPs(packetData, d)
|
||||
|
||||
translated := m.translateOutboundDNAT(packetData, d)
|
||||
if translated {
|
||||
|
||||
@@ -119,7 +119,7 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse endpoint address: %w", err)
|
||||
}
|
||||
addrPort := netip.AddrPortFrom(addr, uint16(endpoint.Port))
|
||||
addrPort := netip.AddrPortFrom(addr.Unmap(), uint16(endpoint.Port))
|
||||
c.activityRecorder.UpsertAddress(peerKey, addrPort)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -2,7 +2,7 @@ package device
|
||||
|
||||
// TunAdapter is an interface for create tun device from external service
|
||||
type TunAdapter interface {
|
||||
ConfigureInterface(address string, mtu int, dns string, searchDomains string, routes string) (int, error)
|
||||
ConfigureInterface(address string, addressV6 string, mtu int, dns string, searchDomains string, routes string) (int, error)
|
||||
UpdateAddr(address string) error
|
||||
ProtectSocket(fd int32) bool
|
||||
}
|
||||
|
||||
@@ -63,7 +63,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
|
||||
searchDomainsToString = ""
|
||||
}
|
||||
|
||||
fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), int(t.mtu), dns, searchDomainsToString, routesString)
|
||||
fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.address.IPv6String(), int(t.mtu), dns, searchDomainsToString, routesString)
|
||||
if err != nil {
|
||||
log.Errorf("failed to create Android interface: %s", err)
|
||||
return nil, err
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -196,18 +196,22 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// fakeAddress returns a fake address that is used to as an identifier for the peer.
|
||||
// The fake address is in the format of 127.1.x.x where x.x is the last two octets of the peer address.
|
||||
// fakeAddress returns a fake address that is used as an identifier for the peer.
|
||||
// The fake address is in the format of 127.1.x.x where x.x is derived from the
|
||||
// last two bytes of the peer address (works for both IPv4 and IPv6).
|
||||
func fakeAddress(peerAddress *net.UDPAddr) (*netip.AddrPort, error) {
|
||||
octets := strings.Split(peerAddress.IP.String(), ".")
|
||||
if len(octets) != 4 {
|
||||
return nil, fmt.Errorf("invalid IP format")
|
||||
if peerAddress == nil {
|
||||
return nil, fmt.Errorf("nil peer address")
|
||||
}
|
||||
|
||||
fakeIP, err := netip.ParseAddr(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3]))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse new IP: %w", err)
|
||||
addr, ok := netip.AddrFromSlice(peerAddress.IP)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid IP format")
|
||||
}
|
||||
addr = addr.Unmap()
|
||||
|
||||
raw := addr.As16()
|
||||
fakeIP := netip.AddrFrom4([4]byte{127, 1, raw[14], raw[15]})
|
||||
|
||||
netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port))
|
||||
return &netipAddr, nil
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"sync"
|
||||
@@ -19,6 +18,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/shared/netiputil"
|
||||
)
|
||||
|
||||
var ErrSourceRangesEmpty = errors.New("sources range is empty")
|
||||
@@ -105,6 +105,10 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
||||
newRulePairs := make(map[id.RuleID][]firewall.Rule)
|
||||
ipsetByRuleSelectors := make(map[string]string)
|
||||
|
||||
// TODO: deny rules should be fatal: if a deny rule fails to apply, we must
|
||||
// roll back all allow rules to avoid a fail-open where allowed traffic bypasses
|
||||
// the missing deny. Currently we accumulate errors and continue.
|
||||
var merr *multierror.Error
|
||||
for _, r := range rules {
|
||||
// if this rule is member of rule selection with more than DefaultIPsCountForSet
|
||||
// it's IP address can be used in the ipset for firewall manager which supports it
|
||||
@@ -117,9 +121,8 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
||||
}
|
||||
pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName)
|
||||
if err != nil {
|
||||
log.Errorf("failed to apply firewall rule: %+v, %v", r, err)
|
||||
d.rollBack(newRulePairs)
|
||||
break
|
||||
merr = multierror.Append(merr, fmt.Errorf("apply firewall rule: %w", err))
|
||||
continue
|
||||
}
|
||||
if len(rulePair) > 0 {
|
||||
d.peerRulesPairs[pairID] = rulePair
|
||||
@@ -127,6 +130,10 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
||||
}
|
||||
}
|
||||
|
||||
if merr != nil {
|
||||
log.Errorf("failed to apply %d peer ACL rule(s): %v", merr.Len(), nberrors.FormatErrorOrNil(merr))
|
||||
}
|
||||
|
||||
for pairID, rules := range d.peerRulesPairs {
|
||||
if _, ok := newRulePairs[pairID]; !ok {
|
||||
for _, rule := range rules {
|
||||
@@ -216,10 +223,9 @@ func (d *DefaultManager) protoRuleToFirewallRule(
|
||||
r *mgmProto.FirewallRule,
|
||||
ipsetName string,
|
||||
) (id.RuleID, []firewall.Rule, error) {
|
||||
//nolint:staticcheck // PeerIP used for backward compatibility with old management
|
||||
ip := net.ParseIP(r.PeerIP)
|
||||
if ip == nil {
|
||||
return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule")
|
||||
ip, err := extractRuleIP(r)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
protocol, err := convertToFirewallProtocol(r.Protocol)
|
||||
@@ -290,13 +296,13 @@ func portInfoEmpty(portInfo *mgmProto.PortInfo) bool {
|
||||
|
||||
func (d *DefaultManager) addInRules(
|
||||
id []byte,
|
||||
ip net.IP,
|
||||
ip netip.Addr,
|
||||
protocol firewall.Protocol,
|
||||
port *firewall.Port,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
) ([]firewall.Rule, error) {
|
||||
rule, err := d.firewall.AddPeerFiltering(id, ip, protocol, nil, port, action, ipsetName)
|
||||
rule, err := d.firewall.AddPeerFiltering(id, ip.AsSlice(), protocol, nil, port, action, ipsetName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("add firewall rule: %w", err)
|
||||
}
|
||||
@@ -306,7 +312,7 @@ func (d *DefaultManager) addInRules(
|
||||
|
||||
func (d *DefaultManager) addOutRules(
|
||||
id []byte,
|
||||
ip net.IP,
|
||||
ip netip.Addr,
|
||||
protocol firewall.Protocol,
|
||||
port *firewall.Port,
|
||||
action firewall.Action,
|
||||
@@ -316,7 +322,7 @@ func (d *DefaultManager) addOutRules(
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
rule, err := d.firewall.AddPeerFiltering(id, ip, protocol, port, nil, action, ipsetName)
|
||||
rule, err := d.firewall.AddPeerFiltering(id, ip.AsSlice(), protocol, port, nil, action, ipsetName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("add firewall rule: %w", err)
|
||||
}
|
||||
@@ -324,9 +330,9 @@ func (d *DefaultManager) addOutRules(
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
// getPeerRuleID() returns unique ID for the rule based on its parameters.
|
||||
// getPeerRuleID returns unique ID for the rule based on its parameters.
|
||||
func (d *DefaultManager) getPeerRuleID(
|
||||
ip net.IP,
|
||||
ip netip.Addr,
|
||||
proto firewall.Protocol,
|
||||
direction int,
|
||||
port *firewall.Port,
|
||||
@@ -345,15 +351,25 @@ func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) st
|
||||
return fmt.Sprintf("%v:%v:%v:%s:%v", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port, rule.PortInfo)
|
||||
}
|
||||
|
||||
func (d *DefaultManager) rollBack(newRulePairs map[id.RuleID][]firewall.Rule) {
|
||||
log.Debugf("rollback ACL to previous state")
|
||||
for _, rules := range newRulePairs {
|
||||
for _, rule := range rules {
|
||||
if err := d.firewall.DeletePeerRule(rule); err != nil {
|
||||
log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.ID(), err)
|
||||
}
|
||||
|
||||
// extractRuleIP extracts the peer IP from a firewall rule.
|
||||
// If sourcePrefixes is populated (new management), decode the first entry and use its address.
|
||||
// Otherwise fall back to the deprecated PeerIP string field (old management).
|
||||
func extractRuleIP(r *mgmProto.FirewallRule) (netip.Addr, error) {
|
||||
if len(r.SourcePrefixes) > 0 {
|
||||
addr, err := netiputil.DecodeAddr(r.SourcePrefixes[0])
|
||||
if err != nil {
|
||||
return netip.Addr{}, fmt.Errorf("decode source prefix: %w", err)
|
||||
}
|
||||
return addr.Unmap(), nil
|
||||
}
|
||||
|
||||
//nolint:staticcheck // PeerIP used for backward compatibility with old management
|
||||
addr, err := netip.ParseAddr(r.PeerIP)
|
||||
if err != nil {
|
||||
return netip.Addr{}, fmt.Errorf("invalid IP address, skipping firewall rule")
|
||||
}
|
||||
return addr.Unmap(), nil
|
||||
}
|
||||
|
||||
func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewall.Protocol, error) {
|
||||
|
||||
@@ -430,8 +430,6 @@ func isInCGNATRange(ip net.IP) bool {
|
||||
}
|
||||
|
||||
func TestAnonymizeFirewallRules(t *testing.T) {
|
||||
// TODO: Add ipv6
|
||||
|
||||
// Example iptables-save output
|
||||
iptablesSave := `# Generated by iptables-save v1.8.7 on Thu Dec 19 10:00:00 2024
|
||||
*filter
|
||||
@@ -467,17 +465,31 @@ Chain FORWARD (policy ACCEPT 0 packets, 0 bytes)
|
||||
Chain OUTPUT (policy ACCEPT 0 packets, 0 bytes)
|
||||
pkts bytes target prot opt in out source destination`
|
||||
|
||||
// Example nftables output
|
||||
// Example ip6tables-save output
|
||||
ip6tablesSave := `# Generated by ip6tables-save v1.8.7 on Thu Dec 19 10:00:00 2024
|
||||
*filter
|
||||
:INPUT ACCEPT [0:0]
|
||||
:FORWARD ACCEPT [0:0]
|
||||
:OUTPUT ACCEPT [0:0]
|
||||
-A INPUT -s fd00:1234::1/128 -j ACCEPT
|
||||
-A INPUT -s 2607:f8b0:4005::1/128 -j DROP
|
||||
-A FORWARD -s 2001:db8::/32 -d 2607:f8b0:4005::200e/128 -j ACCEPT
|
||||
COMMIT`
|
||||
|
||||
// Example nftables output with IPv6
|
||||
nftablesRules := `table inet filter {
|
||||
chain input {
|
||||
type filter hook input priority filter; policy accept;
|
||||
ip saddr 192.168.1.1 accept
|
||||
ip saddr 44.192.140.1 drop
|
||||
ip6 saddr 2607:f8b0:4005::1 drop
|
||||
ip6 saddr fd00:1234::1 accept
|
||||
}
|
||||
chain forward {
|
||||
type filter hook forward priority filter; policy accept;
|
||||
ip saddr 10.0.0.0/8 drop
|
||||
ip saddr 44.192.140.0/24 ip daddr 52.84.12.34/24 accept
|
||||
ip6 saddr 2001:db8::/32 ip6 daddr 2607:f8b0:4005::200e/128 accept
|
||||
}
|
||||
}`
|
||||
|
||||
@@ -540,4 +552,35 @@ Chain OUTPUT (policy ACCEPT 0 packets, 0 bytes)
|
||||
assert.Contains(t, anonNftables, "table inet filter {")
|
||||
assert.Contains(t, anonNftables, "chain input {")
|
||||
assert.Contains(t, anonNftables, "type filter hook input priority filter; policy accept;")
|
||||
|
||||
// IPv6 public addresses in nftables should be anonymized
|
||||
assert.NotContains(t, anonNftables, "2607:f8b0:4005::1")
|
||||
assert.NotContains(t, anonNftables, "2607:f8b0:4005::200e")
|
||||
assert.NotContains(t, anonNftables, "2001:db8::")
|
||||
assert.Contains(t, anonNftables, "2001:db8:ffff::") // Default anonymous v6 range
|
||||
|
||||
// ULA addresses in nftables should remain unchanged (private)
|
||||
assert.Contains(t, anonNftables, "fd00:1234::1")
|
||||
|
||||
// IPv6 nftables structure preserved
|
||||
assert.Contains(t, anonNftables, "ip6 saddr")
|
||||
assert.Contains(t, anonNftables, "ip6 daddr")
|
||||
|
||||
// Test ip6tables-save anonymization
|
||||
anonIp6tablesSave := anonymizer.AnonymizeString(ip6tablesSave)
|
||||
|
||||
// ULA (private) IPv6 should remain unchanged
|
||||
assert.Contains(t, anonIp6tablesSave, "fd00:1234::1/128")
|
||||
|
||||
// Public IPv6 addresses should be anonymized
|
||||
assert.NotContains(t, anonIp6tablesSave, "2607:f8b0:4005::1")
|
||||
assert.NotContains(t, anonIp6tablesSave, "2607:f8b0:4005::200e")
|
||||
assert.NotContains(t, anonIp6tablesSave, "2001:db8::")
|
||||
assert.Contains(t, anonIp6tablesSave, "2001:db8:ffff::") // Default anonymous v6 range
|
||||
|
||||
// Structure should be preserved
|
||||
assert.Contains(t, anonIp6tablesSave, "*filter")
|
||||
assert.Contains(t, anonIp6tablesSave, "COMMIT")
|
||||
assert.Contains(t, anonIp6tablesSave, "-j DROP")
|
||||
assert.Contains(t, anonIp6tablesSave, "-j ACCEPT")
|
||||
}
|
||||
|
||||
@@ -189,10 +189,10 @@ func (s *serviceViaListener) RuntimeIP() netip.Addr {
|
||||
}
|
||||
|
||||
|
||||
// evalListenAddress figure out the listen address for the DNS server
|
||||
// first check the 53 port availability on WG interface or lo, if not success
|
||||
// pick a random port on WG interface for eBPF, if not success
|
||||
// check the 5053 port availability on WG interface or lo without eBPF usage,
|
||||
// evalListenAddress figures out the listen address for the DNS server.
|
||||
// IPv4-only: all peers have a v4 overlay address, and DNS config points to v4.
|
||||
// First checks port 53 on WG interface or lo, then tries eBPF on a random port,
|
||||
// then falls back to port 5053.
|
||||
func (s *serviceViaListener) evalListenAddress() (netip.Addr, uint16, error) {
|
||||
if s.customAddr != nil {
|
||||
return s.customAddr.Addr(), s.customAddr.Port(), nil
|
||||
@@ -278,7 +278,7 @@ func (s *serviceViaListener) tryToUseeBPF() (ebpfMgr.Manager, uint16, bool) {
|
||||
}
|
||||
|
||||
ebpfSrv := ebpf.GetEbpfManagerInstance()
|
||||
err = ebpfSrv.LoadDNSFwd(s.wgInterface.Address().IP.String(), int(port))
|
||||
err = ebpfSrv.LoadDNSFwd(s.wgInterface.Address().IP, int(port))
|
||||
if err != nil {
|
||||
log.Warnf("failed to load DNS forwarder eBPF program, error: %s", err)
|
||||
return nil, 0, false
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/types"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
@@ -29,6 +30,12 @@ import (
|
||||
|
||||
var currentMTU uint16 = iface.DefaultMTU
|
||||
|
||||
// privateClientIface is the subset of the WireGuard interface needed by GetClientPrivate.
|
||||
type privateClientIface interface {
|
||||
Name() string
|
||||
Address() wgaddr.Address
|
||||
}
|
||||
|
||||
func SetCurrentMTU(mtu uint16) {
|
||||
currentMTU = mtu
|
||||
}
|
||||
|
||||
@@ -86,7 +86,7 @@ func (u *upstreamResolver) isLocalResolver(upstream string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
|
||||
func GetClientPrivate(_ privateClientIface, _ netip.Addr, dialTimeout time.Duration) (*dns.Client, error) {
|
||||
return &dns.Client{
|
||||
Timeout: dialTimeout,
|
||||
Net: "udp",
|
||||
|
||||
@@ -52,7 +52,7 @@ func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns
|
||||
return ExchangeWithFallback(ctx, client, r, upstream)
|
||||
}
|
||||
|
||||
func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
|
||||
func GetClientPrivate(_ privateClientIface, _ netip.Addr, dialTimeout time.Duration) (*dns.Client, error) {
|
||||
return &dns.Client{
|
||||
Timeout: dialTimeout,
|
||||
Net: "udp",
|
||||
|
||||
@@ -19,11 +19,7 @@ import (
|
||||
|
||||
type upstreamResolverIOS struct {
|
||||
*upstreamResolverBase
|
||||
lIP netip.Addr
|
||||
lNet netip.Prefix
|
||||
lIPv6 netip.Addr
|
||||
lNetV6 netip.Prefix
|
||||
interfaceName string
|
||||
wgIface WGIface
|
||||
}
|
||||
|
||||
func newUpstreamResolver(
|
||||
@@ -37,11 +33,7 @@ func newUpstreamResolver(
|
||||
|
||||
ios := &upstreamResolverIOS{
|
||||
upstreamResolverBase: upstreamResolverBase,
|
||||
lIP: wgIface.Address().IP,
|
||||
lNet: wgIface.Address().Network,
|
||||
lIPv6: wgIface.Address().IPv6,
|
||||
lNetV6: wgIface.Address().IPv6Net,
|
||||
interfaceName: wgIface.Name(),
|
||||
wgIface: wgIface,
|
||||
}
|
||||
ios.upstreamClient = ios
|
||||
|
||||
@@ -69,24 +61,15 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
|
||||
} else {
|
||||
upstreamIP = upstreamIP.Unmap()
|
||||
}
|
||||
needsPrivate := u.lNet.Contains(upstreamIP) ||
|
||||
u.lNetV6.Contains(upstreamIP) ||
|
||||
addr := u.wgIface.Address()
|
||||
needsPrivate := addr.Network.Contains(upstreamIP) ||
|
||||
addr.IPv6Net.Contains(upstreamIP) ||
|
||||
(u.routeMatch != nil && u.routeMatch(upstreamIP))
|
||||
if needsPrivate {
|
||||
var bindIP netip.Addr
|
||||
switch {
|
||||
case upstreamIP.Is6() && u.lIPv6.IsValid():
|
||||
bindIP = u.lIPv6
|
||||
case upstreamIP.Is4() && u.lIP.IsValid():
|
||||
bindIP = u.lIP
|
||||
}
|
||||
|
||||
if bindIP.IsValid() {
|
||||
log.Debugf("using private client to query %s via upstream %s", r.Question[0].Name, upstream)
|
||||
client, err = GetClientPrivate(bindIP, u.interfaceName, timeout)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("create private client: %s", err)
|
||||
}
|
||||
log.Debugf("using private client to query %s via upstream %s", r.Question[0].Name, upstream)
|
||||
client, err = GetClientPrivate(u.wgIface, upstreamIP, timeout)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("create private client: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -94,23 +77,29 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
|
||||
return ExchangeWithFallback(nil, client, r, upstream)
|
||||
}
|
||||
|
||||
// GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
|
||||
// This method is needed for iOS
|
||||
func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
|
||||
index, err := getInterfaceIndex(interfaceName)
|
||||
// GetClientPrivate returns a new DNS client bound to the local IP of the Netbird interface.
|
||||
// It selects the v6 bind address when the upstream is IPv6 and the interface has one, otherwise v4.
|
||||
func GetClientPrivate(iface privateClientIface, upstreamIP netip.Addr, dialTimeout time.Duration) (*dns.Client, error) {
|
||||
index, err := getInterfaceIndex(iface.Name())
|
||||
if err != nil {
|
||||
log.Debugf("unable to get interface index for %s: %s", interfaceName, err)
|
||||
log.Debugf("unable to get interface index for %s: %s", iface.Name(), err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
addr := iface.Address()
|
||||
bindIP := addr.IP
|
||||
if upstreamIP.Is6() && addr.HasIPv6() {
|
||||
bindIP = addr.IPv6
|
||||
}
|
||||
|
||||
proto, opt := unix.IPPROTO_IP, unix.IP_BOUND_IF
|
||||
if ip.Is6() {
|
||||
if bindIP.Is6() {
|
||||
proto, opt = unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF
|
||||
}
|
||||
|
||||
dialer := &net.Dialer{
|
||||
LocalAddr: net.UDPAddrFromAddrPort(netip.AddrPortFrom(ip, 0)),
|
||||
Timeout: dialTimeout,
|
||||
LocalAddr: net.UDPAddrFromAddrPort(netip.AddrPortFrom(bindIP, 0)),
|
||||
Timeout: dialTimeout,
|
||||
Control: func(network, address string, c syscall.RawConn) error {
|
||||
var operr error
|
||||
fn := func(s uintptr) {
|
||||
|
||||
@@ -80,6 +80,7 @@ func (m *Manager) Start(fwdEntries []*ForwarderEntry) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// IPv4-only: peers reach the forwarder via its v4 overlay address.
|
||||
localAddr := m.wgIface.Address().IP
|
||||
|
||||
if localAddr.IsValid() && m.firewall != nil {
|
||||
|
||||
@@ -2,7 +2,8 @@ package ebpf
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
@@ -12,7 +13,7 @@ const (
|
||||
mapKeyDNSPort uint32 = 1
|
||||
)
|
||||
|
||||
func (tf *GeneralManager) LoadDNSFwd(ip string, dnsPort int) error {
|
||||
func (tf *GeneralManager) LoadDNSFwd(ip netip.Addr, dnsPort int) error {
|
||||
log.Debugf("load eBPF DNS forwarder, watching addr: %s:53, redirect to port: %d", ip, dnsPort)
|
||||
tf.lock.Lock()
|
||||
defer tf.lock.Unlock()
|
||||
@@ -22,7 +23,11 @@ func (tf *GeneralManager) LoadDNSFwd(ip string, dnsPort int) error {
|
||||
return err
|
||||
}
|
||||
|
||||
err = tf.bpfObjs.NbMapDnsIp.Put(mapKeyDNSIP, ip2int(ip))
|
||||
if !ip.Is4() {
|
||||
return fmt.Errorf("eBPF DNS forwarder only supports IPv4, got %s", ip)
|
||||
}
|
||||
ip4 := ip.As4()
|
||||
err = tf.bpfObjs.NbMapDnsIp.Put(mapKeyDNSIP, binary.BigEndian.Uint32(ip4[:]))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -45,7 +50,3 @@ func (tf *GeneralManager) FreeDNSFwd() error {
|
||||
return tf.unsetFeatureFlag(featureFlagDnsForwarder)
|
||||
}
|
||||
|
||||
func ip2int(ipString string) uint32 {
|
||||
ip := net.ParseIP(ipString)
|
||||
return binary.BigEndian.Uint32(ip.To4())
|
||||
}
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
package manager
|
||||
|
||||
import "net/netip"
|
||||
|
||||
// Manager is used to load multiple eBPF programs. E.g., current DNS programs and WireGuard proxy
|
||||
type Manager interface {
|
||||
LoadDNSFwd(ip string, dnsPort int) error
|
||||
LoadDNSFwd(ip netip.Addr, dnsPort int) error
|
||||
FreeDNSFwd() error
|
||||
LoadWgProxy(proxyPort, wgPort int) error
|
||||
FreeWGProxy() error
|
||||
|
||||
@@ -630,7 +630,7 @@ func (e *Engine) initFirewall() error {
|
||||
rosenpassPort := e.rpManager.GetAddress().Port
|
||||
port := firewallManager.Port{Values: []uint16{uint16(rosenpassPort)}}
|
||||
|
||||
// this rule is static and will be torn down on engine down by the firewall manager
|
||||
// IPv4-only: rosenpass peers connect via AllowedIps[0] which is always v4.
|
||||
if _, err := e.firewall.AddPeerFiltering(
|
||||
nil,
|
||||
net.IP{0, 0, 0, 0},
|
||||
@@ -682,10 +682,15 @@ func (e *Engine) blockLanAccess() {
|
||||
|
||||
log.Infof("blocking route LAN access for networks: %v", toBlock)
|
||||
v4 := netip.PrefixFrom(netip.IPv4Unspecified(), 0)
|
||||
v6 := netip.PrefixFrom(netip.IPv6Unspecified(), 0)
|
||||
for _, network := range toBlock {
|
||||
source := v4
|
||||
if network.Addr().Is6() {
|
||||
source = v6
|
||||
}
|
||||
if _, err := e.firewall.AddRouteFiltering(
|
||||
nil,
|
||||
[]netip.Prefix{v4},
|
||||
[]netip.Prefix{source},
|
||||
firewallManager.Network{Prefix: network},
|
||||
firewallManager.ProtocolALL,
|
||||
nil,
|
||||
@@ -1494,10 +1499,10 @@ func (e *Engine) updateOfflinePeers(offlinePeers []*mgmProto.RemotePeerConfig) {
|
||||
replacement := make([]peer.State, len(offlinePeers))
|
||||
for i, offlinePeer := range offlinePeers {
|
||||
log.Debugf("added offline peer %s", offlinePeer.Fqdn)
|
||||
v4, v6 := splitAllowedIPs(offlinePeer.GetAllowedIps(), e.wgInterface.Address().IPv6Net)
|
||||
v4, v6 := overlayAddrsFromAllowedIPs(offlinePeer.GetAllowedIps(), e.wgInterface.Address().IPv6Net)
|
||||
replacement[i] = peer.State{
|
||||
IP: v4,
|
||||
IPv6: v6,
|
||||
IP: addrToString(v4),
|
||||
IPv6: addrToString(v6),
|
||||
PubKey: offlinePeer.GetWgPubKey(),
|
||||
FQDN: offlinePeer.GetFqdn(),
|
||||
ConnStatus: peer.StatusIdle,
|
||||
@@ -1508,30 +1513,37 @@ func (e *Engine) updateOfflinePeers(offlinePeers []*mgmProto.RemotePeerConfig) {
|
||||
e.statusRecorder.ReplaceOfflinePeers(replacement)
|
||||
}
|
||||
|
||||
// splitAllowedIPs separates the peer's overlay v4 (/32) and v6 (/128) addresses
|
||||
// from a list of AllowedIPs CIDRs. The v6 address is only matched if it falls
|
||||
// within ourV6Net (the local overlay v6 subnet), to avoid confusing routed /128
|
||||
// prefixes with the peer's overlay address.
|
||||
func splitAllowedIPs(allowedIPs []string, ourV6Net netip.Prefix) (v4, v6 string) {
|
||||
// overlayAddrsFromAllowedIPs extracts the peer's v4 and v6 overlay addresses
|
||||
// from AllowedIPs strings. Only host routes (/32, /128) are considered; v6 must
|
||||
// fall within ourV6Net to distinguish overlay addresses from routed prefixes.
|
||||
func overlayAddrsFromAllowedIPs(allowedIPs []string, ourV6Net netip.Prefix) (v4, v6 netip.Addr) {
|
||||
for _, cidr := range allowedIPs {
|
||||
prefix, err := netip.ParsePrefix(cidr)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse AllowedIP %q: %v", cidr, err)
|
||||
continue
|
||||
}
|
||||
addr := prefix.Addr().Unmap()
|
||||
switch {
|
||||
case prefix.Addr().Is4() && prefix.Bits() == 32 && v4 == "":
|
||||
v4 = prefix.Addr().String()
|
||||
case prefix.Addr().Is6() && prefix.Bits() == 128 && ourV6Net.Contains(prefix.Addr()) && v6 == "":
|
||||
v6 = prefix.Addr().String()
|
||||
case addr.Is4() && prefix.Bits() == 32 && !v4.IsValid():
|
||||
v4 = addr
|
||||
case addr.Is6() && prefix.Bits() == 128 && ourV6Net.Contains(addr) && !v6.IsValid():
|
||||
v6 = addr
|
||||
}
|
||||
if v4 != "" && v6 != "" {
|
||||
if v4.IsValid() && v6.IsValid() {
|
||||
break
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func addrToString(addr netip.Addr) string {
|
||||
if !addr.IsValid() {
|
||||
return ""
|
||||
}
|
||||
return addr.String()
|
||||
}
|
||||
|
||||
// addNewPeers adds peers that were not know before but arrived from the Management service with the update
|
||||
func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
||||
for _, p := range peersUpdate {
|
||||
@@ -1572,8 +1584,8 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
|
||||
return fmt.Errorf("create peer connection: %w", err)
|
||||
}
|
||||
|
||||
peerV4, peerV6 := splitAllowedIPs(peerConfig.GetAllowedIps(), e.wgInterface.Address().IPv6Net)
|
||||
err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn, peerV4, peerV6)
|
||||
peerV4, peerV6 := overlayAddrsFromAllowedIPs(peerConfig.GetAllowedIps(), e.wgInterface.Address().IPv6Net)
|
||||
err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn, addrToString(peerV4), addrToString(peerV6))
|
||||
if err != nil {
|
||||
log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err)
|
||||
}
|
||||
@@ -2355,8 +2367,7 @@ func getInterfacePrefixes() ([]netip.Prefix, error) {
|
||||
prefix := netip.PrefixFrom(addr.Unmap(), ones).Masked()
|
||||
ip := prefix.Addr()
|
||||
|
||||
// TODO: add IPv6
|
||||
if !ip.Is4() || ip.IsLoopback() || ip.IsMulticast() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
|
||||
if ip.IsLoopback() || ip.IsMulticast() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
@@ -145,13 +145,13 @@ func (e *Engine) extractPeerSSHInfo(remotePeers []*mgmProto.RemotePeerConfig) []
|
||||
continue
|
||||
}
|
||||
|
||||
peerIP, peerIPv6 := e.extractPeerIPs(peerConfig)
|
||||
peerV4, peerV6 := overlayAddrsFromAllowedIPs(peerConfig.GetAllowedIps(), e.wgInterface.Address().IPv6Net)
|
||||
hostname := e.extractHostname(peerConfig)
|
||||
|
||||
peerInfo = append(peerInfo, sshconfig.PeerSSHInfo{
|
||||
Hostname: hostname,
|
||||
IP: peerIP,
|
||||
IPv6: peerIPv6,
|
||||
IP: peerV4,
|
||||
IPv6: peerV6,
|
||||
FQDN: peerConfig.GetFqdn(),
|
||||
})
|
||||
}
|
||||
@@ -159,28 +159,6 @@ func (e *Engine) extractPeerSSHInfo(remotePeers []*mgmProto.RemotePeerConfig) []
|
||||
return peerInfo
|
||||
}
|
||||
|
||||
// extractPeerIPs extracts IPv4 and IPv6 overlay addresses from peer's allowed IPs.
|
||||
// Only considers host routes (/32, /128) within the overlay networks to avoid
|
||||
// picking up routed prefixes or static routes like 2620:fe::fe/128.
|
||||
func (e *Engine) extractPeerIPs(peerConfig *mgmProto.RemotePeerConfig) (v4, v6 netip.Addr) {
|
||||
wgAddr := e.wgInterface.Address()
|
||||
for _, allowedIP := range peerConfig.GetAllowedIps() {
|
||||
prefix, err := netip.ParsePrefix(allowedIP)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse AllowedIP %q: %v", allowedIP, err)
|
||||
continue
|
||||
}
|
||||
addr := prefix.Addr().Unmap()
|
||||
switch {
|
||||
case addr.Is4() && prefix.Bits() == 32 && wgAddr.Network.Contains(addr) && !v4.IsValid():
|
||||
v4 = addr
|
||||
case addr.Is6() && prefix.Bits() == 128 && wgAddr.IPv6Net.IsValid() && wgAddr.IPv6Net.Contains(addr) && !v6.IsValid():
|
||||
v6 = addr
|
||||
}
|
||||
}
|
||||
return v4, v6
|
||||
}
|
||||
|
||||
// extractHostname extracts short hostname from FQDN
|
||||
func (e *Engine) extractHostname(peerConfig *mgmProto.RemotePeerConfig) string {
|
||||
fqdn := peerConfig.GetFqdn()
|
||||
|
||||
@@ -1837,7 +1837,7 @@ func TestFilterAllowedIPs(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitAllowedIPs(t *testing.T) {
|
||||
func TestOverlayAddrsFromAllowedIPs(t *testing.T) {
|
||||
ourV6Net := netip.MustParsePrefix("fd00:1234:5678:abcd::/64")
|
||||
|
||||
tests := []struct {
|
||||
@@ -1900,9 +1900,17 @@ func TestSplitAllowedIPs(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
v4, v6 := splitAllowedIPs(tt.allowedIPs, tt.ourV6Net)
|
||||
assert.Equal(t, tt.wantV4, v4, "v4")
|
||||
assert.Equal(t, tt.wantV6, v6, "v6")
|
||||
v4, v6 := overlayAddrsFromAllowedIPs(tt.allowedIPs, tt.ourV6Net)
|
||||
if tt.wantV4 == "" {
|
||||
assert.False(t, v4.IsValid(), "expected no v4")
|
||||
} else {
|
||||
assert.Equal(t, tt.wantV4, v4.String(), "v4")
|
||||
}
|
||||
if tt.wantV6 == "" {
|
||||
assert.False(t, v6.IsValid(), "expected no v6")
|
||||
} else {
|
||||
assert.Equal(t, tt.wantV6, v6.String(), "v6")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -57,6 +57,7 @@ func NewBindListener(wgIface WgInterface, bind device.EndpointManager, cfg lazyc
|
||||
// deriveFakeIP creates a deterministic fake IP for bind mode based on peer's NetBird IP.
|
||||
// Maps peer IP 100.64.x.y to fake IP 127.2.x.y (similar to relay proxy using 127.1.x.y).
|
||||
// It finds the peer's actual NetBird IP by checking which allowedIP is in the same subnet as our WG interface.
|
||||
// For IPv6-only peers, the last two bytes of the v6 address are used.
|
||||
func deriveFakeIP(wgIface WgInterface, allowedIPs []netip.Prefix) (netip.Addr, error) {
|
||||
if len(allowedIPs) == 0 {
|
||||
return netip.Addr{}, fmt.Errorf("no allowed IPs for peer")
|
||||
@@ -64,6 +65,7 @@ func deriveFakeIP(wgIface WgInterface, allowedIPs []netip.Prefix) (netip.Addr, e
|
||||
|
||||
ourNetwork := wgIface.Address().Network
|
||||
|
||||
// Try v4 first (preferred: deterministic from overlay IP)
|
||||
var peerIP netip.Addr
|
||||
for _, allowedIP := range allowedIPs {
|
||||
ip := allowedIP.Addr()
|
||||
@@ -76,13 +78,24 @@ func deriveFakeIP(wgIface WgInterface, allowedIPs []netip.Prefix) (netip.Addr, e
|
||||
}
|
||||
}
|
||||
|
||||
if !peerIP.IsValid() {
|
||||
return netip.Addr{}, fmt.Errorf("no peer NetBird IP found in allowed IPs")
|
||||
if peerIP.IsValid() {
|
||||
octets := peerIP.As4()
|
||||
return netip.AddrFrom4([4]byte{127, 2, octets[2], octets[3]}), nil
|
||||
}
|
||||
|
||||
octets := peerIP.As4()
|
||||
fakeIP := netip.AddrFrom4([4]byte{127, 2, octets[2], octets[3]})
|
||||
return fakeIP, nil
|
||||
// Fallback: use last two bytes of first v6 overlay IP
|
||||
addr := wgIface.Address()
|
||||
if addr.IPv6Net.IsValid() {
|
||||
for _, allowedIP := range allowedIPs {
|
||||
ip := allowedIP.Addr()
|
||||
if ip.Is6() && addr.IPv6Net.Contains(ip) {
|
||||
raw := ip.As16()
|
||||
return netip.AddrFrom4([4]byte{127, 2, raw[14], raw[15]}), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return netip.Addr{}, fmt.Errorf("no peer NetBird IP found in allowed IPs")
|
||||
}
|
||||
|
||||
func (d *BindListener) setupLazyConn() error {
|
||||
|
||||
@@ -1055,7 +1055,11 @@ func (d *Status) notifyPeerListChanged() {
|
||||
}
|
||||
|
||||
func (d *Status) notifyAddressChanged() {
|
||||
d.notifier.localAddressChanged(d.localPeer.FQDN, d.localPeer.IP)
|
||||
addr := d.localPeer.IP
|
||||
if d.localPeer.IPv6 != "" {
|
||||
addr = addr + "\n" + d.localPeer.IPv6
|
||||
}
|
||||
d.notifier.localAddressChanged(d.localPeer.FQDN, addr)
|
||||
}
|
||||
|
||||
func (d *Status) numOfPeers() int {
|
||||
|
||||
@@ -3,9 +3,8 @@ package client
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -566,7 +565,7 @@ func HandlerFromRoute(params common.HandlerParams) RouteHandler {
|
||||
return dnsinterceptor.New(params)
|
||||
case handlerTypeDynamic:
|
||||
dns := nbdns.NewServiceViaMemory(params.WgInterface)
|
||||
dnsAddr := net.JoinHostPort(dns.RuntimeIP().String(), strconv.Itoa(dns.RuntimePort()))
|
||||
dnsAddr := netip.AddrPortFrom(dns.RuntimeIP(), uint16(dns.RuntimePort()))
|
||||
return dynamic.NewRoute(params, dnsAddr)
|
||||
default:
|
||||
return static.NewRoute(params)
|
||||
|
||||
@@ -582,7 +582,7 @@ func (d *DnsInterceptor) queryUpstreamDNS(ctx context.Context, w dns.ResponseWri
|
||||
if nsNet != nil {
|
||||
reply, err = nbdns.ExchangeWithNetstack(ctx, nsNet, r, upstream)
|
||||
} else {
|
||||
client, clientErr := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), dnsTimeout)
|
||||
client, clientErr := nbdns.GetClientPrivate(d.wgInterface, upstreamIP, dnsTimeout)
|
||||
if clientErr != nil {
|
||||
d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", clientErr))
|
||||
return nil
|
||||
|
||||
@@ -50,10 +50,10 @@ type Route struct {
|
||||
cancel context.CancelFunc
|
||||
statusRecorder *peer.Status
|
||||
wgInterface iface.WGIface
|
||||
resolverAddr string
|
||||
resolverAddr netip.AddrPort
|
||||
}
|
||||
|
||||
func NewRoute(params common.HandlerParams, resolverAddr string) *Route {
|
||||
func NewRoute(params common.HandlerParams, resolverAddr netip.AddrPort) *Route {
|
||||
return &Route{
|
||||
route: params.Route,
|
||||
routeRefCounter: params.RouteRefCounter,
|
||||
|
||||
@@ -17,37 +17,47 @@ import (
|
||||
const dialTimeout = 10 * time.Second
|
||||
|
||||
func (r *Route) getIPsFromResolver(domain domain.Domain) ([]net.IP, error) {
|
||||
privateClient, err := nbdns.GetClientPrivate(r.wgInterface.Address().IP, r.wgInterface.Name(), dialTimeout)
|
||||
privateClient, err := nbdns.GetClientPrivate(r.wgInterface, r.resolverAddr.Addr(), dialTimeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error while creating private client: %s", err)
|
||||
}
|
||||
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion(dns.Fqdn(domain.PunycodeString()), dns.TypeA)
|
||||
|
||||
fqdn := dns.Fqdn(domain.PunycodeString())
|
||||
startTime := time.Now()
|
||||
|
||||
response, _, err := nbdns.ExchangeWithFallback(nil, privateClient, msg, r.resolverAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("DNS query for %s failed after %s: %s ", domain.SafeString(), time.Since(startTime), err)
|
||||
}
|
||||
var ips []net.IP
|
||||
var queryErr error
|
||||
|
||||
if response.Rcode != dns.RcodeSuccess {
|
||||
return nil, fmt.Errorf("dns response code: %s", dns.RcodeToString[response.Rcode])
|
||||
}
|
||||
for _, qtype := range []uint16{dns.TypeA, dns.TypeAAAA} {
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion(fqdn, qtype)
|
||||
|
||||
ips := make([]net.IP, 0)
|
||||
|
||||
for _, answ := range response.Answer {
|
||||
if aRecord, ok := answ.(*dns.A); ok {
|
||||
ips = append(ips, aRecord.A)
|
||||
response, _, err := nbdns.ExchangeWithFallback(nil, privateClient, msg, r.resolverAddr.String())
|
||||
if err != nil {
|
||||
if queryErr == nil {
|
||||
queryErr = fmt.Errorf("DNS query for %s (type %d) after %s: %w", domain.SafeString(), qtype, time.Since(startTime), err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if aaaaRecord, ok := answ.(*dns.AAAA); ok {
|
||||
ips = append(ips, aaaaRecord.AAAA)
|
||||
|
||||
if response.Rcode != dns.RcodeSuccess {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, answ := range response.Answer {
|
||||
if aRecord, ok := answ.(*dns.A); ok {
|
||||
ips = append(ips, aRecord.A)
|
||||
}
|
||||
if aaaaRecord, ok := answ.(*dns.AAAA); ok {
|
||||
ips = append(ips, aaaaRecord.AAAA)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(ips) == 0 {
|
||||
if queryErr != nil {
|
||||
return nil, queryErr
|
||||
}
|
||||
return nil, fmt.Errorf("no A or AAAA records found for %s", domain.SafeString())
|
||||
}
|
||||
|
||||
|
||||
@@ -1,93 +1,145 @@
|
||||
package fakeip
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Manager manages allocation of fake IPs from the 240.0.0.0/8 block
|
||||
type Manager struct {
|
||||
mu sync.Mutex
|
||||
nextIP netip.Addr // Next IP to allocate
|
||||
var (
|
||||
// 240.0.0.1 - 240.255.255.254, block 240.0.0.0/8 (reserved, RFC 1112)
|
||||
v4Base = netip.AddrFrom4([4]byte{240, 0, 0, 1})
|
||||
v4Max = netip.AddrFrom4([4]byte{240, 255, 255, 254})
|
||||
v4Block = netip.PrefixFrom(netip.AddrFrom4([4]byte{240, 0, 0, 0}), 8)
|
||||
|
||||
// 0100::1 - 0100::ffff:ffff:ffff:fffe, block 0100::/64 (discard, RFC 6666)
|
||||
v6Base = netip.AddrFrom16([16]byte{0x01, 0x00, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01})
|
||||
v6Max = netip.AddrFrom16([16]byte{0x01, 0x00, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe})
|
||||
v6Block = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x01, 0x00}), 64)
|
||||
)
|
||||
|
||||
// fakeIPPool holds the allocation state for a single address family.
|
||||
type fakeIPPool struct {
|
||||
nextIP netip.Addr
|
||||
baseIP netip.Addr
|
||||
maxIP netip.Addr
|
||||
block netip.Prefix
|
||||
allocated map[netip.Addr]netip.Addr // real IP -> fake IP
|
||||
fakeToReal map[netip.Addr]netip.Addr // fake IP -> real IP
|
||||
baseIP netip.Addr // First usable IP: 240.0.0.1
|
||||
maxIP netip.Addr // Last usable IP: 240.255.255.254
|
||||
}
|
||||
|
||||
// NewManager creates a new fake IP manager using 240.0.0.0/8 block
|
||||
func NewManager() *Manager {
|
||||
baseIP := netip.AddrFrom4([4]byte{240, 0, 0, 1})
|
||||
maxIP := netip.AddrFrom4([4]byte{240, 255, 255, 254})
|
||||
|
||||
return &Manager{
|
||||
nextIP: baseIP,
|
||||
func newPool(base, maxAddr netip.Addr, block netip.Prefix) *fakeIPPool {
|
||||
return &fakeIPPool{
|
||||
nextIP: base,
|
||||
baseIP: base,
|
||||
maxIP: maxAddr,
|
||||
block: block,
|
||||
allocated: make(map[netip.Addr]netip.Addr),
|
||||
fakeToReal: make(map[netip.Addr]netip.Addr),
|
||||
baseIP: baseIP,
|
||||
maxIP: maxIP,
|
||||
}
|
||||
}
|
||||
|
||||
// AllocateFakeIP allocates a fake IP for the given real IP
|
||||
// Returns the fake IP, or existing fake IP if already allocated
|
||||
func (m *Manager) AllocateFakeIP(realIP netip.Addr) (netip.Addr, error) {
|
||||
if !realIP.Is4() {
|
||||
return netip.Addr{}, fmt.Errorf("only IPv4 addresses supported")
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if fakeIP, exists := m.allocated[realIP]; exists {
|
||||
// allocate allocates a fake IP for the given real IP.
|
||||
// Returns the existing fake IP if already allocated.
|
||||
func (p *fakeIPPool) allocate(realIP netip.Addr) (netip.Addr, error) {
|
||||
if fakeIP, exists := p.allocated[realIP]; exists {
|
||||
return fakeIP, nil
|
||||
}
|
||||
|
||||
startIP := m.nextIP
|
||||
startIP := p.nextIP
|
||||
for {
|
||||
currentIP := m.nextIP
|
||||
currentIP := p.nextIP
|
||||
|
||||
// Advance to next IP, wrapping at boundary
|
||||
if m.nextIP.Compare(m.maxIP) >= 0 {
|
||||
m.nextIP = m.baseIP
|
||||
if p.nextIP.Compare(p.maxIP) >= 0 {
|
||||
p.nextIP = p.baseIP
|
||||
} else {
|
||||
m.nextIP = m.nextIP.Next()
|
||||
p.nextIP = p.nextIP.Next()
|
||||
}
|
||||
|
||||
// Check if current IP is available
|
||||
if _, inUse := m.fakeToReal[currentIP]; !inUse {
|
||||
m.allocated[realIP] = currentIP
|
||||
m.fakeToReal[currentIP] = realIP
|
||||
if _, inUse := p.fakeToReal[currentIP]; !inUse {
|
||||
p.allocated[realIP] = currentIP
|
||||
p.fakeToReal[currentIP] = realIP
|
||||
return currentIP, nil
|
||||
}
|
||||
|
||||
// Prevent infinite loop if all IPs exhausted
|
||||
if m.nextIP.Compare(startIP) == 0 {
|
||||
return netip.Addr{}, fmt.Errorf("no more fake IPs available in 240.0.0.0/8 block")
|
||||
if p.nextIP.Compare(startIP) == 0 {
|
||||
return netip.Addr{}, fmt.Errorf("no more fake IPs available in %s block", p.block)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetFakeIP returns the fake IP for a real IP if it exists
|
||||
// Manager manages allocation of fake IPs for dynamic DNS routes.
|
||||
// IPv4 uses 240.0.0.0/8 (reserved), IPv6 uses 0100::/64 (discard, RFC 6666).
|
||||
type Manager struct {
|
||||
mu sync.Mutex
|
||||
v4 *fakeIPPool
|
||||
v6 *fakeIPPool
|
||||
}
|
||||
|
||||
// NewManager creates a new fake IP manager.
|
||||
func NewManager() *Manager {
|
||||
return &Manager{
|
||||
v4: newPool(v4Base, v4Max, v4Block),
|
||||
v6: newPool(v6Base, v6Max, v6Block),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) pool(ip netip.Addr) *fakeIPPool {
|
||||
if ip.Is6() {
|
||||
return m.v6
|
||||
}
|
||||
return m.v4
|
||||
}
|
||||
|
||||
// AllocateFakeIP allocates a fake IP for the given real IP.
|
||||
func (m *Manager) AllocateFakeIP(realIP netip.Addr) (netip.Addr, error) {
|
||||
realIP = realIP.Unmap()
|
||||
if !realIP.IsValid() {
|
||||
return netip.Addr{}, errors.New("invalid IP address")
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
return m.pool(realIP).allocate(realIP)
|
||||
}
|
||||
|
||||
// GetFakeIP returns the fake IP for a real IP if it exists.
|
||||
func (m *Manager) GetFakeIP(realIP netip.Addr) (netip.Addr, bool) {
|
||||
realIP = realIP.Unmap()
|
||||
if !realIP.IsValid() {
|
||||
return netip.Addr{}, false
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
fakeIP, exists := m.allocated[realIP]
|
||||
return fakeIP, exists
|
||||
fakeIP, ok := m.pool(realIP).allocated[realIP]
|
||||
return fakeIP, ok
|
||||
}
|
||||
|
||||
// GetRealIP returns the real IP for a fake IP if it exists, otherwise false
|
||||
// GetRealIP returns the real IP for a fake IP if it exists.
|
||||
func (m *Manager) GetRealIP(fakeIP netip.Addr) (netip.Addr, bool) {
|
||||
fakeIP = fakeIP.Unmap()
|
||||
if !fakeIP.IsValid() {
|
||||
return netip.Addr{}, false
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
realIP, exists := m.fakeToReal[fakeIP]
|
||||
return realIP, exists
|
||||
realIP, ok := m.pool(fakeIP).fakeToReal[fakeIP]
|
||||
return realIP, ok
|
||||
}
|
||||
|
||||
// GetFakeIPBlock returns the fake IP block used by this manager
|
||||
// GetFakeIPBlock returns the v4 fake IP block used by this manager.
|
||||
func (m *Manager) GetFakeIPBlock() netip.Prefix {
|
||||
return netip.MustParsePrefix("240.0.0.0/8")
|
||||
return m.v4.block
|
||||
}
|
||||
|
||||
// GetFakeIPv6Block returns the v6 fake IP block used by this manager.
|
||||
func (m *Manager) GetFakeIPv6Block() netip.Prefix {
|
||||
return m.v6.block
|
||||
}
|
||||
|
||||
@@ -9,16 +9,16 @@ import (
|
||||
func TestNewManager(t *testing.T) {
|
||||
manager := NewManager()
|
||||
|
||||
if manager.baseIP.String() != "240.0.0.1" {
|
||||
t.Errorf("Expected base IP 240.0.0.1, got %s", manager.baseIP.String())
|
||||
if manager.v4.baseIP.String() != "240.0.0.1" {
|
||||
t.Errorf("Expected v4 base IP 240.0.0.1, got %s", manager.v4.baseIP.String())
|
||||
}
|
||||
|
||||
if manager.maxIP.String() != "240.255.255.254" {
|
||||
t.Errorf("Expected max IP 240.255.255.254, got %s", manager.maxIP.String())
|
||||
if manager.v4.maxIP.String() != "240.255.255.254" {
|
||||
t.Errorf("Expected v4 max IP 240.255.255.254, got %s", manager.v4.maxIP.String())
|
||||
}
|
||||
|
||||
if manager.nextIP.Compare(manager.baseIP) != 0 {
|
||||
t.Errorf("Expected nextIP to start at baseIP")
|
||||
if manager.v6.baseIP.String() != "100::1" {
|
||||
t.Errorf("Expected v6 base IP 100::1, got %s", manager.v6.baseIP.String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -35,7 +35,6 @@ func TestAllocateFakeIP(t *testing.T) {
|
||||
t.Error("Fake IP should be IPv4")
|
||||
}
|
||||
|
||||
// Check it's in the correct range
|
||||
if fakeIP.As4()[0] != 240 {
|
||||
t.Errorf("Fake IP should be in 240.0.0.0/8 range, got %s", fakeIP.String())
|
||||
}
|
||||
@@ -51,13 +50,31 @@ func TestAllocateFakeIP(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllocateFakeIPIPv6Rejection(t *testing.T) {
|
||||
func TestAllocateFakeIPv6(t *testing.T) {
|
||||
manager := NewManager()
|
||||
realIPv6 := netip.MustParseAddr("2001:db8::1")
|
||||
realIP := netip.MustParseAddr("2001:db8::1")
|
||||
|
||||
_, err := manager.AllocateFakeIP(realIPv6)
|
||||
if err == nil {
|
||||
t.Error("Expected error for IPv6 address")
|
||||
fakeIP, err := manager.AllocateFakeIP(realIP)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to allocate fake IPv6: %v", err)
|
||||
}
|
||||
|
||||
if !fakeIP.Is6() {
|
||||
t.Error("Fake IP should be IPv6")
|
||||
}
|
||||
|
||||
if !netip.MustParsePrefix("100::/64").Contains(fakeIP) {
|
||||
t.Errorf("Fake IP should be in 100::/64 range, got %s", fakeIP.String())
|
||||
}
|
||||
|
||||
// Should return same fake IP for same real IP
|
||||
fakeIP2, err := manager.AllocateFakeIP(realIP)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get existing fake IPv6: %v", err)
|
||||
}
|
||||
|
||||
if fakeIP.Compare(fakeIP2) != 0 {
|
||||
t.Errorf("Expected same fake IP, got %s and %s", fakeIP.String(), fakeIP2.String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -65,13 +82,11 @@ func TestGetFakeIP(t *testing.T) {
|
||||
manager := NewManager()
|
||||
realIP := netip.MustParseAddr("1.1.1.1")
|
||||
|
||||
// Should not exist initially
|
||||
_, exists := manager.GetFakeIP(realIP)
|
||||
if exists {
|
||||
t.Error("Fake IP should not exist before allocation")
|
||||
}
|
||||
|
||||
// Allocate and check
|
||||
expectedFakeIP, err := manager.AllocateFakeIP(realIP)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to allocate: %v", err)
|
||||
@@ -87,12 +102,30 @@ func TestGetFakeIP(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRealIPv6(t *testing.T) {
|
||||
manager := NewManager()
|
||||
realIP := netip.MustParseAddr("2001:db8::1")
|
||||
|
||||
fakeIP, err := manager.AllocateFakeIP(realIP)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to allocate: %v", err)
|
||||
}
|
||||
|
||||
gotReal, exists := manager.GetRealIP(fakeIP)
|
||||
if !exists {
|
||||
t.Error("Real IP should exist for allocated fake IP")
|
||||
}
|
||||
|
||||
if gotReal.Compare(realIP) != 0 {
|
||||
t.Errorf("Expected real IP %s, got %s", realIP, gotReal)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultipleAllocations(t *testing.T) {
|
||||
manager := NewManager()
|
||||
|
||||
allocations := make(map[netip.Addr]netip.Addr)
|
||||
|
||||
// Allocate multiple IPs
|
||||
for i := 1; i <= 100; i++ {
|
||||
realIP := netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
|
||||
fakeIP, err := manager.AllocateFakeIP(realIP)
|
||||
@@ -100,7 +133,6 @@ func TestMultipleAllocations(t *testing.T) {
|
||||
t.Fatalf("Failed to allocate fake IP for %s: %v", realIP.String(), err)
|
||||
}
|
||||
|
||||
// Check for duplicates
|
||||
for _, existingFake := range allocations {
|
||||
if fakeIP.Compare(existingFake) == 0 {
|
||||
t.Errorf("Duplicate fake IP allocated: %s", fakeIP.String())
|
||||
@@ -110,7 +142,6 @@ func TestMultipleAllocations(t *testing.T) {
|
||||
allocations[realIP] = fakeIP
|
||||
}
|
||||
|
||||
// Verify all allocations can be retrieved
|
||||
for realIP, expectedFake := range allocations {
|
||||
actualFake, exists := manager.GetFakeIP(realIP)
|
||||
if !exists {
|
||||
@@ -124,11 +155,13 @@ func TestMultipleAllocations(t *testing.T) {
|
||||
|
||||
func TestGetFakeIPBlock(t *testing.T) {
|
||||
manager := NewManager()
|
||||
block := manager.GetFakeIPBlock()
|
||||
|
||||
expected := "240.0.0.0/8"
|
||||
if block.String() != expected {
|
||||
t.Errorf("Expected %s, got %s", expected, block.String())
|
||||
if block := manager.GetFakeIPBlock(); block.String() != "240.0.0.0/8" {
|
||||
t.Errorf("Expected 240.0.0.0/8, got %s", block.String())
|
||||
}
|
||||
|
||||
if block := manager.GetFakeIPv6Block(); block.String() != "100::/64" {
|
||||
t.Errorf("Expected 100::/64, got %s", block.String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -141,7 +174,6 @@ func TestConcurrentAccess(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
results := make(chan netip.Addr, numGoroutines*allocationsPerGoroutine)
|
||||
|
||||
// Concurrent allocations
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(goroutineID int) {
|
||||
@@ -161,7 +193,6 @@ func TestConcurrentAccess(t *testing.T) {
|
||||
wg.Wait()
|
||||
close(results)
|
||||
|
||||
// Check for duplicates
|
||||
seen := make(map[netip.Addr]bool)
|
||||
count := 0
|
||||
for fakeIP := range results {
|
||||
@@ -178,47 +209,61 @@ func TestConcurrentAccess(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestIPExhaustion(t *testing.T) {
|
||||
// Create a manager with limited range for testing
|
||||
manager := &Manager{
|
||||
nextIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}),
|
||||
allocated: make(map[netip.Addr]netip.Addr),
|
||||
fakeToReal: make(map[netip.Addr]netip.Addr),
|
||||
baseIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}),
|
||||
maxIP: netip.AddrFrom4([4]byte{240, 0, 0, 3}), // Only 3 IPs available
|
||||
v4: newPool(
|
||||
netip.AddrFrom4([4]byte{240, 0, 0, 1}),
|
||||
netip.AddrFrom4([4]byte{240, 0, 0, 3}),
|
||||
netip.MustParsePrefix("240.0.0.0/8"),
|
||||
),
|
||||
v6: newPool(
|
||||
netip.MustParseAddr("100::1"),
|
||||
netip.MustParseAddr("100::3"),
|
||||
netip.MustParsePrefix("100::/64"),
|
||||
),
|
||||
}
|
||||
|
||||
// Allocate all available IPs
|
||||
realIPs := []netip.Addr{
|
||||
netip.MustParseAddr("1.0.0.1"),
|
||||
netip.MustParseAddr("1.0.0.2"),
|
||||
netip.MustParseAddr("1.0.0.3"),
|
||||
}
|
||||
|
||||
for _, realIP := range realIPs {
|
||||
_, err := manager.AllocateFakeIP(realIP)
|
||||
for _, realIP := range []string{"1.0.0.1", "1.0.0.2", "1.0.0.3"} {
|
||||
_, err := manager.AllocateFakeIP(netip.MustParseAddr(realIP))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to allocate fake IP: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Try to allocate one more - should fail
|
||||
_, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.4"))
|
||||
if err == nil {
|
||||
t.Error("Expected exhaustion error")
|
||||
t.Error("Expected v4 exhaustion error")
|
||||
}
|
||||
|
||||
// Same for v6
|
||||
for _, realIP := range []string{"2001:db8::1", "2001:db8::2", "2001:db8::3"} {
|
||||
_, err := manager.AllocateFakeIP(netip.MustParseAddr(realIP))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to allocate fake IPv6: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
_, err = manager.AllocateFakeIP(netip.MustParseAddr("2001:db8::4"))
|
||||
if err == nil {
|
||||
t.Error("Expected v6 exhaustion error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrapAround(t *testing.T) {
|
||||
// Create manager starting near the end of range
|
||||
manager := &Manager{
|
||||
nextIP: netip.AddrFrom4([4]byte{240, 0, 0, 254}),
|
||||
allocated: make(map[netip.Addr]netip.Addr),
|
||||
fakeToReal: make(map[netip.Addr]netip.Addr),
|
||||
baseIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}),
|
||||
maxIP: netip.AddrFrom4([4]byte{240, 0, 0, 254}),
|
||||
v4: newPool(
|
||||
netip.AddrFrom4([4]byte{240, 0, 0, 1}),
|
||||
netip.AddrFrom4([4]byte{240, 0, 0, 254}),
|
||||
netip.MustParsePrefix("240.0.0.0/8"),
|
||||
),
|
||||
v6: newPool(
|
||||
netip.MustParseAddr("100::1"),
|
||||
netip.MustParseAddr("100::ffff:ffff:ffff:fffe"),
|
||||
netip.MustParsePrefix("100::/64"),
|
||||
),
|
||||
}
|
||||
// Start near the end
|
||||
manager.v4.nextIP = netip.AddrFrom4([4]byte{240, 0, 0, 254})
|
||||
|
||||
// Allocate the last IP
|
||||
fakeIP1, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.1"))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to allocate first IP: %v", err)
|
||||
@@ -228,7 +273,6 @@ func TestWrapAround(t *testing.T) {
|
||||
t.Errorf("Expected 240.0.0.254, got %s", fakeIP1.String())
|
||||
}
|
||||
|
||||
// Next allocation should wrap around to the beginning
|
||||
fakeIP2, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.2"))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to allocate second IP: %v", err)
|
||||
@@ -238,3 +282,32 @@ func TestWrapAround(t *testing.T) {
|
||||
t.Errorf("Expected 240.0.0.1 after wrap, got %s", fakeIP2.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestMixedV4V6(t *testing.T) {
|
||||
manager := NewManager()
|
||||
|
||||
v4Fake, err := manager.AllocateFakeIP(netip.MustParseAddr("8.8.8.8"))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to allocate v4: %v", err)
|
||||
}
|
||||
|
||||
v6Fake, err := manager.AllocateFakeIP(netip.MustParseAddr("2001:db8::1"))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to allocate v6: %v", err)
|
||||
}
|
||||
|
||||
if !v4Fake.Is4() || !v6Fake.Is6() {
|
||||
t.Errorf("Wrong families: v4=%s v6=%s", v4Fake, v6Fake)
|
||||
}
|
||||
|
||||
// Reverse lookups should work for both
|
||||
gotV4, ok := manager.GetRealIP(v4Fake)
|
||||
if !ok || gotV4.String() != "8.8.8.8" {
|
||||
t.Errorf("v4 reverse lookup failed: got %s, ok=%v", gotV4, ok)
|
||||
}
|
||||
|
||||
gotV6, ok := manager.GetRealIP(v6Fake)
|
||||
if !ok || gotV6.String() != "2001:db8::1" {
|
||||
t.Errorf("v6 reverse lookup failed: got %s, ok=%v", gotV6, ok)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,7 +9,11 @@ import (
|
||||
)
|
||||
|
||||
// IPForwardingState is a struct that keeps track of the IP forwarding state.
|
||||
// todo: read initial state of the IP forwarding from the system and reset the state based on it
|
||||
// todo: read initial state of the IP forwarding from the system and reset the state based on it.
|
||||
// todo: separate v4/v6 forwarding state, since the sysctls are independent
|
||||
// (net.ipv4.ip_forward vs net.ipv6.conf.all.forwarding). Currently the nftables
|
||||
// manager shares one instance between both routers, which works only because
|
||||
// EnableIPForwarding enables both sysctls in a single call.
|
||||
type IPForwardingState struct {
|
||||
enabledCounter int
|
||||
}
|
||||
|
||||
@@ -159,15 +159,23 @@ func (m *DefaultManager) setupAndroidRoutes(config ManagerConfig) {
|
||||
if config.DNSFeatureFlag {
|
||||
m.fakeIPManager = fakeip.NewManager()
|
||||
|
||||
id := uuid.NewString()
|
||||
fakeIPRoute := &route.Route{
|
||||
ID: route.ID(id),
|
||||
v4ID := uuid.NewString()
|
||||
cr = append(cr, &route.Route{
|
||||
ID: route.ID(v4ID),
|
||||
Network: m.fakeIPManager.GetFakeIPBlock(),
|
||||
NetID: route.NetID(id),
|
||||
NetID: route.NetID(v4ID),
|
||||
Peer: m.pubKey,
|
||||
NetworkType: route.IPv4Network,
|
||||
}
|
||||
cr = append(cr, fakeIPRoute)
|
||||
})
|
||||
|
||||
v6ID := uuid.NewString()
|
||||
cr = append(cr, &route.Route{
|
||||
ID: route.ID(v6ID),
|
||||
Network: m.fakeIPManager.GetFakeIPv6Block(),
|
||||
NetID: route.NetID(v6ID),
|
||||
Peer: m.pubKey,
|
||||
NetworkType: route.IPv6Network,
|
||||
})
|
||||
}
|
||||
|
||||
m.notifier.SetInitialClientRoutes(cr, routesForComparison)
|
||||
|
||||
@@ -146,8 +146,7 @@ func routeToRouterPair(route *route.Route, useNewDNSRoute bool) firewall.RouterP
|
||||
if useNewDNSRoute {
|
||||
destination.Set = firewall.NewDomainSet(route.Domains)
|
||||
} else {
|
||||
// TODO: add ipv6 additionally
|
||||
destination = getDefaultPrefix(destination.Prefix)
|
||||
destination = getDefaultPrefix(route.Network)
|
||||
}
|
||||
} else {
|
||||
destination.Prefix = route.Network.Masked()
|
||||
|
||||
@@ -107,8 +107,13 @@ func (r *SysOps) validateRoute(prefix netip.Prefix) error {
|
||||
addr.IsInterfaceLocalMulticast(),
|
||||
addr.IsMulticast(),
|
||||
addr.IsUnspecified() && prefix.Bits() != 0,
|
||||
r.wgInterface.Address().Network.Contains(addr):
|
||||
r.isOwnAddress(addr):
|
||||
return vars.ErrRouteNotAllowed
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SysOps) isOwnAddress(addr netip.Addr) bool {
|
||||
wgAddr := r.wgInterface.Address()
|
||||
return wgAddr.Network.Contains(addr) || (wgAddr.IPv6Net.IsValid() && wgAddr.IPv6Net.Contains(addr))
|
||||
}
|
||||
|
||||
@@ -222,30 +222,20 @@ func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) er
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO: remove once IPv6 is supported on the interface
|
||||
if err := r.addToRouteTable(splitDefaultv6_1, nextHop); err != nil {
|
||||
return fmt.Errorf("add unreachable route split 1: %w", err)
|
||||
}
|
||||
if err := r.addToRouteTable(splitDefaultv6_2, nextHop); err != nil {
|
||||
if err2 := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err2 != nil {
|
||||
log.Warnf("Failed to rollback route addition: %s", err2)
|
||||
// When the interface has no v6, add v6 split-default as blackhole so
|
||||
// unroutable v6 goes to WG (dropped, no AllowedIPs) instead of leaking
|
||||
// to the system default route. When v6 is active, management sends ::/0
|
||||
// as a separate route that the dedicated handler adds.
|
||||
// Soft-fail: v6 blackhole is best-effort, don't abort v4 routing on failure.
|
||||
if !r.wgInterface.Address().HasIPv6() {
|
||||
if err := r.addV6SplitDefault(nextHop); err != nil {
|
||||
log.Warnf("failed to add v6 split-default blackhole: %s", err)
|
||||
}
|
||||
return fmt.Errorf("add unreachable route split 2: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
case vars.Defaultv6:
|
||||
if err := r.addToRouteTable(splitDefaultv6_1, nextHop); err != nil {
|
||||
return fmt.Errorf("add unreachable route split 1: %w", err)
|
||||
}
|
||||
if err := r.addToRouteTable(splitDefaultv6_2, nextHop); err != nil {
|
||||
if err2 := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err2 != nil {
|
||||
log.Warnf("Failed to rollback route addition: %s", err2)
|
||||
}
|
||||
return fmt.Errorf("add unreachable route split 2: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
return r.addV6SplitDefault(nextHop)
|
||||
}
|
||||
|
||||
return r.addToRouteTable(prefix, nextHop)
|
||||
@@ -266,30 +256,42 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface)
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
|
||||
// TODO: remove once IPv6 is supported on the interface
|
||||
if err := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
if err := r.removeFromRouteTable(splitDefaultv6_2, nextHop); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
if !r.wgInterface.Address().HasIPv6() {
|
||||
result = multierror.Append(result, r.removeV6SplitDefault(nextHop))
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(result)
|
||||
case vars.Defaultv6:
|
||||
var result *multierror.Error
|
||||
if err := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
if err := r.removeFromRouteTable(splitDefaultv6_2, nextHop); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(result)
|
||||
return nberrors.FormatErrorOrNil(r.removeV6SplitDefault(nextHop))
|
||||
default:
|
||||
return r.removeFromRouteTable(prefix, nextHop)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *SysOps) addV6SplitDefault(nextHop Nexthop) error {
|
||||
if err := r.addToRouteTable(splitDefaultv6_1, nextHop); err != nil {
|
||||
return fmt.Errorf("add split 1: %w", err)
|
||||
}
|
||||
if err := r.addToRouteTable(splitDefaultv6_2, nextHop); err != nil {
|
||||
if err2 := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err2 != nil {
|
||||
log.Warnf("Failed to rollback v6 split-default: %s", err2)
|
||||
}
|
||||
return fmt.Errorf("add split 2: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SysOps) removeV6SplitDefault(nextHop Nexthop) *multierror.Error {
|
||||
var result *multierror.Error
|
||||
if err := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
if err := r.removeFromRouteTable(splitDefaultv6_2, nextHop); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) error {
|
||||
beforeHook := func(connID hooks.ConnectionID, prefix netip.Prefix) error {
|
||||
if _, err := r.refCounter.IncrementWithID(string(connID), prefix, struct{}{}); err != nil {
|
||||
|
||||
@@ -53,6 +53,8 @@ const (
|
||||
|
||||
// ipv4ForwardingPath is the path to the file containing the IP forwarding setting.
|
||||
ipv4ForwardingPath = "net.ipv4.ip_forward"
|
||||
// ipv6ForwardingPath is the path to the file containing the IPv6 forwarding setting.
|
||||
ipv6ForwardingPath = "net.ipv6.conf.all.forwarding"
|
||||
)
|
||||
|
||||
var ErrTableIDExists = errors.New("ID exists with different name")
|
||||
@@ -185,10 +187,11 @@ func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
|
||||
// No need to check if routes exist as main table takes precedence over the VPN table via Rule 1
|
||||
|
||||
// TODO remove this once we have ipv6 support
|
||||
if prefix == vars.Defaultv4 {
|
||||
// When the peer has no IPv6, blackhole v6 to prevent leaking.
|
||||
// When IPv6 is enabled, management sends ::/0 as a separate route.
|
||||
if prefix == vars.Defaultv4 && (r.wgInterface == nil || !r.wgInterface.Address().HasIPv6()) {
|
||||
if err := addUnreachableRoute(vars.Defaultv6, NetbirdVPNTableID); err != nil {
|
||||
return fmt.Errorf("add blackhole: %w", err)
|
||||
return fmt.Errorf("add v6 blackhole: %w", err)
|
||||
}
|
||||
}
|
||||
if err := addRoute(prefix, Nexthop{netip.Addr{}, intf}, NetbirdVPNTableID); err != nil {
|
||||
@@ -206,10 +209,9 @@ func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error
|
||||
return r.genericRemoveVPNRoute(prefix, intf)
|
||||
}
|
||||
|
||||
// TODO remove this once we have ipv6 support
|
||||
if prefix == vars.Defaultv4 {
|
||||
if prefix == vars.Defaultv4 && (r.wgInterface == nil || !r.wgInterface.Address().HasIPv6()) {
|
||||
if err := removeUnreachableRoute(vars.Defaultv6, NetbirdVPNTableID); err != nil {
|
||||
return fmt.Errorf("remove unreachable route: %w", err)
|
||||
log.Debugf("remove v6 blackhole: %v", err)
|
||||
}
|
||||
}
|
||||
if err := removeRoute(prefix, Nexthop{netip.Addr{}, intf}, NetbirdVPNTableID); err != nil {
|
||||
@@ -762,8 +764,13 @@ func flushRoutes(tableID, family int) error {
|
||||
}
|
||||
|
||||
func EnableIPForwarding() error {
|
||||
_, err := sysctl.Set(ipv4ForwardingPath, 1, false)
|
||||
return err
|
||||
if _, err := sysctl.Set(ipv4ForwardingPath, 1, false); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := sysctl.Set(ipv6ForwardingPath, 1, false); err != nil {
|
||||
log.Warnf("failed to enable IPv6 forwarding: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// entryExists checks if the specified ID or name already exists in the rt_tables file
|
||||
|
||||
@@ -50,10 +50,11 @@ type CustomLogger interface {
|
||||
}
|
||||
|
||||
type selectRoute struct {
|
||||
NetID string
|
||||
Network netip.Prefix
|
||||
Domains domain.List
|
||||
Selected bool
|
||||
NetID string
|
||||
Network netip.Prefix
|
||||
Domains domain.List
|
||||
Selected bool
|
||||
extraNetworks []netip.Prefix
|
||||
}
|
||||
|
||||
func init() {
|
||||
@@ -363,48 +364,60 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) {
|
||||
}
|
||||
|
||||
routeManager := engine.GetRouteManager()
|
||||
routesMap := routeManager.GetClientRoutesWithNetID()
|
||||
if routeManager == nil {
|
||||
return nil, fmt.Errorf("could not get route manager")
|
||||
}
|
||||
routesMap := routeManager.GetClientRoutesWithNetID()
|
||||
routeSelector := routeManager.GetRouteSelector()
|
||||
if routeSelector == nil {
|
||||
return nil, fmt.Errorf("could not get route selector")
|
||||
}
|
||||
|
||||
v6ExitMerged := route.V6ExitMergeSet(routesMap)
|
||||
routes := buildSelectRoutes(routesMap, routeSelector.IsSelected, v6ExitMerged)
|
||||
resolvedDomains := c.recorder.GetResolvedDomainsStates()
|
||||
|
||||
return prepareRouteSelectionDetails(routes, resolvedDomains), nil
|
||||
}
|
||||
|
||||
func buildSelectRoutes(routesMap map[route.NetID][]*route.Route, isSelected func(route.NetID) bool, v6Merged map[route.NetID]struct{}) []*selectRoute {
|
||||
var routes []*selectRoute
|
||||
for id, rt := range routesMap {
|
||||
if len(rt) == 0 {
|
||||
continue
|
||||
}
|
||||
route := &selectRoute{
|
||||
if _, ok := v6Merged[id]; ok {
|
||||
continue
|
||||
}
|
||||
|
||||
r := &selectRoute{
|
||||
NetID: string(id),
|
||||
Network: rt[0].Network,
|
||||
Domains: rt[0].Domains,
|
||||
Selected: routeSelector.IsSelected(id),
|
||||
Selected: isSelected(id),
|
||||
}
|
||||
routes = append(routes, route)
|
||||
|
||||
v6ID := route.NetID(string(id) + route.V6ExitSuffix)
|
||||
if _, ok := v6Merged[v6ID]; ok {
|
||||
r.extraNetworks = []netip.Prefix{routesMap[v6ID][0].Network}
|
||||
}
|
||||
|
||||
routes = append(routes, r)
|
||||
}
|
||||
|
||||
sort.Slice(routes, func(i, j int) bool {
|
||||
iPrefix := routes[i].Network.Bits()
|
||||
jPrefix := routes[j].Network.Bits()
|
||||
|
||||
if iPrefix == jPrefix {
|
||||
iAddr := routes[i].Network.Addr()
|
||||
jAddr := routes[j].Network.Addr()
|
||||
if iAddr == jAddr {
|
||||
return routes[i].NetID < routes[j].NetID
|
||||
}
|
||||
return iAddr.String() < jAddr.String()
|
||||
iBits, jBits := routes[i].Network.Bits(), routes[j].Network.Bits()
|
||||
if iBits != jBits {
|
||||
return iBits < jBits
|
||||
}
|
||||
return iPrefix < jPrefix
|
||||
iAddr, jAddr := routes[i].Network.Addr(), routes[j].Network.Addr()
|
||||
if iAddr != jAddr {
|
||||
return iAddr.Less(jAddr)
|
||||
}
|
||||
return routes[i].NetID < routes[j].NetID
|
||||
})
|
||||
|
||||
resolvedDomains := c.recorder.GetResolvedDomainsStates()
|
||||
|
||||
return prepareRouteSelectionDetails(routes, resolvedDomains), nil
|
||||
|
||||
return routes
|
||||
}
|
||||
|
||||
func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[domain.Domain]peer.ResolvedDomainInfo) *RoutesSelectionDetails {
|
||||
@@ -425,10 +438,15 @@ func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[dom
|
||||
}
|
||||
domainList = append(domainList, domainResp)
|
||||
}
|
||||
rangeStr := r.Network.String()
|
||||
for _, extra := range r.extraNetworks {
|
||||
rangeStr += ", " + extra.String()
|
||||
}
|
||||
|
||||
domainDetails := DomainDetails{items: domainList}
|
||||
routeSelection = append(routeSelection, RoutesSelectionInfo{
|
||||
ID: r.NetID,
|
||||
Network: r.Network.String(),
|
||||
Network: rangeStr,
|
||||
Domains: &domainDetails,
|
||||
Selected: r.Selected,
|
||||
})
|
||||
@@ -456,7 +474,9 @@ func (c *Client) SelectRoute(id string) error {
|
||||
} else {
|
||||
log.Debugf("select route with id: %s", id)
|
||||
routes := toNetIDs([]string{id})
|
||||
if err := routeSelector.SelectRoutes(routes, true, maps.Keys(routeManager.GetClientRoutesWithNetID())); err != nil {
|
||||
routesMap := routeManager.GetClientRoutesWithNetID()
|
||||
routes = route.ExpandV6ExitPairs(routes, routesMap)
|
||||
if err := routeSelector.SelectRoutes(routes, true, maps.Keys(routesMap)); err != nil {
|
||||
log.Debugf("error when selecting routes: %s", err)
|
||||
return fmt.Errorf("select routes: %w", err)
|
||||
}
|
||||
@@ -483,7 +503,9 @@ func (c *Client) DeselectRoute(id string) error {
|
||||
} else {
|
||||
log.Debugf("deselect route with id: %s", id)
|
||||
routes := toNetIDs([]string{id})
|
||||
if err := routeSelector.DeselectRoutes(routes, maps.Keys(routeManager.GetClientRoutesWithNetID())); err != nil {
|
||||
routesMap := routeManager.GetClientRoutesWithNetID()
|
||||
routes = route.ExpandV6ExitPairs(routes, routesMap)
|
||||
if err := routeSelector.DeselectRoutes(routes, maps.Keys(routesMap)); err != nil {
|
||||
log.Debugf("error when deselecting routes: %s", err)
|
||||
return fmt.Errorf("deselect routes: %w", err)
|
||||
}
|
||||
|
||||
@@ -16,10 +16,11 @@ import (
|
||||
)
|
||||
|
||||
type selectRoute struct {
|
||||
NetID route.NetID
|
||||
Network netip.Prefix
|
||||
Domains domain.List
|
||||
Selected bool
|
||||
NetID route.NetID
|
||||
Network netip.Prefix
|
||||
Domains domain.List
|
||||
Selected bool
|
||||
extraNetworks []netip.Prefix
|
||||
}
|
||||
|
||||
// ListNetworks returns a list of all available networks.
|
||||
@@ -44,18 +45,33 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro
|
||||
routesMap := routeMgr.GetClientRoutesWithNetID()
|
||||
routeSelector := routeMgr.GetRouteSelector()
|
||||
|
||||
v6ExitMerged := route.V6ExitMergeSet(routesMap)
|
||||
|
||||
var routes []*selectRoute
|
||||
for id, rt := range routesMap {
|
||||
if len(rt) == 0 {
|
||||
continue
|
||||
}
|
||||
route := &selectRoute{
|
||||
// Skip v6 exit nodes that are merged into their v4 counterpart.
|
||||
if _, ok := v6ExitMerged[id]; ok {
|
||||
continue
|
||||
}
|
||||
|
||||
r := &selectRoute{
|
||||
NetID: id,
|
||||
Network: rt[0].Network,
|
||||
Domains: rt[0].Domains,
|
||||
Selected: routeSelector.IsSelected(id),
|
||||
}
|
||||
routes = append(routes, route)
|
||||
|
||||
// Merge paired v6 exit node prefix into this entry.
|
||||
v6ID := route.NetID(string(id) + route.V6ExitSuffix)
|
||||
if _, ok := v6ExitMerged[v6ID]; ok {
|
||||
v6Prefix := routesMap[v6ID][0].Network
|
||||
r.extraNetworks = []netip.Prefix{v6Prefix}
|
||||
}
|
||||
|
||||
routes = append(routes, r)
|
||||
}
|
||||
|
||||
sort.Slice(routes, func(i, j int) bool {
|
||||
@@ -76,9 +92,13 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro
|
||||
resolvedDomains := s.statusRecorder.GetResolvedDomainsStates()
|
||||
var pbRoutes []*proto.Network
|
||||
for _, route := range routes {
|
||||
rangeStr := route.Network.String()
|
||||
for _, extra := range route.extraNetworks {
|
||||
rangeStr += ", " + extra.String()
|
||||
}
|
||||
pbRoute := &proto.Network{
|
||||
ID: string(route.NetID),
|
||||
Range: route.Network.String(),
|
||||
Range: rangeStr,
|
||||
Domains: route.Domains.ToSafeStringList(),
|
||||
ResolvedIPs: map[string]*proto.IPList{},
|
||||
Selected: route.Selected,
|
||||
@@ -137,7 +157,9 @@ func (s *Server) SelectNetworks(_ context.Context, req *proto.SelectNetworksRequ
|
||||
routeSelector.SelectAllRoutes()
|
||||
} else {
|
||||
routes := toNetIDs(req.GetNetworkIDs())
|
||||
netIdRoutes := maps.Keys(routeManager.GetClientRoutesWithNetID())
|
||||
routesMap := routeManager.GetClientRoutesWithNetID()
|
||||
routes = route.ExpandV6ExitPairs(routes, routesMap)
|
||||
netIdRoutes := maps.Keys(routesMap)
|
||||
if err := routeSelector.SelectRoutes(routes, req.GetAppend(), netIdRoutes); err != nil {
|
||||
return nil, fmt.Errorf("select routes: %w", err)
|
||||
}
|
||||
@@ -183,7 +205,9 @@ func (s *Server) DeselectNetworks(_ context.Context, req *proto.SelectNetworksRe
|
||||
routeSelector.DeselectAllRoutes()
|
||||
} else {
|
||||
routes := toNetIDs(req.GetNetworkIDs())
|
||||
netIdRoutes := maps.Keys(routeManager.GetClientRoutesWithNetID())
|
||||
routesMap := routeManager.GetClientRoutesWithNetID()
|
||||
routes = route.ExpandV6ExitPairs(routes, routesMap)
|
||||
netIdRoutes := maps.Keys(routesMap)
|
||||
if err := routeSelector.DeselectRoutes(routes, netIdRoutes); err != nil {
|
||||
return nil, fmt.Errorf("deselect routes: %w", err)
|
||||
}
|
||||
|
||||
@@ -195,7 +195,7 @@ func getOverlappingNetworks(routes []*proto.Network) []*proto.Network {
|
||||
func getExitNodeNetworks(routes []*proto.Network) []*proto.Network {
|
||||
var filteredRoutes []*proto.Network
|
||||
for _, route := range routes {
|
||||
if route.Range == "0.0.0.0/0" || route.Range == "::/0" {
|
||||
if strings.Contains(route.Range, "0.0.0.0/0") || route.Range == "::/0" {
|
||||
filteredRoutes = append(filteredRoutes, route)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"syscall/js"
|
||||
"time"
|
||||
|
||||
@@ -166,39 +167,58 @@ func createSSHMethod(client *netbird.Client) js.Func {
|
||||
})
|
||||
}
|
||||
|
||||
var jwtToken string
|
||||
if len(args) > 3 && !args[3].IsNull() && !args[3].IsUndefined() {
|
||||
jwtToken = args[3].String()
|
||||
}
|
||||
jwtToken, ipVersion := parseSSHOptions(args)
|
||||
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
sshClient := ssh.NewClient(client)
|
||||
|
||||
if err := sshClient.Connect(host, port, username, jwtToken); err != nil {
|
||||
jsInterface, err := connectSSH(client, host, port, username, jwtToken, ipVersion)
|
||||
if err != nil {
|
||||
reject.Invoke(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := sshClient.StartSession(80, 24); err != nil {
|
||||
if closeErr := sshClient.Close(); closeErr != nil {
|
||||
log.Errorf("Error closing SSH client: %v", closeErr)
|
||||
}
|
||||
reject.Invoke(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
jsInterface := ssh.CreateJSInterface(sshClient)
|
||||
resolve.Invoke(jsInterface)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func performPing(client *netbird.Client, hostname string) {
|
||||
func parseSSHOptions(args []js.Value) (jwtToken string, ipVersion int) {
|
||||
if len(args) > 3 && !args[3].IsNull() && !args[3].IsUndefined() {
|
||||
jwtToken = args[3].String()
|
||||
}
|
||||
if len(args) > 4 {
|
||||
ipVersion = jsIPVersion(args[4])
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func connectSSH(client *netbird.Client, host string, port int, username, jwtToken string, ipVersion int) (js.Value, error) {
|
||||
sshClient := ssh.NewClient(client)
|
||||
|
||||
if err := sshClient.Connect(host, port, username, jwtToken, ipVersion); err != nil {
|
||||
return js.Undefined(), err
|
||||
}
|
||||
|
||||
if err := sshClient.StartSession(80, 24); err != nil {
|
||||
if closeErr := sshClient.Close(); closeErr != nil {
|
||||
log.Errorf("Error closing SSH client: %v", closeErr)
|
||||
}
|
||||
return js.Undefined(), err
|
||||
}
|
||||
|
||||
return ssh.CreateJSInterface(sshClient), nil
|
||||
}
|
||||
|
||||
func performPing(client *netbird.Client, hostname string, ipVersion int) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Default to ping4 to avoid dual-stack ICMP endpoint issues in wireguard-go netstack.
|
||||
network := "ping4"
|
||||
if ipVersion == 6 {
|
||||
network = "ping6"
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
conn, err := client.Dial(ctx, "ping", hostname)
|
||||
conn, err := client.Dial(ctx, network, hostname)
|
||||
if err != nil {
|
||||
js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s failed: %v", hostname, err))
|
||||
return
|
||||
@@ -225,27 +245,39 @@ func performPing(client *netbird.Client, hostname string) {
|
||||
}
|
||||
|
||||
latency := time.Since(start)
|
||||
js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s: %dms", hostname, latency.Milliseconds()))
|
||||
remote := conn.RemoteAddr().String()
|
||||
msg := fmt.Sprintf("Ping to %s: %dms", hostname, latency.Milliseconds())
|
||||
if remote != hostname {
|
||||
msg += fmt.Sprintf(" (via %s)", remote)
|
||||
}
|
||||
js.Global().Get("console").Call("log", msg)
|
||||
}
|
||||
|
||||
func performPingTCP(client *netbird.Client, hostname string, port int) {
|
||||
func performPingTCP(client *netbird.Client, hostname string, port, ipVersion int) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
|
||||
defer cancel()
|
||||
|
||||
network := ipVersionNetwork("tcp", ipVersion)
|
||||
|
||||
address := net.JoinHostPort(hostname, fmt.Sprintf("%d", port))
|
||||
start := time.Now()
|
||||
conn, err := client.Dial(ctx, "tcp", address)
|
||||
conn, err := client.Dial(ctx, network, address)
|
||||
if err != nil {
|
||||
js.Global().Get("console").Call("log", fmt.Sprintf("TCP ping to %s failed: %v", address, err))
|
||||
return
|
||||
}
|
||||
latency := time.Since(start)
|
||||
|
||||
remote := conn.RemoteAddr().String()
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Debugf("failed to close TCP connection: %v", err)
|
||||
}
|
||||
|
||||
js.Global().Get("console").Call("log", fmt.Sprintf("TCP ping to %s succeeded: %dms", address, latency.Milliseconds()))
|
||||
msg := fmt.Sprintf("TCP ping to %s succeeded: %dms", address, latency.Milliseconds())
|
||||
if remote != address {
|
||||
msg += fmt.Sprintf(" (via %s)", remote)
|
||||
}
|
||||
js.Global().Get("console").Call("log", msg)
|
||||
}
|
||||
|
||||
// createPingMethod creates the ping method
|
||||
@@ -262,8 +294,12 @@ func createPingMethod(client *netbird.Client) js.Func {
|
||||
}
|
||||
|
||||
hostname := args[0].String()
|
||||
var ipVersion int
|
||||
if len(args) > 1 {
|
||||
ipVersion = jsIPVersion(args[1])
|
||||
}
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
performPing(client, hostname)
|
||||
performPing(client, hostname, ipVersion)
|
||||
resolve.Invoke(js.Undefined())
|
||||
})
|
||||
})
|
||||
@@ -290,8 +326,12 @@ func createPingTCPMethod(client *netbird.Client) js.Func {
|
||||
|
||||
hostname := args[0].String()
|
||||
port := args[1].Int()
|
||||
var ipVersion int
|
||||
if len(args) > 2 {
|
||||
ipVersion = jsIPVersion(args[2])
|
||||
}
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
performPingTCP(client, hostname, port)
|
||||
performPingTCP(client, hostname, port, ipVersion)
|
||||
resolve.Invoke(js.Undefined())
|
||||
})
|
||||
})
|
||||
@@ -464,6 +504,31 @@ func createSetLogLevelMethod(client *netbird.Client) js.Func {
|
||||
})
|
||||
}
|
||||
|
||||
// ipVersionNetwork appends "4" or "6" to a base network string (e.g. "tcp" -> "tcp4").
|
||||
func ipVersionNetwork(base string, ipVersion int) string {
|
||||
switch ipVersion {
|
||||
case 4:
|
||||
return base + "4"
|
||||
case 6:
|
||||
return base + "6"
|
||||
default:
|
||||
return base
|
||||
}
|
||||
}
|
||||
|
||||
// jsIPVersion extracts an IP version (4 or 6) from a JS string or number.
|
||||
func jsIPVersion(v js.Value) int {
|
||||
switch v.Type() {
|
||||
case js.TypeNumber:
|
||||
return v.Int()
|
||||
case js.TypeString:
|
||||
n, _ := strconv.Atoi(v.String())
|
||||
return n
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
// createPromise is a helper to create JavaScript promises
|
||||
func createPromise(handler func(resolve, reject js.Value)) js.Value {
|
||||
return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any {
|
||||
|
||||
@@ -46,8 +46,9 @@ func NewClient(nbClient *netbird.Client) *Client {
|
||||
}
|
||||
}
|
||||
|
||||
// Connect establishes an SSH connection through NetBird network
|
||||
func (c *Client) Connect(host string, port int, username, jwtToken string) error {
|
||||
// Connect establishes an SSH connection through NetBird network.
|
||||
// ipVersion may be 4, 6, or 0 for automatic selection.
|
||||
func (c *Client) Connect(host string, port int, username, jwtToken string, ipVersion int) error {
|
||||
addr := net.JoinHostPort(host, fmt.Sprintf("%d", port))
|
||||
logrus.Infof("SSH: Connecting to %s as %s", addr, username)
|
||||
|
||||
@@ -63,10 +64,18 @@ func (c *Client) Connect(host string, port int, username, jwtToken string) error
|
||||
Timeout: sshDialTimeout,
|
||||
}
|
||||
|
||||
network := "tcp"
|
||||
switch ipVersion {
|
||||
case 4:
|
||||
network = "tcp4"
|
||||
case 6:
|
||||
network = "tcp6"
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), sshDialTimeout)
|
||||
defer cancel()
|
||||
|
||||
conn, err := c.nbClient.Dial(ctx, "tcp", addr)
|
||||
conn, err := c.nbClient.Dial(ctx, network, addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dial %s: %w", addr, err)
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package cmd
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
@@ -57,7 +58,11 @@ var debugSyncCmd = &cobra.Command{
|
||||
SilenceUsage: true,
|
||||
}
|
||||
|
||||
var pingTimeout string
|
||||
var (
|
||||
pingTimeout time.Duration
|
||||
pingIPv4 bool
|
||||
pingIPv6 bool
|
||||
)
|
||||
|
||||
var debugPingCmd = &cobra.Command{
|
||||
Use: "ping <account-id> <host> [port]",
|
||||
@@ -108,7 +113,10 @@ func init() {
|
||||
debugStatusCmd.Flags().StringVar(&statusFilterByStatus, "filter-by-status", "", "Filter by status (idle|connecting|connected)")
|
||||
debugStatusCmd.Flags().StringVar(&statusFilterByConnectionType, "filter-by-connection-type", "", "Filter by connection type (P2P|Relayed)")
|
||||
|
||||
debugPingCmd.Flags().StringVar(&pingTimeout, "timeout", "", "Ping timeout (e.g., 10s)")
|
||||
debugPingCmd.Flags().DurationVar(&pingTimeout, "timeout", 0, "Ping timeout (e.g., 10s)")
|
||||
debugPingCmd.Flags().BoolVarP(&pingIPv4, "ipv4", "4", false, "Force IPv4")
|
||||
debugPingCmd.Flags().BoolVarP(&pingIPv6, "ipv6", "6", false, "Force IPv6")
|
||||
debugPingCmd.MarkFlagsMutuallyExclusive("ipv4", "ipv6")
|
||||
|
||||
debugCmd.AddCommand(debugHealthCmd)
|
||||
debugCmd.AddCommand(debugClientsCmd)
|
||||
@@ -157,7 +165,14 @@ func runDebugPing(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
port = p
|
||||
}
|
||||
return getDebugClient(cmd).PingTCP(cmd.Context(), args[0], args[1], port, pingTimeout)
|
||||
var ipVersion string
|
||||
switch {
|
||||
case pingIPv4:
|
||||
ipVersion = "4"
|
||||
case pingIPv6:
|
||||
ipVersion = "6"
|
||||
}
|
||||
return getDebugClient(cmd).PingTCP(cmd.Context(), args[0], args[1], port, pingTimeout, ipVersion)
|
||||
}
|
||||
|
||||
func runDebugLogLevel(cmd *cobra.Command, args []string) error {
|
||||
|
||||
@@ -6,10 +6,12 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
)
|
||||
|
||||
// StatusFilters contains filter options for status queries.
|
||||
@@ -230,12 +232,16 @@ func (c *Client) ClientSyncResponse(ctx context.Context, accountID string) error
|
||||
}
|
||||
|
||||
// PingTCP performs a TCP ping through a client.
|
||||
func (c *Client) PingTCP(ctx context.Context, accountID, host string, port int, timeout string) error {
|
||||
// ipVersion may be "4", "6", or "" for automatic.
|
||||
func (c *Client) PingTCP(ctx context.Context, accountID, host string, port int, timeout time.Duration, ipVersion string) error {
|
||||
params := url.Values{}
|
||||
params.Set("host", host)
|
||||
params.Set("port", fmt.Sprintf("%d", port))
|
||||
if timeout != "" {
|
||||
params.Set("timeout", timeout)
|
||||
if timeout > 0 {
|
||||
params.Set("timeout", timeout.String())
|
||||
}
|
||||
if ipVersion != "" {
|
||||
params.Set("ip_version", ipVersion)
|
||||
}
|
||||
|
||||
path := fmt.Sprintf("/debug/clients/%s/pingtcp?%s", url.PathEscape(accountID), params.Encode())
|
||||
@@ -244,11 +250,17 @@ func (c *Client) PingTCP(ctx context.Context, accountID, host string, port int,
|
||||
|
||||
func (c *Client) printPingResult(data map[string]any) {
|
||||
success, _ := data["success"].(bool)
|
||||
host := net.JoinHostPort(fmt.Sprint(data["host"]), fmt.Sprint(data["port"]))
|
||||
if success {
|
||||
_, _ = fmt.Fprintf(c.out, "Success: %v:%v\n", data["host"], data["port"])
|
||||
remote, _ := data["remote"].(string)
|
||||
if remote != "" && remote != host {
|
||||
_, _ = fmt.Fprintf(c.out, "Success: %s (via %s)\n", host, remote)
|
||||
} else {
|
||||
_, _ = fmt.Fprintf(c.out, "Success: %s\n", host)
|
||||
}
|
||||
_, _ = fmt.Fprintf(c.out, "Latency: %v\n", data["latency"])
|
||||
} else {
|
||||
_, _ = fmt.Fprintf(c.out, "Failed: %v:%v\n", data["host"], data["port"])
|
||||
_, _ = fmt.Fprintf(c.out, "Failed: %s\n", host)
|
||||
c.printError(data)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"fmt"
|
||||
"html/template"
|
||||
"maps"
|
||||
"net"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strconv"
|
||||
@@ -525,13 +526,18 @@ func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountI
|
||||
}
|
||||
}
|
||||
|
||||
network := "tcp"
|
||||
if v := r.URL.Query().Get("ip_version"); v == "4" || v == "6" {
|
||||
network += v
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), timeout)
|
||||
defer cancel()
|
||||
|
||||
address := fmt.Sprintf("%s:%d", host, port)
|
||||
address := net.JoinHostPort(host, strconv.Itoa(port))
|
||||
start := time.Now()
|
||||
|
||||
conn, err := client.Dial(ctx, "tcp", address)
|
||||
conn, err := client.Dial(ctx, network, address)
|
||||
if err != nil {
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
"success": false,
|
||||
@@ -541,18 +547,22 @@ func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountI
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
remote := conn.RemoteAddr().String()
|
||||
if err := conn.Close(); err != nil {
|
||||
h.logger.Debugf("close tcp ping connection: %v", err)
|
||||
}
|
||||
|
||||
latency := time.Since(start)
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
resp := map[string]interface{}{
|
||||
"success": true,
|
||||
"host": host,
|
||||
"port": port,
|
||||
"remote": remote,
|
||||
"latency_ms": latency.Milliseconds(),
|
||||
"latency": formatDuration(latency),
|
||||
})
|
||||
}
|
||||
h.writeJSON(w, resp)
|
||||
}
|
||||
|
||||
func (h *Handler) handleLogLevel(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
|
||||
|
||||
@@ -20,6 +20,9 @@ const (
|
||||
MaxMetric = 9999
|
||||
// MaxNetIDChar Max Network Identifier
|
||||
MaxNetIDChar = 40
|
||||
|
||||
// V6ExitSuffix is appended to a v4 exit node NetID to form its v6 counterpart.
|
||||
V6ExitSuffix = "-v6"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -215,3 +218,61 @@ func ParseNetwork(networkString string) (NetworkType, netip.Prefix, error) {
|
||||
|
||||
return IPv4Network, masked, nil
|
||||
}
|
||||
|
||||
var (
|
||||
v4Default = netip.PrefixFrom(netip.IPv4Unspecified(), 0)
|
||||
v6Default = netip.PrefixFrom(netip.IPv6Unspecified(), 0)
|
||||
)
|
||||
|
||||
// IsV4DefaultRoute reports whether p is the IPv4 default route (0.0.0.0/0).
|
||||
func IsV4DefaultRoute(p netip.Prefix) bool { return p == v4Default }
|
||||
|
||||
// IsV6DefaultRoute reports whether p is the IPv6 default route (::/0).
|
||||
func IsV6DefaultRoute(p netip.Prefix) bool { return p == v6Default }
|
||||
|
||||
// ExpandV6ExitPairs appends the paired "-v6" exit node NetID for any v4 exit
|
||||
// node (0.0.0.0/0) in ids that has a matching v6 counterpart (::/0) in routesMap.
|
||||
// It modifies and returns the input slice.
|
||||
func ExpandV6ExitPairs(ids []NetID, routesMap map[NetID][]*Route) []NetID {
|
||||
for _, id := range ids {
|
||||
rt, ok := routesMap[id]
|
||||
if !ok || len(rt) == 0 || !IsV4DefaultRoute(rt[0].Network) {
|
||||
continue
|
||||
}
|
||||
v6ID := NetID(string(id) + V6ExitSuffix)
|
||||
if v6Rt, ok := routesMap[v6ID]; ok && len(v6Rt) > 0 && IsV6DefaultRoute(v6Rt[0].Network) {
|
||||
if !slices.Contains(ids, v6ID) {
|
||||
ids = append(ids, v6ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// V6ExitMergeSet scans routesMap and returns the set of v6 exit node NetIDs
|
||||
// that should be hidden from the UI because they are paired with a v4 exit node.
|
||||
// A v6 ID is paired when it has suffix "-v6", its route is ::/0, and the base
|
||||
// name (without "-v6") exists with route 0.0.0.0/0.
|
||||
func V6ExitMergeSet(routesMap map[NetID][]*Route) map[NetID]struct{} {
|
||||
merged := make(map[NetID]struct{})
|
||||
for id, rt := range routesMap {
|
||||
if len(rt) == 0 {
|
||||
continue
|
||||
}
|
||||
name := string(id)
|
||||
if !IsV6DefaultRoute(rt[0].Network) || !strings.HasSuffix(name, V6ExitSuffix) {
|
||||
continue
|
||||
}
|
||||
baseName := NetID(strings.TrimSuffix(name, V6ExitSuffix))
|
||||
if baseRt, ok := routesMap[baseName]; ok && len(baseRt) > 0 && IsV4DefaultRoute(baseRt[0].Network) {
|
||||
merged[id] = struct{}{}
|
||||
}
|
||||
}
|
||||
return merged
|
||||
}
|
||||
|
||||
// HasV6ExitPair reports whether id has a paired v6 exit node in the merge set.
|
||||
func HasV6ExitPair(id NetID, v6Merged map[NetID]struct{}) bool {
|
||||
_, ok := v6Merged[NetID(string(id)+"-v6")]
|
||||
return ok
|
||||
}
|
||||
|
||||
108
route/route_test.go
Normal file
108
route/route_test.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package route
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestExpandV6ExitPairs(t *testing.T) {
|
||||
v4ExitRoute := &Route{Network: netip.MustParsePrefix("0.0.0.0/0")}
|
||||
v6ExitRoute := &Route{Network: netip.MustParsePrefix("::/0")}
|
||||
regularRoute := &Route{Network: netip.MustParsePrefix("10.0.0.0/8")}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ids []NetID
|
||||
routesMap map[NetID][]*Route
|
||||
expected []NetID
|
||||
}{
|
||||
{
|
||||
name: "v4 exit node with matching v6 pair",
|
||||
ids: []NetID{"exit-node"},
|
||||
routesMap: map[NetID][]*Route{
|
||||
"exit-node": {v4ExitRoute},
|
||||
"exit-node-v6": {v6ExitRoute},
|
||||
},
|
||||
expected: []NetID{"exit-node", "exit-node-v6"},
|
||||
},
|
||||
{
|
||||
name: "v4 exit node without v6 pair",
|
||||
ids: []NetID{"exit-node"},
|
||||
routesMap: map[NetID][]*Route{
|
||||
"exit-node": {v4ExitRoute},
|
||||
},
|
||||
expected: []NetID{"exit-node"},
|
||||
},
|
||||
{
|
||||
name: "regular route is not expanded",
|
||||
ids: []NetID{"office"},
|
||||
routesMap: map[NetID][]*Route{
|
||||
"office": {regularRoute},
|
||||
"office-v6": {v6ExitRoute},
|
||||
},
|
||||
expected: []NetID{"office"},
|
||||
},
|
||||
{
|
||||
name: "v6 already included is not duplicated",
|
||||
ids: []NetID{"exit-node", "exit-node-v6"},
|
||||
routesMap: map[NetID][]*Route{
|
||||
"exit-node": {v4ExitRoute},
|
||||
"exit-node-v6": {v6ExitRoute},
|
||||
},
|
||||
expected: []NetID{"exit-node", "exit-node-v6"},
|
||||
},
|
||||
{
|
||||
name: "multiple exit nodes expanded independently",
|
||||
ids: []NetID{"exit-a", "exit-b"},
|
||||
routesMap: map[NetID][]*Route{
|
||||
"exit-a": {v4ExitRoute},
|
||||
"exit-a-v6": {v6ExitRoute},
|
||||
"exit-b": {v4ExitRoute},
|
||||
"exit-b-v6": {v6ExitRoute},
|
||||
},
|
||||
expected: []NetID{"exit-a", "exit-b", "exit-a-v6", "exit-b-v6"},
|
||||
},
|
||||
{
|
||||
name: "v6 suffix but not exit node network",
|
||||
ids: []NetID{"office"},
|
||||
routesMap: map[NetID][]*Route{
|
||||
"office": {regularRoute},
|
||||
"office-v6": {regularRoute},
|
||||
},
|
||||
expected: []NetID{"office"},
|
||||
},
|
||||
{
|
||||
name: "user-chosen name for exit node with v6 pair",
|
||||
ids: []NetID{"my-exit"},
|
||||
routesMap: map[NetID][]*Route{
|
||||
"my-exit": {v4ExitRoute},
|
||||
"my-exit-v6": {v6ExitRoute},
|
||||
},
|
||||
expected: []NetID{"my-exit", "my-exit-v6"},
|
||||
},
|
||||
{
|
||||
name: "real-world management-generated IDs",
|
||||
ids: []NetID{"0.0.0.0/0"},
|
||||
routesMap: map[NetID][]*Route{
|
||||
"0.0.0.0/0": {v4ExitRoute},
|
||||
"0.0.0.0/0-v6": {v6ExitRoute},
|
||||
},
|
||||
expected: []NetID{"0.0.0.0/0", "0.0.0.0/0-v6"},
|
||||
},
|
||||
{
|
||||
name: "empty input",
|
||||
ids: []NetID{},
|
||||
routesMap: map[NetID][]*Route{},
|
||||
expected: []NetID{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ExpandV6ExitPairs(tt.ids, tt.routesMap)
|
||||
assert.ElementsMatch(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -46,7 +46,7 @@ func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) {
|
||||
InitialPacketSize: nbRelay.QUICInitialPacketSize,
|
||||
}
|
||||
|
||||
udpConn, err := nbnet.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
udpConn, err := nbnet.ListenUDP("udp", &net.UDPAddr{Port: 0})
|
||||
if err != nil {
|
||||
log.Errorf("failed to listen on UDP: %s", err)
|
||||
return nil, err
|
||||
|
||||
Reference in New Issue
Block a user